KV Caching causes enormous savings in evaluating LLMs. One of the things that fundamentally enbable KV caching to work is: modern LLMs like the GPT series tend to be just Transformer decoders instead of an encoder-decoder architecture. Consider the operations that are involved in a decoder:
- positional encodings
- positionwise feedforward netowrks
- LayerNorms (also positionwise)
- Masked Self-Attention All operations except the Masked Self-Attention are purely positionwise, the only place where different positions interact is in the Masked Self-Attention
where the softmax is taken row-wise and the matrix is
to ensure that the resulting softmax matrix is causal, i.e., lower-diagonal.
PyTorch Implementation
In PyTorch, this causal masking is typically implemented using torch.tril() and masked_fill():
import torch
import math
def causal_attention(q, k, v):
seq_len = q.size(-2)
d_k = q.size(-1)
# Compute attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
# Create causal mask
causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
# Apply mask (set future positions to -inf)
scores = scores.masked_fill(~causal_mask, float('-inf'))
# Apply softmax and compute output
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return output, attn_weightsKey functions:
torch.tril(input, diagonal=0): Returns the lower triangular part of a matrix.diagonal=0includes the main diagonal.tensor.masked_fill(mask, value): Fills elements of the tensor withvaluewheremaskisTrue.- The
~operator inverts the boolean mask, so~causal_maskselects the upper triangle positions to mask.