Skip to content

Commit

Permalink
Repo sync (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakinxc authored Mar 31, 2023
1 parent f57abb6 commit f84c843
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 24 deletions.
4 changes: 2 additions & 2 deletions examples/cpp/pir/keyword_pir_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ int main(int argc, char **argv) {

for (size_t i = 0; i < query_result.first.size(); ++i) {
std::vector<std::string> result_ids =
absl::StrSplit(query_batch_items[query_result.first[i]], ",");
absl::StrSplit(query_batch_items[query_result.first[i]], ',');

SPU_ENFORCE(result_ids.size() == ids.size());

std::vector<std::string> result_labels =
absl::StrSplit(query_result.second[i], ",");
absl::StrSplit(query_result.second[i], ',');
SPU_ENFORCE(result_labels.size() == label_columns_name.size());

for (size_t j = 0; j < result_ids.size(); ++j) {
Expand Down
3 changes: 0 additions & 3 deletions libspu/compiler/front_end/hlo_importer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ void runHloPasses(xla::HloModule *module) {
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();

// FIXME: For public gather, this might actually slower
pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateAllGathers);

pipeline.AddPass<TupleSimplifier>();
pipeline.AddPass<WhileLoopSimplifier>();

Expand Down
8 changes: 8 additions & 0 deletions libspu/kernel/hlo/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,14 @@ spu_cc_library(
],
)

spu_cc_test(
name = "utils_test",
srcs = ["utils_test.cc"],
deps = [
":utils",
],
)

spu_cc_library(
name = "shuffle",
srcs = ["shuffle.cc"],
Expand Down
2 changes: 1 addition & 1 deletion libspu/kernel/hlo/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void forEachIndex(absl::Span<const int64_t> shape,
const auto rank = static_cast<int64_t>(shape.size());
// Allows handling R0 arrays, such that the visitor function will be called
// once with the proper empty indexes.
int64_t n = rank - 1;
int64_t n = rank;
std::vector<int64_t> indexes(base.begin(), base.end());

while (n >= 0) {
Expand Down
59 changes: 59 additions & 0 deletions libspu/kernel/hlo/utils_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright 2023 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "libspu/kernel/hlo/utils.h"

#include "gtest/gtest.h"

namespace spu {

TEST(UtilsTest, ForEachIndexScalar) {
int64_t counter = 0;
kernel::forEachIndex({}, {}, {}, {}, [&](absl::Span<const int64_t> idx) {
EXPECT_TRUE(idx.empty());
++counter;
});

EXPECT_EQ(counter, 1);
}

TEST(UtilsTest, ForEachIndex1D) {
int64_t counter = 0;
std::vector<int64_t> expected_idx = {0, 1, 2};
kernel::forEachIndex({3}, {0}, {3}, {1}, [&](absl::Span<const int64_t> idx) {
EXPECT_EQ(idx.size(), 1);
EXPECT_EQ(idx[0], expected_idx[counter]);
++counter;
});

EXPECT_EQ(counter, 3);
}

TEST(UtilsTest, ForEachIndex2D) {
int64_t counter = 0;
std::vector<std::vector<int64_t>> expected_idx = {
{0, 0}, {0, 1}, {1, 0}, {1, 1}};

kernel::forEachIndex({2, 2}, {0, 0}, {2, 2}, {1, 1},
[&](absl::Span<const int64_t> idx) {
EXPECT_EQ(idx.size(), 2);
EXPECT_EQ(idx[0], expected_idx[counter][0]);
EXPECT_EQ(idx[1], expected_idx[counter][1]);
++counter;
});

EXPECT_EQ(counter, 4);
}

} // namespace spu
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def has_ext_modules(self):
long_description=io.open(
os.path.join(ROOT_DIR, "README.md"), "r", encoding="utf-8"
).read(),
url="https://github.com/secretflow/secretflow/spu",
url="https://github.com/secretflow/spu",
keywords=("spu mpc secretflow compiler vm ABY3 secure computation"),
classifiers=[
"Programming Language :: Python :: 3.8",
Expand Down
2 changes: 1 addition & 1 deletion sml/utils/emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"name": "spu-emulation",
"ipam": {
"driver": "default",
"config": [{"subnet": SAMPLE_CIDR}],
"config": [{"subnet": SAMPLE_CIDR, "gateway": "172.16.238.1"}],
},
}
},
Expand Down
31 changes: 16 additions & 15 deletions spu/tests/jnp_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,18 +567,19 @@ def test_shift(self, name, status, dtype, shape, rnd_factory):
),
)

def test_gather(self):
jnp_fn = lambda x, indices: jnp.take(x, indices)
spu_fn = sim_jax(self._sim, jnp_fn)
x_rng = jtu.rand_int(self._rng, low=0, high=32)
indices_rng = jtu.rand_int(self._rng, low=0, high=9)
args = [x_rng((10,), np.int32), indices_rng((3,), np.int32)]
jnp_out = jnp_fn(*args)
spu_out = spu_fn(*args)
npt.assert_equal(
spu_out,
jnp_out,
err_msg="take faild.\nx = {}, indices = {}\nspu = {}\njnp = {}".format(
args[0], args[1], spu_out, jnp_out
),
)
# FIXME(anakinxc): Reenable once we fix secret gather
# def test_gather(self):
# jnp_fn = lambda x, indices: jnp.take(x, indices)
# spu_fn = sim_jax(self._sim, jnp_fn)
# x_rng = jtu.rand_int(self._rng, low=0, high=32)
# indices_rng = jtu.rand_int(self._rng, low=0, high=9)
# args = [x_rng((10,), np.int32), indices_rng((3,), np.int32)]
# jnp_out = jnp_fn(*args)
# spu_out = spu_fn(*args)
# npt.assert_equal(
# spu_out,
# jnp_out,
# err_msg="take faild.\nx = {}, indices = {}\nspu = {}\njnp = {}".format(
# args[0], args[1], spu_out, jnp_out
# ),
# )
2 changes: 1 addition & 1 deletion spu/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.


__version__ = "0.3.2b10"
__version__ = "0.3.2b11"

0 comments on commit f84c843

Please sign in to comment.