diff --git a/src/simple.test.ts b/src/simple.test.ts index 9e418b2..41cc1cf 100644 --- a/src/simple.test.ts +++ b/src/simple.test.ts @@ -1,10 +1,11 @@ import { test, testProp, fc } from '@fast-check/ava'; import { HashZero as zero } from '@ethersproject/constants'; -import { SimpleMerkleTree } from './simple'; import { keccak256 } from '@ethersproject/keccak256'; +import { SimpleMerkleTree } from './simple'; import { BytesLike, HexString, concat, compare } from './bytes'; -const reverseHashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse())); +const reverseNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse())); +const otherNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(reverseNodeHash(a, b)); // double hash import { toHex } from './bytes'; import { InvalidArgumentError, InvariantError } from './utils/errors'; @@ -13,7 +14,7 @@ const leaf = fc.uint8Array({ minLength: 32, maxLength: 32 }).map(toHex); const leaves = fc.array(leaf, { minLength: 1 }); const options = fc.record({ sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()), - nodeHash: fc.oneof(fc.constant(undefined), fc.constant(reverseHashPair)), + nodeHash: fc.oneof(fc.constant(undefined), fc.constant(reverseNodeHash)), }); const tree = fc @@ -94,8 +95,8 @@ testProp( (t, leaves) => { t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).render()); t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).render()); - t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseHashPair }).render()); - t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseHashPair }).render()); + t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).render()); + t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).render()); }, { numRuns: 1, seed: 0 }, ); @@ -106,8 +107,8 @@ testProp( (t, leaves) => { t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).dump()); t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).dump()); - t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseHashPair }).dump()); - t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseHashPair }).dump()); + t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).dump()); + t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).dump()); }, { numRuns: 1, seed: 0 }, ); @@ -115,7 +116,7 @@ testProp( testProp('dump and load', [tree], (t, [tree, options]) => { const dump = tree.dump(); const recoveredTree = SimpleMerkleTree.load(dump, options.nodeHash); - recoveredTree.validate(); + recoveredTree.validate(); // already done in load t.is(dump.hash, options.nodeHash ? 'custom' : undefined); t.is(tree.root, recoveredTree.root); @@ -128,6 +129,12 @@ testProp('reject out of bounds value index', [tree], (t, [tree]) => { t.throws(() => tree.getProof(-1), new InvalidArgumentError('Index out of bounds')); }); +// We need at least 2 leaves for internal node hashing to come into play +testProp('reject loading dump with wrong node hash', [ fc.array(leaf, { minLength: 2 }) ] , (t, leaves) => { + const dump = SimpleMerkleTree.of(leaves, { nodeHash: reverseNodeHash }).dump(); + t.throws(() => SimpleMerkleTree.load(dump, otherNodeHash), new InvariantError('Merkle tree is invalid')); +}); + test('reject invalid leaf size', t => { const invalidLeaf = '0x000000000000000000000000000000000000000000000000000000000000000000'; t.throws(() => SimpleMerkleTree.of([invalidLeaf]), { @@ -148,22 +155,23 @@ test('reject unrecognized tree dump', t => { }); test('reject malformed tree dump', t => { - const loadedTree1 = SimpleMerkleTree.load({ - format: 'simple-v1', - tree: [zero], - values: [ - { - value: '0x0000000000000000000000000000000000000000000000000000000000000001', - treeIndex: 0, - }, - ], - }); - t.throws(() => loadedTree1.getProof(0), new InvariantError('Merkle tree does not contain the expected value')); + t.throws( + () => SimpleMerkleTree.load({ + format: 'simple-v1', + tree: [zero], + values: [ + { + value: '0x0000000000000000000000000000000000000000000000000000000000000001', + treeIndex: 0, + }, + ], + }), + new InvariantError('Merkle tree does not contain the expected value') + ); - const loadedTree2 = SimpleMerkleTree.load({ + t.throws(() => SimpleMerkleTree.load({ format: 'simple-v1', tree: [zero, zero, zero], values: [{ value: zero, treeIndex: 2 }], - }); - t.throws(() => loadedTree2.getProof(0), new InvariantError('Unable to prove value')); + }), new InvariantError('Merkle tree is invalid')); }); diff --git a/src/simple.ts b/src/simple.ts index 77c1697..ff1d550 100644 --- a/src/simple.ts +++ b/src/simple.ts @@ -32,7 +32,9 @@ export class SimpleMerkleTree extends MerkleTreeImpl { nodeHash ? 'Data does not expect a custom node hashing function' : 'Data expects a custom node hashing function', ); - return new SimpleMerkleTree(data.tree, data.values, formatLeaf, nodeHash); + const tree = new SimpleMerkleTree(data.tree, data.values, formatLeaf, nodeHash); + tree.validate(); + return tree; } static verify(root: BytesLike, leaf: BytesLike, proof: BytesLike[], nodeHash?: NodeHash): boolean { diff --git a/src/standard.test.ts b/src/standard.test.ts index 2d622d4..d88c5cf 100644 --- a/src/standard.test.ts +++ b/src/standard.test.ts @@ -85,7 +85,7 @@ testProp( testProp('dump and load', [tree], (t, tree) => { const recoveredTree = StandardMerkleTree.load(tree.dump()); - recoveredTree.validate(); + recoveredTree.validate(); // already done in load t.is(tree.root, recoveredTree.root); t.is(tree.render(), recoveredTree.render()); @@ -110,19 +110,23 @@ test('reject unrecognized tree dump', t => { }); test('reject malformed tree dump', t => { - const loadedTree1 = StandardMerkleTree.load({ - format: 'standard-v1', - tree: [zero], - values: [{ value: ['0'], treeIndex: 0 }], - leafEncoding: ['uint256'], - }); - t.throws(() => loadedTree1.getProof(0), new InvariantError('Merkle tree does not contain the expected value')); - - const loadedTree2 = StandardMerkleTree.load({ - format: 'standard-v1', - tree: [zero, zero, keccak256(keccak256(zero))], - values: [{ value: ['0'], treeIndex: 2 }], - leafEncoding: ['uint256'], - }); - t.throws(() => loadedTree2.getProof(0), new InvariantError('Unable to prove value')); + t.throws( + () => StandardMerkleTree.load({ + format: 'standard-v1', + tree: [zero], + values: [{ value: ['0'], treeIndex: 0 }], + leafEncoding: ['uint256'], + }), + new InvariantError('Merkle tree does not contain the expected value'), + ); + + t.throws( + () => StandardMerkleTree.load({ + format: 'standard-v1', + tree: [zero, zero, keccak256(keccak256(zero))], + values: [{ value: ['0'], treeIndex: 2 }], + leafEncoding: ['uint256'], + }), + new InvariantError('Merkle tree is invalid'), + ); }); diff --git a/src/standard.ts b/src/standard.ts index cd49460..c69488d 100644 --- a/src/standard.ts +++ b/src/standard.ts @@ -32,7 +32,10 @@ export class StandardMerkleTree extends MerkleTreeImpl { static load(data: StandardMerkleTreeData): StandardMerkleTree { validateArgument(data.format === 'standard-v1', `Unknown format '${data.format}'`); validateArgument(data.leafEncoding !== undefined, 'Expected leaf encoding'); - return new StandardMerkleTree(data.tree, data.values, data.leafEncoding); + + const tree = new StandardMerkleTree(data.tree, data.values, data.leafEncoding); + tree.validate(); + return tree; } static verify(root: BytesLike, leafEncoding: string[], leaf: T, proof: BytesLike[]): boolean {