Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZIR-332: Add bytecode interpreter for validity polynomial calculation for keccak #168

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
426 changes: 426 additions & 0 deletions zirgen/Dialect/ByteCode/Analysis/ArmAnalysis.cpp

Large diffs are not rendered by default.

122 changes: 122 additions & 0 deletions zirgen/Dialect/ByteCode/Analysis/ArmAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "mlir/IR/Block.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"

#include "zirgen/Dialect/ByteCode/IR/ByteCode.h"

namespace zirgen::ByteCode {

struct ArmInfo {
mlir::LocationAttr loc;

llvm::SmallVector<llvm::ArrayRef<mlir::Operation*>> allOps;

// Returns a representative sample of the operations present
llvm::ArrayRef<mlir::Operation*> getOps() const {
if (allOps.empty())
return {};
else
return allOps.front();
}

// Number of times we've seen this set of operations.
size_t getCount() const { return allOps.size(); }

// For each operation in this arm, the number of integer arguments
// (from getByteCodeIntArgs) that need to be decoded. This does not
// include temporary values that are listed as operation operands.
llvm::SmallVector<size_t, 1> opIntArgs;

// Arguments needed from outside for this arm, i.e. function
// arguments that aren't stored in temporary storage.
llvm::SmallVector<mlir::StringAttr, 1> funcArgNames;

// Temporary values that need to be loaded in for this arm
size_t numLoadVals = 0;

// Temporary values that need to be stored after evaluting this arm
size_t numYieldVals = 0;

// All values that are used or produced by/from this arm.
// In order, these are:
// function args (funcArgNames)
// temporary values to load (numLoadVals)
// results from each operation in `ops`
llvm::SmallVector<mlir::Value> values;

llvm::ArrayRef<mlir::Value> getFuncArgValues() const {
return llvm::ArrayRef(values).slice(0, funcArgNames.size());
}
mlir::TypeRange getFuncArgTypes() const {
return mlir::TypeRange(llvm::ArrayRef(values).slice(0, funcArgNames.size()));
}
llvm::ArrayRef<mlir::Value> getLoadVals() const {
return llvm::ArrayRef(values).slice(funcArgNames.size(), numLoadVals);
}
llvm::ArrayRef<size_t> getYieldOffsets() const {
assert(numYieldVals <= valueOffsets.size());
return llvm::ArrayRef(valueOffsets).slice(valueOffsets.size() - numYieldVals);
}

// Value offsets of operands for each operation in this arm in
// order, followed by value offsets for all yielded values.
// Value offsets are offsets into `values`.
llvm::SmallVector<size_t> valueOffsets;
};

llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const ArmInfo& armInfo);

class ArmAnalysis {
public:
// Create a new analysis rooted at topOp.
ArmAnalysis(mlir::Operation* topOp);
ArmAnalysis(const ArmAnalysis&) = delete;
void operator=(const ArmAnalysis&) = delete;

// Analyzes the given set of operations.
ArmInfo getArmInfo(llvm::ArrayRef<mlir::Operation*> ops);

// Returns all distinct operations that necessitate being in separate arms.
llvm::ArrayRef<ArmInfo> getDistinctOps() const { return distinctOps; };

// Returns all multi-operation arm candidates found which occur at
// least kMinArmUseCount times.
llvm::ArrayRef<ArmInfo> getMultiOpArms() const { return multiOpArms; }

// Return the input and output types
mlir::FunctionType getFunctionType() const { return funcType; }

llvm::ArrayRef<mlir::StringAttr> getArgNames() const { return argNames; }

private:
void calcDispatchKey(mlir::Operation* op);

std::vector<ArmInfo> distinctOps;
std::vector<ArmInfo> multiOpArms;

std::vector<std::vector<mlir::Operation*>> blockOpStorage;

mlir::FunctionType funcType;
llvm::SmallVector<mlir::StringAttr> argNames;

friend struct ArmAnalysisImpl;
};

} // namespace zirgen::ByteCode
19 changes: 19 additions & 0 deletions zirgen/Dialect/ByteCode/Analysis/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

package(
default_visibility = ["//visibility:public"],
)


cc_library(
name = "Analysis",
srcs = [
"ArmAnalysis.cpp",
],
hdrs = [
"ArmAnalysis.h",
],
deps = [
"//zirgen/Dialect/ByteCode/IR",
"@llvm-project//mlir:FuncDialect",
],
)
77 changes: 77 additions & 0 deletions zirgen/Dialect/ByteCode/IR/Attrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "llvm/ADT/TypeSwitch.h"

