Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pipelines user_agent, ignore CI requests #1058

Merged
merged 4 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ concurrency:
cancel-in-progress: true

env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 60
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- main

env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
Expand Down
46 changes: 42 additions & 4 deletions src/diffusers/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,25 @@

import os
import shutil
import sys
from pathlib import Path
from typing import Optional
from typing import Dict, Optional, Union
from uuid import uuid4

from huggingface_hub import HfFolder, Repository, whoami

from .pipeline_utils import DiffusionPipeline
from .utils import deprecate, is_modelcards_available, logging
from . import __version__
from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
from .utils.import_utils import (
_flax_version,
_jax_version,
_onnxruntime_version,
_torch_version,
is_flax_available,
is_modelcards_available,
is_onnx_available,
is_torch_available,
)


if is_modelcards_available():
Expand All @@ -33,6 +45,32 @@


MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES


def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how it's done in transformers, now here too :)

if DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_flax_available():
ua += f"; jax/{_jax_version}"
ua += f"; flax/{_flax_version}"
if is_onnx_available():
ua += f"; onnxruntime/{_onnxruntime_version}"
# CI will set this value to True
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
Expand Down Expand Up @@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False):

def push_to_hub(
args,
pipeline: DiffusionPipeline,
pipeline,
repo: Repository,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tqdm.auto import tqdm

from .configuration_utils import ConfigMixin
from .hub_utils import http_user_agent
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
Expand Down Expand Up @@ -301,6 +302,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]

if cls != FlaxDiffusionPipeline:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flax pipelines now count too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)

# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
Expand All @@ -311,6 +319,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
user_agent=user_agent,
)
else:
cached_folder = pretrained_model_name_or_path
Expand Down
10 changes: 7 additions & 3 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from PIL import Image
from tqdm.auto import tqdm

from . import __version__
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
CONFIG_NAME,
Expand Down Expand Up @@ -398,10 +398,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]

requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class}
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)

# download all allow_patterns
cached_folder = snapshot_download(
Expand Down
4 changes: 3 additions & 1 deletion src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@
logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False


_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available:
Expand Down Expand Up @@ -136,6 +137,7 @@
_modelcards_available = False


_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available:
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
Expand Down