Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RahulSundarMCW committed Jan 13, 2025
1 parent 4085c8c commit 65d4924
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 64 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ IF(XNNPACK_BUILD_TESTS)
average-pooling-2d
average-pooling-2d-reshape
binary
concatenateN
concatenate
copy
depth-to-space-2d
even-split2
Expand Down
6 changes: 1 addition & 5 deletions src/runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -474,11 +474,7 @@ void propagate_rank(
case xnn_node_type_binary_elementwise:
output_value->shape.num_dims = max(input_value->shape.num_dims, input_value_b->shape.num_dims);
break;
case xnn_node_type_concatenate2:
case xnn_node_type_concatenate3:
case xnn_node_type_concatenate4:
case xnn_node_type_concatenate5:
case xnn_node_type_concatenate_n:
case xnn_node_type_concatenate:
case xnn_node_type_copy:
case xnn_node_type_even_split2:
case xnn_node_type_even_split3:
Expand Down
6 changes: 1 addition & 5 deletions src/subgraph.c
Original file line number Diff line number Diff line change
Expand Up @@ -874,11 +874,7 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph)
case xnn_node_type_binary_elementwise:
case xnn_node_type_unary_elementwise:
case xnn_node_type_batch_matrix_multiply:
case xnn_node_type_concatenate2:
case xnn_node_type_concatenate3:
case xnn_node_type_concatenate4:
case xnn_node_type_concatenate5:
case xnn_node_type_concatenate_n:
case xnn_node_type_concatenate:
case xnn_node_type_convert:
case xnn_node_type_average_pooling_2d:
case xnn_node_type_copy:
Expand Down
47 changes: 9 additions & 38 deletions src/subgraph/concatenate.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ static enum xnn_status reshape_concatenate_operator_helper(
}
}

