diff --git a/traitlets/config/application.py b/traitlets/config/application.py index 9786e224..fb185f0a 100644 --- a/traitlets/config/application.py +++ b/traitlets/config/application.py @@ -446,7 +446,7 @@ def _show_config_changed(self, change): self._save_start = self.start self.start = self.start_show_config # type:ignore[method-assign] - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: SingletonConfigurable.__init__(self, **kwargs) # Ensure my class is in self.classes, so my attributes appear in command line # options and config files. diff --git a/traitlets/config/configurable.py b/traitlets/config/configurable.py index 1bfa0457..f448e696 100644 --- a/traitlets/config/configurable.py +++ b/traitlets/config/configurable.py @@ -5,6 +5,7 @@ import logging +import typing as t from copy import deepcopy from textwrap import dedent @@ -46,7 +47,7 @@ class Configurable(HasTraits): config = Instance(Config, (), {}) parent = Instance("traitlets.config.configurable.Configurable", allow_none=True) - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: """Create a configurable given a config config. Parameters diff --git a/traitlets/config/loader.py b/traitlets/config/loader.py index 1c6b8e8c..34d62e5a 100644 --- a/traitlets/config/loader.py +++ b/traitlets/config/loader.py @@ -2,6 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations import argparse import copy @@ -236,7 +237,7 @@ class Config(dict): # type:ignore[type-arg] """ - def __init__(self, *args, **kwds): + def __init__(self, *args: t.Any, **kwds: t.Any) -> None: dict.__init__(self, *args, **kwds) self._ensure_subconfig() @@ -273,7 +274,7 @@ def merge(self, other): self.update(to_update) - def collisions(self, other: "Config") -> t.Dict[str, t.Any]: + def collisions(self, other: Config) -> dict[str, t.Any]: """Check for collisions between two config objects. Returns a dict of the form {"Class": {"trait": "collision message"}}`, @@ -281,7 +282,7 @@ def collisions(self, other: "Config") -> t.Dict[str, t.Any]: An empty dict indicates no collisions. """ - collisions: t.Dict[str, t.Any] = {} + collisions: dict[str, t.Any] = {} for section in self: if section not in other: continue @@ -490,7 +491,7 @@ def _log_default(self): return get_logger() - def __init__(self, log=None): + def __init__(self, log: t.Any = None) -> None: """A base class for config loaders. log : instance of :class:`logging.Logger` to use. @@ -532,7 +533,7 @@ class FileConfigLoader(ConfigLoader): here. """ - def __init__(self, filename, path=None, **kw): + def __init__(self, filename: str, path: str | None = None, **kw: t.Any) -> None: """Build a config loader for a filename and path. Parameters @@ -795,12 +796,12 @@ class ArgParseConfigLoader(CommandLineConfigLoader): def __init__( self, - argv: t.Optional[t.List[str]] = None, - aliases: t.Optional[t.Dict[Flags, str]] = None, - flags: t.Optional[t.Dict[Flags, str]] = None, + argv: list[str] | None = None, + aliases: dict[Flags, str] | None = None, + flags: dict[Flags, str] | None = None, log: t.Any = None, - classes: t.Optional[t.List[t.Type[t.Any]]] = None, - subcommands: t.Optional[SubcommandsDict] = None, + classes: list[type[t.Any]] | None = None, + subcommands: SubcommandsDict | None = None, *parser_args: t.Any, **parser_kw: t.Any, ) -> None: @@ -899,9 +900,7 @@ def _create_parser(self): def _add_arguments(self, aliases, flags, classes): raise NotImplementedError("subclasses must implement _add_arguments") - def _argcomplete( - self, classes: t.List[t.Any], subcommands: t.Optional[SubcommandsDict] - ) -> None: + def _argcomplete(self, classes: list[t.Any], subcommands: SubcommandsDict | None) -> None: """If argcomplete is enabled, allow triggering command-line autocompletion""" pass @@ -909,7 +908,7 @@ def _parse_args(self, args): """self.parser->self.parsed_data""" uargs = [cast_unicode(a) for a in args] - unpacked_aliases: t.Dict[str, str] = {} + unpacked_aliases: dict[str, str] = {} if self.aliases: unpacked_aliases = {} for alias, alias_target in self.aliases.items(): @@ -957,7 +956,7 @@ def _convert_to_config(self): class _FlagAction(argparse.Action): """ArgParse action to handle a flag""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: self.flag = kwargs.pop("flag") self.alias = kwargs.pop("alias", None) kwargs["const"] = Undefined @@ -983,8 +982,8 @@ class KVArgParseConfigLoader(ArgParseConfigLoader): parser_class = _KVArgParser # type:ignore[assignment] def _add_arguments(self, aliases, flags, classes): - alias_flags: t.Dict[str, t.Any] = {} - argparse_kwds: t.Dict[str, t.Any] + alias_flags: dict[str, t.Any] = {} + argparse_kwds: dict[str, t.Any] paa = self.parser.add_argument self.parser.set_defaults(_flags=[]) paa("extra_args", nargs="*") @@ -1108,9 +1107,7 @@ def _handle_unrecognized_alias(self, arg: str) -> None: """ self.log.warning("Unrecognized alias: '%s', it will have no effect.", arg) - def _argcomplete( - self, classes: t.List[t.Any], subcommands: t.Optional[SubcommandsDict] - ) -> None: + def _argcomplete(self, classes: list[t.Any], subcommands: SubcommandsDict | None) -> None: """If argcomplete is enabled, allow triggering command-line autocompletion""" try: import argcomplete # noqa @@ -1132,7 +1129,7 @@ class KeyValueConfigLoader(KVArgParseConfigLoader): Use KVArgParseConfigLoader """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: warnings.warn( "KeyValueConfigLoader is deprecated since Traitlets 5.0." " Use KVArgParseConfigLoader instead.", diff --git a/traitlets/log.py b/traitlets/log.py index 016529fc..468c7c3c 100644 --- a/traitlets/log.py +++ b/traitlets/log.py @@ -2,13 +2,14 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations import logging -_logger = None +_logger: logging.Logger | None = None -def get_logger(): +def get_logger() -> logging.Logger: """Grab the global logger instance. If a global Application is instantiated, grab its logger. diff --git a/traitlets/tests/test_typing.py b/traitlets/tests/test_typing.py index eb9df28d..92e5bd24 100644 --- a/traitlets/tests/test_typing.py +++ b/traitlets/tests/test_typing.py @@ -4,7 +4,25 @@ import pytest -from traitlets import Bool, CInt, HasTraits, Instance, Int, TCPAddress +from traitlets import ( + Any, + Bool, + CInt, + Dict, + HasTraits, + Instance, + Int, + List, + Set, + TCPAddress, + Type, + Unicode, + Union, + default, + observe, + validate, +) +from traitlets.config import Config if not typing.TYPE_CHECKING: @@ -12,11 +30,194 @@ def reveal_type(*args, **kwargs): pass +# mypy: disallow-untyped-calls + + class Foo: def __init__(self, c): self.c = c +@pytest.mark.mypy_testing +def mypy_decorator_typing(): + class T(HasTraits): + foo = Unicode("").tag(config=True) + + @default("foo") + def _default_foo(self) -> str: + return "hi" + + @observe("foo") + def _foo_observer(self, change: typing.Any) -> bool: + return True + + @validate("foo") + def _foo_validate(self, commit: typing.Any) -> bool: + return True + + t = T() + reveal_type(t.foo) # R: builtins.str + reveal_type(t._foo_observer) # R: Any + reveal_type(t._foo_validate) # R: Any + + +@pytest.mark.mypy_testing +def mypy_config_typing(): + c = Config( + { + "ExtractOutputPreprocessor": {"enabled": True}, + } + ) + reveal_type(c) # R: traitlets.config.loader.Config + + +@pytest.mark.mypy_testing +def mypy_union_typing(): + class T(HasTraits): + style = Union( + [Unicode("default"), Type(klass=object)], + help="Name of the pygments style to use", + default_value="hi", + ).tag(config=True) + + t = T() + reveal_type(Union("foo")) # R: traitlets.traitlets.Union + reveal_type(Union("").tag(sync=True)) # R: traitlets.traitlets.Union + reveal_type(Union(None, allow_none=True)) # R: traitlets.traitlets.Union + reveal_type(Union(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Union + reveal_type(T.style) # R: traitlets.traitlets.Union + reveal_type(t.style) # R: Any + + +@pytest.mark.mypy_testing +def mypy_list_typing(): + class T(HasTraits): + latex_command = List( + ["xelatex", "{filename}", "-quiet"], help="Shell command used to compile latex." + ).tag(config=True) + + t = T() + reveal_type(List("foo")) # R: traitlets.traitlets.List + reveal_type(List("").tag(sync=True)) # R: traitlets.traitlets.List + reveal_type(List(None, allow_none=True)) # R: traitlets.traitlets.List + reveal_type(List(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.List + reveal_type(T.latex_command) # R: traitlets.traitlets.List + reveal_type(t.latex_command) # R: builtins.list[Any] + + +@pytest.mark.mypy_testing +def mypy_dict_typing(): + class T(HasTraits): + foo = Dict({}, help="Shell command used to compile latex.").tag(config=True) + + t = T() + reveal_type(Dict("foo")) # R: traitlets.traitlets.Dict + reveal_type(Dict("").tag(sync=True)) # R: traitlets.traitlets.Dict + reveal_type(Dict(None, allow_none=True)) # R: traitlets.traitlets.Dict + reveal_type(Dict(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Dict + reveal_type(T.foo) # R: traitlets.traitlets.Dict + reveal_type(t.foo) # R: builtins.dict[Any, Any] + + +@pytest.mark.mypy_testing +def mypy_unicode_typing(): + class T(HasTraits): + export_format = Unicode( + allow_none=False, + help="""The export format to be used, either one of the built-in formats + or a dotted object name that represents the import path for an + ``Exporter`` class""", + ).tag(config=True) + + t = T() + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + "foo" + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + "" + ).tag( + sync=True + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]] + None, allow_none=True + ) + ) + reveal_type( + Unicode( # R: traitlets.traitlets.Unicode[Union[builtins.str, None], Union[builtins.str, builtins.bytes, None]] + None, allow_none=True + ).tag( + sync=True + ) + ) + reveal_type( + T.export_format # R: traitlets.traitlets.Unicode[builtins.str, Union[builtins.str, builtins.bytes]] + ) + reveal_type(t.export_format) # R: builtins.str + + +@pytest.mark.mypy_testing +def mypy_set_typing(): + class T(HasTraits): + remove_cell_tags = Set( + Unicode(), + default_value=[], + help=( + "Tags indicating which cells are to be removed," + "matches tags in ``cell.metadata.tags``." + ), + ).tag(config=True) + + safe_output_keys = Set( + config=True, + default_value={ + "metadata", # Not a mimetype per-se, but expected and safe. + "text/plain", + "text/latex", + "application/json", + "image/png", + "image/jpeg", + }, + help="Cell output mimetypes to render without modification", + ) + + t = T() + reveal_type(Set("foo")) # R: traitlets.traitlets.Set + reveal_type(Set("").tag(sync=True)) # R: traitlets.traitlets.Set + reveal_type(Set(None, allow_none=True)) # R: traitlets.traitlets.Set + reveal_type(Set(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Set + reveal_type(T.remove_cell_tags) # R: traitlets.traitlets.Set + reveal_type(t.remove_cell_tags) # R: builtins.set[Any] + reveal_type(T.safe_output_keys) # R: traitlets.traitlets.Set + reveal_type(t.safe_output_keys) # R: builtins.set[Any] + + +@pytest.mark.mypy_testing +def mypy_any_typing(): + class T(HasTraits): + attributes = Any( + config=True, + default_value={ + "a": ["href", "title"], + "abbr": ["title"], + "acronym": ["title"], + }, + help="Allowed HTML tag attributes", + ) + + t = T() + reveal_type(Any("foo")) # R: traitlets.traitlets.Any + reveal_type(Any("").tag(sync=True)) # R: traitlets.traitlets.Any + reveal_type(Any(None, allow_none=True)) # R: traitlets.traitlets.Any + reveal_type(Any(None, allow_none=True).tag(sync=True)) # R: traitlets.traitlets.Any + reveal_type(T.attributes) # R: traitlets.traitlets.Any + reveal_type(t.attributes) # R: Any + + @pytest.mark.mypy_testing def mypy_bool_typing(): class T(HasTraits): diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 8e25f5bd..036f51aa 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -287,7 +287,7 @@ class link: updating = False - def __init__(self, source, target, transform=None): + def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None: _validate_link(source, target) self.source, self.target = source, target self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2 @@ -372,7 +372,7 @@ class directional_link: updating = False - def __init__(self, source, target, transform=None): + def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None: self._transform = transform if transform else lambda x: x _validate_link(source, target) self.source, self.target = source, target @@ -504,7 +504,7 @@ def __init__( help: str | None = None, config: t.Any = None, **kwargs: t.Any, - ): + ) -> None: """Declare a traitlet. If *allow_none* is True, None is a valid value in addition to any @@ -919,7 +919,7 @@ class _CallbackWrapper: callbacks. """ - def __init__(self, cb): + def __init__(self, cb: t.Any) -> None: self.cb = cb # Bound methods have an additional 'self' argument. offset = -1 if isinstance(self.cb, types.MethodType) else 0 @@ -980,7 +980,7 @@ def __new__(mcls, name, bases, classdict): # noqa return super().__new__(mcls, name, bases, classdict) - def __init__(cls, name, bases, classdict): + def __init__(cls, name: str, bases: t.Any, classdict: t.Any) -> None: """Finish initializing the HasDescriptors class.""" super().__init__(name, bases, classdict) cls.setup_class(classdict) @@ -1233,12 +1233,23 @@ def some_other_default(self): # This default generator should not be return DefaultHandler(name) +FuncT = t.TypeVar("FuncT", bound=t.Callable[..., t.Any]) + + class EventHandler(BaseDescriptor): - def _init_call(self, func): + def _init_call(self, func: FuncT) -> EventHandler: self.func = func return self - def __call__(self, *args, **kwargs): + @t.overload + def __call__(self, func: FuncT, *args: t.Any, **kwargs: t.Any) -> FuncT: + ... + + @t.overload + def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: + ... + + def __call__(self, *args: t.Any, **kwargs: t.Any) -> t.Any: """Pass `*args` and `**kwargs` to the handler's function if it exists.""" if hasattr(self, "func"): return self.func(*args, **kwargs) @@ -1252,7 +1263,7 @@ def __get__(self, inst, cls=None): class ObserveHandler(EventHandler): - def __init__(self, names, type): + def __init__(self, names: t.Any, type: t.Any) -> None: self.trait_names = names self.type = type @@ -1261,7 +1272,7 @@ def instance_init(self, inst): class ValidateHandler(EventHandler): - def __init__(self, names): + def __init__(self, names: t.Any) -> None: self.trait_names = names def instance_init(self, inst): @@ -1269,7 +1280,7 @@ def instance_init(self, inst): class DefaultHandler(EventHandler): - def __init__(self, name): + def __init__(self, name: str) -> None: self.trait_name = name def class_init(self, cls, name): @@ -1337,7 +1348,7 @@ def setup_instance(*args, **kwargs): self._cross_validation_lock = False super(HasTraits, self).setup_instance(*args, **kwargs) - def __init__(self, *args, **kwargs): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: # Allow trait values to be set using keyword arguments. # We need to use setattr for this to trigger validation and # notifications. @@ -2024,7 +2035,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2037,7 +2048,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2050,7 +2061,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2063,10 +2074,19 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, default_value=Undefined, klass=None, allow_none=False, **kwargs): + def __init__( + self, + default_value: t.Any = Undefined, + klass: t.Any = None, + allow_none: bool = False, + read_only: bool | None = None, + help: str | None = None, + config: t.Any | None = None, + **kwargs: t.Any, + ) -> None: """Construct a Type trait A Type trait specifies that its values must be subclasses of @@ -2108,7 +2128,14 @@ def __init__(self, default_value=Undefined, klass=None, allow_none=False, **kwar self.klass = klass - super().__init__(new_default_value, allow_none=allow_none, **kwargs) + super().__init__( + new_default_value, + allow_none=allow_none, + read_only=read_only, + help=help, + config=config, + **kwargs, + ) def validate(self, obj, value): """Validates that the value is a valid object instance.""" @@ -2277,7 +2304,7 @@ class or its subclasses. Our implementation is quite different self.default_args = args self.default_kwargs = kw - super().__init__(allow_none=allow_none, **kwargs) + super().__init__(allow_none=allow_none, read_only=read_only, help=help, **kwargs) def validate(self, obj, value): assert self.klass is not None @@ -2359,7 +2386,7 @@ class This(ClassBasedTraitType[t.Optional[T], t.Optional[T]]): info_text = "an instance of the same type as the receiver or None" - def __init__(self, **kwargs): + def __init__(self, **kwargs: t.Any) -> None: super().__init__(None, **kwargs) def validate(self, obj, value): @@ -2376,7 +2403,7 @@ def validate(self, obj, value): class Union(TraitType[t.Any, t.Any]): """A trait type representing a Union type.""" - def __init__(self, trait_types, **kwargs): + def __init__(self, trait_types: t.Any, **kwargs: t.Any) -> None: """Construct a Union trait. This trait allows values that are allowed by at least one of the @@ -2470,52 +2497,52 @@ class Any(TraitType[t.Optional[t.Any], t.Optional[t.Any]]): @t.overload def __init__( self: Any, - default_value: str = ..., + default_value: t.Any = ..., *, allow_none: Literal[False], read_only: bool | None = ..., help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload def __init__( self: Any, - default_value: str = ..., + default_value: t.Any = ..., *, allow_none: Literal[True], read_only: bool | None = ..., help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload def __init__( self: Any, - default_value: str = ..., + default_value: t.Any = ..., *, allow_none: Literal[True, False] = ..., help: str | None = ..., read_only: bool | None = False, config: t.Any = None, **kwargs: t.Any, - ): + ) -> None: ... def __init__( self: Any, - default_value: str = ..., + default_value: t.Any = ..., *, - allow_none: bool | None = False, + allow_none: bool = False, help: str | None = "", read_only: bool | None = False, config: t.Any = None, **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2577,7 +2604,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2589,13 +2616,28 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + def __init__( + self, + default_value: t.Any = Undefined, + allow_none: bool = False, + read_only: bool | None = None, + help: str | None = None, + config: t.Any | None = None, + **kwargs: t.Any, + ) -> None: self.min = kwargs.pop("min", None) self.max = kwargs.pop("max", None) - super().__init__(default_value=default_value, allow_none=allow_none, **kwargs) + super().__init__( + default_value=default_value, + allow_none=allow_none, + read_only=read_only, + help=help, + config=config, + **kwargs, + ) def validate(self, obj, value): if not isinstance(value, int): @@ -2625,7 +2667,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2637,10 +2679,18 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + def __init__( + self: CInt[int | None, t.Any], + default_value: t.Any | Sentinel | None = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any | None = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -2670,7 +2720,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2682,13 +2732,28 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + def __init__( + self: Float[int | None, int | float | None], + default_value: float | Sentinel | None = Undefined, + allow_none: bool = False, + read_only: bool | None = False, + help: str | None = None, + config: t.Any | None = None, + **kwargs: t.Any, + ) -> None: self.min = kwargs.pop("min", -float("inf")) self.max = kwargs.pop("max", float("inf")) - super().__init__(default_value=default_value, allow_none=allow_none, **kwargs) + super().__init__( + default_value=default_value, + allow_none=allow_none, + read_only=read_only, + help=help, + config=config, + **kwargs, + ) def validate(self, obj, value): if isinstance(value, int): @@ -2720,7 +2785,7 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2732,10 +2797,18 @@ def __init__( help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, default_value=Undefined, allow_none=False, **kwargs): + def __init__( + self: CFloat[float | None, t.Any], + default_value: t.Any = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any | None = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -2841,7 +2914,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2853,10 +2926,18 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, **kwargs): + def __init__( + self: Unicode[str | None, str | bytes | None], + default_value: str | Sentinel | None = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -2906,7 +2987,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2918,10 +2999,18 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, **kwargs): + def __init__( + self: CUnicode[str | None, t.Any], + default_value: str | Sentinel | None = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -2981,7 +3070,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -2993,10 +3082,18 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, **kwargs): + def __init__( + self: Bool[bool | None, bool | int | None], + default_value: bool | Sentinel | None = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -3045,7 +3142,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -3057,10 +3154,18 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, **kwargs): + def __init__( + self: CBool[bool | None, t.Any], + default_value: bool | Sentinel | None = ..., + allow_none: bool = ..., + read_only: bool | None = ..., + help: str | None = ..., + config: t.Any = ..., + **kwargs: t.Any, + ) -> None: ... def validate(self, obj, value): @@ -3075,7 +3180,7 @@ class Enum(TraitType[G, S]): def __init__( self: Enum[t.Any, t.Any], values: t.Any, default_value: t.Any = Undefined, **kwargs: t.Any - ): + ) -> None: self.values = values if kwargs.get("allow_none", False) and default_value is Undefined: default_value = None @@ -3128,7 +3233,7 @@ def __init__( values: t.Any, default_value: t.Any = Undefined, **kwargs: t.Any, - ): + ) -> None: super().__init__(values, default_value=default_value, **kwargs) def validate(self, obj, value): @@ -3166,7 +3271,7 @@ def __init__( case_sensitive: bool = False, substring_matching: bool = False, **kwargs: t.Any, - ): + ) -> None: self.case_sensitive = case_sensitive self.substring_matching = substring_matching super().__init__(values, default_value=default_value, **kwargs) @@ -3217,44 +3322,49 @@ class Container(Instance[T]): @t.overload def __init__( self: Container[T], - kind: type[T], *, allow_none: Literal[False], read_only: bool | None = ..., help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload def __init__( self: Container[T | None], - kind: type[T], *, allow_none: Literal[True], read_only: bool | None = ..., help: str | None = ..., config: t.Any | None = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload def __init__( self: Container[T], - kind: type[T], *, + trait: t.Any = ..., + default_value: t.Any = ..., help: str = ..., read_only: bool = ..., config: t.Any = ..., - trait: t.Any = ..., - default_value: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... - def __init__(self, trait=None, default_value=Undefined, **kwargs): + def __init__( + self, + trait: t.Any | None = None, + default_value: t.Any = Undefined, + help: str | None = None, + read_only: bool | None = None, + config: t.Any | None = None, + **kwargs: t.Any, + ) -> None: """Create a container trait type from a list, set, or tuple. The default value is created by doing ``List(default_value)``, @@ -3324,7 +3434,9 @@ def __init__(self, trait=None, default_value=Undefined, **kwargs): elif trait is not None: raise TypeError("`trait` must be a Trait or None, got %s" % repr_type(trait)) - super().__init__(klass=self.klass, args=args, **kwargs) + super().__init__( + klass=self.klass, args=args, help=help, read_only=read_only, config=config, **kwargs + ) def validate(self, obj, value): if isinstance(value, self._cast_types): @@ -3433,12 +3545,12 @@ class List(Container[t.List[t.Any]]): def __init__( self, - trait=None, - default_value=Undefined, - minlen=0, - maxlen=sys.maxsize, - **kwargs, - ): + trait: t.Any = None, + default_value: t.Any = Undefined, + minlen: int = 0, + maxlen: int = sys.maxsize, + **kwargs: t.Any, + ) -> None: """Create a List trait type from a list, set, or tuple. The default value is created by doing ``list(default_value)``, @@ -3465,8 +3577,8 @@ def __init__( maxlen : Int [ default sys.maxsize ] The maximum length of the input list """ - self._minlen = minlen self._maxlen = maxlen + self._minlen = minlen super().__init__(trait=trait, default_value=default_value, **kwargs) def length_error(self, obj, value): @@ -3490,10 +3602,10 @@ def set(self, obj, value): return super().set(obj, value) -class Set(List): +class Set(Container[t.Set[t.Any]]): """An instance of a Python set.""" - klass = set # type:ignore[assignment] + klass = set _cast_types = (tuple, list) _literal_from_string_pairs = ("[]", "()", "{}") @@ -3501,12 +3613,12 @@ class Set(List): # Redefine __init__ just to make the docstring more accurate. def __init__( self, - trait=None, - default_value=Undefined, - minlen=0, - maxlen=sys.maxsize, - **kwargs, - ): + trait: t.Any = None, + default_value: t.Any = Undefined, + minlen: int = 0, + maxlen: int = sys.maxsize, + **kwargs: t.Any, + ) -> None: """Create a Set trait type from a list, set, or tuple. The default value is created by doing ``set(default_value)``, @@ -3533,7 +3645,29 @@ def __init__( maxlen : Int [ default sys.maxsize ] The maximum length of the input list """ - super().__init__(trait, default_value, minlen, maxlen, **kwargs) + self._maxlen = maxlen + self._minlen = minlen + super().__init__(trait=trait, default_value=default_value, **kwargs) + + def length_error(self, obj, value): + e = ( + "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." + % (self.name, class_of(obj), self._minlen, self._maxlen, value) + ) + raise TraitError(e) + + def validate_elements(self, obj, value): + length = len(value) + if length < self._minlen or length > self._maxlen: + self.length_error(obj, value) + + return super().validate_elements(obj, value) + + def set(self, obj, value): + if isinstance(value, str): + return super().set(obj, [value]) + else: + return super().set(obj, value) def default_value_repr(self): # Ensure default value is sorted for a reproducible build @@ -3549,7 +3683,7 @@ class Tuple(Container[t.Tuple[t.Any, ...]]): klass = tuple _cast_types = (list,) - def __init__(self, *traits, **kwargs): + def __init__(self, *traits: t.Any, **kwargs: t.Any) -> None: """Create a tuple from a list, set, or tuple. Create a fixed-type tuple with Traits: @@ -3692,12 +3826,12 @@ class Dict(Instance[t.Dict[t.Any, t.Any]]): def __init__( self, - value_trait=None, - per_key_traits=None, - key_trait=None, - default_value=Undefined, - **kwargs, - ): + value_trait: t.Any = None, + per_key_traits: t.Any = None, + key_trait: t.Any = None, + default_value: t.Any = Undefined, + **kwargs: t.Any, + ) -> None: """Create a dict trait type from a Python dict. The default value is created by doing ``dict(default_value)``, @@ -3963,7 +4097,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... @t.overload @@ -3975,7 +4109,7 @@ def __init__( help: str | None = ..., config: t.Any = ..., **kwargs: t.Any, - ): + ) -> None: ... def __init__( @@ -3987,7 +4121,7 @@ def __init__( help: str | None = None, config: t.Any = None, **kwargs: t.Any, - ): + ) -> None: ... def validate(self, obj, value): @@ -4054,7 +4188,9 @@ class MyEntity(HasTraits): default_value: enum.Enum | None = None info_text = "Trait type adapter to a Enum class" - def __init__(self, enum_class, default_value=None, **kwargs): + def __init__( + self, enum_class: type[t.Any], default_value: t.Any = None, **kwargs: t.Any + ) -> None: assert issubclass(enum_class, enum.Enum), "REQUIRE: enum.Enum, but was: %r" % enum_class allow_none = kwargs.get("allow_none", False) if default_value is None and not allow_none: diff --git a/traitlets/utils/importstring.py b/traitlets/utils/importstring.py index 7ac1e9ab..413c2033 100644 --- a/traitlets/utils/importstring.py +++ b/traitlets/utils/importstring.py @@ -3,9 +3,10 @@ """ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from typing import Any -def import_item(name): +def import_item(name: str) -> Any: """Import and return ``bar`` given the string ``foo.bar``. Calling ``bar = import_item("foo.bar")`` is the functional equivalent of diff --git a/traitlets/utils/sentinel.py b/traitlets/utils/sentinel.py index 75e000f8..079443d8 100644 --- a/traitlets/utils/sentinel.py +++ b/traitlets/utils/sentinel.py @@ -2,20 +2,23 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +import typing as t class Sentinel: - def __init__(self, name, module, docstring=None): + def __init__(self, name: str, module: t.Any, docstring: str | None = None) -> None: self.name = name self.module = module if docstring: self.__doc__ = docstring - def __repr__(self): + def __repr__(self) -> str: return str(self.module) + "." + self.name - def __copy__(self): + def __copy__(self) -> Sentinel: return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: t.Any) -> Sentinel: return self diff --git a/traitlets/utils/tests/test_importstring.py b/traitlets/utils/tests/test_importstring.py index 4a5471a8..1e5db490 100644 --- a/traitlets/utils/tests/test_importstring.py +++ b/traitlets/utils/tests/test_importstring.py @@ -23,4 +23,4 @@ class NotAString: msg = "import_item accepts strings, not '%s'." % NotAString with self.assertRaisesRegex(TypeError, msg): - import_item(NotAString()) + import_item(NotAString()) # type:ignore[arg-type] diff --git a/traitlets/utils/warnings.py b/traitlets/utils/warnings.py index 216b23dc..f9d52c3a 100644 --- a/traitlets/utils/warnings.py +++ b/traitlets/utils/warnings.py @@ -1,17 +1,20 @@ +from __future__ import annotations + import inspect import os +import typing as t import warnings -def warn(msg, category, *, stacklevel, source=None): +def warn(msg: str, category: t.Any, *, stacklevel: int, source: t.Any = None) -> None: """Like warnings.warn(), but category and stacklevel are required. You pretty much never want the default stacklevel of 1, so this helps encourage setting it explicitly.""" - return warnings.warn(msg, category=category, stacklevel=stacklevel, source=source) + warnings.warn(msg, category=category, stacklevel=stacklevel, source=source) -def deprecated_method(method, cls, method_name, msg): +def deprecated_method(method: t.Any, cls: t.Any, method_name: str, msg: str) -> None: """Show deprecation warning about a magic method definition. Uses warn_explicit to bind warning to method definition instead of triggering code, @@ -45,7 +48,7 @@ def deprecated_method(method, cls, method_name, msg): _deprecations_shown = set() -def should_warn(key): +def should_warn(key: t.Any) -> bool: """Add our own checks for too many deprecation warnings. Limit to once per package.