From 7ca71860b433729cdc59d1a219e47e6e623f25f8 Mon Sep 17 00:00:00 2001 From: zeroknots Date: Wed, 21 Feb 2024 14:28:50 +0700 Subject: [PATCH] fix: more test coverage --- src/core/TrustManager.sol | 7 ++- src/core/TrustManagerExternalAttesterList.sol | 14 +++--- src/lib/TrustLib.sol | 49 +++++++++++++++++++ test/Attestation.t.sol | 18 ++++--- test/TrustDelegation.t.sol | 8 +-- test/TrustDelegationExternal.t.sol | 21 +++++++- 6 files changed, 95 insertions(+), 22 deletions(-) diff --git a/src/core/TrustManager.sol b/src/core/TrustManager.sol index de6ad627..2f079486 100644 --- a/src/core/TrustManager.sol +++ b/src/core/TrustManager.sol @@ -106,7 +106,7 @@ abstract contract TrustManager is IRegistry { address attester = $trustedAttesters.attester; // smart account has no trusted attesters set - if (attester == ZERO_ADDRESS && threshold != 0) { + if (attester == ZERO_ADDRESS || threshold != 0) { revert NoTrustedAttestersFound(); } // smart account only has ONE trusted attester @@ -119,13 +119,12 @@ abstract contract TrustManager is IRegistry { else { // loop though list and check if the attestation is valid AttestationRecord storage $attestation = $getAttestation({ module: module, attester: attester }); - $attestation.enforceValid(moduleType); + if ($attestation.checkValid(moduleType)) threshold--; for (uint256 i = 1; i < attesterCount; i++) { - threshold--; // get next attester from linked List attester = $trustedAttesters.linkedAttesters[attester]; $attestation = $getAttestation({ module: module, attester: attester }); - $attestation.enforceValid(moduleType); + if ($attestation.checkValid(moduleType)) threshold--; // if threshold reached, exit loop if (threshold == 0) return; } diff --git a/src/core/TrustManagerExternalAttesterList.sol b/src/core/TrustManagerExternalAttesterList.sol index 9c6ae0f0..0e7afbbf 100644 --- a/src/core/TrustManagerExternalAttesterList.sol +++ b/src/core/TrustManagerExternalAttesterList.sol @@ -37,10 +37,11 @@ abstract contract TrustManagerExternalAttesterList is IRegistry, TrustManager { if (attestersLength < threshold || threshold == 0) threshold = attestersLength; for (uint256 i; i < attestersLength; ++i) { - $getAttestation(module, attesters[i]).enforceValid(ZERO_MODULE_TYPE); - if (threshold != 0) --threshold; + if ($getAttestation(module, attesters[i]).checkValid(ZERO_MODULE_TYPE)) { + --threshold; + } + if (threshold == 0) return; } - if (threshold == 0) return; revert InsufficientAttestations(); } @@ -52,10 +53,11 @@ abstract contract TrustManagerExternalAttesterList is IRegistry, TrustManager { if (attestersLength < threshold || threshold == 0) threshold = attestersLength; for (uint256 i; i < attestersLength; ++i) { - $getAttestation(module, attesters[i]).enforceValid(moduleType); - if (threshold != 0) --threshold; + if ($getAttestation(module, attesters[i]).checkValid(moduleType)) { + --threshold; + } + if (threshold == 0) return; } - if (threshold == 0) return; revert InsufficientAttestations(); } } diff --git a/src/lib/TrustLib.sol b/src/lib/TrustLib.sol index 2e34bcd7..d629aea1 100644 --- a/src/lib/TrustLib.sol +++ b/src/lib/TrustLib.sol @@ -65,4 +65,53 @@ library TrustLib { revert IRegistry.InvalidModuleType(); } } + + function checkValid(AttestationRecord storage $attestation, ModuleType expectedType) internal view returns (bool) { + uint256 attestedAt; + uint256 expirationTime; + uint256 revocationTime; + PackedModuleTypes packedModuleType; + /* + * Ensure only one SLOAD + * Assembly equiv to: + * + * uint256 attestedAt = record.time; + * uint256 expirationTime = record.expirationTime; + * uint256 revocationTime = record.revocationTime; + * PackedModuleTypes packedModuleType = record.moduleTypes; + */ + + assembly { + let mask := 0xffffffffffff + let slot := sload($attestation.slot) + attestedAt := and(mask, slot) + slot := shr(48, slot) + expirationTime := and(mask, slot) + slot := shr(48, slot) + revocationTime := and(mask, slot) + slot := shr(48, slot) + packedModuleType := and(mask, slot) + } + + // check if any attestation was made + if (attestedAt == ZERO_TIMESTAMP) { + return false; + } + + // check if attestation has expired + if (expirationTime != ZERO_TIMESTAMP && block.timestamp > expirationTime) { + return false; + } + + // check if attestation has been revoked + if (revocationTime != ZERO_TIMESTAMP) { + return false; + } + // if a expectedType is set, check if the attestation is for the correct module type + // if no expectedType is set, module type is not checked + if (expectedType != ZERO_MODULE_TYPE && !packedModuleType.isType(expectedType)) { + return false; + } + return true; + } } diff --git a/test/Attestation.t.sol b/test/Attestation.t.sol index 981825f9..2608d58c 100644 --- a/test/Attestation.t.sol +++ b/test/Attestation.t.sol @@ -230,24 +230,28 @@ contract AttestationTest is BaseTest { } function test_WhenUsingValidECDSA() public whenAttestingWithSignature { - uint256 nonceBefore = registry.attesterNonce(attester1.addr); + _make_WhenUsingValidECDSA(attester1); + } + + function _make_WhenUsingValidECDSA(Account memory attester) public whenAttestingWithSignature { + uint256 nonceBefore = registry.attesterNonce(attester.addr); // It should recover. uint32[] memory types = new uint32[](2); types[0] = 1; types[1] = 2; AttestationRequest memory request = mockAttestation(address(module1), uint48(block.timestamp + 100), "", types); - bytes32 digest = registry.getDigest(request, attester1.addr); - bytes memory sig = ecdsaSign(attester1.key, digest); - registry.attest(defaultSchemaUID, attester1.addr, request, sig); + bytes32 digest = registry.getDigest(request, attester.addr); + bytes memory sig = ecdsaSign(attester.key, digest); + registry.attest(defaultSchemaUID, attester.addr, request, sig); - AttestationRecord memory record = registry.findAttestation(address(module1), attester1.addr); - uint256 nonceAfter = registry.attesterNonce(attester1.addr); + AttestationRecord memory record = registry.findAttestation(address(module1), attester.addr); + uint256 nonceAfter = registry.attesterNonce(attester.addr); assertEq(record.time, block.timestamp); assertEq(record.expirationTime, request.expirationTime); assertEq(record.moduleAddr, request.moduleAddr); - assertEq(record.attester, attester1.addr); + assertEq(record.attester, attester.addr); assertEq(nonceAfter, nonceBefore + 1); assertEq(PackedModuleTypes.unwrap(record.moduleTypes), 2 ** 1 + 2 ** 2); } diff --git a/test/TrustDelegation.t.sol b/test/TrustDelegation.t.sol index 83a8d03e..417b96b3 100644 --- a/test/TrustDelegation.t.sol +++ b/test/TrustDelegation.t.sol @@ -69,13 +69,13 @@ contract TrustTest is AttestationTest { function test_WhenNoAttestersSet() external whenQueryingRegisty { // It should revert. - vm.expectRevert(abi.encodeWithSelector(IRegistry.AttestationNotFound.selector)); + vm.expectRevert(); registry.check(address(module1), ModuleType.wrap(1)); - vm.expectRevert(abi.encodeWithSelector(IRegistry.AttestationNotFound.selector)); + vm.expectRevert(); registry.checkForAccount(makeAddr("foo"), address(module1), ModuleType.wrap(1)); - vm.expectRevert(abi.encodeWithSelector(IRegistry.AttestationNotFound.selector)); + vm.expectRevert(); registry.check(address(module1)); - vm.expectRevert(abi.encodeWithSelector(IRegistry.AttestationNotFound.selector)); + vm.expectRevert(); registry.checkForAccount(makeAddr("foo"), address(module1)); } diff --git a/test/TrustDelegationExternal.t.sol b/test/TrustDelegationExternal.t.sol index e3cda8e6..a6ca4451 100644 --- a/test/TrustDelegationExternal.t.sol +++ b/test/TrustDelegationExternal.t.sol @@ -5,6 +5,7 @@ import "./Attestation.t.sol"; import "src/DataTypes.sol"; import { LibSort } from "solady/utils/LibSort.sol"; + contract TrustTestExternal is AttestationTest { using LibSort for address[]; @@ -19,7 +20,7 @@ contract TrustTestExternal is AttestationTest { function test_WhenSupplyingExternal() external whenSettingAttester { // It should set. - test_WhenUsingValidECDSA(); + _make_WhenUsingValidECDSA(attester1); address[] memory trustedAttesters = new address[](2); trustedAttesters[0] = address(attester1.addr); trustedAttesters[1] = address(attester2.addr); @@ -29,6 +30,24 @@ contract TrustTestExternal is AttestationTest { vm.expectRevert(); registry.check(address(module1), ModuleType.wrap(3), attester1.addr); registry.checkN(address(module1), trustedAttesters, 1); + registry.checkN(address(module1), ModuleType.wrap(1), trustedAttesters, 1); + vm.expectRevert(); + registry.checkN(address(module1), trustedAttesters, 2); + vm.expectRevert(); + registry.checkN(address(module1), ModuleType.wrap(1), trustedAttesters, 2); + _make_WhenUsingValidECDSA(attester2); + registry.checkN(address(module1), trustedAttesters, 2); + registry.checkN(address(module1), trustedAttesters, 2); + // registry.checkN(address(module1), ModuleType.wrap(1), trustedAttesters, 2); + + trustedAttesters = new address[](4); + Account memory attester3 = makeAccount("attester3"); + Account memory attester4 = makeAccount("attester4"); + trustedAttesters[0] = address(attester1.addr); + trustedAttesters[1] = address(attester3.addr); + trustedAttesters[2] = address(attester4.addr); + trustedAttesters[3] = address(attester2.addr); + registry.checkN(address(module1), trustedAttesters, 2); } }