Why Cache LLM Results?
LLM inference is expensive. Caching allows reuse of previous computations, reducing both latency and cost for repeated or similar queries.
Cache Taxonomy
| Cache Type | Granularity | Hit Rate | Latency Reduction |
|---|---|---|---|
| Exact match | Exact query | Low (1-5%) | 100% |
| Semantic | Similar queries | Medium (20-40%) | 95-99% |
| Prefix | Shared prefixes | High (50-80%) | 50-80% |
| KV Cache | Attention states | N/A (GPU) | 30-60% |
| Prompt Cache | Repeated prompts | Medium (15-30%) | 90-95% |
Semantic Caching
Semantic caching matches queries by meaning rather than exact text, catching paraphrases and variations.
import numpy as np
import time
from dataclasses import dataclass
from typing import Optional
@dataclass
class CacheEntry:
query: str
response: str
embedding: np.ndarray
cost: float
timestamp: float
hit_count: int = 0
class SemanticCache:
def __init__(self, embedder, threshold: float = 0.92, max_size: int = 10000):
self.embedder = embedder
self.threshold = threshold
self.max_size = max_size
self.cache: dict[str, CacheEntry] = {}
self.stats = {"hits": 0, "misses": 0}
def lookup(self, query: str) -> Optional[dict]:
if not self.cache:
return None
query_embedding = self.embedder.embed([query])[0]
best_match = None
best_score = 0
for key, entry in self.cache.items():
similarity = self._cosine_similarity(query_embedding, entry.embedding)
if similarity > best_score and similarity >= self.threshold:
best_score = similarity
best_match = entry
if best_match:
self.stats["hits"] += 1
best_match.hit_count += 1
return {
"response": best_match.response,
"cached": True,
"similarity": best_score,
"age_seconds": time.time() - best_match.timestamp
}
self.stats["misses"] += 1
return None
def store(self, query: str, response: str, cost: float):
if len(self.cache) >= self.max_size:
self._evict_least_used()
embedding = self.embedder.embed([query])[0]
self.cache[query] = CacheEntry(
query=query,
response=response,
embedding=embedding,
cost=cost,
timestamp=time.time()
)
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def _evict_least_used(self):
if self.cache:
least_used = min(self.cache.values(), key=lambda e: e.hit_count)
del self.cache[least_used.query]
def hit_rate(self) -> float:
total = self.stats["hits"] + self.stats["misses"]
return self.stats["hits"] / total if total > 0 else 0
def savings(self) -> dict:
return {
"total_hits": self.stats["hits"],
"total_misses": self.stats["misses"],
"hit_rate": self.hit_rate(),
"estimated_cost_saved": sum(e.cost * e.hit_count for e in self.cache.values())
}
Prefix Caching
Prefix caching stores the KV cache for common prompt prefixes, avoiding redundant prefill computation.
class PrefixCache:
def __init__(self):
self.prefix_map = {} # prefix_hash -> kv_cache
def get_prefix(self, messages: list[dict], tokenizer) -> tuple[int, int]:
"""Find the longest common prefix with cached content."""
full_text = tokenizer.apply_chat_template(messages, tokenize=False)
full_tokens = tokenizer.encode(full_text)
best_match_len = 0
for prefix_hash, cached in self.prefix_map.items():
prefix_len = self._find_common_prefix(full_tokens, cached["tokens"])
if prefix_len > best_match_len:
best_match_len = prefix_len
return best_match_len, len(full_tokens)
def store_prefix(self, tokens: list[int], kv_cache):
prefix_hash = hash(tuple(tokens[:min(100, len(tokens))]))
self.prefix_map[prefix_hash] = {
"tokens": tokens,
"kv_cache": kv_cache
}
KV Cache Optimization
Prompt Caching in vLLM
from vllm import LLM, SamplingParams
# vLLM automatically handles prefix caching
llm = LLM(
model="meta-llama/Llama-2-7b-hf",
enable_prefix_caching=True, # Automatic prefix KV cache reuse
block_size=16,
gpu_memory_utilization=0.9
)
# Requests with shared system prompt benefit from prefix caching
system_prompt = "You are a helpful assistant with expertise in data engineering."
prompts = [
f"{system_prompt}\n\nWhat is data partitioning?",
f"{system_prompt}\n\nExplain data sharding strategies.",
f"{system_prompt}\n\nHow does bucketing work in Hive?",
]
# Second and third requests reuse KV cache from system prompt
outputs = llm.generate(prompts, SamplingParams(max_tokens=256))
Continuous Batching with Cache Awareness
class CacheAwareBatcher:
def __init__(self, max_batch_size: int = 32, max_wait_ms: float = 100):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.pending = []
def add_request(self, request: dict):
self.pending.append(request)
def get_batch(self) -> list[dict]:
if not self.pending:
return []
# Group by prefix for cache efficiency
prefix_groups = {}
for req in self.pending:
prefix = req.get("prefix_hash", "default")
if prefix not in prefix_groups:
prefix_groups[prefix] = []
prefix_groups[prefix].append(req)
# Prioritize larger groups for better cache reuse
sorted_groups = sorted(prefix_groups.values(), key=len, reverse=True)
batch = []
for group in sorted_groups:
for req in group:
if len(batch) >= self.max_batch_size:
break
batch.append(req)
if len(batch) >= self.max_batch_size:
break
self.pending = [r for r in self.pending if r not in batch]
return batch
Response Caching Patterns
class ResponseCache:
def __init__(self, redis_client):
self.redis = redis_client
self.default_ttl = 3600 # 1 hour
def get_or_compute(self, key: str, compute_fn, ttl: int = None) -> dict:
cached = self.redis.get(f"llm_cache:{key}")
if cached:
return {"response": json.loads(cached), "cached": True}
response = compute_fn()
self.redis.setex(
f"llm_cache:{key}",
ttl or self.default_ttl,
json.dumps(response)
)
return {"response": response, "cached": False}
def invalidate_pattern(self, pattern: str):
keys = self.redis.keys(f"llm_cache:{pattern}")
if keys:
self.redis.delete(*keys)
Cache Metrics
| Metric | Formula | Target |
|---|---|---|
| Hit Rate | hits / (hits + misses) | >30% |
| Latency Savings | (avg_miss_latency - avg_hit_latency) / avg_miss_latency | >90% |
| Cost Savings | sum(cached_costs) / total_costs | >20% |
| Freshness | avg(age_of_cached_responses) | Use case dependent |
DfCache ROI
The return on investment for caching:
\text{ROI}{cache} = \frac{(\text{Cost}{without} - \text{Cost}{with}) - \text{Cost}{cache_infra}}{\text{Cost}_{cache_infra}}
Effective caching strategies can reduce LLM costs by 30-60% while improving response latency for repeated or similar queries.