Skip to content

Commit

Permalink
[Loading] Ignore unneeded files (#1107)
Browse files Browse the repository at this point in the history
* [Loading] Ignore unneeded files

* up
  • Loading branch information
patrickvonplaten authored Nov 2, 2022
1 parent cbcd051 commit c39a511
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]

# make sure we don't download PyTorch weights
ignore_patterns = "*.bin"

if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
requested_pipeline_class = (
requested_pipeline_class
if requested_pipeline_class.startswith("Flax")
else "Flax" + requested_pipeline_class
)

user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)

Expand All @@ -319,6 +328,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
Expand All @@ -337,7 +347,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if config_dict["_class_name"].startswith("Flax")
else "Flax" + config_dict["_class_name"]
)
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
pipeline_class = getattr(diffusers_module, class_name)

# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]

# make sure we don't download flax weights
ignore_patterns = "*.msgpack"

if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]

Expand All @@ -417,6 +420,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ def test_progress_bar(capsys):
assert captured.err == "", "Progress bar should be disabled"


class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]

# None of the downloaded files should be a flax file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_flax_model.msgpack
assert not any(f.endswith(".msgpack") for f in files)


class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):
pipeline = DiffusionPipeline.from_pretrained(
Expand Down
21 changes: 20 additions & 1 deletion tests/test_pipelines_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import numpy as np
Expand All @@ -24,12 +26,29 @@
if is_flax_available():
import jax
import jax.numpy as jnp
from diffusers import FlaxDDIMScheduler, FlaxStableDiffusionPipeline
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap


@require_flax
class DownloadTests(unittest.TestCase):
def test_download_only_pytorch(self):
with tempfile.TemporaryDirectory() as tmpdirname:
# pipeline has Flax weights
_ = FlaxDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None, cache_dir=tmpdirname
)

all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname, os.listdir(tmpdirname)[0], "snapshots"))]
files = [item for sublist in all_root_files for item in sublist]

# None of the downloaded files should be a PyTorch file even if we have some here:
# https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe/blob/main/unet/diffusion_pytorch_model.bin
assert not any(f.endswith(".bin") for f in files)


@slow
@require_flax
class FlaxPipelineTests(unittest.TestCase):
Expand Down

0 comments on commit c39a511

Please sign in to comment.