Skip to content

Commit

Permalink
#12073: assert out on block and width sharded concat (#12111)
Browse files Browse the repository at this point in the history
- Sweep output_mem_config on sharding
- Swap x and y on coregrid for sharding formation to correct shard
- Remove rank 2 tensor sweep tests in n tensor sweeps as they are already covered in concat_interleaved sweep
- Add some breathing room for n tensor concats as format is very slow
  • Loading branch information
sjameelTT authored Sep 3, 2024
1 parent f03bf93 commit 84f944d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from models.utility_functions import torch_random

# Override the default timeout in seconds for hang detection.
TIMEOUT = 20
TIMEOUT = 30 # formatting on host and torch CPU call are slow
random.seed(0)
tiled_dim_unpadded = [32, 64, 96]
tiled_dim_padded = [7, 16]
Expand Down Expand Up @@ -57,10 +57,10 @@ def generate_concat_config(tensor_counts, ranks, variable_dim, other_dims):
parameter_tiled_interleaved = {
f"tiled_interleaved_suite_{n}_tensors": {
"concat_specs": list(
generate_concat_config(n, [4, 3, 2], tiled_dim_unpadded, tiled_dim_unpadded + tiled_dim_padded)
generate_concat_config(n, [4, 3], tiled_dim_unpadded, tiled_dim_unpadded + tiled_dim_padded)
)
+ list(
generate_concat_config(n, [4, 3, 2], tiled_dim_padded, [32, 33])
generate_concat_config(n, [4, 3], tiled_dim_padded, [32, 33])
), # variable dim doesn't support padding, other dims can be anything
"dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"layout": [ttnn.TILE_LAYOUT],
Expand All @@ -78,10 +78,10 @@ def generate_concat_config(tensor_counts, ranks, variable_dim, other_dims):

parameters_row_major_interleaved = {
f"row_major_interleaved_suite_{n}_tensors": {
"concat_specs": list(generate_concat_config(n, [4, 3, 2], rm_dim_even, rm_dim_even))
+ list(generate_concat_config(n, [4, 3, 2], rm_dim_even, rm_dim_odd))
"concat_specs": list(generate_concat_config(n, [4, 3], rm_dim_even, rm_dim_even))
+ list(generate_concat_config(n, [4, 3], rm_dim_even, rm_dim_odd))
+ list(
generate_concat_config(n, [4, 3, 2], rm_dim_odd, rm_dim_even)
generate_concat_config(n, [4, 3], rm_dim_odd, rm_dim_even)
), # variable dim doesn't support padding, other dims can be anything
"dtype": [ttnn.bfloat16],
"layout": [ttnn.ROW_MAJOR_LAYOUT],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@ def generate_concat_width_config(nonvariable_dim, variable_dims, num_cores_choic
"input_mem_config": [
ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG,
],
"output_mem_config": [
ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
],
}
}

Expand All @@ -188,6 +192,10 @@ def generate_concat_width_config(nonvariable_dim, variable_dims, num_cores_choic
"input_mem_config": [
ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG,
],
"output_mem_config": [
ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
],
}
}

Expand All @@ -205,6 +213,10 @@ def generate_concat_width_config(nonvariable_dim, variable_dims, num_cores_choic
"input_mem_config": [
ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
],
"output_mem_config": [
ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
ttnn.L1_MEMORY_CONFIG,
],
}
}

Expand Down Expand Up @@ -250,6 +262,7 @@ def run(
dtype,
layout,
input_mem_config,
output_mem_config,
*,
device,
) -> list:
Expand All @@ -258,51 +271,66 @@ def run(
if input_mem_config == ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG:
input_mem_config_a = ttnn.create_sharded_memory_config(
shape=concat_specs["shape1"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][0], y=concat_specs["num_cores1"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][1], y=concat_specs["num_cores1"][0]),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)
input_mem_config_b = ttnn.create_sharded_memory_config(
shape=concat_specs["shape2"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][0], y=concat_specs["num_cores2"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][1], y=concat_specs["num_cores2"][0]),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)
elif input_mem_config == ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG:
input_mem_config_a = ttnn.create_sharded_memory_config(
shape=concat_specs["shape1"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][0], y=concat_specs["num_cores1"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][1], y=concat_specs["num_cores1"][0]),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)
input_mem_config_b = ttnn.create_sharded_memory_config(
shape=concat_specs["shape2"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][0], y=concat_specs["num_cores2"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][1], y=concat_specs["num_cores2"][0]),
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)
else:
input_mem_config_a = ttnn.create_sharded_memory_config(
shape=concat_specs["shape1"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][0], y=concat_specs["num_cores1"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][1], y=concat_specs["num_cores1"][0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)
input_mem_config_b = ttnn.create_sharded_memory_config(
shape=concat_specs["shape1"],
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][0], y=concat_specs["num_cores2"][1]),
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores2"][1], y=concat_specs["num_cores2"][0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)

torch_input_tensors.append(torch_random(concat_specs["shape1"], -0.1, 0.1, dtype=torch.bfloat16))
torch_input_tensors.append(torch_random(concat_specs["shape2"], -0.1, 0.1, dtype=torch.bfloat16))
torch_output_tensor = torch.concat(torch_input_tensors, dim=concat_specs["dim"])

if (
output_mem_config == ttnn.L1_BLOCK_SHARDED_MEMORY_CONFIG
or output_mem_config == ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG
or output_mem_config == ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG
):
shape_list = list(torch_output_tensor.size())
output_mem_config = ttnn.create_sharded_memory_config(
shape=shape_list,
core_grid=ttnn.CoreGrid(x=concat_specs["num_cores1"][1], y=concat_specs["num_cores1"][0]),
strategy=ttnn.ShardStrategy.BLOCK,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=False,
)

input_tensors = []
input_tensors.append(
ttnn.from_torch(
Expand All @@ -323,8 +351,7 @@ def run(
)
)
start_time = start_measuring_time()
result_tensor = ttnn.concat(input_tensors, dim=concat_specs["dim"], memory_config=input_mem_config_a)
result_tensor = ttnn.concat(input_tensors, dim=concat_specs["dim"], memory_config=output_mem_config)
e2e_perf = stop_measuring_time(start_time)
output_tensor = ttnn.to_torch(result_tensor)
torch_output_tensor = torch.concat(torch_input_tensors, dim=concat_specs["dim"])
return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf]
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) c
TT_FATAL(in_ref.is_sharded() == shard_first, "All tensors must be sharded or all must be interleaved");
if (shard_first) {
TT_FATAL((in_ref.get_layout() == Layout::ROW_MAJOR), "Only row major supported for sharded concat.");
TT_FATAL(in_ref.shard_spec().has_value(), "Sharded tensors must have a shard spec.");
TT_FATAL(in_ref.memory_config().memory_layout != TensorMemoryLayout::BLOCK_SHARDED, "Block sharded inputs are not supported");
TT_FATAL(in_ref.memory_config().memory_layout != TensorMemoryLayout::WIDTH_SHARDED, "Width sharded inputs are not supported");
}
}
if (shard_first) {
Expand Down

0 comments on commit 84f944d

Please sign in to comment.