From d16fa882de56a23e785eaa5ce5269f2d9195b2bc Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:15:08 +0800 Subject: [PATCH] refine_mock_torch (#10396) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ้‡ๆž„ mock_torch ๆจกๅ— --- ci/test/test_mock_script.sh | 9 +- python/oneflow/mock_torch/__init__.py | 393 +-------------------- python/oneflow/mock_torch/__main__.py | 22 +- python/oneflow/mock_torch/dyn_mock_mod.py | 97 +++++ python/oneflow/mock_torch/mock_importer.py | 392 ++++++++++++++++++++ python/oneflow/mock_torch/mock_modules.py | 121 +++++++ python/oneflow/mock_torch/mock_utils.py | 366 +++++++++++++++++++ 7 files changed, 1008 insertions(+), 392 deletions(-) create mode 100644 python/oneflow/mock_torch/dyn_mock_mod.py create mode 100644 python/oneflow/mock_torch/mock_importer.py create mode 100644 python/oneflow/mock_torch/mock_modules.py create mode 100644 python/oneflow/mock_torch/mock_utils.py diff --git a/ci/test/test_mock_script.sh b/ci/test/test_mock_script.sh index 0c525e7dd44..eca0dbe2a31 100644 --- a/ci/test/test_mock_script.sh +++ b/ci/test/test_mock_script.sh @@ -1,5 +1,12 @@ #!/bin/bash set -e +python_version=$(python3 --version 2>&1 | awk '{print $2}') + +if [[ "$python_version" < "3.8" ]]; then + echo "Python version is less than 3.8." + exit 0 +fi + MOCK_TORCH=$PWD/python/oneflow/test/misc/mock_example.py same_or_exit() { @@ -56,4 +63,4 @@ python3 -c "import oneflow as flow; x = flow.load('test.pt'); assert flow.equal( rm test.pt eval $(python3 -m oneflow.mock_torch --lazy --verbose) -python3 -c "import torch.not_exist" | grep -q 'dummy object' +python3 -c "import torch.not_exist" | grep -q 'dummy object' \ No newline at end of file diff --git a/python/oneflow/mock_torch/__init__.py b/python/oneflow/mock_torch/__init__.py index 18ec636e2f7..c3a3adf5b66 100644 --- a/python/oneflow/mock_torch/__init__.py +++ b/python/oneflow/mock_torch/__init__.py @@ -13,393 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. """ -import builtins -import types -from inspect import ismodule, currentframe -from types import ModuleType -from typing import Any, Dict, Optional -from importlib.abc import MetaPathFinder, Loader -from importlib.machinery import ModuleSpec -from importlib.util import find_spec, module_from_spec -import sys -import os -from pathlib import Path -from contextlib import contextmanager -from zipimport import zipimporter - -import oneflow.support.env_var_util as env_var_util - -_first_init = True - -error_msg = """ is not implemented, please submit an issue at -'https://github.com/Oneflow-Inc/oneflow/issues' including the log information of the error, the -minimum reproduction code, and the system information.""" - -hazard_list = [ - "_distutils_hack", - "importlib", - "regex", - "tokenizers", - "safetensors._safetensors_rust", -] - -# patch hasattr so that -# 1. torch.not_exist returns DummyModule object, but -# 2. hasattr(torch, "not_exist") still returns False -_builtin_hasattr = builtins.hasattr -if not isinstance(_builtin_hasattr, types.BuiltinFunctionType): - raise Exception("hasattr already patched by someone else!") - - -def hasattr(obj, name): - return _builtin_hasattr(obj, name) - - -builtins.hasattr = hasattr - - -def probably_called_from_hasattr(): - frame = currentframe().f_back.f_back - return frame.f_code is hasattr.__code__ - - -class MockModuleDict: - def __init__(self, mapping=None): - if mapping is not None and not isinstance(mapping, dict): - raise ValueError("Extra mock library must be a dict.") - self.forward = {} - self.inverse = {} - if mapping is not None: - for key, value in mapping.items(): - self.add(key, value) - - def add(self, key, value): - if key in self.forward or value in self.inverse: - raise ValueError("Key or value already exists.") - self.forward[key] = value - self.inverse[value] = key - - def remove(self, key=None, value=None): - if key is not None: - value = self.forward.pop(key) - self.inverse.pop(value) - elif value is not None: - key = self.inverse.pop(value) - self.forward.pop(key) - else: - raise ValueError("Must provide a key or value to remove.") - - def in_forward_dict(self, s): - return s.split(".")[0] in self.forward.keys() - - def in_inverse_dict(self, s): - return s.split(".")[0] in self.inverse.keys() - - -# module wrapper with checks for existence of methods -class ModuleWrapper(ModuleType): - def __init__(self, module): - self.module = module - - def __setattr__(self, name, value): - super().__setattr__(name, value) - if name != "module": - setattr(self.module, name, value) - - def __getattr__(self, name: str) -> Any: - if not hasattr(self.module, name): - if name == "__path__": - return None - if name == "__all__": - return [attr for attr in dir(self.module) if not attr.startswith("_")] - new_name = self.module.__name__ + "." + name - if _importer.lazy and not probably_called_from_hasattr(): - if _importer.verbose: - print( - f'"{new_name}" is not found in oneflow, use dummy object as fallback.' - ) - return DummyModule(new_name) - else: - if _importer.lazy and _importer.verbose: - print(f"hasattr({self.module.__name__}, {name}) returns False") - raise AttributeError(new_name + error_msg) - attr = getattr(self.module, name) - if ismodule(attr): - return ModuleWrapper(attr) - else: - return attr - - -class OneflowImporter(MetaPathFinder, Loader): - def __init__(self): - # module_from_spec will try to call the loader's create_module, resulting in infinite recursion - self.in_create_module = False - self.enable = False - # both __init__.py of oneflow and torch can't be executed multiple times, so we use a cache - self.enable_mod_cache = {} - self.disable_mod_cache = {} - self.delete_list = [] - - def find_spec(self, fullname, path, target=None): - if module_dict_global.in_forward_dict( - fullname - ): # don't touch modules other than torch or extra libs module - # for first import of real torch, we use default meta path finders, not our own - if not self.enable and self.disable_mod_cache.get(fullname) is None: - return None - return ModuleSpec(fullname, self) - self.delete_list.append(fullname) - return None - - def find_module(self, fullname, path=None): - spec = self.find_spec(fullname, path) - return spec - - def create_module(self, spec): - if self.in_create_module: - return None - self.in_create_module = True - if self.enable: - if module_dict_global.in_forward_dict(spec.name): - oneflow_mod_fullname = ( - module_dict_global.forward[spec.name.split(".")[0]] - + spec.name[len(spec.name.split(".")[0]) :] - ) - if ( - sys.modules.get(oneflow_mod_fullname) is None - and self.enable_mod_cache.get(spec.name) is None - ): - # get actual oneflow module - try: - real_spec = find_spec(oneflow_mod_fullname) - except ModuleNotFoundError: - real_spec = None - if real_spec is None: - self.in_create_module = False - if self.lazy: - if self.verbose: - print( - f"{oneflow_mod_fullname} is not found in oneflow, use dummy object as fallback." - ) - return DummyModule(oneflow_mod_fullname) - else: - raise ModuleNotFoundError(oneflow_mod_fullname + error_msg) - - real_mod = module_from_spec(real_spec) - loader = real_spec.loader - if isinstance(loader, zipimporter): - pass - else: - loader.exec_module(real_mod) - else: - real_mod = sys.modules.get(oneflow_mod_fullname) - if real_mod is None: - real_mod = self.enable_mod_cache[spec.name] - self.in_create_module = False - return real_mod - else: - torch_full_name = spec.name - real_mod = self.disable_mod_cache[torch_full_name] - self.in_create_module = False - return real_mod - - def exec_module(self, module): - if module_dict_global.in_inverse_dict(module.__name__): - fullname = ( - module_dict_global.inverse[module.__name__.split(".")[0]] - + module.__name__[len(module.__name__.split(".")[0]) :] - ) - if self.enable: - if not isinstance(module, DummyModule): - module = ModuleWrapper(module) - sys.modules[fullname] = module - globals()[fullname] = module - - def _enable(self, globals, lazy: bool, verbose: bool, *, from_cli: bool = False): - global _first_init - if _first_init: - _first_init = False - self.enable = False # deal with previously imported torch - sys.meta_path.insert(0, self) - self._enable(globals, lazy, verbose, from_cli=from_cli) - return - self.lazy = lazy - self.verbose = verbose - self.from_cli = from_cli - if self.enable: # already enabled - return - for k, v in sys.modules.copy().items(): - if (not (from_cli and k == "torch")) and module_dict_global.in_forward_dict( - k - ): - aliases = list(filter(lambda alias: globals[alias] is v, globals)) - self.disable_mod_cache.update({k: (v, aliases)}) - del sys.modules[k] - for alias in aliases: - del globals[alias] - for k, (v, aliases) in self.enable_mod_cache.items(): - sys.modules.update({k: v}) - for alias in aliases: - globals.update({alias: v}) - self.enable = True - - def _disable(self, globals): - if not self.enable: # already disabled - return - for k, v in sys.modules.copy().items(): - if module_dict_global.in_forward_dict(k): - aliases = list(filter(lambda alias: globals[alias] is v, globals)) - self.enable_mod_cache.update({k: (v, aliases)}) - del sys.modules[k] - for alias in aliases: - del globals[alias] - name = k if "." not in k else k[: k.find(".")] - if ( - not name in hazard_list - and not k in hazard_list - and k in self.delete_list - ): - aliases = list(filter(lambda alias: globals[alias] is v, globals)) - self.enable_mod_cache.update({k: (v, aliases)}) - del sys.modules[k] - for k, (v, aliases) in self.disable_mod_cache.items(): - sys.modules.update({k: v}) - for alias in aliases: - globals.update({alias: v}) - if self.from_cli: - torch_env = Path(__file__).parent - sys.path.remove(str(torch_env)) - - self.enable = False - - -_importer = OneflowImporter() - - -class DummyModule(ModuleType): - def __getattr__(self, name): - if _importer.verbose: - print( - f'"{self.__name__}" is a dummy object, and its attr "{name}" is accessed.' - ) - if name == "__path__": - return None - if name == "__all__": - return [] - if name == "__file__": - return None - if name == "__mro_entries__": - return lambda x: () - return DummyModule(self.__name__ + "." + name) - - def __getitem__(self, name): - new_name = f"{self.__name__}[{name}]" - if isinstance(name, int): - if _importer.verbose: - print( - f'"{self.__name__}" is a dummy object, and `{new_name}` is called. Raising an IndexError to simulate an empty list.' - ) - raise IndexError - if _importer.verbose: - print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') - return DummyModule(new_name) - - def __call__(self, *args, **kwargs): - new_name = f'{self.__name__}({", ".join(map(repr, args))}, {", ".join(["{}={}".format(k, repr(v)) for k, v in kwargs.items()])})' - if _importer.verbose: - print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') - return DummyModule(new_name) - - def __bool__(self): - if _importer.verbose: - print( - f'"{self.__name__}" is a dummy object, and its bool value is accessed.' - ) - return False - - def __enter__(self): - raise RuntimeError( - f'"{self.__name__}" is a dummy object, and does not support "with" statement.' - ) - - def __exit__(self, exception_type, exception_value, traceback): - raise RuntimeError( - f'"{self.__name__}" is a dummy object, and does not support "with" statement.' - ) - - def __subclasscheck__(self, subclass): - return False - - def __instancecheck__(self, instance): - return False - - -class enable: - def __init__( - self, - lazy: Optional[bool] = None, - verbose: Optional[bool] = None, - extra_dict: Optional[Dict[str, str]] = None, - *, - _from_cli: bool = False, - ): - global module_dict_global - module_dict_global = MockModuleDict(extra_dict) - module_dict_global.add("torch", "oneflow") - self.enable = _importer.enable - forcedly_disabled_by_env_var = env_var_util.parse_boolean_from_env( - "ONEFLOW_DISABLE_MOCK_TORCH", False - ) - globals = currentframe().f_back.f_globals - self.globals = globals - lazy = ( - lazy - if lazy is not None - else env_var_util.parse_boolean_from_env("ONEFLOW_MOCK_TORCH_LAZY", False) - ) - verbose = ( - verbose - if verbose is not None - else env_var_util.parse_boolean_from_env( - "ONEFLOW_MOCK_TORCH_VERBOSE", False - ) - ) - if forcedly_disabled_by_env_var: - return - _importer._enable(globals, lazy, verbose, from_cli=_from_cli) - - def __enter__(self): - pass - - def __exit__(self, exception_type, exception_value, traceback): - if not self.enable: - _importer._disable(self.globals) - - -class disable: - def __init__(self): - self.enable = _importer.enable - if not self.enable: - return - globals = currentframe().f_back.f_globals - self.globals = globals - self.lazy = _importer.lazy - self.verbose = _importer.verbose - _importer._disable(globals) - - def __enter__(self): - pass - - def __exit__(self, exception_type, exception_value, traceback): - if self.enable: - _importer._enable( - # When re-enabling mock torch, from_cli shoule always be False - self.globals, - self.lazy, - self.verbose, - from_cli=False, - ) - - -def is_enabled(): - return _importer.enable +from .mock_importer import ModuleWrapper, enable, disable +from .mock_modules import DummyModule +from .dyn_mock_mod import DynamicMockModule diff --git a/python/oneflow/mock_torch/__main__.py b/python/oneflow/mock_torch/__main__.py index 94c74c63d10..a35892ef614 100644 --- a/python/oneflow/mock_torch/__main__.py +++ b/python/oneflow/mock_torch/__main__.py @@ -16,6 +16,16 @@ import argparse from pathlib import Path import os +import sys + +if sys.version_info < (3, 8): + try: + from importlib_metadata import requires + except ImportError: + import subprocess + + subprocess.check_call("pip install importlib_metadata", shell=True) + subprocess.check_call("pip install packaging", shell=True) parser = argparse.ArgumentParser() parser.add_argument( @@ -33,13 +43,23 @@ def main(): + def is_torch_env(s): + if s.endswith("oneflow/mock_torch"): + return True + return False + if args.mock == "enable": print( f"export ONEFLOW_MOCK_TORCH_LAZY={args.lazy}; export ONEFLOW_MOCK_TORCH_VERBOSE={args.verbose}; export PYTHONPATH={str(torch_env)}:$PYTHONPATH" ) elif args.mock == "disable" and "PYTHONPATH" in os.environ: paths = os.environ["PYTHONPATH"].rstrip(":").split(":") - paths = [x for x in paths if x != str(torch_env)] + paths = [p for p in paths if not is_torch_env(p)] + if len(paths) == 0: + print( + "unset PYTHONPATH; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE" + ) + return path = ":".join(paths) print( f"export PYTHONPATH={path}; unset ONEFLOW_MOCK_TORCH_LAZY; unset ONEFLOW_MOCK_TORCH_VERBOSE" diff --git a/python/oneflow/mock_torch/dyn_mock_mod.py b/python/oneflow/mock_torch/dyn_mock_mod.py new file mode 100644 index 00000000000..6203420bea4 --- /dev/null +++ b/python/oneflow/mock_torch/dyn_mock_mod.py @@ -0,0 +1,97 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from inspect import ismodule +import importlib +from contextlib import contextmanager +from types import ModuleType +from typing import Dict, List +from .mock_importer import enable + + +class DynamicMockModule(ModuleType): + def __init__( + self, pkg_name: str, obj_entity: ModuleType, main_pkg_enable: callable, + ): + self._pkg_name = pkg_name + self._obj_entity = obj_entity # ModuleType or _LazyModule + self._main_pkg_enable = main_pkg_enable + self._intercept_dict = {} + + def __repr__(self) -> str: + return f"" + + def hijack(self, module_name: str, obj: object): + self._intercept_dict[module_name] = obj + + @classmethod + def from_package( + cls, + main_pkg: str, + *, + lazy: bool = True, + verbose: bool = False, + extra_dict: Dict[str, str] = None, + required_dependencies: List[str] = [], + ): + assert isinstance(main_pkg, str) + + @contextmanager + def main_pkg_enable(): + with enable( + lazy=lazy, + verbose=verbose, + extra_dict=extra_dict, + main_pkg=main_pkg, + mock_version=True, + required_dependencies=required_dependencies, + ): + yield + + with main_pkg_enable(): + obj_entity = importlib.import_module(main_pkg) + return cls(main_pkg, obj_entity, main_pkg_enable) + + def _get_module(self, _name: str): + # Fix Lazy import + # https://github.com/huggingface/diffusers/blob/main/src/diffusers/__init__.py#L728-L734 + module_name = f"{self._obj_entity.__name__}.{_name}" + try: + return importlib.import_module(module_name) + except Exception as e: + raise RuntimeError( + f"Failed to import {module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __getattr__(self, name: str): + fullname = f"{self._obj_entity.__name__}.{name}" + if fullname in self._intercept_dict: + return self._intercept_dict[fullname] + + with self._main_pkg_enable(): + obj_entity = getattr(self._obj_entity, name, None) + if obj_entity is None: + obj_entity = self._get_module(name) + + if ismodule(obj_entity): + return DynamicMockModule(self._pkg_name, obj_entity, self._main_pkg_enable) + + return obj_entity + + def __all__(self): + with self._main_pkg_enable(): + return dir(self._obj_entity) diff --git a/python/oneflow/mock_torch/mock_importer.py b/python/oneflow/mock_torch/mock_importer.py new file mode 100644 index 00000000000..94e305583f9 --- /dev/null +++ b/python/oneflow/mock_torch/mock_importer.py @@ -0,0 +1,392 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import builtins +from functools import partial +import types +from inspect import ismodule, currentframe +from types import ModuleType +from typing import Any, Dict, Optional +from importlib.abc import MetaPathFinder, Loader +from importlib.machinery import ModuleSpec +from importlib.util import find_spec, module_from_spec +import sys +from typing import List +from zipimport import zipimporter + +import oneflow.support.env_var_util as env_var_util +from .mock_modules import MockModuleDict, DummyModule +from .mock_utils import MockEnableDisableMixin + + +error_msg = """ is not implemented, please submit an issue at +'https://github.com/Oneflow-Inc/oneflow/issues' including the log information of the error, the +minimum reproduction code, and the system information.""" + + +# patch hasattr so that +# 1. torch.not_exist returns DummyModule object, but +# 2. hasattr(torch, "not_exist") still returns False +_builtin_hasattr = builtins.hasattr +if not isinstance(_builtin_hasattr, types.BuiltinFunctionType): + raise Exception("hasattr already patched by someone else!") + + +def hasattr(obj, name): + return _builtin_hasattr(obj, name) + + +builtins.hasattr = hasattr + + +def probably_called_from_hasattr(): + frame = currentframe().f_back.f_back + return frame.f_code is hasattr.__code__ + + +# module wrapper with checks for existence of methods +class ModuleWrapper(ModuleType): + # TODO add selcted methods + def __init__(self, module): + self.module = module + + def __setattr__(self, name, value): + super().__setattr__(name, value) + if name != "module": + setattr(self.module, name, value) + + def __getattr__(self, name: str) -> Any: + if not hasattr(self.module, name): + if name == "__path__": + return None + if name == "__all__": + return [attr for attr in dir(self.module) if not attr.startswith("_")] + new_name = self.module.__name__ + "." + name + if _importer.lazy and not probably_called_from_hasattr(): + if _importer.verbose: + print( + f'"{new_name}" is not found in oneflow, use dummy object as fallback.' + ) + return DummyModule(new_name, verbose=_importer.verbose) + else: + if _importer.lazy and _importer.verbose: + print(f"hasattr({self.module.__name__}, {name}) returns False") + raise AttributeError(new_name + error_msg) + attr = getattr(self.module, name) + if ismodule(attr): + return ModuleWrapper(attr) + else: + return attr + + +class OneflowImporter(MockEnableDisableMixin, MetaPathFinder, Loader): + def __init__(self): + # module_from_spec will try to call the loader's create_module, resulting in infinite recursion + self.in_create_module = False + self.enable = False + # both __init__.py of oneflow and torch can't be executed multiple times, so we use a cache + self.enable_mod_cache = {} + self.disable_mod_cache = {} + # Record modules loaded during mocking for deletion + self.delete_list = [] + + def find_spec(self, fullname, path, target=None): + if module_dict_global.in_forward_dict( + fullname + ): # don't touch modules other than torch or extra libs module + # for first import of real torch, we use default meta path finders, not our own + if not self.enable and self.disable_mod_cache.get(fullname) is None: + return None + return ModuleSpec(fullname, self) + self.delete_list.append(fullname) + return None + + def find_module(self, fullname, path=None): + spec = self.find_spec(fullname, path) + return spec + + def create_module(self, spec): + if self.in_create_module: + return None + self.in_create_module = True + if self.enable: + if module_dict_global.in_forward_dict(spec.name): + oneflow_mod_fullname = module_dict_global.forward_name(spec.name) + if ( + sys.modules.get(oneflow_mod_fullname) is None + and self.enable_mod_cache.get(spec.name) is None + ): + # get actual oneflow module + try: + real_spec = find_spec(oneflow_mod_fullname) + except ModuleNotFoundError: + real_spec = None + if real_spec is None: + self.in_create_module = False + if self.lazy: + if self.verbose: + print( + f"{oneflow_mod_fullname} is not found in oneflow, use dummy object as fallback." + ) + return DummyModule(oneflow_mod_fullname, verbose=self.verbose) + else: + raise ModuleNotFoundError(oneflow_mod_fullname + error_msg) + + real_mod = module_from_spec(real_spec) + loader = real_spec.loader + if isinstance(loader, zipimporter): + # TODO: verify can mock torch as oneflow in zipimporter + pass + else: + loader.exec_module(real_mod) + else: + real_mod = sys.modules.get(oneflow_mod_fullname) + if real_mod is None: + real_mod = self.enable_mod_cache[spec.name] + self.in_create_module = False + return real_mod + else: + torch_full_name = spec.name + real_mod = self.disable_mod_cache[torch_full_name] + self.in_create_module = False + return real_mod + + def exec_module(self, module): + module_name = module.__name__ + if module_dict_global.in_inverse_dict(module_name): + fullname = module_dict_global.inverse_name(module_name) + if self.enable: + if not isinstance(module, DummyModule): + module = ModuleWrapper(module) + sys.modules[fullname] = module + globals()[fullname] = module + + def _enable( + self, + globals=None, + lazy=False, + verbose=False, + *, + main_pkg: str = None, + mock_version: bool = None, + required_dependencies: List[str] = [], + from_cli: bool = False, + ): + + if verbose: + print("enable mock torch", globals["__name__"]) + + if self.enable: # already enabled + of_importer_module_name = self.globals["__name__"] + input_module_name = globals["__name__"] + if of_importer_module_name != input_module_name: + print( + f"Warning: {of_importer_module_name} is already enabled, but {input_module_name} is trying to enable it again. skip." + ) + return + + # record config for re-enabling + self._mock_enable_config = {k: v for k, v in locals().items() if k != "self"} + # insert importer to the first place of meta_path + sys.meta_path.insert(0, self) + + self.lazy = lazy + self.verbose = verbose + self.from_cli = from_cli + self.globals = globals + + self.mock_enable( + globals=globals, + module_dict=module_dict_global, + main_pkg=main_pkg, + mock_version=mock_version, + required_dependencies=required_dependencies, + from_cli=from_cli, + verbose=verbose, + ) + self.enable = True + + def _disable(self, globals, *, verbose=False): + if verbose: + print( + "disable mock torch in", + globals["__name__"], + "\tself.enable: ", + self.enable, + ) + + if not self.enable: # already disabled + return + + of_importer_module_name = self.globals["__name__"] + input_module_name = globals["__name__"] + if of_importer_module_name != input_module_name: + raise RuntimeError( + f"Error: {of_importer_module_name} is enabled, but {input_module_name} is trying to disable it. must disable it in the same module." + ) + + self.mock_disable( + globals=globals, + module_dict=module_dict_global, + delete_list=self.delete_list, + from_cli=self.from_cli, + ) + + sys.meta_path.remove(self) + self.enable = False + self.delete_list = [] + self.globals = None + + +_importer = OneflowImporter() + + +class BaseMockConfig: + def __init__( + self, + lazy: Optional[bool] = None, + verbose: Optional[bool] = None, + extra_dict: Optional[Dict[str, str]] = None, + *, + main_pkg: Optional[str] = None, + mock_version: Optional[str] = None, + required_dependencies: List[str] = [], + _from_cli: bool = False, + ): + global module_dict_global + module_dict_global = MockModuleDict(extra_dict) + module_dict_global.add("torch", "oneflow") + + required_dependencies.extend( + [k for k in extra_dict or {} if k not in required_dependencies] + ) + if "torch" not in required_dependencies: + required_dependencies.append("torch") + + parse_bool_env = partial( + env_var_util.parse_boolean_from_env, defalut_value=False + ) + + forcedly_disabled_by_env_var = parse_bool_env("ONEFLOW_DISABLE_MOCK_TORCH") + + lazy = lazy if lazy is not None else parse_bool_env("ONEFLOW_MOCK_TORCH_LAZY") + verbose = ( + verbose + if verbose is not None + else parse_bool_env("ONEFLOW_MOCK_TORCH_VERBOSE") + ) + + self.lazy = lazy + self.verbose = verbose + self.forcedly_disabled_by_env_var = forcedly_disabled_by_env_var + self.required_dependencies = required_dependencies + self.parse_bool_env = parse_bool_env + self._from_cli = _from_cli + self.main_pkg = main_pkg + self.mock_version = mock_version + + +class enable(BaseMockConfig): + """https://docs.oneflow.org/master/cookies/oneflow_torch.html""" + + def __init__( + self, + lazy: Optional[bool] = None, + verbose: Optional[bool] = None, + extra_dict: Optional[Dict[str, str]] = None, + *, + main_pkg: Optional[str] = None, + mock_version: Optional[str] = None, + required_dependencies: List[str] = [], + _from_cli: bool = False, + ): + super().__init__( + lazy=lazy, + verbose=verbose, + extra_dict=extra_dict, + main_pkg=main_pkg, + mock_version=mock_version, + required_dependencies=required_dependencies, + _from_cli=_from_cli, + ) + + if self.forcedly_disabled_by_env_var: # super().__init__ will set this + return + + self.globals = currentframe().f_back.f_globals + self.skip_processing = False + if getattr(_importer, "globals", None) is not None: + import_name = _importer.globals["__name__"] + if import_name == self.globals["__name__"]: + self.skip_processing = True + return + + self._importer_enable = _importer.enable + if self._importer_enable: + self._mock_enable_config = _importer._mock_enable_config + _importer._disable(_importer.globals, verbose=self.verbose) + + _importer._enable( + self.globals, + lazy=self.lazy, + verbose=self.verbose, + main_pkg=main_pkg, + mock_version=mock_version, + required_dependencies=required_dependencies, + from_cli=_from_cli, + ) + + def __enter__(self): + pass + + def __exit__(self, exception_type, exception_value, traceback): + + if self.forcedly_disabled_by_env_var or self.skip_processing: + return + + _importer._disable(_importer.globals, verbose=self.verbose) + + if self._importer_enable: + _importer._enable( + # When re-enabling mock torch, from_cli shoule always be False + **self._mock_enable_config, + ) + + +class disable: + def __init__(self): + self._importer_enable = _importer.enable + if not self._importer_enable: + return + + self.globals = currentframe().f_back.f_globals + self.lazy = _importer.lazy + self.verbose = _importer.verbose + self._mock_enable_config = _importer._mock_enable_config + _importer._disable(_importer.globals, verbose=self.verbose) + + def __enter__(self): + pass + + def __exit__(self, exception_type, exception_value, traceback): + if self._importer_enable: + _importer._enable( + # When re-enabling mock torch, from_cli shoule always be False + **self._mock_enable_config, + ) + + +def is_enabled(): + return _importer.enable diff --git a/python/oneflow/mock_torch/mock_modules.py b/python/oneflow/mock_torch/mock_modules.py new file mode 100644 index 00000000000..99edfa4993f --- /dev/null +++ b/python/oneflow/mock_torch/mock_modules.py @@ -0,0 +1,121 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from types import ModuleType + +__all__ = ["MockModuleDict", "DummyModule"] + + +class MockModuleDict: + def __init__(self, mapping=None): + if mapping is not None and not isinstance(mapping, dict): + raise ValueError("Extra mock library must be a dict.") + self.forward = {} + self.inverse = {} + if mapping is not None: + for key, value in mapping.items(): + self.add(key, value) + + def add(self, key, value): + """mock key thorugh value.""" + if key in self.forward or value in self.inverse: + raise ValueError("Key or value already exists.") + self.forward[key] = value + self.inverse[value] = key + + def remove(self, key=None, value=None): + if key is not None: + value = self.forward.pop(key) + self.inverse.pop(value) + elif value is not None: + key = self.inverse.pop(value) + self.forward.pop(key) + else: + raise ValueError("Must provide a key or value to remove.") + + def in_forward_dict(self, s): + return s.split(".")[0] in self.forward.keys() + + def in_inverse_dict(self, s): + return s.split(".")[0] in self.inverse.keys() + + def inverse_name(self, s: str): # s: spec.name + return self.inverse[s.split(".")[0]] + s[len(s.split(".")[0]) :] + + def forward_name(self, s: str): + return self.forward[s.split(".")[0]] + s[len(s.split(".")[0]) :] + + +class DummyModule(ModuleType): + def __init__(self, name, verbose=False): + super().__init__(name) + self._verbose = verbose + + def __getattr__(self, name): + if self._verbose: + print( + f'"{self.__name__}" is a dummy object, and its attr "{name}" is accessed.' + ) + if name == "__path__": + return None + if name == "__all__": + return [] + if name == "__file__": + return None + if name == "__mro_entries__": + return lambda x: () + + return DummyModule(self.__name__ + "." + name, self._verbose) + + def __getitem__(self, name): + new_name = f"{self.__name__}[{name}]" + if isinstance(name, int): + if self._verbose: + print( + f'"{self.__name__}" is a dummy object, and `{new_name}` is called. Raising an IndexError to simulate an empty list.' + ) + raise IndexError + if self._verbose: + print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') + return DummyModule(new_name, self._verbose) + + def __call__(self, *args, **kwargs): + new_name = f'{self.__name__}({", ".join(map(repr, args))}, {", ".join(["{}={}".format(k, repr(v)) for k, v in kwargs.items()])})' + if self._verbose: + print(f'"{self.__name__}" is a dummy object, and `{new_name}` is called.') + return DummyModule(new_name, self._verbose) + + def __bool__(self): + if self._verbose: + print( + f'"{self.__name__}" is a dummy object, and its bool value is accessed.' + ) + return False + + def __enter__(self): + raise RuntimeError( + f'"{self.__name__}" is a dummy object, and does not support "with" statement.' + ) + + def __exit__(self, exception_type, exception_value, traceback): + raise RuntimeError( + f'"{self.__name__}" is a dummy object, and does not support "with" statement.' + ) + + def __subclasscheck__(self, subclass): + return False + + def __instancecheck__(self, instance): + return False diff --git a/python/oneflow/mock_torch/mock_utils.py b/python/oneflow/mock_torch/mock_utils.py new file mode 100644 index 00000000000..64662be79c7 --- /dev/null +++ b/python/oneflow/mock_torch/mock_utils.py @@ -0,0 +1,366 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import sys +import sysconfig +import pkgutil +from collections import deque +from importlib import import_module + +if sys.version_info < (3, 8): + try: + from importlib_metadata import requires + except ImportError: + import subprocess + + subprocess.check_call("pip install importlib_metadata", shell=True) + subprocess.check_call("pip install packaging", shell=True) +else: + from importlib.metadata import requires + +from packaging.requirements import Requirement +from pathlib import Path +from functools import lru_cache +from typing import List, Optional +from types import ModuleType + + +__all__ = ["MockEnableDisableMixin"] + + +class PackageDependencyMixin: + """Get all dependencies of a package filtered by a list of dependencies. + + Example: + >>> import diffusers # version 0.24.0 + >>> op = PackageDependencyMixin() + >>> result = op.has_dependencies("diffusers", ["torch"]) + >>> print(result) + ['huggingface_hub', 'diffusers'] + """ + + pkg_cache = {} # {pkg: [deps]} + + @staticmethod + def find_matching_dependencies( + main_pkg: str, dependencies: List[str], max_visits=1000 + ) -> List[str]: + @lru_cache() + def python_stdlib_packages(): + # current python stdlib path + stdlib_path = sysconfig.get_paths()["stdlib"] + + # use pkgutil to list all modules in the standard library + python_modules = [ + name for _, name, _ in pkgutil.iter_modules([stdlib_path]) + ] + + # combine built-in module names and Python modules + all_modules = list(sys.builtin_module_names) + python_modules + + return all_modules + + def format_package_name(pkg: str): + return Requirement(pkg).name.replace("-", "_") + + @lru_cache() + def get_requirements(pkg: str): + + python_modules = python_stdlib_packages() + if pkg in python_modules: + return [] + try: + direct_dependencies = requires(pkg) + if len(direct_dependencies) == 0: + return [] + + result = set() + for pkg in direct_dependencies: + pkg = format_package_name(pkg) + if pkg == main_pkg: + continue + + if pkg not in python_modules: + result.add(pkg) + + return list(result) + + except: + return [] + + def is_leaf_package(pkg) -> bool: + if pkg in dependencies: + return True + + return len(get_requirements(pkg)) == 0 + + main_pkg = format_package_name(main_pkg) + + # build graph + graph = {} # {dep: [pkg1, pkg2, ...]} + queue = deque([main_pkg]) + visited = set() + stops = set() + while queue: + pkg = queue.popleft() + if is_leaf_package(pkg): + stops.add(pkg) + continue + if pkg in visited: + continue + visited.add(pkg) + if len(visited) > max_visits: + print( + f"\033[1;33mWARNING: max_visits {max_visits} reached, stop searching.\033[0m" + ) + break + + for req in get_requirements(pkg): + graph.setdefault(req, set()).add(pkg) + queue.append(req) + + # init cache and queue + cache = {} + visited.clear() + queue = deque(stops) + for pkg in stops: + cache[pkg] = True if pkg in dependencies else False + + # bfs_from_stops + while queue: + pkg = queue.popleft() + if pkg in visited: + continue + visited.add(pkg) + + for dep in graph.get(pkg, set()): + is_ok = cache.get(dep, False) + if cache[pkg] or is_ok: + is_ok = True + cache[dep] = is_ok + queue.append(dep) + + return [pkg for pkg, is_ok in cache.items() if is_ok] + + @staticmethod + def varify_input(main_pkg, dependencies, callback, verbose=False): + try: + requires(main_pkg) + except: + if verbose: + print( + f"WARNING: main_pkg {main_pkg} has no meta information, please check if it is a valid package." + ) + print("will set it as its own dependency to avoid error.") + PackageDependencyMixin.pkg_cache[main_pkg] = [main_pkg] + dependencies + + if not isinstance(main_pkg, str): + raise ValueError("main_pkg must be a string.") + if not isinstance(dependencies, list): + raise ValueError("dependencies must be a list.") + if not all([isinstance(dep, str) for dep in dependencies]): + raise ValueError("dependencies must be a list of strings.") + if callback is not None and not callable(callback): + raise ValueError("callback must be a callable.") + + @classmethod + def has_dependencies( + self, + main_pkg: str, + dependencies: List[str], + callback: callable = None, + *, + verbose=False, + ) -> List[str]: + """Check if a package has any dependencies in a list of dependencies.""" + PackageDependencyMixin.varify_input(main_pkg, dependencies, callback, verbose) + + deps = PackageDependencyMixin.pkg_cache.get(main_pkg, None) + if deps is None: + deps = PackageDependencyMixin.find_matching_dependencies( + main_pkg, dependencies + ) + PackageDependencyMixin.pkg_cache.update({main_pkg: deps}) + + if verbose: + print("PackageDependencyMixin : main_pkg=", main_pkg, ", deps=", deps) + + if callback: + return callback(deps) + else: + return deps + + +class VersionMixin: + version_cache = {} + + def mock_version(self, module_a: ModuleType, module_b: ModuleType): + """Mock the version of module_a with the version of module_b.""" + if isinstance(module_a, str): + module_a = import_module(module_a) + if isinstance(module_b, str): + module_b = import_module(module_b) + + attr_name = "__version__" + orig_attr = getattr(module_a, attr_name, None) + setattr(module_a, attr_name, getattr(module_b, attr_name, None)) + VersionMixin.version_cache.update({module_a: (attr_name, orig_attr)}) + + def restore_version(self): + for module, (attr_name, orig_attr) in self.version_cache.items(): + setattr(module, attr_name, orig_attr) + VersionMixin.version_cache.clear() + + +class MockEnableDisableMixin(PackageDependencyMixin, VersionMixin): + """Mock torch package using OneFlow.""" + + # list of hazardous modules that may cause issues, handle with care + hazard_list = [ + "_distutils_hack", + "importlib", + "regex", + "tokenizers", + "safetensors._safetensors_rust", + ] + + def is_safe_module(self, module_key): + k = module_key + hazard_list = MockEnableDisableMixin.hazard_list + + name = k if "." not in k else k[: k.find(".")] + if name in hazard_list or k in hazard_list: + return False + return True + + def mock_enable( + self, + globals, # parent_globals + module_dict, # MockModuleDict object + *, + main_pkg: Optional[str] = None, + mock_version: Optional[str] = None, + required_dependencies: List[str], + from_cli=False, + verbose=False, + **kwargs, + ): + """Mock torch package using OneFlow. + + Args: + `globals`: The globals() of the parent module. + + `module_dict`: MockModuleDict object. + + `main_pkg`: The main package to mock. + + `required_dependencies`: The dependencies to mock for the `main_pkg`. + """ + if mock_version: + mock_map = module_dict.forward + for pkg, mock_pkg in mock_map.items(): + self.mock_version(pkg, mock_pkg) + + if not hasattr(self, "enable_mod_cache"): + self.enable_mod_cache = {} + if not hasattr(self, "disable_mod_cache"): + self.disable_mod_cache = {} + if not hasattr(self, "mock_safety_packages"): + self.mock_safety_packages = set() + + if main_pkg: + # Analyze the dependencies of the main package + cur_sys_modules = sys.modules.copy() + existing_deps = self.has_dependencies( + main_pkg, + dependencies=required_dependencies, + callback=lambda x: [dep for dep in x if dep in cur_sys_modules], + verbose=verbose, + ) + if verbose: + print( + "Existing dependencies of ", + "main_pkg: ", + main_pkg, + "existing_deps: ", + existing_deps, + ) + + self.mock_safety_packages.update(existing_deps) + + # disable non-safe modules loaded before mocking + def can_disable_mod_cache(k): # module_key + if not self.is_safe_module(k): + return False + if module_dict.in_forward_dict(k): + return True + for dep_pkg in self.mock_safety_packages: + if k.startswith(dep_pkg + ".") or k == dep_pkg: + return True + return False + + for k, v in sys.modules.copy().items(): + exclude_torch_from_cli = not (from_cli and k == "torch") + if not exclude_torch_from_cli: # torch is imported from CLI + continue + + if can_disable_mod_cache(k): + aliases = [alias for alias, value in globals.items() if value is v] + self.disable_mod_cache.update({k: (v, aliases)}) + del sys.modules[k] + for alias in aliases: + del globals[alias] + + # restore modules loaded during mocking + for k, (v, aliases) in self.enable_mod_cache.items(): + sys.modules.update({k: v}) + for alias in aliases: + globals.update({alias: v}) + + def mock_disable(self, globals, module_dict, delete_list, from_cli=False): + """Disable the mocked packages.""" + if not hasattr(self, "enable_mod_cache") or not hasattr( + self, "disable_mod_cache" + ): + RuntimeError("Please call mock_enable() first.") + + # disable modules loaded during mocking + def can_enable_mod_cache(k): # module_key + if not self.is_safe_module(k): + return False + if module_dict.in_forward_dict(k): + return True + return k in delete_list + + for k, v in sys.modules.copy().items(): + if can_enable_mod_cache(k): + aliases = [alias for alias, value in globals.items() if value is v] + self.enable_mod_cache.update({k: (v, aliases)}) + del sys.modules[k] + for alias in aliases: + del globals[alias] + + # restore modules loaded during before mocking + for k, (v, aliases) in self.disable_mod_cache.items(): + sys.modules.update({k: v}) + for alias in aliases: + globals.update({alias: v}) + + if from_cli: + torch_env = Path(__file__).parent + if str(torch_env) in sys.path: + sys.path.remove(str(torch_env)) + + self.restore_version()