diff --git a/migrations/capcons/capabilities.go b/migrations/capcons/capabilities.go index 8b9314a46e..8b06c93794 100644 --- a/migrations/capcons/capabilities.go +++ b/migrations/capcons/capabilities.go @@ -43,7 +43,7 @@ type Path struct { type AccountCapabilities struct { capabilities []AccountCapability - sortOnce sync.Once + sorted bool } func (c *AccountCapabilities) Record( @@ -63,6 +63,9 @@ func (c *AccountCapabilities) Record( }, }, ) + + // Reset the sorted flag, if new entries are added. + c.sorted = false } // ForEachSorted will first sort the capabilities list, @@ -79,22 +82,24 @@ func (c *AccountCapabilities) ForEachSorted( } func (c *AccountCapabilities) sort() { - c.sortOnce.Do( - func() { - slices.SortFunc( - c.capabilities, - func(a, b AccountCapability) int { - pathA := a.TargetPath - pathB := b.TargetPath - - return cmp.Or( - cmp.Compare(pathA.Domain, pathB.Domain), - strings.Compare(pathA.Identifier, pathB.Identifier), - ) - }, + if c.sorted { + return + } + + slices.SortFunc( + c.capabilities, + func(a, b AccountCapability) int { + pathA := a.TargetPath + pathB := b.TargetPath + + return cmp.Or( + cmp.Compare(pathA.Domain, pathB.Domain), + strings.Compare(pathA.Identifier, pathB.Identifier), ) }, ) + + c.sorted = true } type AccountsCapabilities struct { diff --git a/migrations/capcons/capabilities_test.go b/migrations/capcons/capabilities_test.go index d12599f39e..b4b4869668 100644 --- a/migrations/capcons/capabilities_test.go +++ b/migrations/capcons/capabilities_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" @@ -67,17 +68,50 @@ func TestCapabilitiesIteration(t *testing.T) { nil, ) + require.False(t, caps.sorted) + var paths []interpreter.PathValue + caps.ForEachSorted(func(capability AccountCapability) bool { + paths = append(paths, capability.TargetPath) + return true + }) + + require.True(t, caps.sorted) + + assert.Equal( + t, + []interpreter.PathValue{ + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "b"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "c"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "b"), + }, + paths, + ) + caps.Record( + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "aa"), + nil, + interpreter.StorageKey{}, + nil, + ) + + require.False(t, caps.sorted) + + paths = make([]interpreter.PathValue, 0) caps.ForEachSorted(func(capability AccountCapability) bool { paths = append(paths, capability.TargetPath) return true }) + require.True(t, caps.sorted) + assert.Equal( t, []interpreter.PathValue{ interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "a"), + interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "aa"), interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "b"), interpreter.NewUnmeteredPathValue(common.PathDomainStorage, "c"), interpreter.NewUnmeteredPathValue(common.PathDomainPublic, "a"),