Skip to content

Commit

Permalink
tenstorrent#17134: Fix SD cross attn upblock unit test (tenstorrent#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jovic authored Feb 6, 2025
1 parent 739362e commit 2def718
Showing 1 changed file with 83 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,18 @@

# SPDX-License-Identifier: Apache-2.0

import torch
from diffusers import StableDiffusionPipeline
from loguru import logger
import ttnn
import pytest
from torch import nn

from models.utility_functions import tt_to_torch_tensor, torch_random
from tests.ttnn.utils_for_testing import assert_with_pcc
import torch
import ttnn

from models.utility_functions import (
skip_for_grayskull,
)
from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_cross_attn_upblock_new_conv import (
cross_attention_upblock2d,
)

from models.demos.wormhole.stable_diffusion.custom_preprocessing import custom_preprocessor

from models.utility_functions import skip_for_grayskull, torch_random
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.wormhole.stable_diffusion.tt.ttnn_functional_utility_functions import (
pre_process_input,
weight_to_bfp8,
post_process_output,
)
from tests.ttnn.utils_for_testing import assert_with_pcc


def ttnn_to_torch(input):
Expand All @@ -36,15 +23,59 @@ def ttnn_to_torch(input):
return input


def prepare_input_and_push_to_device(input, device, memory_config):
input = torch.permute(input, (0, 2, 3, 1))
input = torch.reshape(
input,
(
1,
1,
input.shape[0] * input.shape[1] * input.shape[2],
input.shape[3],
),
)

input = ttnn.from_torch(input, ttnn.bfloat16)
input = ttnn.to_layout(input, ttnn.TILE_LAYOUT)
input = ttnn.to_dtype(input, ttnn.bfloat8_b)
return ttnn.to_device(input, device, memory_config=memory_config)


@skip_for_grayskull()
@pytest.mark.skip(reason="#9599: Tests are failing.")
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize(
"hidden_states, res_hidden_states_tuple, index, prev_output_channel, in_channels ,out_channels",
"hidden_states, res_hidden_states_tuple, index, prev_output_channel, in_channels, out_channels, shard_end_core, shard_shape",
[
((2, 1280, 16, 16), ([2, 640, 16, 16], [2, 1280, 16, 16], [2, 1280, 16, 16]), 1, 1280, 640, 1280),
((2, 1280, 32, 32), ([2, 320, 32, 32], [2, 640, 32, 32], [2, 640, 32, 32]), 2, 1280, 320, 640),
((2, 640, 64, 64), ([2, 320, 64, 64], [2, 320, 64, 64], [2, 320, 64, 64]), 3, 640, 320, 320),
(
(2, 1280, 16, 16),
([2, 640, 16, 16], [2, 1280, 16, 16], [2, 1280, 16, 16]),
1,
1280,
640,
1280,
(7, 3),
[128, 160],
),
(
(2, 1280, 32, 32),
([2, 320, 32, 32], [2, 640, 32, 32], [2, 640, 32, 32]),
2,
1280,
320,
640,
(7, 7),
[256, 160],
),
(
(2, 640, 64, 64),
([2, 320, 64, 64], [2, 320, 64, 64], [2, 320, 64, 64]),
3,
640,
320,
320,
(4, 7),
[1024, 128],
),
],
)
@pytest.mark.parametrize("temb", [[1, 1, 2, 1280]])
Expand All @@ -66,14 +97,15 @@ def test_cross_attn_up_block_2d_512x512(
prev_output_channel,
in_channels,
out_channels,
shard_end_core,
shard_shape,
):
# TODO
# setup pytorch model
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32)
unet = pipe.unet
unet.eval()
config = unet.config
state_dict = unet.state_dict()
unet_upblock = pipe.unet.up_blocks[index]

parameters = preprocess_model_parameters(
Expand Down Expand Up @@ -122,36 +154,40 @@ def test_cross_attn_up_block_2d_512x512(
cross_attention_kwargs = (None,)
return_dict = True
num_layers_transformer = 1
norm_num_groups = 32
cross_attention_dim = 768
attention_bias = False
sample_size = None
num_vector_embeds = None
patch_size = None
activation_fn = "geglu"
num_embeds_ada_norm = None
use_linear_projection = False
only_cross_attention = False
upcast_attention = False
norm_type = "layer_norm"
norm_elementwise_affine = True
attn_num_head_channels = 8

hidden_state = ttnn.from_torch(hidden_state, ttnn.bfloat16)
hidden_state = ttnn.to_layout(hidden_state, ttnn.TILE_LAYOUT)
hidden_state = ttnn.to_device(hidden_state, device, memory_config=ttnn.L1_MEMORY_CONFIG)

res0 = ttnn.from_torch(res0, ttnn.bfloat16)
res0 = ttnn.to_layout(res0, ttnn.TILE_LAYOUT)
res0 = ttnn.to_device(res0, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)

res1 = ttnn.from_torch(res1, ttnn.bfloat16)
res1 = ttnn.to_layout(res1, ttnn.TILE_LAYOUT)
res1 = ttnn.to_device(res1, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
hidden_state = prepare_input_and_push_to_device(
hidden_state,
device,
ttnn.MemoryConfig(
ttnn.TensorMemoryLayout.BLOCK_SHARDED,
ttnn.BufferType.L1,
ttnn.ShardSpec(
ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(shard_end_core[0], shard_end_core[1]),
),
}
),
shard_shape,
ttnn.ShardOrientation.ROW_MAJOR,
),
),
)

res2 = ttnn.from_torch(res2, ttnn.bfloat16)
res2 = ttnn.to_layout(res2, ttnn.TILE_LAYOUT)
res2 = ttnn.to_device(res2, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
res0 = prepare_input_and_push_to_device(res0, device, ttnn.DRAM_MEMORY_CONFIG)
res1 = prepare_input_and_push_to_device(res1, device, ttnn.DRAM_MEMORY_CONFIG)
res2 = prepare_input_and_push_to_device(res2, device, ttnn.DRAM_MEMORY_CONFIG)
res_hidden_states_tuple = (res0, res1, res2)

temb = temb.permute(2, 0, 1, 3) # pre-permute temb
temb = ttnn.from_torch(temb, ttnn.bfloat16)
Expand All @@ -166,12 +202,7 @@ def test_cross_attn_up_block_2d_512x512(
add_upsample = True
if index == 3:
add_upsample = False
hidden_state = weight_to_bfp8(pre_process_input(device, hidden_state))
res_hidden_states_tuple = (
weight_to_bfp8(pre_process_input(device, res0)),
weight_to_bfp8(pre_process_input(device, res1)),
weight_to_bfp8(pre_process_input(device, res2)),
)

op = model(
hidden_state,
res_hidden_states_tuple,
Expand All @@ -180,7 +211,7 @@ def test_cross_attn_up_block_2d_512x512(
out_channels,
temb_channels,
num_layers=3,
resnet_eps=1e-6,
resnet_eps=1e-5,
resnet_time_scale_shift="default",
resnet_act_fn="silu",
resnet_groups=32,
Expand Down Expand Up @@ -214,4 +245,4 @@ def test_cross_attn_up_block_2d_512x512(
op = torch.reshape(op, (N, H * 2, W * 2, Cout))
op = op.permute(0, 3, 1, 2)

assert_with_pcc(torch_output, op, 0.92)
assert_with_pcc(torch_output, op, 0.91)

0 comments on commit 2def718

Please sign in to comment.