From 213ac7ae5361351d6d7a7a5f5f26039355f16818 Mon Sep 17 00:00:00 2001 From: trajep Date: Sat, 10 Aug 2024 03:41:54 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=93=BD=EF=B8=8F=20=20Olive=20StrEnumBase?= =?UTF-8?q?=20IntEnumBase=20(#1290)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Describe your changes https://github.com/python/cpython/issues/100458 Quarot require python>=3.11 where the mixin usage of (str, Enum) did not work. This PR is used to create olive strEnum based on python version. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- examples/utils/generator.py | 5 +++-- olive/auto_optimizer/__init__.py | 4 ++-- olive/common/config_utils.py | 5 ++--- olive/common/constants.py | 4 ++-- olive/common/utils.py | 24 +++++++++++++++++++--- olive/constants.py | 6 +++--- olive/data/component/text_generation.py | 4 ++-- olive/data/constants.py | 10 ++++----- olive/engine/packaging/packaging_config.py | 6 +++--- olive/evaluator/metric.py | 10 ++++----- olive/hardware/accelerator.py | 4 ++-- olive/passes/onnx/model_builder.py | 6 +++--- olive/passes/onnx/nvmo_quantization.py | 6 +++--- olive/passes/onnx/onnx_dag.py | 6 +++--- olive/passes/onnx/vitis_ai/quant_utils.py | 5 +++-- olive/passes/openvino/quantization.py | 10 ++++----- olive/passes/pass_config.py | 4 ++-- olive/passes/pytorch/autoawq.py | 4 ++-- olive/passes/pytorch/quarot.py | 4 ++-- olive/platform_sdk/qualcomm/constants.py | 14 ++++++------- olive/strategy/search_parameter.py | 5 ++--- 21 files changed, 82 insertions(+), 64 deletions(-) diff --git a/examples/utils/generator.py b/examples/utils/generator.py index f767a92f9..1f117d423 100644 --- a/examples/utils/generator.py +++ b/examples/utils/generator.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -11,13 +10,15 @@ from onnxruntime import InferenceSession, OrtValue, SessionOptions from transformers import PreTrainedTokenizer +from olive.common.utils import StrEnumBase + if TYPE_CHECKING: from kv_cache_utils import Cache, IOBoundCache from numpy.typing import NDArray from onnx import ValueInfoProto -class AdapterMode(Enum): +class AdapterMode(StrEnumBase): """Enum for adapter modes.""" inputs = "inputs" diff --git a/olive/auto_optimizer/__init__.py b/olive/auto_optimizer/__init__.py index fa4ea072b..43b3c55d2 100644 --- a/olive/auto_optimizer/__init__.py +++ b/olive/auto_optimizer/__init__.py @@ -5,12 +5,12 @@ import logging from copy import deepcopy -from enum import Enum from typing import List, Optional from olive.auto_optimizer.regulate_mixins import RegulatePassConfigMixin from olive.common.config_utils import ConfigBase from olive.common.pydantic_v1 import validator +from olive.common.utils import StrEnumBase from olive.data.config import DataConfig from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import AcceleratorSpec @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class Precision(str, Enum): +class Precision(StrEnumBase): FP32 = "fp32" FP16 = "fp16" INT8 = "int8" diff --git a/olive/common/config_utils.py b/olive/common/config_utils.py index 37f38b76b..994375c96 100644 --- a/olive/common/config_utils.py +++ b/olive/common/config_utils.py @@ -5,7 +5,6 @@ import inspect import json import logging -from enum import Enum from functools import partial from pathlib import Path from types import FunctionType, MethodType @@ -14,7 +13,7 @@ import yaml from olive.common.pydantic_v1 import BaseModel, create_model, root_validator, validator -from olive.common.utils import hash_function, hash_object +from olive.common.utils import StrEnumBase, hash_function, hash_object logger = logging.getLogger(__name__) @@ -212,7 +211,7 @@ def gather_nested_field(cls, values): return values -class CaseInsensitiveEnum(str, Enum): +class CaseInsensitiveEnum(StrEnumBase): """StrEnum class that is insensitive to the case of the input string. Note: Only insensitive when creating the enum object like `CaseInsensitiveEnum("value")`. diff --git a/olive/common/constants.py b/olive/common/constants.py index 2cc9774d8..39e35b373 100644 --- a/olive/common/constants.py +++ b/olive/common/constants.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum +from olive.common.utils import StrEnumBase -class OS(str, Enum): +class OS(StrEnumBase): WINDOWS = "Windows" LINUX = "Linux" diff --git a/olive/common/utils.py b/olive/common/utils.py index e499a2dcb..e2a951446 100644 --- a/olive/common/utils.py +++ b/olive/common/utils.py @@ -13,21 +13,39 @@ import shlex import shutil import subprocess +import sys import tempfile import time from pathlib import Path from typing import Dict, List, Optional, Tuple, Union -from olive.common.constants import OS - logger = logging.getLogger(__name__) +if sys.version_info >= (3, 11): + from enum import IntEnum, StrEnum + + class StrEnumBase(StrEnum): + pass + + class IntEnumBase(IntEnum): + pass + +else: + from enum import Enum + + class StrEnumBase(str, Enum): + pass + + class IntEnumBase(int, Enum): + pass + + def run_subprocess(cmd, env=None, cwd=None, check=False): logger.debug("Running command: %s", cmd) assert isinstance(cmd, (str, list)), f"cmd must be a string or a list, got {type(cmd)}." - windows = platform.system() == OS.WINDOWS + windows = platform.system() == "Windows" if isinstance(cmd, str): # In posix model, the cmd string will be handled with specific posix rules. # https://docs.python.org/3.8/library/shlex.html#parsing-rules diff --git a/olive/constants.py b/olive/constants.py index e3a38f633..ee40c761d 100644 --- a/olive/constants.py +++ b/olive/constants.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum +from olive.common.utils import StrEnumBase -class Framework(str, Enum): +class Framework(StrEnumBase): """Framework of the model.""" ONNX = "ONNX" @@ -16,7 +16,7 @@ class Framework(str, Enum): OPENVINO = "OpenVINO" -class ModelFileFormat(str, Enum): +class ModelFileFormat(StrEnumBase): """Given a framework, there might be 1 or more on-disk model file format(s), model save/Load logic may differ.""" ONNX = "ONNX" diff --git a/olive/data/component/text_generation.py b/olive/data/component/text_generation.py index 2181c13dd..64720648c 100644 --- a/olive/data/component/text_generation.py +++ b/olive/data/component/text_generation.py @@ -3,7 +3,6 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum from pathlib import Path from random import Random from typing import Callable, Dict, List, Optional, Union @@ -14,11 +13,12 @@ from olive.common.config_utils import ConfigBase, validate_config, validate_object from olive.common.pydantic_v1 import validator from olive.common.user_module_loader import UserModuleLoader +from olive.common.utils import StrEnumBase from olive.data.component.dataset import BaseDataset from olive.data.constants import IGNORE_INDEX -class TextGenStrategy(str, Enum): +class TextGenStrategy(StrEnumBase): """Strategy for tokenizing a dataset.""" LINE_BY_LINE = "line-by-line" # each line is a sequence, in order of appearance diff --git a/olive/data/constants.py b/olive/data/constants.py index 35768b992..5ea7b4889 100644 --- a/olive/data/constants.py +++ b/olive/data/constants.py @@ -3,13 +3,13 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum +from olive.common.utils import StrEnumBase # index for targets that should be ignored when computing metrics IGNORE_INDEX = -100 -class DataComponentType(Enum): +class DataComponentType(StrEnumBase): """enumerate for the different types of data components.""" # dataset component type: to load data into memory @@ -22,13 +22,13 @@ class DataComponentType(Enum): DATALOADER = "dataloader" -class DataContainerType(Enum): +class DataContainerType(StrEnumBase): """enumerate for the different types of data containers.""" DATA_CONTAINER = "data_container" -class DefaultDataComponent(Enum): +class DefaultDataComponent(StrEnumBase): """enumerate for the default data components.""" LOAD_DATASET = "default_load_dataset" @@ -37,7 +37,7 @@ class DefaultDataComponent(Enum): DATALOADER = "default_dataloader" -class DefaultDataContainer(Enum): +class DefaultDataContainer(StrEnumBase): """enumerate for the default data containers.""" DATA_CONTAINER = "DataContainer" diff --git a/olive/engine/packaging/packaging_config.py b/olive/engine/packaging/packaging_config.py index 25b4a7158..3c0e6421c 100644 --- a/olive/engine/packaging/packaging_config.py +++ b/olive/engine/packaging/packaging_config.py @@ -2,12 +2,12 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum from typing import Optional, Union from olive.common.config_utils import CaseInsensitiveEnum, ConfigBase, NestedConfig, validate_config from olive.common.constants import BASE_IMAGE from olive.common.pydantic_v1 import validator +from olive.common.utils import StrEnumBase class PackagingType(CaseInsensitiveEnum): @@ -43,7 +43,7 @@ class DockerfilePackagingConfig(CommonPackagingConfig): requirements_file: Optional[str] = None -class InferencingServerType(str, Enum): +class InferencingServerType(StrEnumBase): AzureMLOnline = "AzureMLOnline" AzureMLBatch = "AzureMLBatch" @@ -54,7 +54,7 @@ class InferenceServerConfig(ConfigBase): scoring_script: str -class AzureMLModelModeType(str, Enum): +class AzureMLModelModeType(StrEnumBase): download = "download" copy = "copy" diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 071d6d0d0..dc8bbdeed 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -3,11 +3,11 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from enum import Enum from typing import Any, Dict, List, Optional, Union from olive.common.config_utils import ConfigBase, validate_config from olive.common.pydantic_v1 import validator +from olive.common.utils import StrEnumBase from olive.data.config import DataConfig from olive.evaluator.accuracy import AccuracyBase from olive.evaluator.metric_config import LatencyMetricConfig, MetricGoal, ThroughputMetricConfig, get_user_config_class @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) -class MetricType(str, Enum): +class MetricType(StrEnumBase): # TODO(trajep): support throughput ACCURACY = "accuracy" LATENCY = "latency" @@ -23,7 +23,7 @@ class MetricType(str, Enum): CUSTOM = "custom" -class AccuracySubType(str, Enum): +class AccuracySubType(StrEnumBase): ACCURACY_SCORE = "accuracy_score" F1_SCORE = "f1_score" PRECISION = "precision" @@ -32,7 +32,7 @@ class AccuracySubType(str, Enum): PERPLEXITY = "perplexity" -class LatencySubType(str, Enum): +class LatencySubType(StrEnumBase): # unit: millisecond AVG = "avg" MAX = "max" @@ -45,7 +45,7 @@ class LatencySubType(str, Enum): P999 = "p999" -class ThroughputSubType(str, Enum): +class ThroughputSubType(StrEnumBase): # unit: token per second, tps AVG = "avg" MAX = "max" diff --git a/olive/hardware/accelerator.py b/olive/hardware/accelerator.py index 12fdfcde3..71eaa6c15 100644 --- a/olive/hardware/accelerator.py +++ b/olive/hardware/accelerator.py @@ -4,15 +4,15 @@ # -------------------------------------------------------------------------- import logging from dataclasses import dataclass -from enum import Enum from typing import List, Optional, Union +from olive.common.utils import StrEnumBase from olive.hardware.constants import DEVICE_TO_EXECUTION_PROVIDERS logger = logging.getLogger(__name__) -class Device(str, Enum): +class Device(StrEnumBase): CPU = "cpu" CPU_SPR = "cpu_spr" GPU = "gpu" diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 85abaf437..9dc619b8c 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -9,10 +9,10 @@ import logging import os import tempfile -from enum import Enum from pathlib import Path from typing import Any, Dict, Union +from olive.common.utils import IntEnumBase, StrEnumBase from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import HfModelHandler, ONNXModelHandler from olive.model.utils import resolve_onnx_path @@ -28,7 +28,7 @@ class ModelBuilder(Pass): See https://github.com/microsoft/onnxruntime-genai """ - class Precision(str, Enum): + class Precision(StrEnumBase): FP32 = "fp32" FP16 = "fp16" INT4 = "int4" @@ -36,7 +36,7 @@ class Precision(str, Enum): def __str__(self) -> str: return self.value - class AccuracyLevel(int, Enum): + class AccuracyLevel(IntEnumBase): fp32 = 1 fp16 = 2 bf16 = 3 diff --git a/olive/passes/onnx/nvmo_quantization.py b/olive/passes/onnx/nvmo_quantization.py index 4a8d6d435..4df046d3e 100644 --- a/olive/passes/onnx/nvmo_quantization.py +++ b/olive/passes/onnx/nvmo_quantization.py @@ -4,11 +4,11 @@ # -------------------------------------------------------------------------- import logging from copy import deepcopy -from enum import Enum from pathlib import Path from typing import Any, Dict, Union from olive.common.config_utils import validate_config +from olive.common.utils import StrEnumBase from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec from olive.model import OliveModelHandler @@ -34,7 +34,7 @@ class NVModelOptQuantization(Pass): """Quantize ONNX model with Nvidia-ModelOpt.""" - class Precision(str, Enum): + class Precision(StrEnumBase): FP8 = "fp8" INT8 = "int8" INT4 = "int4" @@ -42,7 +42,7 @@ class Precision(str, Enum): def __str__(self) -> str: return self.value - class Algorithm(str, Enum): + class Algorithm(StrEnumBase): RTN = "RTN" AWQ = "AWQ" diff --git a/olive/passes/onnx/onnx_dag.py b/olive/passes/onnx/onnx_dag.py index f14d9ac18..6c77785df 100644 --- a/olive/passes/onnx/onnx_dag.py +++ b/olive/passes/onnx/onnx_dag.py @@ -4,7 +4,6 @@ # -------------------------------------------------------------------------- import logging from collections import defaultdict -from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Set, Union @@ -13,6 +12,7 @@ from olive.common.config_utils import ConfigBase from olive.common.pydantic_v1 import Field +from olive.common.utils import StrEnumBase if TYPE_CHECKING: from onnx import ModelProto @@ -20,14 +20,14 @@ logger = logging.getLogger(__name__) -class SpecialInput(str, Enum): +class SpecialInput(StrEnumBase): """Special inputs for ONNX nodes.""" INPUT = "__input__" # user input INITIALIZER = "__initializer__" # constant initializer -class SpecialOutput(str, Enum): +class SpecialOutput(StrEnumBase): """Special outputs for ONNX nodes.""" OUTPUT = "__output__" # model output diff --git a/olive/passes/onnx/vitis_ai/quant_utils.py b/olive/passes/onnx/vitis_ai/quant_utils.py index 91b7b1956..c269ff343 100644 --- a/olive/passes/onnx/vitis_ai/quant_utils.py +++ b/olive/passes/onnx/vitis_ai/quant_utils.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT # import sys -from enum import Enum import numpy as np import onnx @@ -11,11 +10,13 @@ from onnxruntime.quantization.quant_utils import get_qmin_qmax_for_qType, quantize_nparray from packaging import version +from olive.common.utils import IntEnumBase + # pylint: skip-file # ruff: noqa -class PowerOfTwoMethod(Enum): +class PowerOfTwoMethod(IntEnumBase): NonOverflow = 0 MinMSE = 1 diff --git a/olive/passes/openvino/quantization.py b/olive/passes/openvino/quantization.py index e8decd5b7..b4bb42d79 100644 --- a/olive/passes/openvino/quantization.py +++ b/olive/passes/openvino/quantization.py @@ -3,11 +3,11 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union from olive.common.config_utils import validate_config +from olive.common.utils import StrEnumBase from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import OliveModelHandler @@ -41,22 +41,22 @@ def _default_validate_func(model: "CompiledModel", validation_loader) -> float: return accuracy_score(predictions, references) -class ModelTypeEnum(str, Enum): +class ModelTypeEnum(StrEnumBase): TRANSFORMER = "TRANSFORMER" -class PresetEnum(str, Enum): +class PresetEnum(StrEnumBase): PERFORMANCE = "PERFORMANCE" MIXED = "MIXED" -class IgnoreScopeTypeEnum(str, Enum): +class IgnoreScopeTypeEnum(StrEnumBase): NAMES = "names" TYPES = "types" PATTERNS = "patterns" -class DropTypeEnum(str, Enum): +class DropTypeEnum(StrEnumBase): ABSOLUTE = "ABSOLUTE" RELATIVE = "RELATIVE" diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 7fd2bbc8e..895b7606c 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -2,17 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum from pathlib import Path from typing import Callable, Dict, List, Optional, Type, Union from olive.common.config_utils import ConfigBase, ConfigParam, ParamCategory, validate_object from olive.common.pydantic_v1 import Field, create_model, validator +from olive.common.utils import StrEnumBase from olive.resource_path import validate_resource_path from olive.strategy.search_parameter import SearchParameter, SpecialParamValue, json_to_search_parameter -class PassParamDefault(str, Enum): +class PassParamDefault(StrEnumBase): """Default values for passes.""" DEFAULT_VALUE = "DEFAULT_VALUE" diff --git a/olive/passes/pytorch/autoawq.py b/olive/passes/pytorch/autoawq.py index 1faa3ff80..6de8ff913 100644 --- a/olive/passes/pytorch/autoawq.py +++ b/olive/passes/pytorch/autoawq.py @@ -3,11 +3,11 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from enum import Enum from typing import Any, Dict, Union import torch +from olive.common.utils import StrEnumBase from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import HfModelHandler, PyTorchModelHandler @@ -24,7 +24,7 @@ class AutoAWQQuantizer(Pass): _requires_user_script = True - class ModelDtype(str, Enum): + class ModelDtype(StrEnumBase): # input model's data type, we can assume the model is all float type # sometime, the model is in double type, but we can convert it to float type # before quantization diff --git a/olive/passes/pytorch/quarot.py b/olive/passes/pytorch/quarot.py index 19d420247..14a6f295f 100644 --- a/olive/passes/pytorch/quarot.py +++ b/olive/passes/pytorch/quarot.py @@ -5,7 +5,6 @@ # ------------------------------------------------------------------------- import logging import sys -from enum import Enum from typing import Any, Dict, Union import torch @@ -13,6 +12,7 @@ from torch.utils.data import DataLoader, SubsetRandomSampler from olive.common.config_utils import validate_config +from olive.common.utils import StrEnumBase from olive.constants import ModelFileFormat from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec @@ -32,7 +32,7 @@ class QuaRot(Pass): This pass only supports HfModelHandler. """ - class ModelDtype(str, Enum): + class ModelDtype(StrEnumBase): # input model's data type, we can assume the model is all float type # sometime, the model is in double type, but we can convert it to float type # before quantization diff --git a/olive/platform_sdk/qualcomm/constants.py b/olive/platform_sdk/qualcomm/constants.py index 4385702cb..f9bb06175 100644 --- a/olive/platform_sdk/qualcomm/constants.py +++ b/olive/platform_sdk/qualcomm/constants.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from enum import Enum +from olive.common.utils import StrEnumBase -class SDKTargetDevice(str, Enum): +class SDKTargetDevice(StrEnumBase): x86_64_linux = "x86_64-linux-clang" x86_64_windows = "x86_64-windows-msvc" # evaluation only @@ -14,20 +14,20 @@ class SDKTargetDevice(str, Enum): aarch64_android = "aarch64-android" -class SNPEDevice(str, Enum): +class SNPEDevice(StrEnumBase): CPU = "cpu" GPU = "gpu" DSP = "dsp" AIP = "aip" -class InputType(str, Enum): +class InputType(StrEnumBase): DEFAULT = "default" IMAGE = "image" OPAQUE = "opaque" -class InputLayout(str, Enum): +class InputLayout(StrEnumBase): NCDHW = "NCDHW" NDHWC = "NDHWC" NCHW = "NCHW" @@ -42,7 +42,7 @@ class InputLayout(str, Enum): NONTRIVIAL = "NONTRIVIAL" -class PerfProfile(str, Enum): +class PerfProfile(StrEnumBase): SYSTEM_SETTINGS = "system_settings" POWER_SAVER = "power_saver" BALANCED = "balanced" @@ -51,7 +51,7 @@ class PerfProfile(str, Enum): BURST = "burst" -class ProfilingLevel(str, Enum): +class ProfilingLevel(StrEnumBase): OFF = "off" BASIC = "basic" MODERATE = "moderate" diff --git a/olive/strategy/search_parameter.py b/olive/strategy/search_parameter.py index eca2d9f38..e1adfc50b 100644 --- a/olive/strategy/search_parameter.py +++ b/olive/strategy/search_parameter.py @@ -3,10 +3,9 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, List, Tuple, Union -from olive.common.utils import flatten_dict, unflatten_dict +from olive.common.utils import StrEnumBase, flatten_dict, unflatten_dict class SearchParameter(ABC): @@ -33,7 +32,7 @@ def to_json(self): raise NotImplementedError -class SpecialParamValue(str, Enum): +class SpecialParamValue(StrEnumBase): """Special values for parameters. IGNORED: the parameter gets the value "OLIVE_IGNORED_PARAM_VALUE". The pass might ignore this parameter.