Ruixiang's blog

work harder, study better, do faster, become stronger

0%

Coding Llama2 model from scratch

This blog will implement decoder-based model llama2 from scratch with PyTorch to get better understanding of model structure and some tricks used in the model.
llama

RMSNorm implementation

RMSNorm

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# The gamma parameter
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x: torch.Tensor):
# (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)
# rsqrt: 1 / sqrt(x)
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
# (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)
return self.weight * self._norm(x.float()).type_as(x)

Rotary Position Embedding

rotary
rotary_matrix
Background Knowledge:

  • Can we find an inner product over the two vectors q (query) and k (key) used in the attention mechanism that only depends on the two vectors and the relative distance of the token they represent?
  • The “intensity” of relationship between two tokens encoded with Rotary Positional Embeddings will be numerically smaller as the distance between them grows.
  • The rotary position embeddings are only applied to the query and the keys, but not the values.
  • The rotary position embeddings are applied after the vector q and k have been multiplied by the W matrix in the attention mechanism, while in the vanilla transformer they’re applied before.
    q_k_rotary
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
# absolute position encoding: according to the formula: theta_i = 10000^(-2*(i-1)/dim) for i = [1, 2, ..., dim/2]
theta_numerator = torch.arange(0, head_dim, 2).float() # (head_dim/2)
theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (head_dim/2)
# construct the position (the "m" parameter)
m = torch.arange(seq_len, device=device) #(seq_len)
# Multiply each theta by each position using the outer product
# (seq_len) outer_product* (head_dim/2) -> (seq_len, head_dim/2)
freqs = torch.outer(m, theta).float()
# compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows:
# (seq_len, head_dim/2) -> (seq_len, head_dim/2)
freqs_complex = torch.polar(torch.ones_list(freqs), freqs)
return freqs_complex

def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [batch_size, seq_len, dim]
# xq_.shape = [batch_size, seq_len, dim // 2, 2]
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)

# 转为复数域
xq_ = torch.view_as_complex(xq_)
xk_ = torch.view_as_complex(xk_)

# 应用旋转操作,然后将结果转回实数域
# xq_out.shape = [batch_size, seq_len, dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)
return xq_out.type_as(xq), xk_out.type_as(xk)

# During attention calculation, Q and K will apply rotary embedding:
# freqs_cis = precompute_theta_pos_frequencies(dim, max_seq_len*2, "cuda")
# Q, K = apply_rotary_emb(Q, K, freqs_cis)

Grouped Query Attention

Motivation: Our goal should not only be to optimize the number of operations we do, but also minimize the memory access/transfers that we perform. Since the data transfer and memory access are always the bottlenecks in modern GPU, not computation operations.
This is a comprise between quality (multi-head attention) and speed (multi-query attention).
attention_comparison
multi_query_attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1 # Later set in the build method
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
# Needed for KV cache
max_batch_size: int = 32
max_seq_len: int = 2048
device: str = None

# used for group/multi query attention
def repeat_kv(x: torch:Tensor, n_rep: int) -> torch.Tensor:
batch_size, seq_len, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
x = x[:, :, :, None, :]
x = x.expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
x = x.reshape(batch_size, seq_len, n_kv_heads*n_rep, head_dim)
return x


class SelfAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
# the number of heads for K and V
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
# the number of heads for Q
self.n_heads_q = args.n_heads
# how many times the K and V should be repeated
self.n_rep = self.n_heads_q // self.n_kv_heads
# the dimension of each head, i.e. the part of the embedding that each head will handle
self.head_dim = args.dim // args.n_heads
# Q, K, V, output projection matrix
self.wq = nn.Linear(args.dim, self.n_heads_q * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.n_heads_q * self.head_dim, args.dim, bias=False)
# Add KV cache
self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
# freqs_complex is used to add rotary embedding into Q, K
# start_pos is used to update KV cache
batch_size, seq_len, _ = x.shape # (batch_size, seq_len, Dim) since we are using KV cache, so the seq_len is 1
xq = self.wq(x) # (batch_size, seq_len, Dim) -> (batch_size, seq_len, heads_Q * head_dim)
xk = self.wk(x) # (batch_size, seq_len, Dim) -> (batch_size, seq_len, heads_KV * head_dim)
xv = self.wv(x) # (batch_size, seq_len, Dim) -> (batch_size, seq_len, heads_KV * head_dim)
xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim) # (batch_size, seq_len, heads_Q * head_dim) -> (batch_size, seq_len, heads_Q, head_dim)
xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
# apply rotary embedding
xq, xk = apply_rotary_emb(xq, xk, freqs_complex)# (batch_size, seq_len, heads_Q or heads_KV, head_dim) -> (batch_size, seq_len, heads_Q or heads_KV, head_dim)
# replace the entry in the KV cache
self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv
keys = self.cache_k[:batch_size, :start_pos+seq_len] #(batch_size, seq_len_kv, heads_KV, head_dim): here seq_len_kv is all the length of all the previous tokens
values = self.cache_v[:batch_size, :start_pos+seq_len] #(batch_size, seq_len_kv, heads_KV, head_dim)
keys = repeat_kv(keys) # (batch_size, seq_len_kv, heads_KV, head_dim) -> (batch_size, seq_len_kv, heads_Q, head_dim)
values = repeat_kv(values)
xq = xq.transpose(1, 2) # (batch_size, seq_len, heads_Q, head_dim) -> (batch_size, heads_Q, seq_len, head_dim)
keys = keys.transpose(1, 2) # (batch_size, seq_len_kv, heads_Q, head_dim) -> (batch_size, heads_Q, seq_len_kv, head_dim)
values = values.transpose(1, 2)
# scaled dot attention calculation
# (batch_size, heads_Q, seq_len=1, head_dim) @ (batch_size, heads_Q, head_dim, seq_len_kv) -> (batch_size, heads_Q, seq_len=1, seq_len_kv)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
# (batch_size, heads_Q, seq_len=1, seq_len_kv) @ (batch_size, heads_Q, seq_len_kv, head_dim) -> (batch_size, heads_Q, seq_len=1, head_dim)
output = torch.matmul(scores, values)
output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
return self.wo(output) # (batch_size, seq_len, dim) -> (batch_size, seq_len, dim)

SwiGLU activation function and FeedForward

SwiGLU
silu
silu_func

1
2
3
4
5
6
7
8
9
10
11
12
class FeedFroward(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
hidden_dim = 4 * args.dim
hidden_dim = int(2 * hidden_dim / 3)
self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, args.dim)
self.w3 = nn.Linear(args.dim, hidden_dim)

def forward(self, x) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))

Llama2 model

Put all above components together to get llama2 model.

  • Build blocks

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    class EncoderBlock(nn.Module):

    def __init__(self, args: ModelArgs):
    super().__init__()

    self.n_heads = args.n_heads
    self.dim = args.dim
    self.head_dim = args.dim // args.n_heads

    self.attention = SelfAttention(args)
    self.feed_forward = FeedForward(args)

    # Normalization BEFORE the attention block
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    # Normalization BEFORE the feed forward block
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
    h = x + self.attention.forward(
    self.attention_norm(x), start_pos, freqs_complex
    )
    out = h + self.feed_forward.forward(self.ffn_norm(h))
    return out
  • Build whole model

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    class Transformer(nn.Module):

    def __init__(self, args: ModelArgs):
    super().__init__()

    assert args.vocab_size != -1, "Vocab size must be set"

    self.args = args
    self.vocab_size = args.vocab_size
    self.n_layers = args.n_layers
    self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

    self.layers = nn.ModuleList()
    for layer_id in range(args.n_layers):
    self.layers.append(EncoderBlock(args))

    self.norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

    self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)

    def forward(self, tokens: torch.Tensor, start_pos: int):
    batch_size, seq_len = tokens.shape
    assert seq_len == 1, "Only one token at a time can be processed"
    h = self.tok_embeddings(tokens)
    # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
    # Consecutively apply all the encoder layers
    for layer in self.layers:
    h = layer(h, start_pos, freqs_complex)
    h = self.norm(h)
    output = self.output(h).float()
    return output

References

Welcome to my other publishing channels