The biggest lesson we’ve learnt in the past few years is that scaling is the key for model performance. Scaling without sacrificing inference speed makes Mixture of Export (MoE) models very appealing for large language model training. Today we’ll closely examine the Mixtral model to study MoE models.

Introduction

Most of today’s MoE model are following an architecture that is similar to the Switch Transformer [1] which is shown below:

MoE model Figure 1. Switch MoE Model

In these models, the sparsity lies in the feed-forward layer for the Transformer block. Switch is using one expert out of 4 experts for each token. For models like Mixtral/Grok both are using two experts out of 8 experts. Router dynamically chooses experts for each token. Can we route different samples to different experts? The answer is yes, however, coarse-grained design (giving less flexibility for model to learn the pattern) usually leads to worse performance.

Dynamic Routing

There are a couple of ways to design router to route tokens to each expert. Ideally, we want to design a router that could make each expert specialize one of domains/tasks. Obviously there is no straightforward way to achieve this. In Mixtral, softmax-topk based gating mechanism is used to select experts.

For any input $x$ of dimension $[\text{sequence\_len}, \text{dim}]$, it multiplies with a gate matrix $W$ of shape $[\text{dim}, 8]$, then we get a router representation of shape $[\text{sequence\_len}, 8]$. It selects top k (num of experts per token) logits which then go through softmax op to normalize to get k experts weights. In Mixtral, the k is equal to 2.

MoE training is prone to instability issues because it has extra exponential functions. To deal with mixed precision roundoff errors, people apply z-loss to logits before sending them to router.

$$ L_z = \frac{1}{B} \sum_{i=1}^{B} (log\sum_{j=1}^{N}e^{x_j^{x_i}})^2 $$

In python,

z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff

Ref [7] uses the similar kind of approach to stabilize the training. $$ L_{max_z} = 2 e^{-4} * z^2 $$ where $z$ is the max logit value.

Load Balancing

For dynamic routing, which token is routed to which expert is unknown upfront, so there exists the load balancing issue. Common solution is to add an load balancing loss.

Training

Directly training MoE could be challenging due to low efficiency. One popular approach is called sparse upcycling to use pretrained dense model to initialize the sparse model and continue to train for certain steps.

Sparse upcycling Figure 2. Training MoE Model

(To be continued)

Public Implementations

References

  1. Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer
  2. Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
  3. BASE Layers: Simplifying Training of Large, Sparse Models
  4. Mixtral of Experts
  5. Sparse Upcycling: Training Mixture-of-Experts from Dense Checkpoints
  6. Beyond Distillation: Task-level Mixture-of-Experts for Efficient Inference
  7. Baichuan 2: Open Large-scale Language Models