KV Caching causes enormous savings in evaluating LLMs. One of the things that fundamentally enbable KV caching to work is: modern LLMs like the GPT series tend to be just Transformer decoders instead of an encoder-decoder architecture. Consider the operations that are involved in a decoder:

  • positional encodings
  • positionwise feedforward netowrks
  • LayerNorms (also positionwise)
  • Masked Self-Attention All operations except the Masked Self-Attention are purely positionwise, the only place where different positions interact is in the Masked Self-Attention

where the softmax is taken row-wise and the matrix is

to ensure that the resulting softmax matrix is causal, i.e., lower-diagonal.

PyTorch Implementation

In PyTorch, this causal masking is typically implemented using torch.tril() and masked_fill():

import torch
import math
 
def causal_attention(q, k, v):
    seq_len = q.size(-2)
    d_k = q.size(-1)
    
    # Compute attention scores
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Create causal mask
    causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
    
    # Apply mask (set future positions to -inf)
    scores = scores.masked_fill(~causal_mask, float('-inf'))
    
    # Apply softmax and compute output
    attn_weights = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    
    return output, attn_weights

Key functions:

  • torch.tril(input, diagonal=0): Returns the lower triangular part of a matrix. diagonal=0 includes the main diagonal.
  • tensor.masked_fill(mask, value): Fills elements of the tensor with value where mask is True.
  • The ~ operator inverts the boolean mask, so ~causal_mask selects the upper triangle positions to mask.