From 7089e18161cff12f26027a4254ed51d88be582ca Mon Sep 17 00:00:00 2001 From: Joe Date: Tue, 17 Dec 2024 14:16:20 +0800 Subject: [PATCH] repo-sync-2024-12-16T10:32:45+0800 (#130) --- .../secretflow_serving_lib/graph_builder.py | 12 +- secretflow_serving/ops/arrow_processing.cc | 7 +- .../ops/arrow_processing_test.cc | 178 +++++++++++++++++- 3 files changed, 190 insertions(+), 7 deletions(-) diff --git a/python_lib/secretflow_serving_lib/graph_builder.py b/python_lib/secretflow_serving_lib/graph_builder.py index 088a2e0..d9ca0f9 100644 --- a/python_lib/secretflow_serving_lib/graph_builder.py +++ b/python_lib/secretflow_serving_lib/graph_builder.py @@ -13,6 +13,7 @@ # limitations under the License. import io +import logging import tarfile from itertools import chain from typing import Any, Dict, List @@ -23,7 +24,7 @@ from . import libserving # type: ignore from .api import get_op from .attr_pb2 import AttrType, AttrValue -from .bundle_pb2 import FileFormatType, ModelBundle, ModelInfo, ModelManifest +from .bundle_pb2 import FileFormatType, ModelBundle, ModelManifest from .graph_pb2 import ( DispatchType, ExecutionDef, @@ -33,7 +34,6 @@ RuntimeConfig, HeConfig, ) -from .op_pb2 import OpDef def construct_attr_value(attr_type: AttrType, value) -> AttrValue: @@ -245,9 +245,11 @@ def set_he_config(self, pk_bytes, sk_bytes, scale): def build_proto(self) -> GraphDef: '''Get the GraphDef include all nodes and executions''' - graph_def_str = libserving.graph_validator_impl( - self.graph.proto().SerializeToString() - ) + graph_def_str = self.graph.proto().SerializeToString() + + logging.info(f"check serving graph: {graph_def_str}") + libserving.graph_validator_impl(graph_def_str) + graph = GraphDef() graph.ParseFromString(graph_def_str) return graph diff --git a/secretflow_serving/ops/arrow_processing.cc b/secretflow_serving/ops/arrow_processing.cc index 951d19f..4a65fd1 100644 --- a/secretflow_serving/ops/arrow_processing.cc +++ b/secretflow_serving/ops/arrow_processing.cc @@ -170,7 +170,8 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts) "the last compute function({}) is not returnable", end_func.name()); result_id_ = end_func.output().data_id(); - int num_fields = input_schema_list_.front()->num_fields(); + std::map table_field_num_map = { + {0, input_schema_list_.front()->num_fields()}}; std::map data_id_map = { {0, arrow::Datum::Kind::RECORD_BATCH}}; @@ -184,6 +185,7 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts) // check ext func inputs type valid SERVING_ENFORCE(input_kinds[0] == arrow::Datum::Kind::RECORD_BATCH, errors::ErrorCode::LOGIC_ERROR); + auto num_fields = table_field_num_map[func.inputs()[0].data_id()]; if (ex_func_name == compute::ExtendFunctionName::EFN_TB_COLUMN || ex_func_name == compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN) { // std::shared_ptr column(int) const @@ -260,6 +262,9 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts) data_id_map.emplace(func.output().data_id(), output_kind).second, errors::ErrorCode::LOGIC_ERROR, "found duplicate data_id: {}", func.output().data_id()); + if (output_kind == arrow::Datum::Kind::RECORD_BATCH) { + table_field_num_map.emplace(func.output().data_id(), num_fields); + } switch (ex_func_name) { case compute::ExtendFunctionName::EFN_TB_COLUMN: { diff --git a/secretflow_serving/ops/arrow_processing_test.cc b/secretflow_serving/ops/arrow_processing_test.cc index 368ef39..f70c751 100644 --- a/secretflow_serving/ops/arrow_processing_test.cc +++ b/secretflow_serving/ops/arrow_processing_test.cc @@ -467,7 +467,183 @@ INSTANTIATE_TEST_SUITE_P( {R"JSON([1.234, 2.78864, 3.1415926])JSON"}, {arrow::field("x1", arrow::float64())}, {R"JSON([1.234, 2.78864, 3.1415926])JSON"}, - {arrow::field("x1", arrow::float64())}})); + {arrow::field("x1", arrow::float64())}}, + /*remove first*/ + Param{true, + {R"JSON({ + "name": "EFN_TB_REMOVE_COLUMN", + "inputs": [{ + "dataId": 0 + }, { + "customScalar": { + "i64": "0" + } + }], + "output": { + "dataId": 1 + } + })JSON", + R"JSON({ + "name": "EFN_TB_COLUMN", + "inputs": [{ + "dataId": 0 + }, { + "customScalar": { + "i64": "0" + } + }], + "output": { + "dataId": 2 + } + })JSON", + R"JSON({ + "name": "subtract", + "inputs": [{ + "dataId": 2 + }, { + "customScalar": { + "d": -0.51819918217641925 + } + }], + "output": { + "dataId": 3 + } + })JSON", + R"JSON({ + "name": "subtract", + "inputs": [{ + "dataId": 2 + }, { + "customScalar": { + "d": 1.9297598961851565 + } + }], + "output": { + "dataId": 4 + } + })JSON", + R"JSON({ + "name": "abs", + "inputs": [{ + "dataId": 3 + }], + "output": { + "dataId": 5 + } + })JSON", + R"JSON({ + "name": "abs", + "inputs": [{ + "dataId": 4 + }], + "output": { + "dataId": 6 + } + })JSON", + R"JSON({ + "name": "less", + "inputs": [{ + "dataId": 5 + }, { + "customScalar": { + "d": 1e-07 + } + }], + "output": { + "dataId": 7 + } + })JSON", + R"JSON({ + "name": "less", + "inputs": [{ + "dataId": 6 + }, { + "customScalar": { + "d": 1e-07 + } + }], + "output": { + "dataId": 8 + } + })JSON", + R"JSON({ + "name": "if_else", + "inputs": [{ + "dataId": 7 + }, { + "customScalar": { + "f": 1 + } + }, { + "customScalar": { + "f": 0 + } + }], + "output": { + "dataId": 9 + } + })JSON", + R"JSON({ + "name": "if_else", + "inputs": [{ + "dataId": 8 + }, { + "customScalar": { + "f": 1 + } + }, { + "customScalar": { + "f": 0 + } + }], + "output": { + "dataId": 10 + } + })JSON", + R"JSON({ + "name": "EFN_TB_ADD_COLUMN", + "inputs": [{ + "dataId": 1 + }, { + "customScalar": { + "i64": "0" + } + }, { + "customScalar": { + "s": "contact_unknown_0" + } + }, { + "dataId": 10 + }], + "output": { + "dataId": 11 + } + })JSON", + R"JSON({ + "name": "EFN_TB_ADD_COLUMN", + "inputs": [{ + "dataId": 11 + }, { + "customScalar": { + "i64": "1" + } + }, { + "customScalar": { + "s": "contact_unknown_1" + } + }, { + "dataId": 9 + }], + "output": { + "dataId": 12 + } + })JSON"}, + {}, + {R"JSON([2])JSON"}, + {arrow::field("contact_unknown", arrow::float64())}, + {R"JSON([0])JSON", R"JSON([0])JSON"}, + {arrow::field("contact_unknown_0", arrow::float32()), + arrow::field("contact_unknown_1", arrow::float32())}})); class ArrowProcessingExceptionTest : public ::testing::TestWithParam { protected: