LLM Architecture
Mixture of Experts β Scaling LLMs Without Scaling Compute
MoE models conditionally route inputs to a subset of experts, achieving better performance per FLOP than dense models.
- Sparse Routing β A gating network selects top-k experts per input, activating only a fraction of parameters
- Load Balancing β Specialized loss prevents expert collapse and ensures uniform utilization
- Mixtral 8x7B β Uses 8 experts with only 2 active, matching LLaMA-2 70B quality at 6x faster inference
"Not every input needs every parameter β learning to route different inputs to different experts is the key insight."
Mixture of Experts (MoE)
Mixture of Experts is an architecture that conditionally routes inputs to a subset of parameters, enabling models to scale total parameters while keeping computation fixed. This achieves better performance per FLOP than dense models.
The MoE Intuition
An architecture where multiple "expert" neural networks process different inputs, with a learned "gating" network determining which experts process each input. Only a subset of experts are activated per input, enabling sparse computation.
The key insight: not every input needs every parameter. By learning to route different inputs to different experts, MoE models achieve better performance with less computation per token.
Gating Function
The gating function determines expert selection:
Gating Network
Here,
- =input token representation
- =learnable gating weight matrix
- =learnable gating bias
- =probability distribution over experts
Here,
- =total number of experts
- =gating weight for expert i
- =output of expert i
- =final output
In practice, only the top-k experts are activated:
Top-k Routing
Here,
- =set of k experts with highest gating scores
- =number of experts to activate (typically 1 or 2)
Load Balancing
Without explicit balancing, the gating network may collapse to using only a few experts.
Without load balancing loss, the gating network tends to converge to selecting the same experts for all inputs, resulting in underutilization of the full model capacity. This is known as the "rich get richer" phenomenon.
Load Balancing Loss
Here,
- =number of experts
- =fraction of tokens routed to expert i
- =average gating probability for expert i
- =balancing coefficient (typically 0.01-0.1)
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_experts: int = 8,
top_k: int = 2,
balance_loss_coeff: float = 0.01
):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
self.balance_loss_coeff = balance_loss_coeff
# Expert networks
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_experts)
])
# Gating network
self.gate = nn.Linear(input_dim, num_experts, bias=False)
def compute_load_balance_loss(self, gate_probs: torch.Tensor) -> torch.Tensor:
"""Compute load balancing loss."""
# f_i: fraction of tokens routed to each expert
# In practice, this is computed per-batch
f = gate_probs.mean(dim=0) # Average probability per expert
# P_i: average gating probability for each expert
P = gate_probs.mean(dim=0)
# Balance loss
balance_loss = self.num_experts * (f * P).sum()
return balance_loss
def forward(self, x: torch.Tensor):
batch_size, seq_len, input_dim = x.shape
x_flat = x.view(-1, input_dim)
# Compute gating scores
gate_logits = self.gate(x_flat) # (batch*seq, num_experts)
gate_probs = F.softmax(gate_logits, dim=-1)
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Compute load balance loss
balance_loss = self.compute_load_balance_loss(gate_probs)
# Route tokens to experts
output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = top_k_indices[:, k] # (batch*seq,)
expert_weights = top_k_probs[:, k] # (batch*seq,)
for i in range(self.num_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[i](expert_input)
output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
output = output.view(batch_size, seq_len, -1)
return output, balance_loss
Mixtral Architecture
Mixtral (by Mistral AI) is a prominent MoE model with 8 experts and top-2 routing:
Mixtral 8x7B has 46.7B total parameters but only uses ~12.9B parameters per forward pass (2 out of 8 experts active). This achieves performance comparable to LLaMA-2 70B while being 6x faster at inference.
Mixtral Architecture Details
| Component | Specification |
|---|---|
| Total Parameters | 46.7B |
| Active Parameters | ~12.9B |
| Experts | 8 per layer |
| Top-k | 2 |
| Expert FFN Size | 14336 |
| Hidden Size | 4096 |
| Layers | 32 |
| Attention Heads | 32 |
class MixtralBlock(nn.Module):
def __init__(self, dim: int = 4096, num_experts: int = 8, top_k: int = 2):
super().__init__()
self.attention = nn.MultiheadAttention(dim, 32, batch_first=True)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# MoE FFN
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(dim, dim * 4),
nn.SiLU(),
nn.Linear(dim * 4, dim)
) for _ in range(num_experts)
])
self.gate = nn.Linear(dim, num_experts, bias=False)
self.top_k = top_k
def forward(self, x: torch.Tensor):
# Self-attention with residual
residual = x
x = self.norm1(x)
x, _ = self.attention(x, x, x)
x = residual + x
# MoE FFN with residual
residual = x
x = self.norm2(x)
# Gating
gate_scores = F.softmax(self.gate(x), dim=-1)
top_k_probs, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Expert computation
output = torch.zeros_like(x)
for k in range(self.top_k):
for i in range(len(self.experts)):
mask = (top_k_indices[:, :, k] == i)
if mask.any():
expert_input = x[mask]
expert_output = self.experts[i](expert_input)
output[mask] += top_k_probs[:, :, k].unsqueeze(-1)[mask] * expert_output
x = residual + output
return x
DeepSeek-MoE
DeepSeek-MoE introduces finer-grained expert specialization:
DeepSeek-MoE uses more experts (up to 160) with finer granularity and implements "shared experts" that are always activated, combined with "routed experts" that are selected by the gating network.
DeepSeek-MoE Design Principles
- Finer-grained experts: More, smaller experts for better specialization
- Shared experts: Always-active experts for common knowledge
- Routed experts: Conditionally activated for specialized knowledge
class DeepSeekMoELayer(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_shared_experts: int = 2,
num_routed_experts: int = 64,
top_k: int = 6
):
super().__init__()
self.num_shared_experts = num_shared_experts
self.num_routed_experts = num_routed_experts
self.top_k = top_k
# Shared experts (always active)
self.shared_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_shared_experts)
])
# Routed experts (conditionally active)
self.routed_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
) for _ in range(num_routed_experts)
])
# Gating network
self.gate = nn.Linear(input_dim, num_routed_experts, bias=False)
def forward(self, x: torch.Tensor):
batch_size, seq_len, input_dim = x.shape
x_flat = x.view(-1, input_dim)
# Compute shared expert output
shared_output = torch.zeros_like(x_flat)
for expert in self.shared_experts:
shared_output += expert(x_flat)
shared_output /= self.num_shared_experts
# Compute gated expert output
gate_logits = self.gate(x_flat)
gate_probs = F.softmax(gate_logits, dim=-1)
top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
routed_output = torch.zeros_like(x_flat)
for k in range(self.top_k):
expert_indices = top_k_indices[:, k]
expert_weights = top_k_probs[:, k]
for i in range(self.num_routed_experts):
mask = (expert_indices == i)
if mask.any():
expert_input = x_flat[mask]
expert_output = self.routed_experts[i](expert_input)
routed_output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
# Combine shared and routed outputs
output = shared_output + routed_output
return output.view(batch_size, seq_len, -1)
Load Balancing Analysis
Expert Utilization Variance
Here,
- =fraction of tokens routed to expert i
- =ideal uniform utilization
- =number of experts
A well-balanced MoE model has low variance in expert utilization:
| Model | Experts | Top-k | Utilization Variance |
|---|---|---|---|
| Mixtral 8x7B | 8 | 2 | 0.002 |
| Switch Transformer | 128 | 1 | 0.015 |
| GShard | 2048 | 2 | 0.008 |
Practical: Deploying MoE Models
class MoEInferenceEngine:
def __init__(self, model_name: str):
self.model = self.load_moe_model(model_name)
self.expert_cache = {}
def load_moe_model(self, model_name: str):
"""Load MoE model with expert offloading support."""
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
return model
def generate_with_expert_stats(self, prompt: str, max_tokens: int = 100):
"""Generate text while tracking expert utilization."""
expert_counts = {i: 0 for i in range(8)} # Assuming 8 experts
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.model.device)
for _ in range(max_tokens):
with torch.no_grad():
outputs = self.model(input_ids, output_router_logits=True)
# Track expert usage
if hasattr(outputs, 'router_logits'):
router_logits = outputs.router_logits[-1]
expert_indices = torch.argmax(router_logits, dim=-1)
for idx in expert_indices:
expert_counts[idx.item()] += 1
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
return {
"text": self.tokenizer.decode(input_ids[0]),
"expert_utilization": expert_counts
}
When deploying MoE models, ensure all experts fit in memory or use expert offloading. The total parameter count can be 4-8x larger than dense models of similar quality, but inference is faster because only a subset of experts are active per token.
Summary
- MoE models route inputs to a subset of expert networks via a gating function
- Top-k routing activates only k experts per input, enabling sparse computation
- Load balancing loss prevents expert collapse: L_balance = Ξ± Β· N Β· Ξ£ f_i Β· P_i
- Mixtral 8x7B achieves LLaMA-2 70B quality with 6x faster inference
- DeepSeek-MoE uses shared + routed experts for better specialization
- MoE models require 4-8x more memory than dense models but are faster at inference
- Expert utilization variance measures load balancing quality
Practice Exercises
-
Gating Analysis: Visualize the gating network's expert selection patterns. Do different experts specialize on different types of inputs?
-
Load Balancing: Train an MoE model with and without load balancing loss. Compare expert utilization.
-
Expert Specialization: Analyze what each expert learns. Do experts specialize on different linguistic phenomena?
-
MoE vs Dense: Compare an MoE model with a dense model of equal computational budget. Which achieves better performance?
-
Deployment Optimization: Implement expert offloading to reduce memory usage. Measure the impact on inference speed.
What to Learn Next
-> Multimodal LLMs Combining MoE with multimodal architectures for efficient scaling.
-> Long Context and Context Window MoE models must handle long contexts efficiently across experts.
-> LLM Architecture Deep Dive Understanding the transformer blocks that MoE replaces with expert routing.
-> Scaling Laws and Chinchilla How MoE relates to compute-optimal scaling strategies.
-> Building Production LLM Applications Deploying MoE models requires handling large parameter counts in memory.
-> Open Source LLM Ecosystem Mixtral and other open-source MoE models available for deployment.
Previous: 18 - Multimodal LLMs <- | Next: 20 - LLM Agent Frameworks ->