Ruixiang's blog

work harder, study better, do faster, become stronger

0%

KV Cache in Transformer Inference

KV Cache is an important technique used in transformer-decoder based models during inference, which can save much computation cost and make the inference faster.

Since the transformer-decoder based model is auto-regressive model, i.e. it generates new token one by one at each time step. Every time we get a new token, we append this new token to the input tokens of the model to generate a sequence of output tokens, the last token of output will be another new token. We repeat this generation process until we get the end of text token or hit to the limit of max sequence length.

During every inference step, we found we are only interested in the last token of model output tokens, which is a generated new token. But the old tokens have been generated again and again during each inference step due to the sequence to sequence mechanism of transformer-decoder based models. The model needs access to all the previous tokens to decide on which token to output during the masked self-attention stage, since they constitute its previous context.
Therefore, KV cache has been introduced to help model avoid doing repetitive computation in the attention calculation of inference and save GPU memory during inference.

The illustrated motivation of KV cache (source):
KV_cache_reason

Here is the image of after we applying KV cache to generate new token (source):
KV_cache

The Python implementation of KV Cache:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

# given a decoder-based gpt2 model
def gpt2(inputs: list[int], **params) -> list[list[float]]:
# inputs: [sep_len], outputs: [sep_len, vocab_size]
# provide arbitrary inputs and get output of the same length, with each element of the output indicating the probability of the next token.
output = Pseudo_Model(inputs, **params)
return output

# inference without KV cache
for _ in range(n_tokens_to_generate):
logits = gpt2(inputs, **params) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
generated_tokens = inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids


###### To optimize above inference process, we apply KV cache

# first we will add KV cache into multi-head attention calculation
def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
# n_seq = 1 when we pass kvcache, so we will compute new_q, new_k and new_v
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
if kvcache:
# qkv
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
old_k, old_v = kvcache
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
qkv = [new_q, k, v] # new_q = [1, n_embd]
current_cache = [qkv[1], qkv[2]]
# split into heads
qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [n_head, 3, n_seq, n_embd/n_head]
# causal mask to hide future inputs from being attended to
if kvcache:
causal_mask = np.zeros((1, k.shape[0]))
else:
causal_mask = (1 - np.tri(x.shape[0])) * -1e10 # [n_seq, n_seq]
# perform attention over each head
out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [n_head, 3, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
# merge heads
x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
# out projection
x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
return x, current_cache

# Then we apply this KV-cached multi-head attention into the whole gpt2 model (cover all blocks)
def kv_cached_gpt2(inputs: list[int], kvcache, **params) -> list[list[float]]:
# inputs: [sep_len], outputs: [sep_len, vocab_size]
# provide arbitrary inputs and get output of the same length, with each element of the output indicating the probability of the next token.
if not kvcache:
kvcache = [None]
else:
# cache already available, only send last token as input for predicting next token
inputs = [inputs[-1]] # we only feed one query into model to calculate attention
output, new_kvcache = Pseudo_Model(inputs, kvcache, **params)
return output, new_kvcache

# inference with KV cache
kvcache = None
for _ in range(n_tokens_to_generate): # auto-regressive decode loop
logits, kvcache = kv_cached_gpt2(inputs, kvcache=kvcache, **params) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
generated_tokens = inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids

References:

Welcome to my other publishing channels