Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sdk): remove hex encoding for segment hash #1805

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions sdk/manifest.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Manifest struct {
EncryptionInformation `json:"encryptionInformation"`
Payload `json:"payload"`
Assertions []Assertion `json:"assertions,omitempty"`
TDFVersion string `json:"tdf_spec_version,omitempty"`
}

type attributeObject struct {
Expand Down
62 changes: 50 additions & 12 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
)

const (
sdkVersion = "4.3.0"
maxFileSizeSupported = 68719476736 // 64gb
defaultMimeType = "application/octet-stream"
tdfAsZip = "zip"
Expand Down Expand Up @@ -197,7 +198,8 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
return nil, fmt.Errorf("io.writer.Write failed: %w", err)
}

segmentSig, err := calculateSignature(cipherData, tdfObject.payloadKey[:], tdfConfig.segmentIntegrityAlgorithm)
segmentSig, err := calculateSignature(cipherData, tdfObject.payloadKey[:],
tdfConfig.segmentIntegrityAlgorithm, false)
if err != nil {
return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand All @@ -216,7 +218,8 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
readPos += readSize
}

rootSignature, err := calculateSignature([]byte(aggregateHash), tdfObject.payloadKey[:], tdfConfig.integrityAlgorithm)
rootSignature, err := calculateSignature([]byte(aggregateHash), tdfObject.payloadKey[:],
tdfConfig.integrityAlgorithm, false)
if err != nil {
return nil, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -263,11 +266,17 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
tmpAssertion.Statement = assertion.Statement
tmpAssertion.AppliesToState = assertion.AppliesToState

hashOfAssertion, err := tmpAssertion.GetHash()
hashOfAssertionAsHex, err := tmpAssertion.GetHash()
if err != nil {
return nil, err
}

hashOfAssertion := make([]byte, hex.DecodedLen(len(hashOfAssertionAsHex)))
_, err = hex.Decode(hashOfAssertion, hashOfAssertionAsHex)
if err != nil {
return nil, fmt.Errorf("error decoding hex string: %w", err)
}

var completeHashBuilder strings.Builder
completeHashBuilder.WriteString(aggregateHash)
completeHashBuilder.Write(hashOfAssertion)
Expand All @@ -284,7 +293,7 @@ func (s SDK) CreateTDFContext(ctx context.Context, writer io.Writer, reader io.R
assertionSigningKey = assertion.SigningKey
}

if err := tmpAssertion.Sign(string(hashOfAssertion), string(encoded), assertionSigningKey); err != nil {
if err := tmpAssertion.Sign(string(hashOfAssertionAsHex), string(encoded), assertionSigningKey); err != nil {
return nil, fmt.Errorf("failed to sign assertion: %w", err)
}

Expand Down Expand Up @@ -322,6 +331,14 @@ func (r *Reader) Manifest() Manifest {
// prepare the manifest for TDF
func (s SDK) prepareManifest(ctx context.Context, t *TDFObject, tdfConfig TDFConfig) error { //nolint:funlen,gocognit // Better readability keeping it as is
manifest := Manifest{}

version, err := ParseVersion(sdkVersion)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the ability to extend this version based on the decision of going with option 2 in this adr.

#1677

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err != nil {
return fmt.Errorf("ReadVersion failed:%w", err)
}

manifest.TDFVersion = version.String()

if len(tdfConfig.splitPlan) == 0 && len(tdfConfig.kasInfoList) == 0 {
return fmt.Errorf("%w: no key access template specified or inferred", errInvalidKasInfo)
}
Expand Down Expand Up @@ -567,6 +584,8 @@ func (r *Reader) WriteTo(writer io.Writer) (int64, error) {
}
}

isLegacyTDF := r.manifest.TDFVersion == ""

var totalBytes int64
var payloadReadOffset int64
for _, seg := range r.manifest.EncryptionInformation.IntegrityInformation.Segments {
Expand All @@ -585,7 +604,7 @@ func (r *Reader) WriteTo(writer io.Writer) (int64, error) {
sigAlg = GMAC
}

payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg)
payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg, isLegacyTDF)
if err != nil {
return totalBytes, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -646,6 +665,7 @@ func (r *Reader) ReadAt(buf []byte, offset int64) (int, error) { //nolint:funlen
return 0, ErrTDFPayloadReadFail
}

isLegacyTDF := r.manifest.TDFVersion == ""
var decryptedBuf bytes.Buffer
var payloadReadOffset int64
for index, seg := range r.manifest.EncryptionInformation.IntegrityInformation.Segments {
Expand All @@ -669,7 +689,7 @@ func (r *Reader) ReadAt(buf []byte, offset int64) (int, error) { //nolint:funlen
sigAlg = GMAC
}

payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg)
payloadSig, err := calculateSignature(readBuf, r.payloadKey, sigAlg, isLegacyTDF)
if err != nil {
return 0, fmt.Errorf("splitKey.GetSignaturefailed: %w", err)
}
Expand Down Expand Up @@ -933,18 +953,29 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn
}

// Get the hash of the assertion
hashOfAssertion, err := assertion.GetHash()
hashOfAssertionAsHex, err := assertion.GetHash()
if err != nil {
return fmt.Errorf("%w: failed to get hash of assertion: %w", ErrAssertionFailure{ID: assertion.ID}, err)
}

hashOfAssertion := make([]byte, hex.DecodedLen(len(hashOfAssertionAsHex)))
_, err = hex.Decode(hashOfAssertion, hashOfAssertionAsHex)
if err != nil {
return fmt.Errorf("error decoding hex string: %w", err)
}

isLegacyTDF := r.manifest.TDFVersion == ""
if isLegacyTDF {
hashOfAssertion = hashOfAssertionAsHex
}

var completeHashBuilder bytes.Buffer
completeHashBuilder.Write(aggregateHash.Bytes())
completeHashBuilder.Write(hashOfAssertion)

base64Hash := ocrypto.Base64Encode(completeHashBuilder.Bytes())

if string(hashOfAssertion) != assertionHash {
if string(hashOfAssertionAsHex) != assertionHash {
return fmt.Errorf("%w: assertion hash missmatch", ErrAssertionFailure{ID: assertion.ID})
}

Expand Down Expand Up @@ -972,29 +1003,36 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn
}

// calculateSignature calculate signature of data of the given algorithm.
func calculateSignature(data []byte, secret []byte, alg IntegrityAlgorithm) (string, error) {
func calculateSignature(data []byte, secret []byte, alg IntegrityAlgorithm, isLegacyTDF bool) (string, error) {
if alg == HS256 {
hmac := ocrypto.CalculateSHA256Hmac(secret, data)
return hex.EncodeToString(hmac), nil
if isLegacyTDF {
return hex.EncodeToString(hmac), nil
}
return string(hmac), nil
}
if kGMACPayloadLength > len(data) {
return "", fmt.Errorf("fail to create gmac signature")
}

return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil
if isLegacyTDF {
return hex.EncodeToString(data[len(data)-kGMACPayloadLength:]), nil
}
return string(data[len(data)-kGMACPayloadLength:]), nil
}

// validate the root signature
func validateRootSignature(manifest Manifest, aggregateHash, secret []byte) (bool, error) {
rootSigAlg := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Algorithm
rootSigValue := manifest.EncryptionInformation.IntegrityInformation.RootSignature.Signature
isLegacyTDF := manifest.TDFVersion == ""

sigAlg := HS256
if strings.EqualFold(gmacIntegrityAlgorithm, rootSigAlg) {
sigAlg = GMAC
}

sig, err := calculateSignature(aggregateHash, secret, sigAlg)
sig, err := calculateSignature(aggregateHash, secret, sigAlg, isLegacyTDF)
if err != nil {
return false, fmt.Errorf("splitkey.getSignature failed:%w", err)
}
Expand Down
47 changes: 38 additions & 9 deletions sdk/tdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/lib/ocrypto"
kaspb "github.com/opentdf/platform/protocol/go/kas"
Expand Down Expand Up @@ -264,7 +266,7 @@ func (s *TDFSuite) Test_SimpleTDF() {
"https://example.com/attr/Classification/value/X",
}

expectedTdfSize := int64(2095)
expectedTdfSize := int64(2058)
tdfFilename := "secure-text.tdf"
plainText := "Virtru"
{
Expand Down Expand Up @@ -394,7 +396,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
assertionVerificationKeys: nil,
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -427,7 +429,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
DefaultKey: defaultKey,
},
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -476,7 +478,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 3195,
expectedSize: 2988,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -516,7 +518,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: false,
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand All @@ -533,7 +535,7 @@ func (s *TDFSuite) Test_TDFWithAssertion() {
},
},
disableAssertionVerification: true,
expectedSize: 2302,
expectedSize: 2180,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -642,7 +644,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
SigningKey: defaultKey,
},
},
expectedSize: 2896,
expectedSize: 2689,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -690,7 +692,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
},
},
},
expectedSize: 3195,
expectedSize: 2988,
},
{
assertions: []AssertionConfig{
Expand Down Expand Up @@ -724,7 +726,7 @@ func (s *TDFSuite) Test_TDFWithAssertionNegativeTests() {
assertionVerificationKeys: &AssertionVerificationKeys{
DefaultKey: defaultKey,
},
expectedSize: 2896,
expectedSize: 2689,
},
} {
expectedTdfSize := test.expectedSize
Expand Down Expand Up @@ -1479,3 +1481,30 @@ func (s *TDFSuite) checkIdentical(file, checksum string) bool {
c := h.Sum(nil)
return checksum == fmt.Sprintf("%x", c)
}

func TestParseVersion(t *testing.T) {
tests := []struct {
input string
expected Version
hasError bool
}{
{"1.2.3", Version{Major: 1, Minor: 2, Patch: 3}, false},
{"1.2.3+p1", Version{Major: 1, Minor: 2, Patch: 3, Preview: 1}, false},
{"1.2.3+p1.2", Version{Major: 1, Minor: 2, Patch: 3, Preview: 1, Revision: 2}, false},
{"1.2", Version{}, true},
{"1.2.3+p", Version{}, true},
{"1.2.3+p1.", Version{}, true},
}

for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result, err := ParseVersion(tt.input)
if tt.hasError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, tt.expected, *result)
}
})
}
}
76 changes: 76 additions & 0 deletions sdk/version-parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package sdk

