diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 81c75fecec6e..75ab6c7c5f2a 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -10,6 +10,7 @@ concurrency: cancel-in-progress: true env: + DIFFUSERS_IS_CI: yes OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 PYTEST_TIMEOUT: 60 diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index dfd83aa9af46..c3993563d546 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -6,6 +6,7 @@ on: - main env: + DIFFUSERS_IS_CI: yes HF_HOME: /mnt/cache OMP_NUM_THREADS: 8 MKL_NUM_THREADS: 8 diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index 1f8cc0db0fc9..8bf0933a1dce 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -16,13 +16,25 @@ import os import shutil +import sys from pathlib import Path -from typing import Optional +from typing import Dict, Optional, Union +from uuid import uuid4 from huggingface_hub import HfFolder, Repository, whoami -from .pipeline_utils import DiffusionPipeline -from .utils import deprecate, is_modelcards_available, logging +from . import __version__ +from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging +from .utils.import_utils import ( + _flax_version, + _jax_version, + _onnxruntime_version, + _torch_version, + is_flax_available, + is_modelcards_available, + is_onnx_available, + is_torch_available, +) if is_modelcards_available(): @@ -33,6 +45,32 @@ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" +SESSION_ID = uuid4().hex +DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES + + +def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: + """ + Formats a user-agent string with basic info about a request. + """ + ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}" + if DISABLE_TELEMETRY: + return ua + "; telemetry/off" + if is_torch_available(): + ua += f"; torch/{_torch_version}" + if is_flax_available(): + ua += f"; jax/{_jax_version}" + ua += f"; flax/{_flax_version}" + if is_onnx_available(): + ua += f"; onnxruntime/{_onnxruntime_version}" + # CI will set this value to True + if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES: + ua += "; is_ci/true" + if isinstance(user_agent, dict): + ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) + elif isinstance(user_agent, str): + ua += "; " + user_agent + return ua def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): @@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False): def push_to_hub( args, - pipeline: DiffusionPipeline, + pipeline, repo: Repository, commit_message: Optional[str] = "End of training", blocking: bool = True, diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index e96c0c7467f3..f0330531e17f 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -29,6 +29,7 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin +from .hub_utils import http_user_agent from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging @@ -301,6 +302,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + if cls != FlaxDiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"pipeline_class": requested_pipeline_class} + user_agent = http_user_agent(user_agent) + # download all allow_patterns cached_folder = snapshot_download( pretrained_model_name_or_path, @@ -311,6 +319,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, + user_agent=user_agent, ) else: cached_folder = pretrained_model_name_or_path diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index c0a44363a2f9..92623f9f9245 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -30,9 +30,9 @@ from PIL import Image from tqdm.auto import tqdm -from . import __version__ from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module +from .hub_utils import http_user_agent from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .utils import ( CONFIG_NAME, @@ -398,10 +398,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] - requested_pipeline_class = config_dict.get("_class_name", cls.__name__) - user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class} + if cls != DiffusionPipeline: + requested_pipeline_class = cls.__name__ + else: + requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + user_agent = {"pipeline_class": requested_pipeline_class} if custom_pipeline is not None: user_agent["custom_pipeline"] = custom_pipeline + user_agent = http_user_agent(user_agent) # download all allow_patterns cached_folder = snapshot_download( diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b2aabee70c92..2a5f7f64dd07 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -90,7 +90,8 @@ logger.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False - +_jax_version = "N/A" +_flax_version = "N/A" if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None if _flax_available: @@ -136,6 +137,7 @@ _modelcards_available = False +_onnxruntime_version = "N/A" _onnx_available = importlib.util.find_spec("onnxruntime") is not None if _onnx_available: candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")