Ruixiang's blog

work harder, study better, do faster, become stronger

0%

Multi-Head Self Attention Implementation

This blog summarizes the implementation of attention for encoder and decoder in transformer model.

attention equation

scaled_dot_product_attention & multi-head attention

Pytorch Implementation

  • Attention in transformer encoder (multi-head self attention)
    Note: the reason why we need divided $\sqrt{d}$ during attention scores calculation:
    attention_scale
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

from torch import Tensor
import torch.nn.functional as f
import torch
import torch.nn
import math

def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, mask=None) -> Tensor:
# query, key, value shape: (batch_size, head, seq_len, d_k) when using multi-head attention
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) #(batch_size, head, seq_len, seq_len)
if mask is not None: # the difference between encoder attention and decoder attention
scores = scores.masked_fill(mask==0, -1e9)
attn_scores = F.softmax(scores, dim=-1) #(batch_size, head, seq_len, seq_len)
out = torch.matmul(attn_scores, value) #(batch_size, head, seq_len, d_k)
return out, attn_scores

def multi_head_attention(query: Tensor, key: Tensor, value: Tensor, mask=None, d_model: int, head: int) -> Tensor:
# query, key, value projection matrix
w_q = nn.Linear(d_model, d_model, bias=False) # d_model is the embedding size
w_k = nn.linear(d_model, d_model, bias=False)
w_v = nn.linear(d_model, d_model, bias=False)
w_o = nn.linear(d_model, d_model, bias=False) # output projection
query = w_q(query) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
key = w_k(key)
value = w_v(value)
d_k = d_model // head # split the embedding size into multi heads
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, head, d_k) --> (batch_size, head, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], head, d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], head, d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], head, d_k).transpose(1, 2)
out, attn_scores = scaled_dot_product_attention(query, key, value, mask)
# combine all heads
# (batch_size, head, seq_len, d_k) -> (batch_size, seq_len, head, d_k) -> (batch_size, seq_len, d_model)
out = out.transpose(1, 2).contiguous().view(out.shape[0], -1, head*d_k)
out = w_o(out) # (batch_size, seq_len, d_model)
return out

# The difference between self attention and normal attention is query=key=value in self attention.
# In normal attention, usually only key=value
def multi_head_self_attention(x: Tensor, mask=None, head: int) -> Tensor:
d_model = x.shape[-1]
# query, key, value projection matrix
w_q = nn.Linear(d_model, d_model, bias=False) # d_model is the embedding size
w_k = nn.linear(d_model, d_model, bias=False)
w_v = nn.linear(d_model, d_model, bias=False)
w_o = nn.linear(d_model, d_model, bias=False) # output projection
query = w_q(x) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
key = w_k(x)
value = w_v(x)
d_k = d_model // head # split the embedding size into multi heads
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, head, d_k) --> (batch_size, head, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], head, d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], head, d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], head, d_k).transpose(1, 2)
# calculate the attention
out, attn_scores = scaled_dot_product_attention(query, key, value, mask)
# combine all heads
# (batch_size, head, seq_len, d_k) -> (batch_size, seq_len, head, d_k) -> (batch_size, seq_len, d_model)
out = out.transpose(1, 2).contiguous().view(out.shape[0], -1, head*d_k)
out = w_o(out) # (batch_size, seq_len, d_model)
return out
  • Attention in transformer decoder (mask multi-head self attention)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

def causal_mask(seq): # (batch_size, seq_len, d_model)
mask_attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # (batch_size, seq_len, seq_len)
mask_attn = np.triu(np.ones(mask_attn_shape), k=1)
mask_attn = torch.from_numpy(subsequence_mask).byte() # convert from numpy to tensor
return mask_attn

def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor, mask) -> Tensor:
# query, key, value shape: (batch_size, head, seq_len. d_k) when using multi-head attention
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) #(batch_size, head, seq_len, seq_len)
# the difference between encoder attention and decoder attention
scores = scores.masked_fill(mask==0, -1e9)
attn_scores = F.softmax(scores, dim=-1) #(batch_size, head, seq_len, seq_len)
out = torch.matmul(p_attn, value) #(batch_size, head, seq_len, d_k)
return out, attn_scores

def masked_multi_head_self_attention(x: Tensor, mask=None, head: int) -> Tensor:
d_model = x.shape[-1]
# query, key, value projection matrix
w_q = nn.Linear(d_model, d_model, bias=False) # d_model is the embedding size
w_k = nn.linear(d_model, d_model, bias=False)
w_v = nn.linear(d_model, d_model, bias=False)
w_o = nn.linear(d_model, d_model, bias=False) # output projection
query = w_q(x) # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_model)
key = w_k(x)
value = w_v(x)
mask = causal_mask(x) # attention mask: [batch_size, seq_len, seq_len]
d_k = d_model // head # split the embedding size into multi heads
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, head, d_k) --> (batch_size, head, seq_len, d_k)
query = query.view(query.shape[0], query.shape[1], head, d_k).transpose(1, 2)
key = key.view(key.shape[0], key.shape[1], head, d_k).transpose(1, 2)
value = value.view(value.shape[0], value.shape[1], head, d_k).transpose(1, 2)
# multi-head mask
mask = mask.unsqueeze(1).repeat(1, head, 1, 1) # (batch_size, seq_len, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
# calculate the attention
out, attn_scores = scaled_dot_product_attention(query, key, value, mask)
# combine all heads
# (batch_size, head, seq_len, d_k) -> (batch_size, seq_len, head, d_k) -> (batch_size, seq_len, d_model)
out = out.transpose(1, 2).contiguous().view(out.shape[0], -1, head*d_k)
out = w_o(out) # (batch_size, seq_len, d_model)
return out

NumPy implementation

  • Attention in transformer encoder (multi-head self attention)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    import numpy as np

    def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

    def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
    return x @ w + b

    def attention(q, k, v): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
    return softmax(q @ k.T / np.sqrt(q.shape[-1])) @ v

    def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
    # qkv projection
    x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
    # split into qkv
    qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
    # split into heads
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
    out_heads = [attention(q, k, v) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
    # merge heads
    x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
    # out projection
    x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
    return x
  • Attention in transformer decoder (mask multi-head self attention)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    import numpy as np

    def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

    def linear(x, w, b): # [m, in], [in, out], [out] -> [m, out]
    return x @ w + b

    def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
    return softmax(q @ k.T / np.sqrt(q.shape[-1]) + mask) @ v

    def mha(x, c_attn, c_proj, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
    # qkv projection
    x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
    # split into qkv
    qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
    # split into heads
    qkv_heads = list(map(lambda x: np.split(x, n_head, axis=-1), qkv)) # [3, n_seq, n_embd] -> [3, n_head, n_seq, n_embd/n_head]
    # causal mask to hide future inputs from being attended to
    causal_mask = (1 - np.tri(x.shape[0], dtype=x.dtype)) * -1e10 # [n_seq, n_seq]
    # perform attention over each head
    out_heads = [attention(q, k, v, causal_mask) for q, k, v in zip(*qkv_heads)] # [3, n_head, n_seq, n_embd/n_head] -> [n_head, n_seq, n_embd/n_head]
    # merge heads
    x = np.hstack(out_heads) # [n_head, n_seq, n_embd/n_head] -> [n_seq, n_embd]
    # out projection
    x = linear(x, **c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
    return x

References:

Welcome to my other publishing channels