πŸŽ‰ 75% of content is free forever β€” Unlock Premium from $10/mo β†’
CW
Search courses…
πŸ’Ό Servicesℹ️ Aboutβœ‰οΈ ContactView Pricing Plansfrom $10

LLM Caching Strategies

AI Safety & GuardrailsCaching🟒 Free Lesson

Advertisement

Why Cache LLM Results?

LLM inference is expensive. Caching allows reuse of previous computations, reducing both latency and cost for repeated or similar queries.

Cache Taxonomy

Cache TypeGranularityHit RateLatency Reduction
Exact matchExact queryLow (1-5%)100%
SemanticSimilar queriesMedium (20-40%)95-99%
PrefixShared prefixesHigh (50-80%)50-80%
KV CacheAttention statesN/A (GPU)30-60%
Prompt CacheRepeated promptsMedium (15-30%)90-95%

Semantic Caching

Semantic caching matches queries by meaning rather than exact text, catching paraphrases and variations.

import numpy as np
import time
from dataclasses import dataclass
from typing import Optional

@dataclass
class CacheEntry:
    query: str
    response: str
    embedding: np.ndarray
    cost: float
    timestamp: float
    hit_count: int = 0

class SemanticCache:
    def __init__(self, embedder, threshold: float = 0.92, max_size: int = 10000):
        self.embedder = embedder
        self.threshold = threshold
        self.max_size = max_size
        self.cache: dict[str, CacheEntry] = {}
        self.stats = {"hits": 0, "misses": 0}

    def lookup(self, query: str) -> Optional[dict]:
        if not self.cache:
            return None

        query_embedding = self.embedder.embed([query])[0]
        best_match = None
        best_score = 0

        for key, entry in self.cache.items():
            similarity = self._cosine_similarity(query_embedding, entry.embedding)
            if similarity > best_score and similarity >= self.threshold:
                best_score = similarity
                best_match = entry

        if best_match:
            self.stats["hits"] += 1
            best_match.hit_count += 1
            return {
                "response": best_match.response,
                "cached": True,
                "similarity": best_score,
                "age_seconds": time.time() - best_match.timestamp
            }

        self.stats["misses"] += 1
        return None

    def store(self, query: str, response: str, cost: float):
        if len(self.cache) >= self.max_size:
            self._evict_least_used()

        embedding = self.embedder.embed([query])[0]
        self.cache[query] = CacheEntry(
            query=query,
            response=response,
            embedding=embedding,
            cost=cost,
            timestamp=time.time()
        )

    def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
        return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

    def _evict_least_used(self):
        if self.cache:
            least_used = min(self.cache.values(), key=lambda e: e.hit_count)
            del self.cache[least_used.query]

    def hit_rate(self) -> float:
        total = self.stats["hits"] + self.stats["misses"]
        return self.stats["hits"] / total if total > 0 else 0

    def savings(self) -> dict:
        return {
            "total_hits": self.stats["hits"],
            "total_misses": self.stats["misses"],
            "hit_rate": self.hit_rate(),
            "estimated_cost_saved": sum(e.cost * e.hit_count for e in self.cache.values())
        }

Prefix Caching

Prefix caching stores the KV cache for common prompt prefixes, avoiding redundant prefill computation.

class PrefixCache:
    def __init__(self):
        self.prefix_map = {}  # prefix_hash -> kv_cache

    def get_prefix(self, messages: list[dict], tokenizer) -> tuple[int, int]:
        """Find the longest common prefix with cached content."""
        full_text = tokenizer.apply_chat_template(messages, tokenize=False)
        full_tokens = tokenizer.encode(full_text)

        best_match_len = 0
        for prefix_hash, cached in self.prefix_map.items():
            prefix_len = self._find_common_prefix(full_tokens, cached["tokens"])
            if prefix_len > best_match_len:
                best_match_len = prefix_len

        return best_match_len, len(full_tokens)

    def store_prefix(self, tokens: list[int], kv_cache):
        prefix_hash = hash(tuple(tokens[:min(100, len(tokens))]))
        self.prefix_map[prefix_hash] = {
            "tokens": tokens,
            "kv_cache": kv_cache
        }

KV Cache Optimization

Prompt Caching in vLLM

from vllm import LLM, SamplingParams

# vLLM automatically handles prefix caching
llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    enable_prefix_caching=True,  # Automatic prefix KV cache reuse
    block_size=16,
    gpu_memory_utilization=0.9
)

# Requests with shared system prompt benefit from prefix caching
system_prompt = "You are a helpful assistant with expertise in data engineering."

prompts = [
    f"{system_prompt}\n\nWhat is data partitioning?",
    f"{system_prompt}\n\nExplain data sharding strategies.",
    f"{system_prompt}\n\nHow does bucketing work in Hive?",
]

# Second and third requests reuse KV cache from system prompt
outputs = llm.generate(prompts, SamplingParams(max_tokens=256))

Continuous Batching with Cache Awareness

class CacheAwareBatcher:
    def __init__(self, max_batch_size: int = 32, max_wait_ms: float = 100):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.pending = []

    def add_request(self, request: dict):
        self.pending.append(request)

    def get_batch(self) -> list[dict]:
        if not self.pending:
            return []

        # Group by prefix for cache efficiency
        prefix_groups = {}
        for req in self.pending:
            prefix = req.get("prefix_hash", "default")
            if prefix not in prefix_groups:
                prefix_groups[prefix] = []
            prefix_groups[prefix].append(req)

        # Prioritize larger groups for better cache reuse
        sorted_groups = sorted(prefix_groups.values(), key=len, reverse=True)

        batch = []
        for group in sorted_groups:
            for req in group:
                if len(batch) >= self.max_batch_size:
                    break
                batch.append(req)
            if len(batch) >= self.max_batch_size:
                break

        self.pending = [r for r in self.pending if r not in batch]
        return batch

Response Caching Patterns

class ResponseCache:
    def __init__(self, redis_client):
        self.redis = redis_client
        self.default_ttl = 3600  # 1 hour

    def get_or_compute(self, key: str, compute_fn, ttl: int = None) -> dict:
        cached = self.redis.get(f"llm_cache:{key}")
        if cached:
            return {"response": json.loads(cached), "cached": True}

        response = compute_fn()
        self.redis.setex(
            f"llm_cache:{key}",
            ttl or self.default_ttl,
            json.dumps(response)
        )
        return {"response": response, "cached": False}

    def invalidate_pattern(self, pattern: str):
        keys = self.redis.keys(f"llm_cache:{pattern}")
        if keys:
            self.redis.delete(*keys)

Cache Metrics

MetricFormulaTarget
Hit Ratehits / (hits + misses)>30%
Latency Savings(avg_miss_latency - avg_hit_latency) / avg_miss_latency>90%
Cost Savingssum(cached_costs) / total_costs>20%
Freshnessavg(age_of_cached_responses)Use case dependent

DfCache ROI

The return on investment for caching:

\text{ROI}{cache} = \frac{(\text{Cost}{without} - \text{Cost}{with}) - \text{Cost}{cache_infra}}{\text{Cost}_{cache_infra}}

Effective caching strategies can reduce LLM costs by 30-60% while improving response latency for repeated or similar queries.

⭐

Premium Content

LLM Caching Strategies

Unlock this lesson and 900+ advanced tutorials with a Premium plan.

🎯End-to-end Projects
πŸ’ΌInterview Prep
πŸ“œCertificates
🀝Community Access

Already a member? Log in

Need Expert AI Ops & LLM Ops Help?

Get personalized tutoring, project support, or professional consulting.

Advertisement