import (
"fmt"
"os"
"strings"
)

type Version struct {
Major int
Minor int
Patch int
Preview int
Revision int
}

func ReadVersion() (*Version, error) {
content, err := os.ReadFile("VERSION")
if err != nil {
return nil, fmt.Errorf("reading VERSION file: %w", err)
}

return ParseVersion(strings.TrimSpace(string(content)))
}

func ParseVersion(v string) (*Version, error) {
const maxParts = 2
var ver Version
var preview, revision string

parts := strings.SplitN(v, "+p", maxParts)
mainVersion := parts[0]

if len(parts) > 1 {
if parts[1] == "" {
return nil, fmt.Errorf("invalid preview format")
}
previewParts := strings.SplitN(parts[1], ".", maxParts)
preview = previewParts[0]
if len(previewParts) > 1 {
if previewParts[1] == "" {
return nil, fmt.Errorf("invalid revision format")
}
revision = previewParts[1]
}
}

if _, err := fmt.Sscanf(mainVersion, "%d.%d.%d", &ver.Major, &ver.Minor, &ver.Patch); err != nil {
return nil, fmt.Errorf("parsing version: %w", err)
}

if preview != "" {
if _, err := fmt.Sscanf(preview, "%d", &ver.Preview); err != nil {
return nil, fmt.Errorf("parsing preview version: %w", err)
}
}

if revision != "" {
if _, err := fmt.Sscanf(revision, "%d", &ver.Revision); err != nil {
return nil, fmt.Errorf("parsing revision: %w", err)
}
}

return &ver, nil
}

func (v *Version) String() string {
base := fmt.Sprintf("%d.%d.%d", v.Major, v.Minor, v.Patch)
if v.Preview > 0 {
if v.Revision > 0 {
return fmt.Sprintf("%s+p%d.%d", base, v.Preview, v.Revision)
}
return fmt.Sprintf("%s+p%d", base, v.Preview)
}
return base
}
Loading