Skip to content

Commit

Permalink
Fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 15, 2025
2 parents a73acf6 + 4496a40 commit bb1b87f
Show file tree
Hide file tree
Showing 80 changed files with 1,102 additions and 928 deletions.
30 changes: 15 additions & 15 deletions docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:

```python
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.base_model_config.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i+1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
```

And that's it! We're ready to use the new checkpoint format in Fast-LLM.
Expand Down
7 changes: 3 additions & 4 deletions examples/mistral.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ training:
iterations: null
test_iters: 0
batch:
sequence_length: 8192
micro_batch_size: 1
batch_size: 32
sequence_length: 4096
micro_batch_size: 2
batch_size: 64
data:
datasets:
Training:
Expand Down Expand Up @@ -50,7 +50,6 @@ model:
zero_stage: 2
distributed:
training_dtype: bf16
distributed_timeout: 3600
seed: 984059
run:
experiment_dir: mistral_example
41 changes: 29 additions & 12 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,14 +666,12 @@ def _get_class_name(cls) -> str:
return get_type_name(cls)

@classmethod
def from_dict[
T
](
cls: type[T],
def from_dict(
cls,
default: typing.Union["Config", dict[str, typing.Any]],
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
) -> T:
) -> typing.Self:
if isinstance(default, Config):
default = default._to_dict()
for update in updates:
Expand All @@ -685,16 +683,21 @@ def from_dict[
return cls._from_dict(default, strict)

@classmethod
def from_flat_dict[
T
](cls: type[T], default: dict[str, typing.Any], strict: bool = True,) -> T:
def from_flat_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
) -> typing.Self:
# TODO v0.3: Remove flat format
return cls._from_dict(default, strict, True)

@classmethod
def _from_dict[
T
](cls: type[T], default: dict[str, typing.Any], strict: bool = True, flat: bool = False,) -> T:
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
# TODO v0.3: Remove flat format
out_arg_dict = {}

Expand Down Expand Up @@ -843,7 +846,7 @@ def compare(self, other: "Config", log_fn: typing.Union[type[BaseException], typ
)

@classmethod
def _check_abstract(cls):
def _check_abstract(cls) -> None:
if cls._abstract:
raise ValidationError(f"{cls.__name__} is abstract")
if not cls.__class_validated__:
Expand Down Expand Up @@ -899,3 +902,17 @@ def __init_subclass__(cls):
else:
# dataclasses expects an annotation, so we use the one from the base class.
cls.__annotations__[name] = base_class_field.type


class Configurable[ConfigType: Config]:
config_class: typing.ClassVar[type[Config]] = Config

def __init__(self, config: ConfigType, *args, **kwargs):
Assert.custom(isinstance, config, self.config_class)
self._config = config
# Handle multiple inheritance.
super().__init__(*args, **kwargs)

@property
def config(self) -> ConfigType:
return self._config
22 changes: 12 additions & 10 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@

import contextlib
import logging
import typing

import torch
from torch._C._distributed_c10d import Work
from torch.distributed import ( # noqa
ProcessGroup,
ReduceOp,
Expand All @@ -23,7 +25,7 @@
logger = logging.getLogger(__name__)


def broadcast(tensor, src, group, async_op=False):
def broadcast(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False) -> Work | None:
"""Same as torch.distributed.broadcast, but without the complication of going through the global rank."""
assert group is not None
opts = torch.distributed.BroadcastOptions()
Expand All @@ -36,7 +38,7 @@ def broadcast(tensor, src, group, async_op=False):
work.wait()


def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str):
def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: str) -> None:
# A utility function to check for tensor-parallel (or other) mismatches.
all_tensors = tensor.new_empty((group.size(),) + tensor.shape)
all_gather_into_tensor(all_tensors, tensor, group)
Expand All @@ -51,7 +53,7 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(group: ProcessGroup | None, value: int | str = 1):
def safe_barrier(group: ProcessGroup | None, value: int | str = 1) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group)
Expand All @@ -60,11 +62,11 @@ def safe_barrier(group: ProcessGroup | None, value: int | str = 1):


def allreduce_scalar(
value,
value: float | int,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
):
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
torch.distributed.all_reduce(value, op=op, group=group)
Expand All @@ -74,11 +76,11 @@ def allreduce_scalar(


def broadcast_scalar(
value,
value: float | int,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
):
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
Expand All @@ -88,7 +90,7 @@ def broadcast_scalar(
return tensor.item()


def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0):
def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
work = group.send([tensor], dst, tag)
if async_op:
Expand All @@ -97,7 +99,7 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta
work.wait()


def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0):
def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
assert group is not None
work = group.recv([tensor], src, tag)
if async_op:
Expand All @@ -107,7 +109,7 @@ def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, ta


@contextlib.contextmanager
def set_generator(generator: torch.Generator):
def set_generator(generator: torch.Generator) -> typing.Generator[None, None, None]:
"""Use the generator as default, for ops that don't support a generator argument."""
default_generator: torch.Generator = torch.cuda.default_generators[torch.cuda.current_device()]
assert generator is not default_generator
Expand Down
8 changes: 4 additions & 4 deletions fast_llm/core/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_flash_available = False


def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor):
def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor) -> torch.Tensor:
assert _apex_available
norm, _ = _multi_tensor_applier(
_multi_tensor_l2norm,
Expand All @@ -37,7 +37,7 @@ def l2_norm(tensors: list[torch.Tensor], noop_flag: torch.Tensor):
return norm


def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float):
def scale_(tensors: list[torch.Tensor], noop_flag: torch.Tensor, scale: torch.Tensor | float) -> None:
assert _apex_available
_multi_tensor_applier(
_multi_tensor_scale,
Expand All @@ -60,7 +60,7 @@ def fused_adam(
wd: float,
eps: float,
step: int,
):
) -> None:
_multi_tensor_applier(
_multi_tensor_adam,
noop_flag,
Expand All @@ -86,7 +86,7 @@ def flash_attn(
causal: bool = False,
generator: torch.Generator | None,
softmax_scale: float | None = None,
):
) -> torch.Tensor:
assert _flash_available
with set_generator(generator):
return _flash_attn_func(
Expand Down
Loading

0 comments on commit bb1b87f

Please sign in to comment.