LIama 3+Mamba join forces! Distilled to linear RNN, the inference speed is increased by 1.6 times
Cressy from Aofei Temple
Quantum Bit | Public Account QbitAI
Distilling Llama 3 to Mamba can increase inference speed by up to 1.6 times!
And the performance is not reduced, and it is even better than the original model.
This is a new work from Together AI, which combines the Transformer and Mamba models through distillation and also involves an inference acceleration algorithm for the hybrid model.
Tri Dao, the great man who proposed the Mamba architecture and the author of FlashAttention, also participated in this project.
The founder and CEO of Together AI said that the combination of Transformer and Mamba is a major development direction for future large models.
Distilling Transformer into Mamba
Before distillation officially begins, it is necessary to initialize the Transformer to linear RNN.
The authors observed that there are certain similarities between the Transformer's attention mechanism and the computation of RNN.
Therefore, the Transformer's attention can be linearized to establish a connection between the two.
Using this correspondence, the parameters of the pre-trained Transformer model can be copied to the Mamba model.
After completing the parameter initialization, the authors adopted a three-stage distillation process to further improve the performance of the Mamba model, enabling it to better learn Transformer knowledge.
The first stage is based on pseudo-label distillation - using a pre-trained Transformer teacher model to generate pseudo-labels on unlabeled data, and then letting the Mamba student model be trained on these pseudo-labels.
The loss function of this process combines KL divergence loss and cross entropy loss, which are used to imitate the output distribution of the teacher model and the fitting of pseudo labels, respectively.
The second stage is supervised fine-tuning on the instruction dataset, using a labeled instruction dataset such as OpenHermes 2.5 for training.
The final stage is to use human feedback data to optimize through a reward-based model.
The authors collected human feedback data on the model’s output, then built a reward model based on it and used RL algorithms (such as PPO) to optimize the model’s performance under this reward model.
On eight 80G A100 GPUs, the entire distillation process for each hybrid model takes less than five days.
Through the above distillation process, the author obtained the Transformer-Mamba hybrid model, and then proposed the Speculative Decoding algorithm to accelerate the reasoning process.
Hybrid model inference acceleration algorithm
The basic idea of the speculative decoding algorithm is to use a lightweight Draft model to predict multiple tokens, and then use the verification model (Verifier) to verify these predictions.
This can significantly improve the parallelism of decoding and speed up the generation process.
The Draft model is usually a small Transformer that predicts the next K tokens based on the current context.
For the predicted K tokens, the Transformer layer can directly process these K tokens in parallel and calculate their hidden states;
The Mamba layer needs to process each token in sequence, first calculating the hidden state of the current token and comparing it with the previous hidden state.
-
If the current token is correct, it is added to the accepted sequence and the latest hidden state is updated (but the intermediate states are not saved).
-
If the current token is wrong, stop processing subsequent tokens and roll back the latest hidden state to the last accepted token.
If all K tokens in the sequence are accepted, they are added to the output sequence and continue predicting the next set of tokens.
If a token is rejected, the prediction sequence is truncated from the first rejected token and returns to the initial step to re-predict from that position.
Llama 3 inference speed increased by 1.6 times
Test results show that the hybrid model performs comparable to or better than Llama-3 in single-talk (AlpacaEval) and multi-turn (MT-Bench) chat dialogue tasks.
We also tested the performance of models with different mixing ratios, and found that the model with a 1:1 mixing ratio performed best.
In the zero-shot general NLP task evaluation, the hybrid model outperforms the RNN model of the same size on average.
On the few-shot OpenLLM Leaderboard, the hybrid model performs on par with the best open source RNN model and outperforms the corresponding Instruct model on the GSM8K and CRUX tasks.
In addition to model performance, the authors also tested the acceleration effect brought by the speculative decoding algorithm.
The first test was on the pure Mamba model. The results showed that on the 2.8B and 7B models, the inference speed increased by 1.7-2.6 times compared with the original decoding method.
Furthermore, the authors tested the distilled Zephyr and Llama hybrid models and found that the inference speed of the Zephyr hybrid model was increased by more than 1.8 times, and the Llama hybrid model was accelerated by about 1.6 times.
Paper address:
https://www.together.ai/blog/the-mamba-in-the-llama-distilling-and-accelerating-hybrid-models
-over-
QuantumBit's annual AI theme planning Now soliciting!
Welcome to submit your contributions to the special topic 1,001 AI applications , 365 AI implementation solutions
Or share with us the AI products you are looking for or the new AI trends you have discovered
Click here ???? Follow me, remember to mark the star~