LLM Systems
LLM Inference Optimization — Making Language Models Faster and Cheaper
Efficient inference is the key to deploying LLMs at scale, requiring optimization across memory, compute, and scheduling.
- KV Cache — Eliminates redundant attention computation but requires careful memory management
- Speculative Decoding — A small draft model proposes tokens verified in parallel by the target model
- Continuous Batching — Dynamic request scheduling maximizes GPU utilization and throughput
"Start with quantization for immediate gains, then add continuous batching for multi-user scenarios."
LLM Inference Optimization
Efficient inference is critical for deploying language models in production. The autoregressive nature of LLMs—generating one token at a time—creates unique optimization challenges. This tutorial covers the core techniques for maximizing throughput and minimizing latency.
The Inference Challenge
LLM inference has two distinct phases:
- Prefill phase: Process the entire input prompt in parallel (compute-bound)
- Decode phase: Generate tokens one at a time (memory-bandwidth-bound)
The number of tokens generated per second, measured as total output tokens divided by wall-clock time. Throughput is limited by memory bandwidth during the decode phase.
KV Cache
The KV cache is the most fundamental optimization for autoregressive generation. Instead of recomputing attention for all previous tokens at each step, we cache key-value pairs.
KV Cache Memory
Here,
- =number of transformer layers
- =number of attention heads
- =dimension per head
- =sequence length
- =batch size
- =factor for both keys and values
For a 7B parameter model with 32 layers, 32 heads, 128 head dimension, and sequence length 2048, the KV cache requires approximately 2 × 32 × 32 × 128 × 2048 × 2 bytes ≈ 1GB per sequence in float16.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class KVCacheModel:
def __init__(self, model_name: str):
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.kv_cache = None
def prefill(self, prompt: str):
"""Process entire prompt and cache KV pairs."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model(**inputs, use_cache=True)
self.kv_cache = outputs.past_key_values
return outputs.logits[:, -1, :]
def decode_step(self, input_ids: torch.Tensor):
"""Generate next token using cached KV pairs."""
with torch.no_grad():
outputs = self.model(
input_ids,
past_key_values=self.kv_cache,
use_cache=True
)
self.kv_cache = outputs.past_key_values
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
return next_token
def generate(self, prompt: str, max_new_tokens: int = 100):
"""Full generation with KV cache."""
self.kv_cache = None
logits = self.prefill(prompt)
next_token = torch.argmax(logits[:, -1:], dim=-1)
generated = [next_token.item()]
for _ in range(max_new_tokens - 1):
next_token = self.decode_step(next_token)
generated.append(next_token.item())
return self.tokenizer.decode(generated)
Quantization
Quantization reduces model size and accelerates inference by using lower-precision representations.
Quantization Methods Comparison
| Method | Type | Bits | Speedup | Quality Impact |
|---|---|---|---|---|
| FP16 | Post-training | 16 | 1x | None |
| INT8 | Post-training | 8 | 1.5-2x | Minimal |
| INT4 (GPTQ) | Post-training | 4 | 2-3x | Small |
| AWQ | Post-training | 4 | 2-3x | Small |
| GGUF | Runtime | 2-8 | Variable | Variable |
| QLoRA | Training | 4 | N/A | Minimal |
GPTQ Quantization
GPTQ (GPT Quantization) uses second-order information to quantize weights with minimal accuracy loss:
GPTQ Quantization Objective
Here,
- =original weight matrix (m × n)
- =calibration data activations (n × c)
- =quantized weight matrix
- =number of bits (typically 4)
- =number of calibration samples
AWQ (Activation-Aware Weight Quantization)
AWQ identifies salient weight channels based on activation magnitudes and preserves them during quantization:
import torch
from awq import AutoAWQForCausalLM
def quantize_with_awq(model_path: str, output_path: str):
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
quant_config = {
"zero_point": True,
"q_group_size": 128,
"w_bit": 4,
"version": "GEMM"
}
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data="dataset"
)
model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
Speculative Decoding
Speculative decoding uses a smaller "draft" model to propose tokens that are verified in parallel by the target model.
A technique where a small, fast draft model generates candidate token sequences that are verified in parallel by the larger target model. Accepted tokens are accepted in bulk, achieving sub-linear token generation cost.
Speculative Decoding Acceptance
Here,
- =probability from target model
- =probability from draft model
- =acceptance probability
class SpeculativeDecoder:
def __init__(self, target_model, draft_model, tokenizer, gamma: int = 5):
self.target = target_model
self.draft = draft_model
self.tokenizer = tokenizer
self.gamma = gamma # max draft tokens per step
def generate(self, prompt: str, max_new_tokens: int = 100):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
generated = input_ids.clone()
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# Draft phase: generate gamma tokens with draft model
draft_tokens = []
draft_probs = []
x = generated.clone()
for _ in range(self.gamma):
with torch.no_grad():
logits = self.draft(x).logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token)
draft_probs.append(probs)
x = torch.cat([x, token], dim=-1)
# Verification phase: run target model on all draft tokens
with torch.no_grad():
target_logits = self.target(x).logits
# Accept or reject each draft token
n_accepted = 0
for i in range(self.gamma):
target_prob = torch.softmax(target_logits[:, -1-i, :], dim=-1)
draft_prob = draft_probs[i]
token = draft_tokens[i]
accept_prob = torch.min(
torch.ones(1),
target_prob[0, token[0, 0]] / draft_prob[0, token[0, 0]]
)
if torch.rand(1) < accept_prob:
n_accepted += 1
else:
break
# Add accepted tokens
accepted = torch.cat(draft_tokens[:n_accepted], dim=-1)
generated = torch.cat([generated, accepted], dim=-1)
# Sample one more token from target if not at max
if generated.shape[1] - input_ids.shape[1] < max_new_tokens:
with torch.no_grad():
target_sample = torch.softmax(target_logits[:, -1, :], dim=-1)
new_token = torch.multinomial(target_sample, 1)
generated = torch.cat([generated, new_token], dim=-1)
return self.tokenizer.decode(generated[0])
Continuous Batching
Continuous batching (also called in-flight batching) allows dynamic request scheduling instead of static batch processing.
Continuous Batching Throughput
Here,
- =number of requests in the batch
- =output length of request i
- =total processing time
class ContinuousBatchScheduler:
def __init__(self, model, max_batch_size: int = 32, max_tokens: int = 4096):
self.model = model
self.max_batch_size = max_batch_size
self.max_tokens = max_tokens
self.pending_requests = []
self.active_requests = []
def add_request(self, request):
self.pending_requests.append(request)
def schedule_step(self):
# Fill batch from pending requests
while (len(self.active_requests) < self.max_batch_size and
self.pending_requests):
request = self.pending_requests.pop(0)
if self._can_add(request):
self.active_requests.append(request)
# Run one decode step for all active requests
if self.active_requests:
self._decode_step()
# Remove completed requests
completed = [r for r in self.active_requests if r.is_done]
for r in completed:
self.active_requests.remove(r)
def _can_add(self, request):
total_tokens = sum(r.current_length for r in self.active_requests)
return total_tokens + request.current_length <= self.max_tokens
def _decode_step(self):
# Batch all active requests together
batch = self._prepare_batch()
with torch.no_grad():
logits = self.model(batch.input_ids).logits
for i, request in enumerate(self.active_requests):
next_token = torch.argmax(logits[i, -1, :])
request.append_token(next_token)
Deployment Frameworks
vLLM
vLLM pioneered PagedAttention for efficient KV cache management:
PagedAttention stores KV cache in non-contiguous memory blocks (like virtual memory pages), reducing memory fragmentation from 60-80% to near 0%. This enables serving more concurrent requests.
from vllm import LLM, SamplingParams
# Initialize vLLM engine
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.9,
max_model_len=4096
)
# Batch inference
prompts = [
"Explain quantum computing in simple terms.",
"Write a Python function to sort a list.",
"What are the benefits of exercise?"
]
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=256
)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
TensorRT-LLM
TensorRT-LLM provides NVIDIA-optimized inference:
import tensorrt_llm
# Build engine
builder = tensorrt_llm.Builder()
network = builder.create_network()
# ... build network architecture ...
# Optimize for inference
config = builder.create_builder_config()
config.max_batch_size = 32
config.max_input_len = 2048
config.max_output_len = 512
engine = builder.build_serialized_network(network, config)
Latency vs Throughput Tradeoffs
Cost per Token
Here,
- =cloud GPU hourly rate / 3600
- =time to generate response
- =number of output tokens
| Optimization | Latency Impact | Throughput Impact | Use Case |
|---|---|---|---|
| KV Cache | v v v | v | Always use |
| Quantization (INT4) | v | ^ ^ ^ | Memory-limited |
| Speculative Decoding | v v | ^ | High-latency apps |
| Continuous Batching | v | ^ ^ ^ | Multi-user serving |
| Tensor Parallelism | v v | ^ ^ | Large models |
For production deployments, start with quantization (INT4/GPTQ) for immediate memory and throughput gains, then add continuous batching for multi-user scenarios. Speculative decoding is most beneficial for latency-sensitive applications with single-user serving.
Practical Deployment Example
from vllm import LLM, SamplingParams
from fastapi import FastAPI
from pydantic import BaseModel
import uvicorn
app = FastAPI()
# Initialize optimized model
llm = LLM(
model="TheBloke/Llama-2-7B-Chat-GPTQ",
quantization="gptq",
tensor_parallel_size=1,
gpu_memory_utilization=0.85,
max_model_len=2048,
enforce_eager=True # Disable CUDA graphs for lower latency
)
class GenerationRequest(BaseModel):
prompt: str
max_tokens: int = 256
temperature: float = 0.7
@app.post("/generate")
async def generate(request: GenerationRequest):
params = SamplingParams(
temperature=request.temperature,
top_p=0.95,
max_tokens=request.max_tokens
)
outputs = llm.generate([request.prompt], params)
return {
"text": outputs[0].outputs[0].text,
"tokens": len(outputs[0].outputs[0].token_ids),
"usage": {
"prompt_tokens": len(outputs[0].prompt_token_ids),
"completion_tokens": len(outputs[0].outputs[0].token_ids)
}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Summary
- KV cache eliminates redundant attention computation, requiring 2 × L × n_heads × d_head × s × b × bytes per sequence
- GPTQ and AWQ provide 2-3x speedup with INT4 quantization and minimal quality loss
- Speculative decoding achieves sub-linear token generation cost via draft-verify
- Continuous batching maximizes throughput by dynamically scheduling requests
- vLLM's PagedAttention reduces KV cache memory fragmentation to near 0%
- Latency vs throughput tradeoffs depend on use case; optimize accordingly
Practice Exercises
-
KV Cache Analysis: Calculate KV cache memory for LLaMA-2-7B, LLaMA-2-13B, and LLaMA-2-70B at sequence length 4096. What batch sizes are feasible on a 24GB GPU?
-
Quantization Benchmark: Compare inference speed and quality between FP16, INT8, and INT4 (GPTQ) for a 7B model. Measure perplexity on WikiText-2 and tokens/second.
-
Speculative Decoding: Implement speculative decoding with a 7B target model and 1.5B draft model. Measure acceptance rate and speedup.
-
Throughput Optimization: Deploy a model with vLLM and measure throughput at different batch sizes. Find the optimal configuration for your hardware.
-
Latency Optimization: Optimize a model for minimum first-token latency using quantization, KV cache, and CUDA graphs. Measure the impact of each optimization.
What to Learn Next
-> Building Production LLM Applications Putting inference optimization into practice with monitoring and cost tracking.
-> LLM Safety and Red Teaming Ensuring optimized inference doesn't compromise safety guarantees.
-> QLoRA and Quantization Deeper dive into quantization techniques for memory and speed optimization.
-> LoRA and PEFT Parameter-efficient methods that complement inference optimization strategies.
-> LLM Evaluation Benchmarks Measuring the impact of optimization on model quality across benchmarks.
-> Open Source LLM Ecosystem Available tools and frameworks for deploying optimized models.
Previous: 15 - LLM Evaluation Benchmarks <- | Next: 17 - Long Context Window ->