#include "zirgen/Dialect/ByteCode/IR/ByteCode.h"

using namespace mlir;

namespace zirgen::ByteCode {

#if 0
DispatchKeyAttr getDispatchKey(Operation* op) {
SmallVector<size_t> intArgs;
if (auto bcInterface = llvm::dyn_cast<ByteCodeOpInterface>(op)) {
bcInterface.getByteCodeIntArgs(intArgs);
}

auto operandTypes = llvm::to_vector(op->getOperandTypes());
auto resultTypes = llvm::to_vector(op->getResultTypes());

SmallVector<mlir::Attribute> intKinds;
for (auto idx : llvm::seq(intArgs.size())) {
intKinds.push_back(StringAttr::get(
op->getContext(), (op->getName().getStringRef() + "_" + std::to_string(idx)).str()));
}

SmallVector<size_t> blockArgNums;
for (Value operand : op->getOperands()) {
if (auto blockArg = llvm::dyn_cast<BlockArgument>(operand)) {
blockArgNums.push_back(blockArg.getArgNumber());
}
}

return DispatchKeyAttr::get(op->getContext(),
/*operationName=*/op->getName().getStringRef(),
operandTypes,
resultTypes,
intKinds,
/*blockArgs=*/blockArgNums);
}
#endif

std::string getNameForIntKind(mlir::Attribute intKind) {
if (auto strAttr = llvm::dyn_cast<StringAttr>(intKind)) {
return strAttr.str();
}
if (auto unitAttr = llvm::dyn_cast<UnitAttr>(intKind)) {
return "unit";
}
std::string str;
llvm::raw_string_ostream os(str);
os << intKind;

llvm::erase_if(str, [](char c) { return c == '"' || c == ' '; });
return str;
}

} // namespace zirgen::ByteCode
20 changes: 20 additions & 0 deletions zirgen/Dialect/ByteCode/IR/Attrs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@


class IntKindAttr : public Attribute {
public:
using Attribute::Attribute;
using ValueType = bool;

static BoolAttr get(MLIRContext* context, bool value);

/// Enable conversion to IntegerAttr and its interfaces. This uses conversion
/// vs. inheritance to avoid bringing in all of IntegerAttrs methods.
operator IntegerAttr() const { return IntegerAttr(impl); }
operator TypedAttr() const { return IntegerAttr(impl); }

/// Return the boolean value of this attribute.
bool getValue() const;

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr);
};
62 changes: 62 additions & 0 deletions zirgen/Dialect/ByteCode/IR/Attrs.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2025 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BYTECODE_ATTRS
#define BYTECODE_ATTRS

include "mlir/IR/AttrTypeBase.td"
include "zirgen/Dialect/ByteCode/IR/Dialect.td"

class ByteCodeAttr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<ByteCodeDialect, name, traits> {
let mnemonic = attrMnemonic;
}

def TempBufAttr : ByteCodeAttr<"TempBuf", "temp_buf"> {
let parameters = (ins
// Name of this temporary buffer, and the intKind of values used to index into it.
"mlir::StringAttr": $bufName,
// Number of elements of this temporary buffer to be allocated.
"size_t": $size
);
let assemblyFormat = [{ $bufName `size` $size }];
}

def TempBufArrayAttr : TypedArrayAttrBase<TempBufAttr, "Array of temporary buffers">;

def IntKindInfoAttr : ByteCodeAttr<"IntKindInfo", "int_kind_info"> {
let summary = "Information on a set of encoded integers in a bytecode encoding";
let parameters = (ins
// Name of this integer kind
"mlir::StringAttr": $intKind,
// Number of bits used to encode
"size_t": $encodedBits
);
let assemblyFormat = [{ $intKind `u` $encodedBits }];
}

def IntKindInfoArrayAttr : TypedArrayAttrBase<IntKindInfoAttr, "Array of integer kinds">;

def EncodedAttr : ByteCodeAttr<"Encoded", "encoded"> {
let summary = "A encoded bytecode program for passing to an executor";
let parameters = (ins
// Actual encoded data
StringRefParameter<>: $encoded,
// Sizes of any temporary buffers that need to be allocated to execute this bytecode.
ArrayRefParameter<"zirgen::ByteCode::TempBufAttr">: $tempBufs
);
let assemblyFormat = [{ $encoded `temps` $tempBufs }];
}

#endif // BYTECODE_ATTRS
Loading
Loading