diff --git a/mmagic/models/editors/consistency_models/consistencymodel.py b/mmagic/models/editors/consistency_models/consistencymodel.py index eb31aefa1..94b407d41 100644 --- a/mmagic/models/editors/consistency_models/consistencymodel.py +++ b/mmagic/models/editors/consistency_models/consistencymodel.py @@ -13,7 +13,7 @@ from mmagic.registry import MODELS from mmagic.structures import DataSample from mmagic.utils import ForwardInputs -from .consistencymodel_utils import (device, get_generator, get_sample_fn, +from .consistencymodel_utils import (get_generator, get_sample_fn, get_sigmas_karras, karras_sample) ModelType = Union[Dict, nn.Module] @@ -56,6 +56,9 @@ def __init__(self, super().__init__(data_preprocessor=data_preprocessor) self.num_classes = num_classes + self.device = torch.device('cpu') + if torch.cuda.is_available(): + self.device = torch.device('cuda') if 'consistency' in training_mode: self.distillation = True else: @@ -117,7 +120,7 @@ def __init__(self, self.model.load_state_dict( torch.load(model_path, map_location='cpu')) - self.model.to(device()) + self.model.to(self.device) if sampler == 'multistep': assert len(ts) > 0 @@ -147,7 +150,7 @@ def infer(self, class_id: Optional[int] = None): (self.batch_size, 3, self.image_size, self.image_size), steps=self.steps, model_kwargs=self.model_kwargs, - device=device(), + device=self.device, clip_denoised=self.clip_denoised, sampler=self.sampler, sigma_min=self.sigma_min, @@ -216,18 +219,18 @@ def forward(self, self.sigma_min, self.sigma_max, self.diffusion.rho, - device=device()) + device=self.device) else: sigmas = get_sigmas_karras( self.steps, self.sigma_min, self.sigma_max, self.diffusion.rho, - device=device()) + device=self.device) noise = self.generator.randn( *(self.batch_size, 3, self.image_size, self.image_size), - device=device()) * self.sigma_max + device=self.device) * self.sigma_max sample_fn = get_sample_fn(self.sampler) @@ -291,11 +294,11 @@ def label_fn(self, class_id): 'it should be within the range (0,num_classes).' classes = torch.tensor( [int(class_id) for i in range(self.batch_size)], - device=device()) + device=self.device) else: classes = torch.randint( low=0, high=self.num_classes, size=(self.batch_size, ), - device=device()) + device=self.device) return classes diff --git a/mmagic/models/editors/consistency_models/consistencymodel_utils.py b/mmagic/models/editors/consistency_models/consistencymodel_utils.py index 73354f8e6..ba9e2e567 100644 --- a/mmagic/models/editors/consistency_models/consistencymodel_utils.py +++ b/mmagic/models/editors/consistency_models/consistencymodel_utils.py @@ -7,30 +7,6 @@ import torch.nn as nn -def device(): - """return torch.device.""" - if torch.cuda.is_available(): - return torch.device('cuda') - return torch.device('cpu') - - -def get_weightings(weight_schedule, snrs, sigma_data): - """return weightings.""" - if weight_schedule == 'snr': - weightings = snrs - elif weight_schedule == 'snr+1': - weightings = snrs + 1 - elif weight_schedule == 'karras': - weightings = snrs + 1.0 / sigma_data**2 - elif weight_schedule == 'truncated-snr': - weightings = torch.clamp(snrs, min=1.0) - elif weight_schedule == 'uniform': - weightings = torch.ones_like(snrs) - else: - raise NotImplementedError() - return weightings - - class SiLU(nn.Module): """PyTorch 1.7 has SiLU, but we support PyTorch 1.5.""" @@ -654,7 +630,7 @@ def __init__(self, num_samples, seed=0): self.seed = seed self.rng_cpu = torch.Generator() if torch.cuda.is_available(): - self.rng_cuda = torch.Generator(device()) + self.rng_cuda = torch.Generator(torch.device('cuda')) self.set_seed(seed) def get_global_size_and_indices(self, size): @@ -737,7 +713,8 @@ def __init__(self, num_samples, seed=0): self.rng_cpu = [torch.Generator() for _ in range(num_samples)] if torch.cuda.is_available(): self.rng_cuda = [ - torch.Generator(device()) for _ in range(num_samples) + torch.Generator(torch.device('cuda')) + for _ in range(num_samples) ] self.set_seed(seed) diff --git a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py index e704e8093..408c6d980 100644 --- a/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py +++ b/tests/test_models/test_editors/test_consistency_models/test_consistency_models.py @@ -82,17 +82,20 @@ @pytest.mark.skipif( 'win' in platform.system().lower(), reason='skip on windows due to limited RAM.') -class TestDeblurGanV2(TestCase): +class TestConsistencyModels(TestCase): def test_init(self): model = ConsistencyModel( unet=unet_config, denoiser=denoiser_config, data_preprocessor=DataPreprocessor()) + if torch.cuda.is_available(): + self.assertIsInstance(model.device, torch.device('cuda')) self.assertIsInstance(model, ConsistencyModel) self.assertIsInstance(model.data_preprocessor, DataPreprocessor) self.assertIsInstance(model.model, ConsistencyUNetModel) self.assertIsInstance(model.diffusion, KarrasDenoiser) + unet_cfg = deepcopy(unet_config) diffuse_cfg = deepcopy(denoiser_config) unet = MODELS.build(unet_cfg) @@ -114,6 +117,11 @@ def test_onestep_infer(self): for datasample in result: assert datasample.fake_img.shape == (3, model.image_size, model.image_size) + result, labels = model.infer() + assert len(result) == model.batch_size + assert len(labels) == model.batch_size + for datasample in result: + assert datasample.shape == (model.image_size, model.image_size, 3) def test_multistep_infer(self): model = MODELS.build(config_multistep) @@ -127,6 +135,11 @@ def test_multistep_infer(self): for datasample in result: assert datasample.fake_img.shape == (3, model.image_size, model.image_size) + result, labels = model.infer() + assert len(result) == model.batch_size + assert len(labels) == model.batch_size + for datasample in result: + assert datasample.shape == (model.image_size, model.image_size, 3) def teardown_module():