Ruixiang's blog

work harder, study better, do faster, become stronger

0%

Coding Transformer model from scratch

Here is a simple PyTorch implementation of transformer model.
Basically the implementation has following components:

  • Building Blocks: Multi-Head Attention, Position-Wise Feed-Forward Networks, Positional Encoding
  • Building Encoder and Decoder layers
  • Combining Encoder and Decoder layers to create complete Transformer model

transformer_architecture

pos_encoding

PyTorch implementation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy


###### Basic components in transformer block
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
# d_model: size of hidden dimension
# num_heads: number of heads for multi-head attention
super(MultiHeadAttention, self).__init__()
# since we split hidden dimension into different heads, make sure it is divisible by number of heads
assert d_model % num_heads == 0

self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # size of hidden dimension in each head

# init weight matrix for q,k,v and output
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)

def scaled_dot_product_attention(self, Q, K, V, mask=None):
# mask is None for encoder block, mask is triangle 0-1 matrix for decoder block
# softmax((Q * K.T) / sqrt(d_k)) * V
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None: # decoder
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_prob = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_prob, V)
return output

def split_heads(self, x):
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

def combine_heads(self, x):
batch_size, num_heads, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output = scaled_dot_product_attention(Q, K, V, mask)
# attention output projection to get final output
output = self.W_o(self.combine_heads(attn_output))
return output


class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(PositionWiseFeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()

def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x


class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super(PositionalEncoding, self).__init__()

pe = torch.zeros(max_seq_length, d_model)
position = torch.arrange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arrange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
# sin, cos transformation for each dimension (even, odd) in all positions
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

self.register_buffer('pe', pe.unsqueeze(0))

def forward(self, x):
return x + self.pe[:, :x.size(1)]

## positional embedding implementation in NumPy
# import numpy as np
# def get_pos_encoding(max_seq_length, d_model):
# angles = np.fromfunction(lambda i, j : i/10000**(2*j/d_model), (max_seq_length, int(d_model/2)))
# pos_enc = np.ones(max_seq_length, d_model)
# pos_enc[:,::2] = np.sin(angles)
# pos_enc[:,1::2] = np.cos(angles)
# return pos_enc


###### Build transformer encoder block
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
# self attention for encoder input
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask):
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x


###### Build transformer decoder block
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
# masked self attention for decoder input
self.self_attn = MultiHeadAttention(d_model, num_heads)
# cross attention between encoder and decoder
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm3 = LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, encoder_output, src_mask, tgt_mask):
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x


###### Build complete transformer model
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers,
d_ff, max_seq_length, dropout):
super(Transformer, self).__init__()
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.position_encoding = PositionalEncoding(d_model, max_seq_length)

self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
# decoder output layer converts hidden dimensions into number of vocab size to generate tokens
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)

def generate_mask(self, src, tgt):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2)
seq_length = tgt.size(1)
# nopeak_mask is True or False matrix: shape is [1, seq_length, seq_length]
# tensor([[[ True, False, False],
# [ True, True, False],
# [ True, True, True]]])
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
tgt_mask = tgt_mask & nopeak_mask # diagonal matrix (up part is 0)
return src_mask, tgt_mask

def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
src_embedded = self.dropout(self.position_encoding(self.encoder_embedding(src)))
tgt_embedded = self.dropout(self.position_encoding(self.decoder_embedding(tgt)))

encoder_output = src_embedded
for encoder_layer in self.encoder_layers:
encoder_output = encoder_layer(encoder_output, src_mask)

decoder_output = tgt_embedded
for decoder_layer in self.decoder_layers:
decoder_output = decoder_layer(decoder_output, encoder_output, src_mask, tgt_mask)

output = self.fc(decoder_output)
return output

Here is a simple transformer model training code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)) # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)) # (batch_size, seq_length)

criterion = nn.CrossEntropyLoss(ignore_index=0) # usually we use ignore_index=-100 for [MASK], here we use 0 to make consistent with above transformer code
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
transformer.train()
for epoch in range(10):
optimizer.zero_grad()
output = transformer(src_data, tgt_data[:, :-1])
loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
loss.backward()
optimizer.step()

References:

Welcome to my other publishing channels