Skip to content

Commit

Permalink
Merge pull request karpathy#205 from djlisbonne/add_moe
Browse files Browse the repository at this point in the history
Add Mixture of Experts (MoE) support
  • Loading branch information
gkielian authored Aug 7, 2024
2 parents 0fea89b + 9383bec commit f336f90
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ csv_logs/
# checkpoint directories
out*/
.aider*

venv/*
6 changes: 6 additions & 0 deletions gpt_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class GPTConfig:
dropout: float = 0.0
window_size: int = 128
gate: bool = False
use_moe: bool = False
moe_layer_freq: int = 2
n_experts: int = 8
moe_top_k: int = 2
moe_router_scheme: str = "softmax"


# Training options
## Gradient Checkpointing - More memory efficient (can do long contexts), but is slower
Expand Down
62 changes: 61 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
from variations.position_encoding_variations import QuantizedEmbedding, RotaryEmbedding, ShortRope, SymmetricalOverlapAngularPositions, FIRE
from variations.activation_variations import activation_dictionary
from variations.linear_variations import linear_dictionary
from variations.router_variations import router_dictionary

def create_shared_param_group(layer_type, config):

# explore MoE layers being reflected symmetrically

shared_size = None
shared_sym = None # if true, output array is symmetrical
layer_block = None
Expand All @@ -54,7 +58,11 @@ def create_shared_param_group(layer_type, config):
# Create new layer block every "shared_size"
if i % shared_size == 0:
if layer_type == "mlp":
layer_block = MLP(config)
if config.use_moe and i % config.moe_layer_freq == 0:
# this iter is an moe layer iter
layer_block = MoELayer(config)
else:
layer_block = MLP(config)
elif layer_type == "attn":
layer_block = CausalSelfAttention(config, fire_pos_enc=fire_pos_enc)
else:
Expand Down Expand Up @@ -648,3 +656,55 @@ def generate_with_stop(self, idx, max_new_tokens, stop_string, decode, temperatu
break

return idx, generated_text


class MoELayer(nn.Module):
""" Mixture of Experts layer to replace FFN (or every other FFN) """

def __init__(self, config):
super().__init__()
self.top_k = config.moe_top_k
# TODO: implement expert capacity throttling
# self.expert_capacity = config.expert_capacity
self.num_experts = config.n_experts
self.router = router_dictionary[config.moe_router_scheme](config)
self.experts = nn.ModuleList([MLP(config) for _ in range(config.n_experts)])

def forward(self, x):
# Assuming x has shape [batch_size, seq_len, n_embd]
batch_size, seq_len, _ = x.shape
gating_output, indices = self.router(x)
# print(f"gating_output.shape: {gating_output.shape}")
# print(f"indices 1 count: {indices}")
final_output = torch.zeros_like(x)

# Flatten the batch and sequence dimensions to treat each token independently
flat_x = x.view(-1, x.size(-1))
# print(f"x.shape() = {x.shape}")
# print(f"flat_x = {flat_x.shape}")
flat_gating_output = gating_output.view(-1, gating_output.size(-1))
# print(f"flat_gating_output.shape = {flat_gating_output.shape}")

# Process each expert in parallel
for i, expert in enumerate(self.experts):
# Create a mask for the inputs where the current expert is in top-k
expert_mask = (indices == i).any(dim=-1)
flat_mask = expert_mask.view(-1)
# print(f"expert_mask shape = {expert_mask.shape}")
# print(f"flat_mask shape = {flat_mask.shape}")

if flat_mask.any():
expert_input = flat_x[flat_mask]
expert_output = expert(expert_input)

# Extract and apply gating scores
gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
weighted_output = expert_output * gating_scores

# Update final output additively by indexing and adding
final_output[expert_mask] += weighted_output.squeeze(1)
# print(f"final_output.shape = {final_output.shape}\n")
return final_output



5 changes: 5 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def parse_args():
model_group.add_argument('--use_post_ln', default=False, action=argparse.BooleanOptionalAction)
model_group.add_argument('--window_size', default=None, type=int, help="Sliding window size, note this cannot be greater than block size")
model_group.add_argument('--gate', default=False, action=argparse.BooleanOptionalAction, help="option for gated attention see https://arxiv.org/abs/2306.12929")
model_group.add_argument('--use_moe', default=False, action=argparse.BooleanOptionalAction, help="option for Mixture of Experts (MoE) architecture")
model_group.add_argument('--moe_layer_freq', default=2, type=int, help="set frequency for replacing FFNs with MoE layers")
model_group.add_argument('--n_experts', default=8, type=int, help="set number of experts per MoE layer")
model_group.add_argument('--moe_top_k', default=2, type=int)
model_group.add_argument('--moe_router_scheme', default="softmax", type=str, help="option to set routing scheme for MoE layer, defaults to softmax")

## MLP Options
model_group.add_argument('--use_parallel_mlp', default=False, action=argparse.BooleanOptionalAction)
Expand Down
54 changes: 54 additions & 0 deletions variations/router_variations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import math
import torch
import torch.nn as nn
from torch.nn import functional as F

class TopKRouter(nn.Module):
""" Conventional Softmax Top_k Gating network (router) NN for MoE layers """
def __init__(self, config):
super().__init__()
self.top_k = config.moe_top_k
self.moe_router_scheme = config.moe_router_scheme
self.route_linear = nn.Linear(config.n_embd, config.n_experts)

def forward(self, x):
logits = self.route_linear(x)

top_k_logits, indices = logits.topk(self.top_k, dim=-1)
zeros = torch.full_like(logits, float('-inf'))

sparse_logits = zeros.scatter(-1, indices, top_k_logits)
router_output= F.softmax(sparse_logits, dim=-1)

return router_output, indices


class NoisyTopKRouter(nn.Module):
""" Noisy Top_k Gating network (router) NN for MoE layers """
def __init__(self, config):
super().__init__()
self.top_k = config.moe_top_k
self.moe_router_scheme = config.moe_router_scheme
self.route_linear = nn.Linear(config.n_embd, config.n_experts)
self.noise_linear = nn.Linear(config.n_embd, config.n_experts)

def forward(self, x):
logits = self.route_linear(x)

noise_logits = self.noise_linear(x)
noise = torch.randn_like(logits)*F.softplus(noise_logits)

top_k_noisy_logits = noise_logits + noise
top_k_logits, indices = logits.topk(self.top_k, dim=1)

zeros = torch.full_like(top_k_noisy_logits, float('-inf'))
sparse_logits = zeros.scatter(-1, indices, top_k_logits)

router_output = F.softmax(sparse_logits, dim=-1)

return router_output, indices

router_dictionary = {
"softmax": TopKRouter,
"noisy_top_k": NoisyTopKRouter,
}

0 comments on commit f336f90

Please sign in to comment.