Skip to content

Commit

Permalink
feat: 添加 einsum 算子
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Oct 16, 2023
1 parent 26d4248 commit 1baa70c
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "operators/constant_of_shape.hh"
#include "operators/conv.hh"
#include "operators/cum_sum.hh"
#include "operators/einsum.hh"
#include "operators/expand.hh"
#include "operators/gather.hh"
#include "operators/gather_elements.hh"
Expand All @@ -33,7 +34,7 @@
namespace refactor::onnx {

void register_() {
// clang-format off
// clang-format off
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("onnx::" #NAME)
REGISTER(BatchNormalization, BatchNormalization);
REGISTER(Cast , Cast );
Expand All @@ -47,6 +48,7 @@ namespace refactor::onnx {
REGISTER(ConstantOfShape , ConstantOfShape );
REGISTER(Conv , Conv );
REGISTER(CumSum , CumSum );
REGISTER(Einsum , Einsum );
REGISTER(Expand , Expand );
REGISTER(Gather , Gather );
REGISTER(GatherElements , GatherElements );
Expand Down
35 changes: 35 additions & 0 deletions src/07onnx/src/operators/einsum.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// #include "computation/operators/einsum.h"
#include "einsum.hh"
#include "common.h"
#include "refactor/common.h"
#include <variant>

namespace refactor::onnx {
using Op = Einsum;

Op::Einsum(std::string equation_)
: Operator(), equation(std::move(equation_)) {}

auto Op::build(std::string_view, Attributes attributes) -> OpBox {
return OpBox(std::make_unique<Op>(std::move(attributes.at("equation").string())));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "onnx::Einsum"; }

auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
if (inputs.empty()) {
return Err(InferError(ERROR_MSG("Input size error")));
}
TODO("");
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
TODO("");
}

}// namespace refactor::onnx
25 changes: 25 additions & 0 deletions src/07onnx/src/operators/einsum.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef ONNX_EINSUM_HH
#define ONNX_EINSUM_HH

#include "frontend/operator.h"

namespace refactor::onnx {
using namespace frontend;

struct Einsum final : public Operator {
std::string equation;

Einsum(std::string);

static OpBox build(std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::onnx

#endif// ONNX_EINSUM_HH
20 changes: 20 additions & 0 deletions src/07onnx/test/test_einsum.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "../src/operators/einsum.hh"
#include "onnx/operators.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace frontend;
using namespace onnx;

TEST(infer, Einsum) {
onnx::register_();
auto edges = Edges{{Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(5)}, {}), ""}};
graph_topo::idx_t inputs[]{0};
// auto infered = Einsum("ij->ji").infer(TensorRefs(edges, slice(inputs, 1)), {true});
// ASSERT_TRUE(infered.isOk());
// auto outputs = std::move(infered.unwrap());
// ASSERT_EQ(outputs.size(), 1);
// auto y = std::move(outputs[0]);
// ASSERT_EQ(y->dataType, DataType::F32);
// ASSERT_EQ(y->shape, (Shape{DimExpr(5), DimExpr(2)}));
}

0 comments on commit 1baa70c

Please sign in to comment.