Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: rename stage block types #8

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ MBLM can be used with the default Transformer Decoder or Mamba block. The below

```py
import torch

from mblm import (
MBLM,
MambaBlockConfig,
MambaBlock,
MBLMModelConfig,
MBLMReturnType,
TransformerBlockConfig,
TransformerBlock,
)

mblm = MBLM(
Expand All @@ -63,14 +64,14 @@ mblm = MBLM(
pad_token_id=256,
train_checkpoint_chunks=None,
block=[
MambaBlockConfig(
MambaBlock(
d_state=128,
d_conv=4,
expand=2,
headdim=64,
pos_emb_type=None,
),
TransformerBlockConfig(
TransformerBlock(
attn_head_dims=64,
attn_num_heads=16,
attn_use_rot_embs=True,
Expand All @@ -97,6 +98,7 @@ Alternatively, you can read configuration from a YAML string (or file):
```py
import torch
import yaml

from mblm import MBLM, MBLMModelConfig, MBLMReturnType

yml_model_config = """
Expand Down Expand Up @@ -131,17 +133,13 @@ You can define custom stage blocks for MBLM as follows. A stageblock must provid

```py
import torch
from mblm import (
MBLM,
MBLMModelConfig,
MBLMReturnType,
TransformerBlockConfig,
)
from pydantic import Field

from mblm import MBLM, MBLMModelConfig, MBLMReturnType, TransformerBlock
from mblm.model.block import StageBlock
from pydantic import BaseModel, Field

# Define any custom model
class MyLSTM(torch.nn.Module):
class LSTM(torch.nn.Module):
def __init__(self, lstm: torch.nn.LSTM):
super().__init__()
self.lstm = lstm
Expand All @@ -151,15 +149,15 @@ class MyLSTM(torch.nn.Module):
out, _ = self.lstm(input_ids)
return out

# Add a block config and inherit from StageBlock and BaseModel
class LSTMBlockConfig(StageBlock, BaseModel):
# Add a block config and inherit from StageBlock
class LSTMBlock(StageBlock):
block_type: str = Field(init=False, default="lstm")

# Add whatever is needed
dropout: float

def to_model(self, model_dim: int, num_layers: int) -> torch.nn.Module:
return MyLSTM(
return LSTM(
torch.nn.LSTM(
input_size=model_dim,
hidden_size=model_dim,
Expand All @@ -178,11 +176,11 @@ mblm = MBLM(
pad_token_id=256,
train_checkpoint_chunks=None,
block=[
LSTMBlockConfig(
LSTMBlock(
dropout=0.1,
pos_emb_type=None,
),
TransformerBlockConfig(
TransformerBlock(
attn_head_dims=64,
attn_num_heads=16,
attn_use_rot_embs=True,
Expand All @@ -202,10 +200,11 @@ If you want to parse a YAML config to a custom block, **register the block** bef
```py
import torch
import yaml
from pydantic import Field

from mblm import MBLM, MBLMModelConfig, MBLMReturnType
from mblm.model.block import StageBlock
from mblm.model.config import block_registry # Add this!
from pydantic import BaseModel, Field

# Define any custom model
class MyLSTM(torch.nn.Module):
Expand All @@ -218,8 +217,8 @@ class MyLSTM(torch.nn.Module):
out, _ = self.lstm(input_ids)
return out

# Add a block config and inherit from StageBlock and BaseModel
class LSTMBlockConfig(StageBlock, BaseModel):
# Add a block config and inherit from StageBlock
class LSTMBlockConfig(StageBlock):
block_type: str = Field(init=False, default="lstm")

# Add whatever is needed
Expand Down Expand Up @@ -269,10 +268,11 @@ If you want to use the MBLM trainer with [torchrun](https://pytorch.org/docs/sta
# Filename: train_my_mblm.py

import torch
from typing_extensions import Unpack

from mblm import MambaBlock, TransformerBlock
from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig
from mblm.data.types import BatchWithLossMask, ModelMode
from mblm.model.mamba import MambaBlockConfig
from mblm.model.transformer import TransformerBlockConfig
from mblm.train.core.config import CoreTrainConfig
from mblm.train.mblm import (
TrainEntryConfig,
Expand All @@ -281,7 +281,6 @@ from mblm.train.mblm import (
dataset_registry,
train_mblm,
)
from typing_extensions import Unpack


class MyDataset(DistributedDataset[BatchWithLossMask]):
Expand Down Expand Up @@ -372,14 +371,14 @@ config = TrainEntryConfig(
pad_token_id=256,
train_checkpoint_chunks=None,
block=[
MambaBlockConfig(
MambaBlock(
d_state=128,
d_conv=4,
expand=2,
headdim=64,
pos_emb_type=None,
),
TransformerBlockConfig(
TransformerBlock(
attn_head_dims=64,
attn_num_heads=16,
attn_use_rot_embs=True,
Expand All @@ -392,6 +391,7 @@ config = TrainEntryConfig(

if __name__ == "__main__":
train_mblm(config)

```

Then, run the above file with:
Expand Down
8 changes: 4 additions & 4 deletions src/mblm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@


from mblm.model.config import MBLMModelConfig, MBLMReturnType
from mblm.model.mamba import MambaBlockConfig
from mblm.model.mamba import MambaBlock
from mblm.model.mblm import MBLM
from mblm.model.transformer import TransformerBlockConfig
from mblm.model.transformer import TransformerBlock

__all__ = [
"MBLM",
"MBLMModelConfig",
"MBLMReturnType",
"TransformerBlockConfig",
"MambaBlockConfig",
"TransformerBlock",
"MambaBlock",
]
8 changes: 4 additions & 4 deletions src/mblm/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from pydantic import BaseModel, computed_field, field_validator, model_validator

from mblm.model.block import StageBlock, StageBlockRegistry
from mblm.model.mamba import MambaBlockConfig
from mblm.model.transformer import TransformerBlockConfig
from mblm.model.mamba import MambaBlock
from mblm.model.transformer import TransformerBlock

block_registry = StageBlockRegistry()
block_registry.register(TransformerBlockConfig)
block_registry.register(MambaBlockConfig)
block_registry.register(TransformerBlock)
block_registry.register(MambaBlock)


class MBLMReturnType(str, Enum):
Expand Down
4 changes: 2 additions & 2 deletions src/mblm/model/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
SOFTWARE."""


