Skip to content

Commit

Permalink
Generate signals for layout members
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobdweightman committed Jan 18, 2025
1 parent b11c5cb commit bfc8ebe
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 50 deletions.
63 changes: 46 additions & 17 deletions zirgen/compiler/picus/picus.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,20 @@ enum class SignalType {
AssumeDeterministic,
};

template <typename F> void visit(AnySignal signal, F f) {
template <typename F> void visit(AnySignal signal, F f, bool visitedLayout = false) {
if (!signal) {
// no-op
} else if (auto s = dyn_cast<Signal>(signal)) {
f(s);
} else if (auto arr = dyn_cast<SignalArray>(signal)) {
for (auto elem : arr)
visit(elem, f);
visit(elem, f, visitedLayout);
} else if (auto str = dyn_cast<SignalStruct>(signal)) {
for (auto field : str) {
assert(field.getName() != "@layout");
visit(field.getValue(), f);
if (!visitedLayout || field.getName() != "@layout") {
visitedLayout = true;
visit(field.getValue(), f, visitedLayout);
}
}
}
}
Expand Down Expand Up @@ -122,18 +124,17 @@ class PicusPrinter {
workQueue.push(lookupConstructor(param.getType()));
}

// The layout is an output
if (auto layout = component.getLayout()) {
AnySignal layoutSignal = signalize("layout", layout.getType());
declareSignals(layoutSignal, SignalType::Output);
valuesToSignals.insert({layout, layoutSignal});
}

// The result is an output
AnySignal result = signalize("result", component.getOutType());
declareSignals(result, SignalType::Output);
valuesToSignals.insert({Value(), result});

// The layout is an output
if (auto layout = component.getLayout()) {
AnySignal layoutSignal = cast<SignalStruct>(result).get("@layout");
valuesToSignals.insert({layout, layoutSignal});
}

for (Operation& op : component.getBody().front()) {
visitOp(&op);
}
Expand Down Expand Up @@ -208,6 +209,7 @@ class PicusPrinter {
}

void visitOp(LookupOp lookup) {
// TODO: this doesn't handle @super member lookups!!!!!!!
auto signal = cast<SignalStruct>(valuesToSignals.at(lookup.getBase()));
auto subSignal = signal.get(lookup.getMember());
valuesToSignals.insert({lookup.getOut(), subSignal});
Expand Down Expand Up @@ -349,7 +351,7 @@ class PicusPrinter {
void visitOp(PackOp pack) {
SmallVector<NamedAttribute> fields;
for (auto [field, arg] : llvm::zip(pack.getOut().getType().getFields(), pack.getMembers())) {
if (field.isPrivate || field.name.strref() == "@layout")
if (field.isPrivate)
continue;
AnySignal member = valuesToSignals.at(arg);
fields.emplace_back(field.name, member);
Expand All @@ -368,7 +370,11 @@ class PicusPrinter {
SmallVector<Signal> rets = flatten(returnSignal);
assert(outs.size() == rets.size());
for (auto [outs, rets] : llvm::zip(outs, rets)) {
os << "(assert (= " << outs.str() << " " << rets.str() << "))\n";
// Skip emitting vacuous constraints (a = a). These can come from the same
// layout occuring at different levels of nesting within an @super member.
if (outs != rets) {
os << "(assert (= " << outs.str() << " " << rets.str() << "))\n";
}
}
}

Expand Down Expand Up @@ -459,24 +465,47 @@ class PicusPrinter {
}

// Constructs a fresh signal structure corresponding to the given type
AnySignal signalize(std::string prefix, Type type) {
AnySignal signalize(std::string prefix, Type type, AnySignal layout = nullptr) {
if (isa<ValType>(type) || isa<RefType>(type)) {
return Signal::get(ctx, prefix);
} else if (auto array = dyn_cast<ArrayLikeTypeInterface>(type)) {
SmallVector<AnySignal> elements;
for (size_t i = 0; i < array.getSize(); i++) {
std::string name = prefix + "_" + std::to_string(i);
elements.push_back(signalize(name, array.getElement()));
AnySignal sublayout;
if (auto arrLayout = cast_if_present<SignalArray>(layout)) {
sublayout = arrLayout[i];
}
elements.push_back(signalize(name, array.getElement(), sublayout));
}
return SignalArray::get(ctx, elements);
} else if (auto str = dyn_cast<StructType>(type)) {
SmallVector<NamedAttribute> fields;
// If we haven't generated a layout yet, generate it first. Then,
// recursively pass along sublayouts for reuse, so that we don't generate
// extra signals for registers at every level of nesting. For example,
// [email protected] and foo.bar.@layout should refer to the same layout.
if (!layout) {
for (auto field : str.getFields()) {
if (field.name == "@layout") {
std::string name = prefix + "_" + canonicalizeIdentifier(field.name.str());
layout = signalize(name, field.type);
break;
}
}
}
for (auto field : str.getFields()) {
if (field.name.strref() == "@layout")
if (field.name == "@layout") {
fields.emplace_back(field.name, layout);
continue;
}
if (!field.isPrivate) {
std::string name = prefix + "_" + canonicalizeIdentifier(field.name.str());
fields.emplace_back(field.name, signalize(name, field.type));
AnySignal sublayout;
if (auto strLayout = cast_if_present<SignalStruct>(layout)) {
sublayout = strLayout.get(field.name);
}
fields.emplace_back(field.name, signalize(name, field.type, sublayout));
}
}
return SignalStruct::get(ctx, fields);
Expand Down
12 changes: 6 additions & 6 deletions zirgen/compiler/picus/test/alias_layout_1.zir
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

// CHECK: (prime-number 2013265921)
// CHECK-NEXT: (begin-module Top)
// CHECK-NEXT: (output layout_a__super)
// CHECK-NEXT: (output layout_b__super__super)
// CHECK-NEXT: (output result__layout_a__super)
// CHECK-NEXT: (output result__layout_b__super__super)
// CHECK-NEXT: (output result_a__super)
// CHECK-NEXT: (output result_b__super__super)
// CHECK-NEXT: (assert (= x0 0))
// CHECK-NEXT: (assert (= x1 (- x0 layout_b__super__super)))
// CHECK-NEXT: (assert (= x1 (- x0 result__layout_b__super__super)))
// CHECK-NEXT: (assert (= x1 0))
// CHECK-NEXT: (assert (= layout_a__super layout_b__super__super))
// CHECK-NEXT: (assert (= result_a__super layout_a__super))
// CHECK-NEXT: (assert (= result_b__super__super layout_b__super__super))
// CHECK-NEXT: (assert (= result__layout_a__super result__layout_b__super__super))
// CHECK-NEXT: (assert (= result_a__super result__layout_a__super))
// CHECK-NEXT: (assert (= result_b__super__super result__layout_b__super__super))
// CHECK-NEXT: (end-module)

#[picus_analyze]
Expand Down
17 changes: 17 additions & 0 deletions zirgen/compiler/picus/test/argument_layout.zir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: zirgen %s --emit=picus | FileCheck %s

// Initially during development, we used a simpler scheme for translating
// parameter layouts than we ended up needing. This test ensures that the layout
// of the argument and result appear as intended, and without any redundancy.

// CHECK: (begin-module Foo)
// CHECK-NEXT: (input x0__layout__super__super)
// CHECK-NEXT: (input x0__super__super)
// CHECK-NEXT: (output result__layout__super__super__super)
// CHECK-NEXT: (output result__super__super__super)

#[picus_analyze]
component Foo(x: Reg) {
// x@0 = 9;
Reg(x@0)
}
4 changes: 2 additions & 2 deletions zirgen/compiler/picus/test/log.zir
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

// CHECK: (prime-number 2013265921)
// CHECK-NEXT: (begin-module Top)
// CHECK-NEXT: (output layout_x__super__super)
// CHECK-NEXT: (output result__layout_x__super__super)
// CHECK-NEXT: (assert (= x0 5))
// CHECK-NEXT: (assert (= x1 (- x0 layout_x__super__super)))
// CHECK-NEXT: (assert (= x1 (- x0 result__layout_x__super__super)))
// CHECK-NEXT: (assert (= x1 0))
// CHECK-NEXT: (call [] Log [ ])
// CHECK-NEXT: (end-module)
Expand Down
24 changes: 12 additions & 12 deletions zirgen/compiler/picus/test/map.zir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

// CHECK: (prime-number 2013265921)
// CHECK-NEXT: (begin-module Top)
// CHECK-NEXT: (output layout_a_0__super__super__super__super)
// CHECK-NEXT: (output layout_a_1__super__super__super__super)
// CHECK-NEXT: (output layout_a_2__super__super__super__super)
// CHECK-NEXT: (output layout_a_3__super__super__super__super)
// CHECK-NEXT: (output result__layout_a_0__super__super__super__super)
// CHECK-NEXT: (output result__layout_a_1__super__super__super__super)
// CHECK-NEXT: (output result__layout_a_2__super__super__super__super)
// CHECK-NEXT: (output result__layout_a_3__super__super__super__super)
// CHECK-NEXT: (output result_a_0__super__super__super__super)
// CHECK-NEXT: (output result_a_1__super__super__super__super)
// CHECK-NEXT: (output result_a_2__super__super__super__super)
Expand All @@ -14,18 +14,18 @@
// CHECK-NEXT: (assert (= x1 2))
// CHECK-NEXT: (assert (= x2 1))
// CHECK-NEXT: (assert (= x3 0))
// CHECK-NEXT: (assert (= x4 (- x3 layout_a_0__super__super__super__super)))
// CHECK-NEXT: (assert (= x4 (- x3 result__layout_a_0__super__super__super__super)))
// CHECK-NEXT: (assert (= x4 0))
// CHECK-NEXT: (assert (= x5 (- x2 layout_a_1__super__super__super__super)))
// CHECK-NEXT: (assert (= x5 (- x2 result__layout_a_1__super__super__super__super)))
// CHECK-NEXT: (assert (= x5 0))
// CHECK-NEXT: (assert (= x6 (- x1 layout_a_2__super__super__super__super)))
// CHECK-NEXT: (assert (= x6 (- x1 result__layout_a_2__super__super__super__super)))
// CHECK-NEXT: (assert (= x6 0))
// CHECK-NEXT: (assert (= x7 (- x0 layout_a_3__super__super__super__super)))
// CHECK-NEXT: (assert (= x7 (- x0 result__layout_a_3__super__super__super__super)))
// CHECK-NEXT: (assert (= x7 0))
// CHECK-NEXT: (assert (= result_a_0__super__super__super__super layout_a_0__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_1__super__super__super__super layout_a_1__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_2__super__super__super__super layout_a_2__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_3__super__super__super__super layout_a_3__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_0__super__super__super__super result__layout_a_0__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_1__super__super__super__super result__layout_a_1__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_2__super__super__super__super result__layout_a_2__super__super__super__super))
// CHECK-NEXT: (assert (= result_a_3__super__super__super__super result__layout_a_3__super__super__super__super))
// CHECK-NEXT: (end-module)

#[picus_analyze]
Expand Down
27 changes: 14 additions & 13 deletions zirgen/compiler/picus/test/mux.zir
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,34 @@

// CHECK: (prime-number 2013265921)
// CHECK-NEXT: (begin-module Top)
// CHECK-NEXT: (output layout_first__super__super)
// CHECK-NEXT: (output layout_x__super__super__super)
// CHECK-NEXT: (output layout_x_arm0__super__super__super)
// CHECK-NEXT: (output layout_x_arm1__super__super__super)
// CHECK-NEXT: (output result__layout_first__super__super)
// CHECK-NEXT: (output result__layout_x__super__super__super)
// CHECK-NEXT: (output result__layout_x_arm0__super__super__super)
// CHECK-NEXT: (output result__layout_x_arm1__super__super__super)
// CHECK-NEXT: (output result_first__super__super)
// CHECK-NEXT: (output result_x__super__super)
// CHECK-NEXT: (assert (= x0 8))
// CHECK-NEXT: (assert (= x1 7))
// CHECK-NEXT: (assert (= x2 1))
// CHECK-NEXT: (assert (= x3 0))
// CHECK-NEXT: (assert (= x4 (- x3 layout_first__super__super)))
// CHECK-NEXT: (assert (= x4 (- x3 result__layout_first__super__super)))
// CHECK-NEXT: (assert (= x4 0))
// CHECK-NEXT: (assert (= x5 (- x2 layout_first__super__super)))
// CHECK-NEXT: (assert (= x5 (- x2 result__layout_first__super__super)))
// CHECK-NEXT: ; begin mux
// CHECK-NEXT: (assert (= (* x5 layout_x__super__super__super) (* x5 layout_x_arm0__super__super__super)))
// CHECK-NEXT: (assert (= x6 (- x1 layout_x_arm0__super__super__super)))
// CHECK-NEXT: (assert (= (* x5 result__layout_x__super__super__super) (* x5 result__layout_x_arm0__super__super__super)))
// CHECK-NEXT: (assert (= x6 (- x1 result__layout_x_arm0__super__super__super)))
// CHECK-NEXT: (assert (= x6 0))
// CHECK-NEXT: ; mark mux arm
// CHECK-NEXT: (assert (= (* layout_first__super__super layout_x__super__super__super) (* layout_first__super__super layout_x_arm1__super__super__super)))
// CHECK-NEXT: (assert (= x7 (- x0 layout_x_arm1__super__super__super)))
// CHECK-NEXT: (assert (= (* result__layout_first__super__super result__layout_x__super__super__super) (* result__layout_first__super__super result__layout_x_arm1__super__super__super)))
// CHECK-NEXT: (assert (= x7 (- x0 result__layout_x_arm1__super__super__super)))
// CHECK-NEXT: (assert (= x7 0))
// CHECK-NEXT: ; mark mux arm
// CHECK-NEXT: (assert (= mux_x8__super__super (+ (* x5 layout_x_arm0__super__super__super) (* layout_first__super__super layout_x_arm1__super__super__super))))
// CHECK-NEXT: (assert (= mux_x8__layout__super__super (+ (* x5 result__layout_x_arm0__super__super__super) (* result__layout_first__super__super result__layout_x_arm1__super__super__super))))
// CHECK-NEXT: (assert (= mux_x8__super__super (+ (* x5 result__layout_x_arm0__super__super__super) (* result__layout_first__super__super result__layout_x_arm1__super__super__super))))
// CHECK-NEXT: ; end mux
// CHECK-NEXT: (call [] Log [ ])
// CHECK-NEXT: (assert (= result_first__super__super layout_first__super__super))
// CHECK-NEXT: (assert (= result_x__super__super layout_x__super__super__super))
// CHECK-NEXT: (assert (= result_first__super__super result__layout_first__super__super))
// CHECK-NEXT: (assert (= result_x__super__super result__layout_x__super__super__super))
// CHECK-NEXT: (end-module)

#[picus_analyze]
Expand Down

0 comments on commit bfc8ebe

Please sign in to comment.