πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

Mixture of Experts

ArchitecturesMoE🟒 Free Lesson

Advertisement

LLM Architecture

Mixture of Experts β€” Scaling LLMs Without Scaling Compute

MoE models conditionally route inputs to a subset of experts, achieving better performance per FLOP than dense models.

  • Sparse Routing β€” A gating network selects top-k experts per input, activating only a fraction of parameters
  • Load Balancing β€” Specialized loss prevents expert collapse and ensures uniform utilization
  • Mixtral 8x7B β€” Uses 8 experts with only 2 active, matching LLaMA-2 70B quality at 6x faster inference

"Not every input needs every parameter β€” learning to route different inputs to different experts is the key insight."

Mixture of Experts (MoE)

Mixture of Experts is an architecture that conditionally routes inputs to a subset of parameters, enabling models to scale total parameters while keeping computation fixed. This achieves better performance per FLOP than dense models.

The MoE Intuition

An architecture where multiple "expert" neural networks process different inputs, with a learned "gating" network determining which experts process each input. Only a subset of experts are activated per input, enabling sparse computation.

The key insight: not every input needs every parameter. By learning to route different inputs to different experts, MoE models achieve better performance with less computation per token.

Gating Function

The gating function determines expert selection:

Gating Network

G(x)=softmax(Wgβ‹…x+bg)G(x) = \text{softmax}(W_g \cdot x + b_g)

Here,

  • xx=input token representation
  • WgW_g=learnable gating weight matrix
  • bgb_g=learnable gating bias
  • G(x)G(x)=probability distribution over experts
MoE Forward Pass
y=βˆ‘i=1Ngi(x)β‹…Ei(x)y = \sum_{i=1}^{N} g_i(x) \cdot E_i(x)

Here,

  • NN=total number of experts
  • gi(x)g_i(x)=gating weight for expert i
  • Ei(x)E_i(x)=output of expert i
  • yy=final output

In practice, only the top-k experts are activated:

Top-k Routing

