Skip to content

Commit

Permalink
Table augmented
Browse files Browse the repository at this point in the history
  • Loading branch information
hecmas committed Dec 3, 2024
1 parent 5211645 commit d401515
Showing 1 changed file with 99 additions and 31 deletions.
130 changes: 99 additions & 31 deletions state-machines/binary/pil/binary_table.pil
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,33 @@ require "constants.pil";

// PIL Binary Operations Table used by Binary
// Running Total
// MINU/MINU_W (OP:0x09) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^19
// MIN/MIN_W (OP:0x0a) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20
// MAXU/MAXU_W (OP:0x0b) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 + 2^19
// MAX/MAX_W (OP:0x0c) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^21
// LTU/LTU_W (OP:0x04) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^18
// LT/LT_W (OP:0x05) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^19
// EQ/EQ_W (OP:0x08) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^19 + 2^18
// ADD/ADD_W (OP:0x02) ** 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20
// SUB/SUB_W (OP:0x03) ** 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^18
// LEU/LEU_W (OP:0x06) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^18 + 2^17
// LE/LE_W (OP:0x07) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19
// AND/AND_W (OP:0x20) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^17
// OR/OR_W (OP:0x21) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^18
// XOR/XOR_W (OP:0x22) 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^21 + 2^20 + 2^19 + 2^18 + 2^17
// EXT_32 (OP:0x23) 2^8 (A) x 2^1 (CIN) x 2^2 (FLAGS) = 2^16 | 2^21 + 2^20 + 2^19 + 2^18 + 2^17 + 2^11 => < 2^22
// MINU/MINU_W (OP:0x09) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^19
// MIN/MIN_W (OP:0x0a) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20
// MAXU/MAXU_W (OP:0x0b) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^20 + 2^19
// MAX/MAX_W (OP:0x0c) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) x 2^1 (RESULT_IS_A) = 2^19 | 2^21
// LT_ABS_NP (OP:????) * 2^16 (AxB) x 2^1 (LAST) x 2^2 (CIN) = 2^19 | 2^21 + 2^19
// LT_ABS_PN (OP:????) * 2^16 (AxB) x 2^1 (LAST) x 2^2 (CIN) = 2^19 | 2^21 + 2^20
// LTU/LTU_W (OP:0x04) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^18
// LT/LT_W (OP:0x05) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^19
// EQ/EQ_W (OP:0x08) * 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^21 + 2^20 + 2^19 + 2^18
// ADD/ADD_W (OP:0x02) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^22
// SUB/SUB_W (OP:0x03) 2^16 (AxB) x 2^1 (LAST) x 2^1 (CIN) = 2^18 | 2^22 + 2^18
// LEU/LEU_W (OP:0x06) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^18 + 2^17
// LE/LE_W (OP:0x07) * 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19
// AND/AND_W (OP:0x20) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^17
// OR/OR_W (OP:0x21) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^18
// XOR/XOR_W (OP:0x22) ** 2^16 (AxB) x 2^1 (LAST) = 2^17 | 2^22 + 2^19 + 2^18 + 2^17
// EXT_32 (OP:0x23) *** 2^8 (A) x 2^1 (CIN) x 2^2 (FLAGS) = 2^11 | 2^22 + 2^19 + 2^18 + 2^17 + 2^11 => < 2^23
// --------------------------------------------------------------------------------------------------------------------------
// (*) Use carry
// (**) Do not use last indicator, but it is used for simplicity of the lookup
// (*) Uses the carry of the last byte of the result (use_last_carry)
// (**) The op do not use LAST, but the binary does so we need to consider it
// (***) The op do not use CIN, but the binary does so we need to consider it
// Note: EXT_32 is the only unary operation

const int EXT_32_OP = 0x23;
const int BINARY_TABLE_ID = 125;

airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) {
airtemplate BinaryTable(const int N = 2**23, const int disable_fixed = 0) {

#pragma memory m1 start
col witness multiplicity;
Expand All @@ -39,28 +42,41 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) {
return;
}

if (N < 2**22) {
error(`N must be at least 2^22, but N=${N} was provided`);
if (N < 2**23) {
error(`N must be at least 2^23, but N=${N} was provided`);
}

#pragma timer tt start
#pragma timer t1 start

col fixed A = [0..255]...; // Input A (8 bits)
// Input A (8 bits)
col fixed A = [0..255]...;

col fixed B = [[0:P2_8..255:P2_8]:62,0:P2_11]...; // Input B (8 bits)
// Input B (8 bits)
col fixed B = [[0:P2_8..255:P2_8]:78, // 78 = 4*8 + 2*8 + 3*4 + 2*4 + 2*2 + 3*2
0:P2_11]...; // B is 0 for EXT_32

col fixed LAST = [[0:P2_16, 1:P2_16]:(4*4), // Indicator of the last byte (1 bit)
[0:P2_16, 1:P2_16]:(5*2),
[0:P2_16, 1:P2_16]:5,
// Indicator of the last byte (1 bit)
col fixed LAST = [[0:P2_16, 1:P2_16]:(4*4), // MINU,MIN,MAXU,MAX
[0:P2_16, 1:P2_16]:(2*4), // LT_ABS_NP,LT_ABS_PN
[0:P2_16, 1:P2_16]:(3*2), // LTU,LT,EQ
[0:P2_16, 1:P2_16]:(2*2), // ADD,SUB
[0:P2_16, 1:P2_16]:2, // LEU,LE
[0:P2_16, 1:P2_16]:3, // AND,OR,XOR
0:P2_11]...;

col fixed CIN = [[0:P2_17, 1:P2_17]:(4*2), // Input carry (1 bit)
[0:P2_17, 1:P2_17]:5,
0:(P2_17*5),
[0:P2_8, 1:P2_8]:4]...;
// Input carry (1/2 bits)
col fixed CIN = [[0:P2_17, 1:P2_17]:(4*2), // MINU,MIN,MAXU,MAX
[0:P2_17..3:P2_17]:2, // LT_ABS_NP,LT_ABS_PN
[0:P2_17, 1:P2_17]:3, // LTU,LT,EQ
[0:P2_17, 1:P2_17]:2, // ADD,SUB
0:(P2_17*2), // LEU,LE
0:(P2_17*3), // AND,OR,XOR
[0:P2_8, 1:P2_8]:4]...; // EXT_32

// Operation opcode (fixed values)
col fixed OP = [0x09:P2_19, 0x0a:P2_19, 0x0b:P2_19, 0x0c:P2_19, // MINU,MIN,MAXU,MAX
0x??:P2_19, 0x??:P2_19, // LT_ABS_NP,LT_ABS_PN
0x04:P2_18, 0x05:P2_18, 0x08:P2_18, // LTU,LT,EQ
0x02:P2_18, 0x03:P2_18, // ADD,SUB
0x06:P2_17, 0x07:P2_17, // LEU,LE
Expand All @@ -70,8 +86,8 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) {
// NOTE: MINU/MINU_W, MIN/MIN_W, MAXU/MAXU_W, MAX/MAX_W has double size because
// the result_is_a is 0 in the first half and 1 in the second half.

const int TABLE_SIZE = P2_19 * 4 + P2_18 * 5 + P2_17 * 6;
const int TABLE_BASE_EXT32 = P2_16 * 62;
const int TABLE_SIZE = P2_19 * 6 + P2_18 * 5 + P2_17 * 5 + P2_11;
const int TABLE_BASE_EXT32 = P2_16 * 78;

#pragma timer t1 end
#pragma timer t2 start
Expand Down Expand Up @@ -178,6 +194,58 @@ airtemplate BinaryTable(const int N = 2**22, const int disable_fixed = 0) {
}
op_is_min_max = 1;

case 0x0d: // LT_ABS_NP
// Both necessary carries are encoded by cin in binary as
// cin = 0bYX,
// where X is the carry of the LT operation and Y is
// the carry of the operation a ^ 0xFF + _cop

// Decode the carries
const int _clt = cin & 0x01;
const int _cop = cin & 0x02;

const int _a = a ^ 0xFF + _cop; // _cop should be 1 at the first byte and _a >> 8 at the rest
const int _b = b;

if ((_a & 0xFF) < _b) {
cout = 1;
c = plast;
} else if ((_a & 0xFF) == _b) {
cout = _clt;
c = plast * _clt;
}

// Encode the result carries
cout += 2*(_a >> 8);

use_last_carry = plast;

case 0x0e: // LT_ABS_PN
// Both necessary carries are encoded by cin in binary as
// cin = 0bYX,
// where X is the carry of the LT operation and Y is
// the carry of the operation b ^ 0xFF + _cop

// Decode the carries
const int _clt = cin & 0x01;
const int _cop = cin & 0x02;

const int _a = a;
const int _b = b ^ 0xFF + _cop; // _cop should be 1 at the first byte and _b >> 8 at the rest

if (_a < (_b & 0xFF)) {
cout = 1;
c = plast;
} else if (_a == (_b & 0xFF)) {
cout = _clt;
c = plast * _clt;
}

// Encode the result carries
cout += 2*(_b >> 8);

use_last_carry = plast;

case 0x20: // AND
c = a & b;

Expand Down

0 comments on commit d401515

Please sign in to comment.