Ruixiang's blog

work harder, study better, do faster, become stronger

0%

MoE (Mix-of-Expert) Model

This image shows the basic structure of MoE. (source)
moe

Recently Mixtral 8x7B MoE model dominate the open source models, as it shows on-par/better performance compared to open sources LLMs with more parameters.
The MoE models have following features:

  • Achieve the same quality as its dense models and much faster during pretraining
  • Have faster inference compared to a model with the same number of parameters
  • Require high VRAM as all experts are loaded in memory

What exactly is a MoE? In the context of transformer models, a MoE consists of two main elements:

  • Sparse MoE layers are used instead of dense feed-forward network (FFN) layers. MoE layers have a certain number of “experts”, where each expert is a neural network
  • A gate network or router, that determines which tokens are sent to which “expert”

Simple Pytorch Implementation of MoE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
import torch.nn as nn
import torch.optim as optim

# single expert: FNN model
class ExpertModel(nn.Module):
def __init__(self, input_dim):
super(ExpertModel, self)__init__()
self.fc = nn.Linear(input_dim, 1)

def forward(self, x):
return self.fc(x)


class MixtureOfExperts(nn.Module):
def __init__(self, input_dim, num_experts):
super(MixtureOfExperts, self).__init__()
self.experts = nn.ModuleList(
[ExpertModel(input_dim) for _ in range(num_experts)]
)
self.gating_network = nn.Sequential(
nn.Linear(input_dim, num_experts),
nn.Softmax(dim=1)
)

def forward(self, x):
gating_weights = self.gating_network(x)
expert_outputs = [expert(x) for expert in self.experts]
expert_outputs = torch.stack(expert_outputs, dim=1)
final_output = torch.sum(expert_outputs * gating_weights.unsqueeze(2), dim=1)
return final_output

Mixtral 8x7B MoE

In Mixtral 8x7B MoE every FFN layer of the transformer model is replaced by an MoE layer, which is composed of a gate network and 8 experts. The gate network is a learned gating network decides which experts to send a part of the input. During training and inference, only 2 experts will be selected per token by the learned gating network.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch import nn
from dataclasses import dataclass
import torch
from typing import List, Optional, Tuple

@dataclass
class MoeArgs(Serializable):
num_experts: int
num_experts_per_tok: int

class MoeLayer(nn.Module):
def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
super().__init__()
assert len(experts) > 0
self.experts = nn.ModuleList(experts) # a list contains num_experts FFNs
self.gate = gate # a linear layer: nn.Linear(dim, num_experts, bias=False)
self.args = moe_args

def forward(self, inputs: torch.Tensor):
gate_logits = self.gate(inputs)
weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)
weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
results = torch.zeros_like(inputs)
for i, expert in enumerate(self.experts):
batch_idx, nth_expert = torch.where(selected_experts == i)
results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
inputs[batch_idx]
)
return results

Different model architectures between Llama model and Mixtral 8x7b MoE model (source)
Mixtral

Resources

Welcome to my other publishing channels