gi(x)={exp⁑((Wgβ‹…x)i)βˆ‘j∈Top-kexp⁑((Wgβ‹…x)j)ifΒ i∈Top-k0otherwiseg_i(x) = \begin{cases} \frac{\exp((W_g \cdot x)_i)}{\sum_{j \in \text{Top-k}} \exp((W_g \cdot x)_j)} & \text{if } i \in \text{Top-k} \\ 0 & \text{otherwise} \end{cases}

Here,

  • Top-k\text{Top-k}=set of k experts with highest gating scores
  • kk=number of experts to activate (typically 1 or 2)

Load Balancing

Without explicit balancing, the gating network may collapse to using only a few experts.

Without load balancing loss, the gating network tends to converge to selecting the same experts for all inputs, resulting in underutilization of the full model capacity. This is known as the "rich get richer" phenomenon.

Load Balancing Loss

Lbalance=Ξ±β‹…Nβ‹…βˆ‘i=1Nfiβ‹…Pi\mathcal{L}_{\text{balance}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i

Here,

  • NN=number of experts
  • fif_i=fraction of tokens routed to expert i
  • PiP_i=average gating probability for expert i
  • Ξ±\alpha=balancing coefficient (typically 0.01-0.1)
import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
        balance_loss_coeff: float = 0.01
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.balance_loss_coeff = balance_loss_coeff
        
        # Expert networks
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
    
    def compute_load_balance_loss(self, gate_probs: torch.Tensor) -> torch.Tensor:
        """Compute load balancing loss."""
        # f_i: fraction of tokens routed to each expert
        # In practice, this is computed per-batch
        f = gate_probs.mean(dim=0)  # Average probability per expert
        
        # P_i: average gating probability for each expert
        P = gate_probs.mean(dim=0)
        
        # Balance loss
        balance_loss = self.num_experts * (f * P).sum()
        return balance_loss
    
    def forward(self, x: torch.Tensor):
        batch_size, seq_len, input_dim = x.shape
        x_flat = x.view(-1, input_dim)
        
        # Compute gating scores
        gate_logits = self.gate(x_flat)  # (batch*seq, num_experts)
        gate_probs = F.softmax(gate_logits, dim=-1)
        
        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Compute load balance loss
        balance_loss = self.compute_load_balance_loss(gate_probs)
        
        # Route tokens to experts
        output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, k]  # (batch*seq,)
            expert_weights = top_k_probs[:, k]  # (batch*seq,)
            
            for i in range(self.num_experts):
                mask = (expert_indices == i)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.experts[i](expert_input)
                    output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
        
        output = output.view(batch_size, seq_len, -1)
        return output, balance_loss

Mixtral Architecture

Mixtral (by Mistral AI) is a prominent MoE model with 8 experts and top-2 routing:

Mixtral 8x7B has 46.7B total parameters but only uses ~12.9B parameters per forward pass (2 out of 8 experts active). This achieves performance comparable to LLaMA-2 70B while being 6x faster at inference.

Mixtral Architecture Details

ComponentSpecification
Total Parameters46.7B
Active Parameters~12.9B
Experts8 per layer
Top-k2
Expert FFN Size14336
Hidden Size4096
Layers32
Attention Heads32
class MixtralBlock(nn.Module):
    def __init__(self, dim: int = 4096, num_experts: int = 8, top_k: int = 2):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, 32, batch_first=True)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # MoE FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(dim, dim * 4),
                nn.SiLU(),
                nn.Linear(dim * 4, dim)
            ) for _ in range(num_experts)
        ])
        
        self.gate = nn.Linear(dim, num_experts, bias=False)
        self.top_k = top_k
    
    def forward(self, x: torch.Tensor):
        # Self-attention with residual
        residual = x
        x = self.norm1(x)
        x, _ = self.attention(x, x, x)
        x = residual + x
        
        # MoE FFN with residual
        residual = x
        x = self.norm2(x)
        
        # Gating
        gate_scores = F.softmax(self.gate(x), dim=-1)
        top_k_probs, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        # Expert computation
        output = torch.zeros_like(x)
        for k in range(self.top_k):
            for i in range(len(self.experts)):
                mask = (top_k_indices[:, :, k] == i)
                if mask.any():
                    expert_input = x[mask]
                    expert_output = self.experts[i](expert_input)
                    output[mask] += top_k_probs[:, :, k].unsqueeze(-1)[mask] * expert_output
        
        x = residual + output
        return x

DeepSeek-MoE

DeepSeek-MoE introduces finer-grained expert specialization:

DeepSeek-MoE uses more experts (up to 160) with finer granularity and implements "shared experts" that are always activated, combined with "routed experts" that are selected by the gating network.

DeepSeek-MoE Design Principles

  1. Finer-grained experts: More, smaller experts for better specialization
  2. Shared experts: Always-active experts for common knowledge
  3. Routed experts: Conditionally activated for specialized knowledge
class DeepSeekMoELayer(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_shared_experts: int = 2,
        num_routed_experts: int = 64,
        top_k: int = 6
    ):
        super().__init__()
        self.num_shared_experts = num_shared_experts
        self.num_routed_experts = num_routed_experts
        self.top_k = top_k
        
        # Shared experts (always active)
        self.shared_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_shared_experts)
        ])
        
        # Routed experts (conditionally active)
        self.routed_experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.GELU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_routed_experts)
        ])
        
        # Gating network
        self.gate = nn.Linear(input_dim, num_routed_experts, bias=False)
    
    def forward(self, x: torch.Tensor):
        batch_size, seq_len, input_dim = x.shape
        x_flat = x.view(-1, input_dim)
        
        # Compute shared expert output
        shared_output = torch.zeros_like(x_flat)
        for expert in self.shared_experts:
            shared_output += expert(x_flat)
        shared_output /= self.num_shared_experts
        
        # Compute gated expert output
        gate_logits = self.gate(x_flat)
        gate_probs = F.softmax(gate_logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        
        routed_output = torch.zeros_like(x_flat)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, k]
            expert_weights = top_k_probs[:, k]
            
            for i in range(self.num_routed_experts):
                mask = (expert_indices == i)
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.routed_experts[i](expert_input)
                    routed_output[mask] += expert_weights[mask].unsqueeze(-1) * expert_output
        
        # Combine shared and routed outputs
        output = shared_output + routed_output
        return output.view(batch_size, seq_len, -1)