from pydantic import BaseModel, Field
from pydantic import Field

from mblm.model.block import StageBlock
from mblm.model.mamba_shim import Mamba1, Mamba1Config, Mamba2Mixer


class MambaBlockConfig(StageBlock, BaseModel):
class MambaBlock(StageBlock):
"""
General config for creating a Mamba block inside MBLM.
Uses roughly 3 * expand * d_model^2 parameters.
Expand Down
4 changes: 2 additions & 2 deletions src/mblm/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@

import torch
from MEGABYTE_pytorch.megabyte import Attention, FeedForward, RMSNorm, RotaryEmbedding, token_shift
from pydantic import BaseModel, Field
from pydantic import Field

from mblm.model.block import StageBlock


class TransformerBlockConfig(StageBlock, BaseModel):
class TransformerBlock(StageBlock):
"""
General config for creating a Transformer Decocer block inside MBLM.
"""
Expand Down
8 changes: 3 additions & 5 deletions tests/integration/config/test_sample_config_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@

import pytest

from mblm import MBLM, MBLMModelConfig
from mblm import MBLM, MambaBlock, MBLMModelConfig, TransformerBlock
from mblm.data.dataset.clevr import ClevrOptionalArgs
from mblm.model.mamba import MambaBlockConfig
from mblm.model.transformer import TransformerBlockConfig
from mblm.train.mblm import TrainEntryConfig
from mblm.utils.io import load_yml

Expand All @@ -31,8 +29,8 @@ def ensure_dataset_args_are_valid(self, config: TrainEntryConfig) -> None:

def ensure_model_is_created(self, config: TrainEntryConfig) -> None:
for b in config.params.stage_blocks:
assert isinstance(b, (TransformerBlockConfig, MambaBlockConfig))
if isinstance(b, TransformerBlockConfig):
assert isinstance(b, (TransformerBlock, MambaBlock))
if isinstance(b, TransformerBlock):
assert b.block_type == "transformer"
else:
# mamba1, can be mamba2 (only if tested on Linux with mamba_ssm installed)
Expand Down
29 changes: 13 additions & 16 deletions tests/integration/install/test_custom_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@

def test_from_config():
import torch
from mblm import (
MBLM,
MBLMModelConfig,
MBLMReturnType,
TransformerBlockConfig,
)
from pydantic import Field

from mblm import MBLM, MBLMModelConfig, MBLMReturnType, TransformerBlock
from mblm.model.block import StageBlock
from pydantic import BaseModel, Field

# Define any custom model
class MyLSTM(torch.nn.Module):
class LSTM(torch.nn.Module):
def __init__(self, lstm: torch.nn.LSTM):
super().__init__()
self.lstm = lstm
Expand All @@ -25,15 +21,15 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
out, _ = self.lstm(input_ids)
return out

# Add a block config and inherit from StageBlock and BaseModel
class LSTMBlockConfig(StageBlock, BaseModel):
# Add a block config and inherit from StageBlock
class LSTMBlock(StageBlock):
block_type: str = Field(init=False, default="lstm")

# Add whatever is needed
dropout: float

def to_model(self, model_dim: int, num_layers: int) -> torch.nn.Module:
return MyLSTM(
return LSTM(
torch.nn.LSTM(
input_size=model_dim,
hidden_size=model_dim,
Expand All @@ -52,11 +48,11 @@ def to_model(self, model_dim: int, num_layers: int) -> torch.nn.Module:
pad_token_id=256,
train_checkpoint_chunks=None,
block=[
LSTMBlockConfig(
LSTMBlock(
dropout=0.1,
pos_emb_type=None,
),
TransformerBlockConfig(
TransformerBlock(
attn_head_dims=64,
attn_num_heads=16,
attn_use_rot_embs=True,
Expand All @@ -74,10 +70,11 @@ def to_model(self, model_dim: int, num_layers: int) -> torch.nn.Module:
def test_from_yaml():
import torch
import yaml
from pydantic import Field

from mblm import MBLM, MBLMModelConfig, MBLMReturnType
from mblm.model.block import StageBlock
from mblm.model.config import block_registry # Add this!
from pydantic import BaseModel, Field

# Define any custom model
class MyLSTM(torch.nn.Module):
Expand All @@ -90,8 +87,8 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
out, _ = self.lstm(input_ids)
return out

# Add a block config and inherit from StageBlock and BaseModel
class LSTMBlockConfig(StageBlock, BaseModel):
# Add a block config and inherit from StageBlock
class LSTMBlockConfig(StageBlock):
block_type: str = Field(init=False, default="lstm")

# Add whatever is needed
Expand Down
10 changes: 5 additions & 5 deletions tests/integration/install/test_custom_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Filename: train_my_mblm.py

import torch
from typing_extensions import Unpack

from mblm import MambaBlock, TransformerBlock
from mblm.data.datasets import DistributedDataset, DistributedDatasetConfig
from mblm.data.types import BatchWithLossMask, ModelMode
from mblm.model.mamba import MambaBlockConfig
from mblm.model.transformer import TransformerBlockConfig
from mblm.train.core.config import CoreTrainConfig
from mblm.train.mblm import (
TrainEntryConfig,
Expand All @@ -13,7 +14,6 @@
dataset_registry,
train_mblm,
)
from typing_extensions import Unpack


class MyDataset(DistributedDataset[BatchWithLossMask]):
Expand Down Expand Up @@ -104,14 +104,14 @@ def supports_test_mode() -> bool:
pad_token_id=256,
train_checkpoint_chunks=None,
block=[
MambaBlockConfig(
MambaBlock(
d_state=128,
d_conv=4,
expand=2,
headdim=64,
pos_emb_type=None,
),
TransformerBlockConfig(
TransformerBlock(
attn_head_dims=64,
attn_num_heads=16,
attn_use_rot_embs=True,
Expand Down
Loading
Loading