Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#17134: Enable test_transformer_2d_model UT in SD #17534

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,118 +22,6 @@


@skip_for_grayskull()
@pytest.mark.parametrize(
"input_shape, index1, index2, attention_head_dim, block",
[
(
(2, 320, 32, 32),
3,
2,
40,
"up",
),
(
(2, 640, 16, 16),
1,
1,
80,
"down",
),
(
(2, 1280, 4, 4),
2,
1,
160,
"down",
),
(
(2, 1280, 8, 8),
2,
1,
160,
"down",
),
],
)
@pytest.mark.parametrize("model_name", ["CompVis/stable-diffusion-v1-4"])
def test_transformer_2d_model_256x256(
input_shape, index1, index2, block, attention_head_dim, model_name, device, reset_seeds
):
pytest.skip()
encoder_hidden_states = [1, 2, 77, 768]
timestep = (None,)
class_labels = (None,)
cross_attention_kwargs = (None,)
return_dict = True

num_layers = 1
num_attention_heads = 8
norm_num_groups = 32
norm_type = "layer_norm"
cross_attention_dim = 768
upcast_attention = False

_, in_channels, _, _ = input_shape

input = torch.randn(input_shape) * 0.01
encoder_hidden_states = torch.randn(encoder_hidden_states)

pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float32)
unet = pipe.unet
unet.eval()
config = unet.config
transformer = pipe.unet.mid_block.attentions[0]

parameters = preprocess_model_parameters(
model_name=model_name, initialize_model=lambda: unet, custom_preprocessor=custom_preprocessor, device=device
)

if block == "up":
parameters = parameters.up_blocks[index1].attentions[index2]
transformer = pipe.unet.up_blocks[index1].attentions[index2]
elif block == "down":
parameters = parameters.down_blocks[index1].attentions[index2]
transformer = pipe.unet.down_blocks[index1].attentions[index2]
elif block == "mid":
parameters = parameters.mid_block.attentions[0]
transformer = pipe.unet.mid_block.attentions[0]

torch_output = transformer(input, encoder_hidden_states.squeeze(0)).sample

ttnn_hidden_state = ttnn.from_torch(input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)

ttnn_encoder_hidden_states = ttnn.from_torch(
encoder_hidden_states, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device
)

ttnn_transformer = transformer_2d_model(
hidden_states=ttnn_hidden_state,
parameters=parameters,
config=config,
encoder_hidden_states=ttnn_encoder_hidden_states,
timestep=timestep,
class_labels=class_labels,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=return_dict,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
out_channels=in_channels,
num_layers=num_layers,
norm_num_groups=norm_num_groups,
norm_type=norm_type,
device=device,
cross_attention_dim=cross_attention_dim,
upcast_attention=upcast_attention,
)

ttnn_output_torch = ttnn.to_torch(ttnn.to_layout(ttnn.from_device(ttnn_transformer), layout=ttnn.ROW_MAJOR_LAYOUT))

assert_with_pcc(torch_output, ttnn_output_torch, 0.99)


@skip_for_grayskull()
@pytest.mark.skip(reason="#9599: Tests are failing.")
@pytest.mark.parametrize(
"input_shape, index1, index2, attention_head_dim, block ",
[
Expand Down Expand Up @@ -227,9 +115,19 @@ def test_transformer_2d_model_512x512(
packer_l1_acc=False,
)
model = transformer_2d_model(
device, parameters, {}, input_shape[0], input_shape[2], input_shape[3], compute_kernel_config
device, parameters, input_shape[0], input_shape[2], input_shape[3], compute_kernel_config
)
ttnn_hidden_state = pre_process_input(model.device, ttnn_hidden_state)
ttnn_hidden_state = ttnn.reshape(
ttnn_hidden_state,
(
1,
1,
ttnn_hidden_state.shape[0] * ttnn_hidden_state.shape[1] * ttnn_hidden_state.shape[2],
ttnn_hidden_state.shape[-1],
),
)

output = model(
hidden_states=ttnn_hidden_state,
config=config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __init__(
for i, resnet in enumerate(parameters.resnets)
]
self.attentions = [
transformer_2d_model(
device, attention, reader_patterns_cache, batch_size, input_height, input_width, compute_kernel_config
)
transformer_2d_model(device, attention, batch_size, input_height, input_width, compute_kernel_config)
for attention in parameters.attentions
]
self.downsample_2d = downsample_2d(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def __init__(
for resnet in parameters.resnets
]
self.attentions = [
transformer_2d_model(
device, attention, reader_patterns_cache, batch_size, input_height, input_width, compute_kernel_config
)
transformer_2d_model(device, attention, batch_size, input_height, input_width, compute_kernel_config)
for attention in parameters.attentions
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def ttnn_to_torch(input):


class transformer_2d_model:
def __init__(
self, device, parameters, reader_patterns_cache, batch_size, input_height, input_width, compute_kernel_config
):
def __init__(self, device, parameters, batch_size, input_height, input_width, compute_kernel_config):
self.device = device
self.compute_kernel_config = compute_kernel_config
parameters.proj_in.weight, parameters.proj_in.bias = permute_conv_parameters(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def __init__(
for resnet in parameters.resnets
]
self.attentions = [
transformer_2d_model(
device, attention, reader_patterns_cache, batch_size, input_height, input_width, compute_kernel_config
)
transformer_2d_model(device, attention, batch_size, input_height, input_width, compute_kernel_config)
for attention in parameters.attentions
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,7 @@ def is_tile_dim_alligned(dim):


def pre_process_input(device, tensor):
batch_size = tensor.shape[0]
input_channels = tensor.shape[1]
input_height = tensor.shape[2]
input_width = tensor.shape[3]
tensor = fallback_ops.permute(tensor, (0, 2, 3, 1), output_layout=ttnn.ROW_MAJOR_LAYOUT, output_on_device=False)
tensor = ttnn.to_device(tensor, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
return tensor
import math

assert input_channels == tensor.padded_shape[3]
padded_input_channels = math.ceil(input_channels / 32) * 32
if padded_input_channels != input_channels:
tensor = fallback_ops.pad(
tensor,
(0, padded_input_channels - input_channels, 0, 0, 0, 0),
output_layout=ttnn.ROW_MAJOR_LAYOUT,
output_on_device=False,
)
# Reshape 4d to 2d
tensor = fallback_ops.reshape(
tensor,
1,
1,
batch_size * input_height * input_width,
padded_input_channels,
output_layout=ttnn.ROW_MAJOR_LAYOUT,
output_on_device=False,
)
tensor = ttnn.Tensor(tensor)
tensor = ttnn.to_device(tensor, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)
tensor = ttnn.to_layout(tensor, ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG)
return tensor
return ttnn.permute(tensor, (0, 2, 3, 1))


def pad_encoder_hidden_states(device, tensor, required_sequence_length):
Expand Down
Loading