Skip to content

Commit

Permalink
repo-sync-2024-12-16T10:32:45+0800 (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
oeqqwq authored Dec 17, 2024
1 parent ce43e4f commit 7089e18
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 7 deletions.
12 changes: 7 additions & 5 deletions python_lib/secretflow_serving_lib/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import io
import logging
import tarfile
from itertools import chain
from typing import Any, Dict, List
Expand All @@ -23,7 +24,7 @@
from . import libserving # type: ignore
from .api import get_op
from .attr_pb2 import AttrType, AttrValue
from .bundle_pb2 import FileFormatType, ModelBundle, ModelInfo, ModelManifest
from .bundle_pb2 import FileFormatType, ModelBundle, ModelManifest
from .graph_pb2 import (
DispatchType,
ExecutionDef,
Expand All @@ -33,7 +34,6 @@
RuntimeConfig,
HeConfig,
)
from .op_pb2 import OpDef


def construct_attr_value(attr_type: AttrType, value) -> AttrValue:
Expand Down Expand Up @@ -245,9 +245,11 @@ def set_he_config(self, pk_bytes, sk_bytes, scale):

def build_proto(self) -> GraphDef:
'''Get the GraphDef include all nodes and executions'''
graph_def_str = libserving.graph_validator_impl(
self.graph.proto().SerializeToString()
)
graph_def_str = self.graph.proto().SerializeToString()

logging.info(f"check serving graph: {graph_def_str}")
libserving.graph_validator_impl(graph_def_str)

graph = GraphDef()
graph.ParseFromString(graph_def_str)
return graph
Expand Down
7 changes: 6 additions & 1 deletion secretflow_serving/ops/arrow_processing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts)
"the last compute function({}) is not returnable", end_func.name());
result_id_ = end_func.output().data_id();

int num_fields = input_schema_list_.front()->num_fields();
std::map<int, int> table_field_num_map = {
{0, input_schema_list_.front()->num_fields()}};
std::map<int32_t, arrow::Datum::Kind> data_id_map = {
{0, arrow::Datum::Kind::RECORD_BATCH}};

Expand All @@ -184,6 +185,7 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts)
// check ext func inputs type valid
SERVING_ENFORCE(input_kinds[0] == arrow::Datum::Kind::RECORD_BATCH,
errors::ErrorCode::LOGIC_ERROR);
auto num_fields = table_field_num_map[func.inputs()[0].data_id()];
if (ex_func_name == compute::ExtendFunctionName::EFN_TB_COLUMN ||
ex_func_name == compute::ExtendFunctionName::EFN_TB_REMOVE_COLUMN) {
// std::shared_ptr<Array> column(int) const
Expand Down Expand Up @@ -260,6 +262,9 @@ ArrowProcessing::ArrowProcessing(OpKernelOptions opts)
data_id_map.emplace(func.output().data_id(), output_kind).second,
errors::ErrorCode::LOGIC_ERROR, "found duplicate data_id: {}",
func.output().data_id());
if (output_kind == arrow::Datum::Kind::RECORD_BATCH) {
table_field_num_map.emplace(func.output().data_id(), num_fields);
}

switch (ex_func_name) {
case compute::ExtendFunctionName::EFN_TB_COLUMN: {
Expand Down
178 changes: 177 additions & 1 deletion secretflow_serving/ops/arrow_processing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,183 @@ INSTANTIATE_TEST_SUITE_P(
{R"JSON([1.234, 2.78864, 3.1415926])JSON"},
{arrow::field("x1", arrow::float64())},
{R"JSON([1.234, 2.78864, 3.1415926])JSON"},
{arrow::field("x1", arrow::float64())}}));
{arrow::field("x1", arrow::float64())}},
/*remove first*/
Param{true,
{R"JSON({
"name": "EFN_TB_REMOVE_COLUMN",
"inputs": [{
"dataId": 0
}, {
"customScalar": {
"i64": "0"
}
}],
"output": {
"dataId": 1
}
})JSON",
R"JSON({
"name": "EFN_TB_COLUMN",
"inputs": [{
"dataId": 0
}, {
"customScalar": {
"i64": "0"
}
}],
"output": {
"dataId": 2
}
})JSON",
R"JSON({
"name": "subtract",
"inputs": [{
"dataId": 2
}, {
"customScalar": {
"d": -0.51819918217641925
}
}],
"output": {
"dataId": 3
}
})JSON",
R"JSON({
"name": "subtract",
"inputs": [{
"dataId": 2
}, {
"customScalar": {
"d": 1.9297598961851565
}
}],
"output": {
"dataId": 4
}
})JSON",
R"JSON({
"name": "abs",
"inputs": [{
"dataId": 3
}],
"output": {
"dataId": 5
}
})JSON",
R"JSON({
"name": "abs",
"inputs": [{
"dataId": 4
}],
"output": {
"dataId": 6
}
})JSON",
R"JSON({
"name": "less",
"inputs": [{
"dataId": 5
}, {
"customScalar": {
"d": 1e-07
}
}],
"output": {
"dataId": 7
}
})JSON",
R"JSON({
"name": "less",
"inputs": [{
"dataId": 6
}, {
"customScalar": {
"d": 1e-07
}
}],
"output": {
"dataId": 8
}
})JSON",
R"JSON({
"name": "if_else",
"inputs": [{
"dataId": 7
}, {
"customScalar": {
"f": 1
}
}, {
"customScalar": {
"f": 0
}
}],
"output": {
"dataId": 9
}
})JSON",
R"JSON({
"name": "if_else",
"inputs": [{
"dataId": 8
}, {
"customScalar": {
"f": 1
}
}, {
"customScalar": {
"f": 0
}
}],
"output": {
"dataId": 10
}
})JSON",
R"JSON({
"name": "EFN_TB_ADD_COLUMN",
"inputs": [{
"dataId": 1
}, {
"customScalar": {
"i64": "0"
}
}, {
"customScalar": {
"s": "contact_unknown_0"
}
}, {
"dataId": 10
}],
"output": {
"dataId": 11
}
})JSON",
R"JSON({
"name": "EFN_TB_ADD_COLUMN",
"inputs": [{
"dataId": 11
}, {
"customScalar": {
"i64": "1"
}
}, {
"customScalar": {
"s": "contact_unknown_1"
}
}, {
"dataId": 9
}],
"output": {
"dataId": 12
}
})JSON"},
{},
{R"JSON([2])JSON"},
{arrow::field("contact_unknown", arrow::float64())},
{R"JSON([0])JSON", R"JSON([0])JSON"},
{arrow::field("contact_unknown_0", arrow::float32()),
arrow::field("contact_unknown_1", arrow::float32())}}));

class ArrowProcessingExceptionTest : public ::testing::TestWithParam<Param> {
protected:
Expand Down

0 comments on commit 7089e18

Please sign in to comment.