From 65d49245939e66920c5ca87536be804f8aa4fad1 Mon Sep 17 00:00:00 2001 From: RahulSudarMCW Date: Mon, 13 Jan 2025 11:36:08 +0530 Subject: [PATCH] Address review comments --- CMakeLists.txt | 2 +- src/runtime.c | 6 +-- src/subgraph.c | 6 +-- src/subgraph/concatenate.c | 47 +++++------------------- src/xnnpack/node-type-defs.h | 6 +-- test/BUILD.bazel | 5 ++- test/{concatenateN.cc => concatenate.cc} | 12 +++--- test/fusion.cc | 2 +- test/subgraph-tester.h | 2 +- 9 files changed, 24 insertions(+), 64 deletions(-) rename test/{concatenateN.cc => concatenate.cc} (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 74d56e0ebff..5a6107a001f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1360,7 +1360,7 @@ IF(XNNPACK_BUILD_TESTS) average-pooling-2d average-pooling-2d-reshape binary - concatenateN + concatenate copy depth-to-space-2d even-split2 diff --git a/src/runtime.c b/src/runtime.c index 3c6283d4ee2..d37bf99bc7b 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -474,11 +474,7 @@ void propagate_rank( case xnn_node_type_binary_elementwise: output_value->shape.num_dims = max(input_value->shape.num_dims, input_value_b->shape.num_dims); break; - case xnn_node_type_concatenate2: - case xnn_node_type_concatenate3: - case xnn_node_type_concatenate4: - case xnn_node_type_concatenate5: - case xnn_node_type_concatenate_n: + case xnn_node_type_concatenate: case xnn_node_type_copy: case xnn_node_type_even_split2: case xnn_node_type_even_split3: diff --git a/src/subgraph.c b/src/subgraph.c index 0822605fa88..b30de02a0ea 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -874,11 +874,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) case xnn_node_type_binary_elementwise: case xnn_node_type_unary_elementwise: case xnn_node_type_batch_matrix_multiply: - case xnn_node_type_concatenate2: - case xnn_node_type_concatenate3: - case xnn_node_type_concatenate4: - case xnn_node_type_concatenate5: - case xnn_node_type_concatenate_n: + case xnn_node_type_concatenate: case xnn_node_type_convert: case xnn_node_type_average_pooling_2d: case xnn_node_type_copy: diff --git a/src/subgraph/concatenate.c b/src/subgraph/concatenate.c index a90a4f2bfd5..9de37eeb606 100644 --- a/src/subgraph/concatenate.c +++ b/src/subgraph/concatenate.c @@ -107,16 +107,16 @@ static enum xnn_status reshape_concatenate_operator_helper( } } -static enum xnn_status reshape_concatenate_n_operator( +static enum xnn_status reshape_concatenate_operator( struct xnn_operator_data* opdata, struct xnn_value* values, size_t num_values, - size_t num_inputs, pthreadpool_t threadpool) { enum xnn_status status; - num_inputs = opdata->num_inputs; + size_t num_inputs = opdata->num_inputs; + assert(num_inputs <= XNN_MAX_OPERATOR_OBJECTS); uint32_t input_id[XNN_MAX_OPERATOR_OBJECTS]; for (size_t i = 0; i < num_inputs; ++i) { input_id[i] = opdata->inputs[i]; @@ -217,15 +217,6 @@ static enum xnn_status setup_concatenate_operator_helper( } } -static enum xnn_status reshape_concatenaten_operator( - struct xnn_operator_data* opdata, - struct xnn_value* values, - size_t num_values, - pthreadpool_t threadpool) -{ - return reshape_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool); -} - static enum xnn_status setup_concatenate_n_operator( const struct xnn_operator_data* opdata, const struct xnn_value* values, @@ -322,7 +313,7 @@ static enum xnn_status setup_concatenaten_operator( return setup_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool); } -enum xnn_status xnn_define_concatenate_n( +enum xnn_status xnn_define_concatenate_impl( enum xnn_node_type node_type, xnn_subgraph_t subgraph, int32_t axis, @@ -355,32 +346,12 @@ enum xnn_status xnn_define_concatenate_n( } } - status = check_datatype_copyable(subgraph, input_ids[0], output_id, "first", node_type); - if (status != xnn_status_success) { - return status; - } - status = check_datatype_copyable(subgraph, input_ids[1], output_id, "second", node_type); - if (status != xnn_status_success) { - return status; - } - if (num_inputs > 2) { - status = check_datatype_copyable(subgraph, input_ids[2], output_id, "third", node_type); + for (size_t i = 0; i < num_inputs; i++) { + status = check_datatype_copyable(subgraph, input_ids[i], output_id, "ith", node_type); if (status != xnn_status_success) { return status; } } - if (num_inputs > 3) { - status = check_datatype_copyable(subgraph, input_ids[3], output_id, "fourth", node_type); - if (status != xnn_status_success) { - return status; - } - } - if (num_inputs > 4) { - status = check_datatype_copyable(subgraph, input_ids[4], output_id, "fifth", node_type); - if (status != xnn_status_success) { - return status; - } - } struct xnn_node* node = xnn_subgraph_new_node(subgraph); if (node == NULL) { @@ -395,7 +366,7 @@ enum xnn_status xnn_define_concatenate_n( node->flags = flags; node->create = create_concatenaten_operator; - node->reshape = reshape_concatenaten_operator; + node->reshape = reshape_concatenate_operator; node->setup = setup_concatenaten_operator; for (size_t i = 0; i < num_inputs; ++i) { @@ -413,6 +384,6 @@ enum xnn_status xnn_define_concatenate( uint32_t output_id, uint32_t flags) { - return xnn_define_concatenate_n( - xnn_node_type_concatenate_n, subgraph, axis, num_inputs, inputs, output_id, flags); + return xnn_define_concatenate_impl( + xnn_node_type_concatenate, subgraph, axis, num_inputs, inputs, output_id, flags); } diff --git a/src/xnnpack/node-type-defs.h b/src/xnnpack/node-type-defs.h index a0cbbe09055..bc2b20a3493 100644 --- a/src/xnnpack/node-type-defs.h +++ b/src/xnnpack/node-type-defs.h @@ -13,11 +13,7 @@ XNN_ENUM_ITEM(xnn_node_type_argmax_pooling_2d, "ArgMax Pooling 2D") XNN_ENUM_ITEM(xnn_node_type_average_pooling_2d, "Average Pooling 2D") XNN_ENUM_ITEM(xnn_node_type_batch_matrix_multiply, "Batch Matrix Multiply") XNN_ENUM_ITEM(xnn_node_type_binary_elementwise, "Binary Elementwise") -XNN_ENUM_ITEM(xnn_node_type_concatenate2, "Concatenate2") -XNN_ENUM_ITEM(xnn_node_type_concatenate3, "Concatenate3") -XNN_ENUM_ITEM(xnn_node_type_concatenate4, "Concatenate4") -XNN_ENUM_ITEM(xnn_node_type_concatenate5, "Concatenate5") -XNN_ENUM_ITEM(xnn_node_type_concatenate_n, "ConcatenateN") +XNN_ENUM_ITEM(xnn_node_type_concatenate, "Concatenate") XNN_ENUM_ITEM(xnn_node_type_convert, "Convert") XNN_ENUM_ITEM(xnn_node_type_convolution_2d, "Convolution 2D") XNN_ENUM_ITEM(xnn_node_type_copy, "Copy") diff --git a/test/BUILD.bazel b/test/BUILD.bazel index c95c2c9ebb7..e67e22be3fe 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1698,9 +1698,9 @@ xnnpack_unit_test( ) [xnnpack_unit_test( - name = "concatenate%d_test" % n, + name = ""concatenate_test" if n == None else "concatenate%d_test" % n, srcs = [ - "concatenate%d.cc" % n, + "concatenate.cc" if n == None else "concatenate%d.cc" % n, ], deps = [ ":replicable_random_device", @@ -1717,6 +1717,7 @@ xnnpack_unit_test( 3, 4, 5, + None, ]] xnnpack_unit_test( diff --git a/test/concatenateN.cc b/test/concatenate.cc similarity index 98% rename from test/concatenateN.cc rename to test/concatenate.cc index a6d7f0359d8..3fe91db48fb 100644 --- a/test/concatenateN.cc +++ b/test/concatenate.cc @@ -101,7 +101,7 @@ template class ConcatenateNTest : public ::testing::Test { size_t RandomNumInputs() { - return std::uniform_int_distribution(2, 5)(rng); // You can adjust the range + return std::uniform_int_distribution(2, XNN_MAX_OPERATOR_OBJECTS)(rng); // You can adjust the range } @@ -176,7 +176,7 @@ TEST_F(ConcatenateNTestQS8, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_concatenate_n); + ASSERT_EQ(node->type, xnn_node_type_concatenate); ASSERT_EQ(node->params.concatenate.axis, axis); ASSERT_EQ(node->num_inputs, num_inputs); @@ -221,7 +221,7 @@ TEST_F(ConcatenateNTestQU8, define) xnn_define_concatenate(subgraph, axis,num_inputs, input_ids.data(), output_id, /*flags=*/0)); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_concatenate_n); + ASSERT_EQ(node->type, xnn_node_type_concatenate); ASSERT_EQ(node->params.concatenate.axis, axis); ASSERT_EQ(node->num_inputs, num_inputs); ASSERT_EQ(node->num_outputs, 1); @@ -259,7 +259,7 @@ TEST_F(ConcatenateNTestF16, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_concatenate_n); + ASSERT_EQ(node->type, xnn_node_type_concatenate); ASSERT_EQ(node->params.concatenate.axis, axis); ASSERT_EQ(node->num_inputs, num_inputs); ASSERT_EQ(node->num_outputs, 1); @@ -298,7 +298,7 @@ TEST_F(ConcatenateNTestF32, define) ASSERT_EQ(subgraph->num_nodes, 1); const struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_concatenate_n); + ASSERT_EQ(node->type, xnn_node_type_concatenate); ASSERT_EQ(node->params.concatenate.axis, axis); ASSERT_EQ(node->num_inputs, num_inputs); ASSERT_EQ(node->num_outputs, 1); @@ -629,7 +629,7 @@ TEST_F(ConcatenateNTestF32, Reshape) ASSERT_EQ(subgraph->num_nodes, 1); struct xnn_node* node = &subgraph->nodes[0]; - ASSERT_EQ(node->type, xnn_node_type_concatenate_n); + ASSERT_EQ(node->type, xnn_node_type_concatenate); ASSERT_EQ(node->num_inputs, num_inputs); for (int i = 0; i < num_inputs; i++) { ASSERT_EQ(node->inputs[i], input_ids[i]); diff --git a/test/fusion.cc b/test/fusion.cc index 354e1f871f7..6030fbe74f4 100644 --- a/test/fusion.cc +++ b/test/fusion.cc @@ -658,7 +658,7 @@ TEST(COPY, fused_upstream_with_multiple_outputs) { EXPECT_EQ(split_node->outputs[1], copy_out2); const xnn_node* concat_node = tester.Node(3); - ASSERT_EQ(concat_node->type, xnn_node_type_concatenate2); + ASSERT_EQ(concat_node->type, xnn_node_type_concatenate); ASSERT_EQ(concat_node->num_inputs, 2); EXPECT_EQ(concat_node->inputs[0], copy_out1); EXPECT_EQ(concat_node->inputs[1], copy_out2); diff --git a/test/subgraph-tester.h b/test/subgraph-tester.h index 1aa07113c3e..9817a97eeac 100644 --- a/test/subgraph-tester.h +++ b/test/subgraph-tester.h @@ -274,7 +274,7 @@ class SubgraphTester { return *this; } - SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) { + SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) { const uint32_t input_ids[] = {input1_id, input2_id}; // Create an array of input IDs const xnn_status status = xnn_define_concatenate( subgraph_.get(), axis, 2 /* num_inputs */, input_ids, output_id, 0 /* flags */);