diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 80555f278505..c281c772dbd2 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -302,10 +302,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] + # make sure we don't download PyTorch weights + ignore_patterns = "*.bin" + if cls != FlaxDiffusionPipeline: requested_pipeline_class = cls.__name__ else: requested_pipeline_class = config_dict.get("_class_name", cls.__name__) + requested_pipeline_class = ( + requested_pipeline_class + if requested_pipeline_class.startswith("Flax") + else "Flax" + requested_pipeline_class + ) + user_agent = {"pipeline_class": requested_pipeline_class} user_agent = http_user_agent(user_agent) @@ -319,6 +328,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: @@ -337,7 +347,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if config_dict["_class_name"].startswith("Flax") else "Flax" + config_dict["_class_name"] ) - pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + pipeline_class = getattr(diffusers_module, class_name) # some modules can be passed directly to the init # in this case they are already instantiated in `kwargs` diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5c94df25cca0..94c1e135abe5 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -395,6 +395,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name] + # make sure we don't download flax weights + ignore_patterns = "*.msgpack" + if custom_pipeline is not None: allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] @@ -417,6 +420,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_auth_token=use_auth_token, revision=revision, allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, user_agent=user_agent, ) else: diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index c11287339a5b..1654518f1e93 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -73,6 +73,22 @@ def test_progress_bar(capsys): assert captured.err == "", "Progress bar should be disabled" +class DownloadTests(unittest.TestCase): + def test_download_only_pytorch(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a flax file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack + assert not any(f.endswith(".msgpack") for f in files) + + class CustomPipelineTests(unittest.TestCase): def test_load_custom_pipeline(self): pipeline = DiffusionPipeline.from_pretrained( diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index ae52fa689bef..ac5e2621a514 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile import unittest import numpy as np @@ -24,12 +26,29 @@ if is_flax_available(): import jax import jax.numpy as jnp - from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline + from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline from flax.jax_utils import replicate from flax.training.common_utils import shard from jax import pmap +@require_flax +class DownloadTests(unittest.TestCase): + def test_download_only_pytorch(self): + with tempfile.TemporaryDirectory() as tmpdirname: + # pipeline has Flax weights + _ = FlaxDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))] + files = [item for sublist in all_root_files for item in sublist] + + # None of the downloaded files should be a PyTorch file even if we have some here: + # https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin + assert not any(f.endswith(".bin") for f in files) + + @slow @require_flax class FlaxPipelineTests(unittest.TestCase):