diff --git a/src/subgraph/even-split.c b/src/subgraph/even-split.c index 181bd2d3ce9..fe16af8c2e2 100644 --- a/src/subgraph/even-split.c +++ b/src/subgraph/even-split.c @@ -301,18 +301,18 @@ static enum xnn_status check_datatype_copyable( return xnn_subgraph_check_quantization_parameter_matches(node_type, input_id, input_value, output_id, output_value); } -enum xnn_status xnn_define_even_split_impl( - enum xnn_node_type node_type, +enum xnn_status xnn_define_even_split( xnn_subgraph_t subgraph, int32_t split_dim, uint32_t input_id, - size_t num_outputs, + uint32_t num_outputs, const uint32_t* output_ids, uint32_t flags) { - assert(num_outputs > 1); - assert(num_outputs < 5); + assert(num_outputs >= 1); + assert(num_outputs <= XNN_MAX_OUTPUTS); + enum xnn_node_type node_type = xnn_node_type_even_split; enum xnn_status status; if ((status = xnn_subgraph_check_xnnpack_initialized(node_type)) != xnn_status_success) { return status; @@ -342,8 +342,8 @@ enum xnn_status xnn_define_even_split_impl( return xnn_status_invalid_parameter; } - for(int i = 0; i < num_outputs; ++i){ - check_datatype_copyable(subgraph, input_id, output_ids[i], "Nth", node_type); + for (int i = 0; i < num_outputs; ++i) { + check_datatype_copyable(subgraph, input_id, output_ids[i], "Nth", node_type); } struct xnn_node* node = xnn_subgraph_new_node(subgraph); @@ -365,16 +365,4 @@ enum xnn_status xnn_define_even_split_impl( node->flags = flags; return xnn_status_success; -}; - -inline enum xnn_status xnn_define_even_split( - xnn_subgraph_t subgraph, - int32_t split_dim, - uint32_t input_id, - uint32_t num_outputs, - const uint32_t* output_ids, - uint32_t flags) -{ - return xnn_define_even_split_impl( - xnn_node_type_even_split, subgraph, split_dim, input_id, num_outputs, output_ids, flags); } diff --git a/test/BUILD.bazel b/test/BUILD.bazel index e0af25fde7f..93c77791d0c 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -1804,7 +1804,7 @@ xnnpack_unit_test( [xnnpack_unit_test( name = "even_split_test" if n == None else "even_split%d_test" % n, srcs = [ - "even_split" if n == None else "even-split%s.cc" % n, + "even-split.cc" if n == None else "even-split%s.cc" % n, ], deps = [ ":replicable_random_device",