Skip to content

Commit

Permalink
Code structure cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mdehoog committed Nov 30, 2024
1 parent 24f3b27 commit 1a3806c
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 81 deletions.
24 changes: 19 additions & 5 deletions src/Asn1Decode.sol
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ library Asn1Decode {
}

/*
* @dev Extract value of bitstring node from DER-encoded structure
* @dev Extract pointer of bitstring node from DER-encoded structure
* @param der The DER-encoded ASN1 structure
* @param ptr Points to the indices of the current node
* @return A pointer to a bitstring
Expand All @@ -86,6 +86,12 @@ library Asn1Decode {
return LibNodePtr.toNodePtr(ptr.header(), ptr.content() + 1, ptr.length() - 1);
}

/*
* @dev Extract value of bitstring node from DER-encoded structure
* @param der The DER-encoded ASN1 structure
* @param ptr Points to the indices of the current node
* @return A bitstring encoded in a uint256
*/
function bitstringUintAt(bytes memory der, NodePtr ptr) internal pure returns (uint256) {
require(der[ptr.header()] == 0x03, "Not type BIT STRING");
uint256 len = ptr.length() - 1;
Expand Down Expand Up @@ -116,6 +122,12 @@ library Asn1Decode {
return uint256(readBytesN(der, ptr.content(), len) >> (32 - len) * 8);
}

/*
* @dev Extract value of a positive integer node from DER-encoded structure
* @param der The DER-encoded ASN1 structure
* @param ptr Points to the indices of the current node
* @return 384-bit uint encoded in uint128 and uint256
*/
function uint384At(bytes memory der, NodePtr ptr) internal pure returns (uint128, uint256) {
require(der[ptr.header()] == 0x02, "Not type INTEGER");
require(der[ptr.content()] & 0x80 == 0, "Not positive");
Expand All @@ -131,6 +143,12 @@ library Asn1Decode {
);
}

/*
* @dev Extract value of a timestamp from DER-encoded structure
* @param der The DER-encoded ASN1 structure
* @param ptr Points to the indices of the current node
* @return UNIX timestamp (seconds since 1970/01/01)
*/
function timestampAt(bytes memory der, NodePtr ptr) internal pure returns (uint256) {
uint16 _years;
uint256 offset = ptr.content();
Expand All @@ -152,10 +170,6 @@ library Asn1Decode {
return timestampFromDateTime(_years, _months, _days, _hours, _mins, _secs);
}

function byteAtOffset(bytes memory der, NodePtr ptr, uint256 offset) internal pure returns (bytes1) {
return der[ptr.content() + offset];
}

function readNodeLength(bytes memory der, uint256 ix) private pure returns (NodePtr) {
uint256 length;
uint80 ixFirstContentByte;
Expand Down
177 changes: 101 additions & 76 deletions src/NitroValidator.sol
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.15;

import {console} from "forge-std/console.sol";
import {Sha2Ext} from "./Sha2Ext.sol";
import {Asn1Decode} from "./Asn1Decode.sol";
import {NitroAttestation} from "./NitroAttestation.sol";
import {ECDSA384} from "./ECDSA384.sol";
import {LibBytes} from "./LibBytes.sol";
import {NodePtr, LibNodePtr} from "./NodePtr.sol";

// adapted from https://github.com/marlinprotocol/NitroProver/blob/f1d368d1f172ad3a55cd2aaaa98ad6a6e7dcde9d/src/NitroProver.sol
// and https://github.com/marlinprotocol/NitroProver/blob/f1d368d1f172ad3a55cd2aaaa98ad6a6e7dcde9d/src/CertManager.sol

contract NitroValidator {
using Asn1Decode for bytes;
using NitroAttestation for bytes;
Expand All @@ -32,7 +34,8 @@ contract NitroValidator {
// 1.3.132.0.34 {iso(1) identified-organization(3) certicom(132) curve(0) ansip384r1(34)} represents NIST 384-bit elliptic curve
bytes32 public constant SECP_384_R1_OID = keccak256(hex"2b81040022");

bytes32 public constant DIGEST_VALUE = keccak256("SHA384");
// attestation / certificate constants
bytes32 public constant ATTESTATION_DIGEST = keccak256("SHA384");
bytes32 public constant BASIC_CONSTRAINTS_OID = keccak256(hex"551d13");
bytes32 public constant KEY_USAGE_OID = keccak256(hex"551d0f");

Expand Down Expand Up @@ -95,13 +98,18 @@ contract NitroValidator {
);
}

function validateAttestation(bytes memory attestationTbs, bytes memory signature) external {
function validateAttestation(bytes memory attestationTbs, bytes memory signature)
external
returns (NitroAttestation.Ptrs memory)
{
NitroAttestation.Ptrs memory ptrs = attestationTbs.parseAttestation();

require(ptrs.moduleID.length() > 0, "no module id");
require(ptrs.timestamp > 0, "no timestamp");
require(ptrs.cabundle.length > 0, "no cabundle");
require(attestationTbs.keccak(ptrs.digest.content(), ptrs.digest.length()) == DIGEST_VALUE, "invalid digest");
require(
attestationTbs.keccak(ptrs.digest.content(), ptrs.digest.length()) == ATTESTATION_DIGEST, "invalid digest"
);
require(1 <= ptrs.pcrs.length && ptrs.pcrs.length <= 32, "invalid pcrs");
require(
attestationTbs[ptrs.publicKey.header()] == Asn1Decode.NULL_VALUE
Expand All @@ -128,7 +136,9 @@ contract NitroValidator {
parentCache = _verifyCert(attestationTbs, ptrs.cert, certHash, true, parentCache);

bytes memory hash = Sha2Ext.sha384(attestationTbs, 0, attestationTbs.length);
verifySignature(parentCache.pubKey, hash, signature);
_verifySignature(parentCache.pubKey, hash, signature);

return ptrs;
}

function _verifyCert(
Expand Down Expand Up @@ -164,154 +174,169 @@ contract NitroValidator {
return cache;
}

function _verifyCertSignature(bytes memory certificate, NodePtr ptr, bytes memory pubKey) internal view {
NodePtr sigAlgoPtr = certificate.nextSiblingOf(ptr);
require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo");

bytes memory hash = Sha2Ext.sha384(certificate, ptr.header(), ptr.totalLength());
bytes memory sigPacked = packSig(certificate, sigAlgoPtr);
verifySignature(pubKey, hash, sigPacked);
}

function packSig(bytes memory certificate, NodePtr sigAlgoPtr) internal pure returns (bytes memory) {
NodePtr sigPtr = certificate.nextSiblingOf(sigAlgoPtr);
NodePtr sigBPtr = certificate.bitstring(sigPtr);
NodePtr sigRoot = certificate.rootOf(sigBPtr);
NodePtr sigRPtr = certificate.firstChildOf(sigRoot);
NodePtr sigSPtr = certificate.nextSiblingOf(sigRPtr);
(uint128 rhi, uint256 rlo) = certificate.uint384At(sigRPtr);
(uint128 shi, uint256 slo) = certificate.uint384At(sigSPtr);
return abi.encodePacked(rhi, rlo, shi, slo);
}

function _parseTbs(bytes memory certificate, NodePtr ptr, bool clientCert)
internal
view
returns (uint256, int256, bytes memory)
returns (uint256 notAfter, int256 maxPathLen, bytes memory pubKey)
{
NodePtr versionPtr = certificate.firstChildOf(ptr);
NodePtr vPtr = certificate.firstChildOf(versionPtr);
uint256 version = certificate.uintAt(vPtr);
// as extensions are used in cert, version should be 3 (value 2) as per https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.1
require(version == 2, "version should be 3");
NodePtr serialPtr = certificate.nextSiblingOf(versionPtr);
NodePtr sigAlgoPtr = certificate.nextSiblingOf(serialPtr);

require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo");
return _parseTbs2(certificate, sigAlgoPtr, clientCert);
uint256 version = certificate.uintAt(vPtr);
// as extensions are used in cert, version should be 3 (value 2) as per https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.1
require(version == 2, "version should be 3");

(notAfter, maxPathLen, pubKey) = _parseTbsInner(certificate, sigAlgoPtr, clientCert);
}

function _parseTbs2(bytes memory certificate, NodePtr sigAlgoPtr, bool clientCert)
function _parseTbsInner(bytes memory certificate, NodePtr sigAlgoPtr, bool clientCert)
internal
view
returns (uint256, int256, bytes memory)
returns (uint256 notAfter, int256 maxPathLen, bytes memory pubKey)
{
NodePtr issuerPtr = certificate.nextSiblingOf(sigAlgoPtr);
NodePtr validityPtr = certificate.nextSiblingOf(issuerPtr);
uint256 notAfter = _verifyValidity(certificate, validityPtr);
NodePtr subjectPtr = certificate.nextSiblingOf(validityPtr);
(int256 maxPathLen, bytes memory pubKey) = _verifyTbs2(certificate, subjectPtr, clientCert);
return (notAfter, maxPathLen, pubKey);
}
NodePtr subjectPublicKeyInfoPtr = certificate.nextSiblingOf(subjectPtr);
NodePtr extensionsPtr = certificate.nextSiblingOf(subjectPublicKeyInfoPtr);

function _verifyValidity(bytes memory certificate, NodePtr validityPtr) internal view returns (uint256) {
NodePtr notBeforePtr = certificate.firstChildOf(validityPtr);
uint256 notBefore = certificate.timestampAt(notBeforePtr);
require(notBefore <= block.timestamp, "certificate not valid yet");
NodePtr notAfterPtr = certificate.nextSiblingOf(notBeforePtr);
uint256 notAfter = certificate.timestampAt(notAfterPtr);
require(notAfter >= block.timestamp, "certificate not valid anymore");
return notAfter;
notAfter = _verifyValidity(certificate, validityPtr);
maxPathLen = _verifyExtensions(certificate, extensionsPtr, clientCert);
pubKey = _parsePubKey(certificate, subjectPublicKeyInfoPtr);
}

function _verifyTbs2(bytes memory certificate, NodePtr subjectPtr, bool clientCert)
function _parsePubKey(bytes memory certificate, NodePtr subjectPublicKeyInfoPtr)
internal
pure
returns (int256, bytes memory)
returns (bytes memory subjectPubKey)
{
NodePtr subjectPublicKeyInfoPtr = certificate.nextSiblingOf(subjectPtr);
NodePtr pubKeyAlgoPtr = certificate.firstChildOf(subjectPublicKeyInfoPtr);
NodePtr pubKeyAlgoIdPtr = certificate.firstChildOf(pubKeyAlgoPtr);
NodePtr algoParamsPtr = certificate.nextSiblingOf(pubKeyAlgoIdPtr);
NodePtr subjectPublicKeyPtr = certificate.nextSiblingOf(pubKeyAlgoPtr);
NodePtr subjectPubKeyPtr = certificate.bitstring(subjectPublicKeyPtr);

require(
certificate.keccak(pubKeyAlgoIdPtr.content(), pubKeyAlgoIdPtr.length()) == EC_PUB_KEY_OID,
"invalid cert algo id"
);

NodePtr algoParamsPtr = certificate.nextSiblingOf(pubKeyAlgoIdPtr);
require(
certificate.keccak(algoParamsPtr.content(), algoParamsPtr.length()) == SECP_384_R1_OID,
"invalid cert algo param"
);

NodePtr subjectPublicKeyPtr = certificate.nextSiblingOf(pubKeyAlgoPtr);
NodePtr subjectPubKeyPtr = certificate.bitstring(subjectPublicKeyPtr);
uint256 end = subjectPubKeyPtr.content() + subjectPubKeyPtr.length();
bytes memory subjectPubKey = certificate.slice(end - 96, end);
subjectPubKey = certificate.slice(end - 96, end);
}

NodePtr extensionsPtr = certificate.nextSiblingOf(subjectPublicKeyInfoPtr);
int256 maxPathLen = _verifyExtensions(certificate, extensionsPtr, clientCert);
function _verifyValidity(bytes memory certificate, NodePtr validityPtr) internal view returns (uint256 notAfter) {
NodePtr notBeforePtr = certificate.firstChildOf(validityPtr);
NodePtr notAfterPtr = certificate.nextSiblingOf(notBeforePtr);

uint256 notBefore = certificate.timestampAt(notBeforePtr);
notAfter = certificate.timestampAt(notAfterPtr);

return (maxPathLen, subjectPubKey);
require(notBefore <= block.timestamp, "certificate not valid yet");
require(notAfter >= block.timestamp, "certificate not valid anymore");
}

function _verifyExtensions(bytes memory certificate, NodePtr extensionsPtr, bool clientCert)
internal
pure
returns (int256)
returns (int256 maxPathLen)
{
int256 maxPathLen = -1;
require(certificate[extensionsPtr.header()] == 0xa3, "invalid extensions");
extensionsPtr = certificate.firstChildOf(extensionsPtr);
NodePtr extensionPtr = certificate.firstChildOf(extensionsPtr);
uint256 end = extensionsPtr.content() + extensionsPtr.length();
bool basicConstraintsFound = false;
bool keyUsageFound = false;
maxPathLen = -1;

while (true) {
NodePtr oidPtr = certificate.firstChildOf(extensionPtr);
bytes32 oid = certificate.keccak(oidPtr.content(), oidPtr.length());

if (oid == BASIC_CONSTRAINTS_OID || oid == KEY_USAGE_OID) {
NodePtr valuePtr = certificate.nextSiblingOf(oidPtr);

if (certificate[valuePtr.header()] == 0x01) {
// skip optional critical bool
require(valuePtr.length() == 1, "invalid critical bool value");
valuePtr = certificate.nextSiblingOf(valuePtr);
}

valuePtr = certificate.octetString(valuePtr);

if (oid == BASIC_CONSTRAINTS_OID) {
basicConstraintsFound = true;
NodePtr basicConstraintsPtr = certificate.firstChildOf(valuePtr);
if (certificate[basicConstraintsPtr.header()] == 0x01) {
// skip optional isCA bool
require(basicConstraintsPtr.length() == 1, "invalid isCA bool value");
basicConstraintsPtr = certificate.nextSiblingOf(basicConstraintsPtr);
}
if (certificate[basicConstraintsPtr.header()] == 0x02) {
maxPathLen = int256(certificate.uintAt(basicConstraintsPtr));
}
maxPathLen = _verifyBasicConstraintsExtension(certificate, valuePtr);
} else {
keyUsageFound = true;
uint256 value = certificate.bitstringUintAt(valuePtr);
// bits are reversed (DigitalSignature 0x01 => 0x80, CertSign 0x32 => 0x04)
if (clientCert) {
require(value & 0x80 == 0x80, "DigitalSignature must be present");
} else {
require(value & 0x04 == 0x04, "CertSign must be present");
}
_verifyKeyUsageExtension(certificate, valuePtr, clientCert);
}
}

if (extensionPtr.content() + extensionPtr.length() == end) {
break;
}
extensionPtr = certificate.nextSiblingOf(extensionPtr);
}

require(basicConstraintsFound, "basicConstraints not found");
require(keyUsageFound, "keyUsage not found");
require(!clientCert || maxPathLen == -1, "maxPathLen must be undefined for client cert");
}

function _verifyBasicConstraintsExtension(bytes memory certificate, NodePtr valuePtr)
internal
pure
returns (int256 maxPathLen)
{
maxPathLen = -1;
NodePtr basicConstraintsPtr = certificate.firstChildOf(valuePtr);
if (certificate[basicConstraintsPtr.header()] == 0x01) {
// skip optional isCA bool
require(basicConstraintsPtr.length() == 1, "invalid isCA bool value");
basicConstraintsPtr = certificate.nextSiblingOf(basicConstraintsPtr);
}
if (certificate[basicConstraintsPtr.header()] == 0x02) {
maxPathLen = int256(certificate.uintAt(basicConstraintsPtr));
}
}

function _verifyKeyUsageExtension(bytes memory certificate, NodePtr valuePtr, bool clientCert) internal pure {
uint256 value = certificate.bitstringUintAt(valuePtr);
// bits are reversed (DigitalSignature 0x01 => 0x80, CertSign 0x32 => 0x04)
if (clientCert) {
require(maxPathLen == -1, "maxPathLen must be undefined for client cert");
require(value & 0x80 == 0x80, "DigitalSignature must be present");
} else {
require(value & 0x04 == 0x04, "CertSign must be present");
}
return maxPathLen;
}

function verifySignature(bytes memory pubKey, bytes memory hash, bytes memory sig) internal view {
function _verifyCertSignature(bytes memory certificate, NodePtr ptr, bytes memory pubKey) internal view {
NodePtr sigAlgoPtr = certificate.nextSiblingOf(ptr);
require(certificate.keccak(sigAlgoPtr.content(), sigAlgoPtr.length()) == CERT_ALGO_OID, "invalid cert sig algo");

bytes memory hash = Sha2Ext.sha384(certificate, ptr.header(), ptr.totalLength());

NodePtr sigPtr = certificate.nextSiblingOf(sigAlgoPtr);
NodePtr sigBPtr = certificate.bitstring(sigPtr);
NodePtr sigRoot = certificate.rootOf(sigBPtr);
NodePtr sigRPtr = certificate.firstChildOf(sigRoot);
NodePtr sigSPtr = certificate.nextSiblingOf(sigRPtr);
(uint128 rhi, uint256 rlo) = certificate.uint384At(sigRPtr);
(uint128 shi, uint256 slo) = certificate.uint384At(sigSPtr);
bytes memory sigPacked = abi.encodePacked(rhi, rlo, shi, slo);

_verifySignature(pubKey, hash, sigPacked);
}

function _verifySignature(bytes memory pubKey, bytes memory hash, bytes memory sig) internal view {
ECDSA384.Parameters memory CURVE_PARAMETERS = ECDSA384.Parameters({
a: CURVE_A,
b: CURVE_B,
Expand Down

0 comments on commit 1a3806c

Please sign in to comment.