Skip to content

Commit

Permalink
remove one unit test due to memory limit
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Jul 13, 2023
1 parent acbe99b commit 3b5e576
Showing 1 changed file with 19 additions and 34 deletions.
53 changes: 19 additions & 34 deletions tests/test_models/test_editors/test_controlnet/test_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def test_init_weights(self):

def test_infer(self):
control_sd = self.control_sd
control = torch.ones([1, 3, 64, 64])

def mock_encode_prompt(prompt, do_classifier_free_guidance,
num_images_per_prompt, *args, **kwargs):
Expand All @@ -105,49 +104,35 @@ def mock_encode_prompt(prompt, do_classifier_free_guidance,
encode_prompt = control_sd._encode_prompt
control_sd._encode_prompt = mock_encode_prompt

prompt = 'an insect robot preparing a delicious meal'

# one prompt, one control, repeat 1 time
result = control_sd.infer(
prompt,
control=control,
height=64,
width=64,
num_inference_steps=1,
return_type='numpy')
assert result['samples'].shape == (1, 3, 64, 64)
self._test_infer(control_sd, 1, 1, 1, 1)

# two prompt, two control, repeat 1 time
result = control_sd.infer([prompt, prompt],
control=control,
height=64,
width=64,
num_inference_steps=1,
return_type='numpy')
assert result['samples'].shape == (2, 3, 64, 64)
self._test_infer(control_sd, 2, 2, 1, 2)

# one prompt, one control, repeat 2 times
self._test_infer(control_sd, 1, 1, 2, 2)

# two prompt, two control, repeat 2 times
# NOTE: skip this due to memory limit
# self._test_infer(control_sd, 2, 2, 2, 4)

control_sd._encode_prompt = encode_prompt

def _test_infer(self, control_sd, num_prompt, num_control, num_repeat,
tar_shape):
prompt = ''
control = torch.ones([1, 3, 64, 64])

result = control_sd.infer(
prompt,
control=control,
[prompt] * num_prompt,
control=[control] * num_control,
height=64,
width=64,
num_images_per_prompt=2,
num_images_per_prompt=num_repeat,
num_inference_steps=1,
return_type='numpy')
assert result['samples'].shape == (2, 3, 64, 64)

# two prompt, two control, repeat 2 times
result = control_sd.infer([prompt, prompt],
control=[control, control],
height=64,
width=64,
num_images_per_prompt=2,
num_inference_steps=1,
return_type='numpy')
assert result['samples'].shape == (4, 3, 64, 64)

control_sd._encode_prompt = encode_prompt
assert result['samples'].shape == (tar_shape, 3, 64, 64)

def test_val_step(self):
control_sd = self.control_sd
Expand Down

0 comments on commit 3b5e576

Please sign in to comment.