Load Balancing Analysis

Expert Utilization Variance

Var(f)=1Nβˆ‘i=1N(fiβˆ’fΛ‰)2\text{Var}(f) = \frac{1}{N} \sum_{i=1}^{N} (f_i - \bar{f})^2

Here,

  • fif_i=fraction of tokens routed to expert i
  • fΛ‰=1/N\bar{f} = 1/N=ideal uniform utilization
  • NN=number of experts

A well-balanced MoE model has low variance in expert utilization:

ModelExpertsTop-kUtilization Variance
Mixtral 8x7B820.002
Switch Transformer12810.015
GShard204820.008

Practical: Deploying MoE Models

class MoEInferenceEngine:
    def __init__(self, model_name: str):
        self.model = self.load_moe_model(model_name)
        self.expert_cache = {}
    
    def load_moe_model(self, model_name: str):
        """Load MoE model with expert offloading support."""
        from transformers import AutoModelForCausalLM
        
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16
        )
        return model
    
    def generate_with_expert_stats(self, prompt: str, max_tokens: int = 100):
        """Generate text while tracking expert utilization."""
        expert_counts = {i: 0 for i in range(8)}  # Assuming 8 experts
        
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        
        for _ in range(max_tokens):
            with torch.no_grad():
                outputs = self.model(input_ids, output_router_logits=True)
                
                # Track expert usage
                if hasattr(outputs, 'router_logits'):
                    router_logits = outputs.router_logits[-1]
                    expert_indices = torch.argmax(router_logits, dim=-1)
                    for idx in expert_indices:
                        expert_counts[idx.item()] += 1
                
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
                input_ids = torch.cat([input_ids, next_token.unsqueeze(1)], dim=-1)
        
        return {
            "text": self.tokenizer.decode(input_ids[0]),
            "expert_utilization": expert_counts
        }

When deploying MoE models, ensure all experts fit in memory or use expert offloading. The total parameter count can be 4-8x larger than dense models of similar quality, but inference is faster because only a subset of experts are active per token.

Summary

  • MoE models route inputs to a subset of expert networks via a gating function
  • Top-k routing activates only k experts per input, enabling sparse computation
  • Load balancing loss prevents expert collapse: L_balance = Ξ± Β· N Β· Ξ£ f_i Β· P_i
  • Mixtral 8x7B achieves LLaMA-2 70B quality with 6x faster inference
  • DeepSeek-MoE uses shared + routed experts for better specialization
  • MoE models require 4-8x more memory than dense models but are faster at inference
  • Expert utilization variance measures load balancing quality

Practice Exercises

  1. Gating Analysis: Visualize the gating network's expert selection patterns. Do different experts specialize on different types of inputs?

  2. Load Balancing: Train an MoE model with and without load balancing loss. Compare expert utilization.

  3. Expert Specialization: Analyze what each expert learns. Do experts specialize on different linguistic phenomena?

  4. MoE vs Dense: Compare an MoE model with a dense model of equal computational budget. Which achieves better performance?

  5. Deployment Optimization: Implement expert offloading to reduce memory usage. Measure the impact on inference speed.


What to Learn Next

-> Multimodal LLMs Combining MoE with multimodal architectures for efficient scaling.

-> Long Context and Context Window MoE models must handle long contexts efficiently across experts.

-> LLM Architecture Deep Dive Understanding the transformer blocks that MoE replaces with expert routing.

-> Scaling Laws and Chinchilla How MoE relates to compute-optimal scaling strategies.

-> Building Production LLM Applications Deploying MoE models requires handling large parameter counts in memory.

-> Open Source LLM Ecosystem Mixtral and other open-source MoE models available for deployment.


Previous: 18 - Multimodal LLMs <- | Next: 20 - LLM Agent Frameworks ->

⭐

Premium Content

Mixture of Experts

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert LLM Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement