Skip to content

Commit

Permalink
addressing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mustyantsev committed Jan 15, 2025
1 parent a00059a commit 15c66d7
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 141 deletions.
62 changes: 27 additions & 35 deletions sdk/assertion.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,47 +123,39 @@ func (a Assertion) GetHash() ([]byte, error) {
return ocrypto.SHA256AsHex(transformedJSON), nil
}

type FlexibleValue struct {
AsString *string
AsObject map[string]interface{}
}

func (fv FlexibleValue) MarshalJSON() ([]byte, error) {
if fv.AsObject != nil {
objAsJSON, err := json.Marshal(fv.AsObject)
if err != nil {
return nil, err
}
return json.Marshal(string(objAsJSON))
func (s *Statement) UnmarshalJSON(data []byte) error {
// Define a custom struct for deserialization
type Alias Statement
aux := &struct {
Value json.RawMessage `json:"value,omitempty"`
*Alias
}{
Alias: (*Alias)(s),
}

if fv.AsString != nil {
return json.Marshal(*fv.AsString)
if err := json.Unmarshal(data, &aux); err != nil {
return err
}

return json.Marshal(nil)
}

func (fv *FlexibleValue) UnmarshalJSON(data []byte) error {
// Try to unmarshal as a raw string
var strValue string
if err := json.Unmarshal(data, &strValue); err == nil {
var temp map[string]interface{}
if json.Unmarshal([]byte(strValue), &temp) == nil {
fv.AsObject = temp
} else {
fv.AsString = &strValue
// Attempt to decode Value as an object
var temp map[string]interface{}
if json.Unmarshal(aux.Value, &temp) == nil {
// Re-encode the object as a string and assign to Value
objAsString, err := json.Marshal(temp)
if err != nil {
return err
}
return nil
}

var objValue map[string]interface{}
if err := json.Unmarshal(data, &objValue); err == nil {
fv.AsObject = objValue
return nil
s.Value = string(objAsString)
} else {
// Assign raw string to Value
var str string
if err := json.Unmarshal(aux.Value, &str); err != nil {
return fmt.Errorf("value is neither a valid JSON object nor a string: %s", string(aux.Value))
}
s.Value = str
}

return fmt.Errorf("value is neither a valid JSON object nor a string")
return nil
}

// Statement includes information applying to the scope of the assertion.
Expand All @@ -174,7 +166,7 @@ type Statement struct {
// Schema describes the schema of the payload. (e.g. tdf)
Schema string `json:"schema,omitempty" validate:"required"`
// Value is the payload of the assertion.
Value FlexibleValue `json:"value,omitempty" validate:"required"`
Value string `json:"value,omitempty" validate:"required"`
}

// Binding enforces cryptographic integrity of the assertion.
Expand Down
71 changes: 30 additions & 41 deletions sdk/assertion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ func TestTDFWithAssertion(t *testing.T) {
Statement: Statement{
Format: "json+stanag5636",
Schema: "urn:nato:stanag:5636:A:1:elements:json",
Value: FlexibleValue{
AsString: func() *string {
val := "{\"ocl\":{\"pol\":\"62c76c68-d73d-4628-8ccc-4c1e18118c22\",\"cls\":\"SECRET\",\"catl\":[{\"type\":\"P\",\"name\":\"Releasable To\",\"vals\":[\"usa\"]}],\"dcr\":\"2024-10-21T20:47:36Z\"},\"context\":{\"@base\":\"urn:nato:stanag:5636:A:1:elements:json\"}}"
return &val
}(),
},
Value: "{\"ocl\":{\"pol\":\"62c76c68-d73d-4628-8ccc-4c1e18118c22\",\"cls\":\"SECRET\",\"catl\":[{\"type\":\"P\",\"name\":\"Releasable To\",\"vals\":[\"usa\"]}],\"dcr\":\"2024-10-21T20:47:36Z\"},\"context\":{\"@base\":\"urn:nato:stanag:5636:A:1:elements:json\"}}",
},
}

Expand All @@ -42,32 +37,31 @@ func TestTDFWithAssertion(t *testing.T) {

func TestTDFWithAssertionJsonObject(t *testing.T) {
// Define the assertion config with a JSON object in the statement value
value := `{
"ocl": {
"pol": "2ccf11cb-6c9a-4e49-9746-a7f0a295945d",
"cls": "SECRET",
"catl": [
{
"type": "P",
"name": "Releasable To",
"vals": ["usa"]
}
],
"dcr": "2024-12-17T13:00:52Z"
},
"context": {
"@base": "urn:nato:stanag:5636:A:1:elements:json"
}
}`
assertionConfig := AssertionConfig{
ID: "ab43266781e64b51a4c52ffc44d6152c",
Type: "handling",
Scope: "payload",
AppliesToState: "", // Use "" or a pointer to a string if necessary
Statement: Statement{
Format: "json-structured",
Value: FlexibleValue{
AsObject: map[string]interface{}{ // Correct usage of FlexibleValue
"ocl": map[string]interface{}{
"pol": "2ccf11cb-6c9a-4e49-9746-a7f0a295945d",
"cls": "SECRET",
"catl": []map[string]interface{}{
{
"type": "P",
"name": "Releasable To",
"vals": []string{"usa"},
},
},
"dcr": "2024-12-17T13:00:52Z",
},
"context": map[string]interface{}{
"@base": "urn:nato:stanag:5636:A:1:elements:json",
},
},
},
Value: value,
},
}

Expand All @@ -80,28 +74,23 @@ func TestTDFWithAssertionJsonObject(t *testing.T) {
Statement: assertionConfig.Statement,
}

// Serialize the JSON object in the statement value
serializedStatementValue, err := json.Marshal(assertion.Statement.Value.AsObject)
require.NoError(t, err)
var obj map[string]interface{}
err := json.Unmarshal([]byte(assertionConfig.Statement.Value), &obj)
require.NoError(t, err, "Unmarshaling the Value into a map should succeed")

// Ensure the serialized value is valid JSON
var deserialized map[string]interface{}
err = json.Unmarshal(serializedStatementValue, &deserialized)
require.NoError(t, err)
ocl, ok := obj["ocl"].(map[string]interface{})
require.True(t, ok, "Parsed Value should contain 'ocl' as an object")
require.Equal(t, "SECRET", ocl["cls"], "'cls' field should match")
require.Equal(t, "2ccf11cb-6c9a-4e49-9746-a7f0a295945d", ocl["pol"], "'pol' field should match")

// Set the serialized value back into the statement
assertion.Statement.Value = FlexibleValue{
AsString: func() *string {
val := string(serializedStatementValue)
return &val
}(),
}
context, ok := obj["context"].(map[string]interface{})
require.True(t, ok, "Parsed Value should contain 'context' as an object")
require.Equal(t, "urn:nato:stanag:5636:A:1:elements:json", context["@base"], "'@base' field should match")

// Calculate the hash of the assertion
hashOfAssertion, err := assertion.GetHash()
require.NoError(t, err)

// Assert the expected hash (example hash, replace with actual expected value)
expectedHash := "c1733259597a7025d2fdbd000a68c5ee3652cf2cd61c0be8f92f941c521cee92"
expectedHash := "722dd40a90a0f7ec718fb156207a647e64daa43c0ae1f033033473a172c72aee"
assert.Equal(t, expectedHash, string(hashOfAssertion))
}
51 changes: 12 additions & 39 deletions sdk/tdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,46 +796,22 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn
for _, keyAccessObj := range r.manifest.EncryptionInformation.KeyAccessObjs {
client := newKASClient(r.dialOptions, r.tokenSource, &r.kasSessionKey)

ss := keySplitStep{KAS: keyAccessObj.KasURL, SplitID: keyAccessObj.SplitID}

var err error
var wrappedKey []byte
if !mixedSplits { //nolint:nestif // todo: subfunction
wrappedKey, err = client.unwrap(ctx, keyAccessObj, r.manifest.EncryptionInformation.Policy)
if err != nil {
errToReturn := fmt.Errorf("doPayloadKeyUnwrap splitKey.rewrap failed: %w", err)
if strings.Contains(err.Error(), codes.InvalidArgument.String()) {
return fmt.Errorf("%w: %w", ErrRewrapBadRequest, errToReturn)
}
if strings.Contains(err.Error(), codes.PermissionDenied.String()) {
return fmt.Errorf("%w: %w", errRewrapForbidden, errToReturn)
}
return errToReturn

wrappedKey, err = client.unwrap(ctx, keyAccessObj, r.manifest.EncryptionInformation.Policy)
if err != nil {
errToReturn := fmt.Errorf("doPayloadKeyUnwrap splitKey.rewrap failed: %w", err)
if strings.Contains(err.Error(), codes.InvalidArgument.String()) {
return fmt.Errorf("%w: %w", ErrRewrapBadRequest, errToReturn)
}
} else {
knownSplits[ss.SplitID] = true
if foundSplits[ss.SplitID] {
// already found
continue
}
wrappedKey, err = client.unwrap(ctx, keyAccessObj, r.manifest.EncryptionInformation.Policy)
if err != nil {
errToReturn := fmt.Errorf("kao unwrap failed for split %v: %w", ss, err)
if !strings.Contains(err.Error(), codes.InvalidArgument.String()) {
skippedSplits[ss] = fmt.Errorf("%w: %w", ErrRewrapBadRequest, errToReturn)
}
if !strings.Contains(err.Error(), codes.PermissionDenied.String()) {
skippedSplits[ss] = fmt.Errorf("%w: %w", errRewrapForbidden, errToReturn)
}
skippedSplits[ss] = errToReturn
continue
if strings.Contains(err.Error(), codes.PermissionDenied.String()) {
return fmt.Errorf("%w: %w", errRewrapForbidden, errToReturn)
}
return errToReturn
}

for keyByteIndex, keyByte := range wrappedKey {
payloadKey[keyByteIndex] ^= keyByte
}
foundSplits[ss.SplitID] = true
copy(payloadKey[:], wrappedKey)

if len(keyAccessObj.EncryptedMetadata) != 0 {
gcm, err := ocrypto.NewAESGcm(wrappedKey)
Expand Down Expand Up @@ -863,10 +839,6 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn

unencryptedMetadata = metaData
}

if r.manifest.EncryptionInformation.KeyAccessObjs[0].SplitID == "" {
break
}
}

if mixedSplits && len(knownSplits) > len(foundSplits) {
Expand Down Expand Up @@ -948,7 +920,8 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn

base64Hash := ocrypto.Base64Encode(completeHashBuilder.Bytes())

if string(hashOfAssertion) != assertionHash {
stras := string(hashOfAssertion)
if stras != assertionHash {
return fmt.Errorf("%w: assertion hash missmatch", ErrAssertionFailure{ID: assertion.ID})
}

Expand Down
Loading

0 comments on commit 15c66d7

Please sign in to comment.