static enum xnn_status reshape_concatenate_n_operator(
static enum xnn_status reshape_concatenate_operator(
struct xnn_operator_data* opdata,
struct xnn_value* values,
size_t num_values,
size_t num_inputs,
pthreadpool_t threadpool)
{
enum xnn_status status;

num_inputs = opdata->num_inputs;
size_t num_inputs = opdata->num_inputs;
assert(num_inputs <= XNN_MAX_OPERATOR_OBJECTS);
uint32_t input_id[XNN_MAX_OPERATOR_OBJECTS];
for (size_t i = 0; i < num_inputs; ++i) {
input_id[i] = opdata->inputs[i];
Expand Down Expand Up @@ -217,15 +217,6 @@ static enum xnn_status setup_concatenate_operator_helper(
}
}

static enum xnn_status reshape_concatenaten_operator(
struct xnn_operator_data* opdata,
struct xnn_value* values,
size_t num_values,
pthreadpool_t threadpool)
{
return reshape_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool);
}

static enum xnn_status setup_concatenate_n_operator(
const struct xnn_operator_data* opdata,
const struct xnn_value* values,
Expand Down Expand Up @@ -322,7 +313,7 @@ static enum xnn_status setup_concatenaten_operator(
return setup_concatenate_n_operator(opdata, values, num_values, opdata->num_inputs, threadpool);
}

enum xnn_status xnn_define_concatenate_n(
enum xnn_status xnn_define_concatenate_impl(
enum xnn_node_type node_type,
xnn_subgraph_t subgraph,
int32_t axis,
Expand Down Expand Up @@ -355,32 +346,12 @@ enum xnn_status xnn_define_concatenate_n(
}
}

status = check_datatype_copyable(subgraph, input_ids[0], output_id, "first", node_type);
if (status != xnn_status_success) {
return status;
}
status = check_datatype_copyable(subgraph, input_ids[1], output_id, "second", node_type);
if (status != xnn_status_success) {
return status;
}
if (num_inputs > 2) {
status = check_datatype_copyable(subgraph, input_ids[2], output_id, "third", node_type);
for (size_t i = 0; i < num_inputs; i++) {
status = check_datatype_copyable(subgraph, input_ids[i], output_id, "ith", node_type);
if (status != xnn_status_success) {
return status;
}
}
if (num_inputs > 3) {
status = check_datatype_copyable(subgraph, input_ids[3], output_id, "fourth", node_type);
if (status != xnn_status_success) {
return status;
}
}
if (num_inputs > 4) {
status = check_datatype_copyable(subgraph, input_ids[4], output_id, "fifth", node_type);
if (status != xnn_status_success) {
return status;
}
}

struct xnn_node* node = xnn_subgraph_new_node(subgraph);
if (node == NULL) {
Expand All @@ -395,7 +366,7 @@ enum xnn_status xnn_define_concatenate_n(
node->flags = flags;

node->create = create_concatenaten_operator;
node->reshape = reshape_concatenaten_operator;
node->reshape = reshape_concatenate_operator;
node->setup = setup_concatenaten_operator;

for (size_t i = 0; i < num_inputs; ++i) {
Expand All @@ -413,6 +384,6 @@ enum xnn_status xnn_define_concatenate(
uint32_t output_id,
uint32_t flags)
{
return xnn_define_concatenate_n(
xnn_node_type_concatenate_n, subgraph, axis, num_inputs, inputs, output_id, flags);
return xnn_define_concatenate_impl(
xnn_node_type_concatenate, subgraph, axis, num_inputs, inputs, output_id, flags);
}
6 changes: 1 addition & 5 deletions src/xnnpack/node-type-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ XNN_ENUM_ITEM(xnn_node_type_argmax_pooling_2d, "ArgMax Pooling 2D")
XNN_ENUM_ITEM(xnn_node_type_average_pooling_2d, "Average Pooling 2D")
XNN_ENUM_ITEM(xnn_node_type_batch_matrix_multiply, "Batch Matrix Multiply")
XNN_ENUM_ITEM(xnn_node_type_binary_elementwise, "Binary Elementwise")
XNN_ENUM_ITEM(xnn_node_type_concatenate2, "Concatenate2")
XNN_ENUM_ITEM(xnn_node_type_concatenate3, "Concatenate3")
XNN_ENUM_ITEM(xnn_node_type_concatenate4, "Concatenate4")
XNN_ENUM_ITEM(xnn_node_type_concatenate5, "Concatenate5")
XNN_ENUM_ITEM(xnn_node_type_concatenate_n, "ConcatenateN")
XNN_ENUM_ITEM(xnn_node_type_concatenate, "Concatenate")
XNN_ENUM_ITEM(xnn_node_type_convert, "Convert")
XNN_ENUM_ITEM(xnn_node_type_convolution_2d, "Convolution 2D")
XNN_ENUM_ITEM(xnn_node_type_copy, "Copy")
Expand Down
5 changes: 3 additions & 2 deletions test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1698,9 +1698,9 @@ xnnpack_unit_test(
)

[xnnpack_unit_test(
name = "concatenate%d_test" % n,
name = ""concatenate_test" if n == None else "concatenate%d_test" % n,
srcs = [
"concatenate%d.cc" % n,
"concatenate.cc" if n == None else "concatenate%d.cc" % n,
],
deps = [
":replicable_random_device",
Expand All @@ -1717,6 +1717,7 @@ xnnpack_unit_test(
3,
4,
5,
None,
]]

xnnpack_unit_test(
Expand Down
12 changes: 6 additions & 6 deletions test/concatenateN.cc → test/concatenate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ template <typename T> class ConcatenateNTest : public ::testing::Test {

size_t RandomNumInputs()
{
return std::uniform_int_distribution<size_t>(2, 5)(rng); // You can adjust the range
return std::uniform_int_distribution<size_t>(2, XNN_MAX_OPERATOR_OBJECTS)(rng); // You can adjust the range
}


Expand Down Expand Up @@ -176,7 +176,7 @@ TEST_F(ConcatenateNTestQS8, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_concatenate_n);
ASSERT_EQ(node->type, xnn_node_type_concatenate);
ASSERT_EQ(node->params.concatenate.axis, axis);
ASSERT_EQ(node->num_inputs, num_inputs);

Expand Down Expand Up @@ -221,7 +221,7 @@ TEST_F(ConcatenateNTestQU8, define)
xnn_define_concatenate(subgraph, axis,num_inputs, input_ids.data(), output_id, /*flags=*/0));

const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_concatenate_n);
ASSERT_EQ(node->type, xnn_node_type_concatenate);
ASSERT_EQ(node->params.concatenate.axis, axis);
ASSERT_EQ(node->num_inputs, num_inputs);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -259,7 +259,7 @@ TEST_F(ConcatenateNTestF16, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_concatenate_n);
ASSERT_EQ(node->type, xnn_node_type_concatenate);
ASSERT_EQ(node->params.concatenate.axis, axis);
ASSERT_EQ(node->num_inputs, num_inputs);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -298,7 +298,7 @@ TEST_F(ConcatenateNTestF32, define)

ASSERT_EQ(subgraph->num_nodes, 1);
const struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_concatenate_n);
ASSERT_EQ(node->type, xnn_node_type_concatenate);
ASSERT_EQ(node->params.concatenate.axis, axis);
ASSERT_EQ(node->num_inputs, num_inputs);
ASSERT_EQ(node->num_outputs, 1);
Expand Down Expand Up @@ -629,7 +629,7 @@ TEST_F(ConcatenateNTestF32, Reshape)

ASSERT_EQ(subgraph->num_nodes, 1);
struct xnn_node* node = &subgraph->nodes[0];
ASSERT_EQ(node->type, xnn_node_type_concatenate_n);
ASSERT_EQ(node->type, xnn_node_type_concatenate);
ASSERT_EQ(node->num_inputs, num_inputs);
for (int i = 0; i < num_inputs; i++) {
ASSERT_EQ(node->inputs[i], input_ids[i]);
Expand Down
2 changes: 1 addition & 1 deletion test/fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ TEST(COPY, fused_upstream_with_multiple_outputs) {
EXPECT_EQ(split_node->outputs[1], copy_out2);

const xnn_node* concat_node = tester.Node(3);
ASSERT_EQ(concat_node->type, xnn_node_type_concatenate2);
ASSERT_EQ(concat_node->type, xnn_node_type_concatenate);
ASSERT_EQ(concat_node->num_inputs, 2);
EXPECT_EQ(concat_node->inputs[0], copy_out1);
EXPECT_EQ(concat_node->inputs[1], copy_out2);
Expand Down
2 changes: 1 addition & 1 deletion test/subgraph-tester.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class SubgraphTester {
return *this;
}

SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) {
SubgraphTester& AddConcatenate2(size_t axis, uint32_t input1_id, uint32_t input2_id, uint32_t output_id) {
const uint32_t input_ids[] = {input1_id, input2_id}; // Create an array of input IDs
const xnn_status status = xnn_define_concatenate(
subgraph_.get(), axis, 2 /* num_inputs */, input_ids, output_id, 0 /* flags */);
Expand Down

0 comments on commit 65d4924

Please sign in to comment.