Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
jlamypoirier committed Jan 13, 2025
1 parent 09640d8 commit a73acf6
Showing 1 changed file with 31 additions and 34 deletions.
65 changes: 31 additions & 34 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __post_init__(self):
if _AUTO_VALIDATE:
self.validate()

def __setattr__(self, key, value):
def __setattr__(self, key: str, value: typing.Any) -> None:
"""
Make the class read-only after validation.
"""
Expand All @@ -307,7 +307,7 @@ def __setattr__(self, key, value):
)
super().__setattr__(key, value)

def __delattr__(self, key):
def __delattr__(self, key: str) -> None:
"""
Make the class read-only after validation.
"""
Expand All @@ -318,7 +318,7 @@ def __delattr__(self, key):
)
super().__delattr__(key)

def validate(self, *, _is_validating=False):
def validate[T](self: T, *, _is_validating: bool = False) -> T:
"""
Validate a class and mark it as read-only
This should not be overridden in derived classes.
Expand All @@ -334,7 +334,7 @@ def validate(self, *, _is_validating=False):
self._validated = True
return self

def _validate(self):
def _validate(self) -> None:
"""
Verify that the type hints are respected,
and fix some know entries compatible with the type hint (ex. `int -> float`, `str -> pathlib.Path`)
Expand Down Expand Up @@ -522,7 +522,7 @@ def fields(cls) -> typing.Iterable[tuple[str, Field]]:
return cls.__dataclass_fields__.items() # noqa

@classmethod
def get_field(cls, name) -> Field:
def get_field(cls, name: str) -> Field:
return cls.__dataclass_fields__[name] # noqa

def _to_dict(
Expand All @@ -531,7 +531,7 @@ def _to_dict(
all_fields: bool = False,
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
serializable: bool = False,
):
) -> dict[str, typing.Any]:
"""
Serialize the config to a dict that can (generally) be used to reconstruct an identical `Config`.
When not flat, the dict includes a `__class__` entry which allows support for derived classes.
Expand Down Expand Up @@ -561,12 +561,12 @@ def _add_field_to_args(
args: dict | list,
name: str | None,
field: Field | None,
value,
value: typing.Any,
verbose: int | None = None,
all_fields: bool = False,
format_: _ConfigDictFormat = _ConfigDictFormat.nested,
serializable: bool = False,
):
) -> None:
if (
field is not None
and (not field.init or field._field_type == dataclasses._FIELD_CLASSVAR)
Expand Down Expand Up @@ -622,7 +622,7 @@ def _add_field_to_args(
raise NotImplementedError(format_)

@classmethod
def _serialize_value(cls, value):
def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None:
value = value
if hasattr(value, "__fast_llm_serialize__"):
value = value.__fast_llm_serialize__()
Expand All @@ -634,24 +634,24 @@ def _serialize_value(cls, value):
value = str(value)
return value

def to_copy(
self,
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
def to_copy[
T
](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T:
return self.from_dict(self, *updates, strict=strict)

def to_serialized(self, verbose: int | None = FieldVerboseLevel.core):
def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]:
return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True)

def to_logs(
def to_logs[
T
](
self,
verbose: int | None = FieldVerboseLevel.core,
log_fn=logger.info,
log_fn: typing.Callable[[str], T] = logger.info,
title: str | None = None,
width: int = 80,
fill_char: str = "-",
):
) -> T:
arg_dict = self.to_serialized(verbose=verbose)
if title is None:
title = self._get_class_name()
Expand All @@ -662,16 +662,18 @@ def to_logs(
)

@classmethod
def _get_class_name(cls):
def _get_class_name(cls) -> str:
return get_type_name(cls)

@classmethod
def from_dict(
cls,
def from_dict[
T
](
cls: type[T],
default: typing.Union["Config", dict[str, typing.Any]],
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
):
) -> T:
if isinstance(default, Config):
default = default._to_dict()
for update in updates:
Expand All @@ -683,21 +685,16 @@ def from_dict(
return cls._from_dict(default, strict)

@classmethod
def from_flat_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
):
def from_flat_dict[
T
](cls: type[T], default: dict[str, typing.Any], strict: bool = True,) -> T:
# TODO v0.3: Remove flat format
return cls._from_dict(default, strict, True)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
):
def _from_dict[
T
](cls: type[T], default: dict[str, typing.Any], strict: bool = True, flat: bool = False,) -> T:
# TODO v0.3: Remove flat format
out_arg_dict = {}

Expand Down Expand Up @@ -814,7 +811,7 @@ def _handle_renamed_field(
old_name: str | tuple[str, ...],
new_name: str | tuple[str, ...],
fn: typing.Callable | None = None,
):
) -> None:
if old_name in default:
warnings.warn(f"Field `{old_name}` is deprecated in class {get_type_name(cls)}, use `{new_name}` instead.")
value = pop_nested_dict_value(default, old_name)
Expand Down

0 comments on commit a73acf6

Please sign in to comment.