From 3cf2764344b1ffc926dbf8f7f33aaf15ec0196a8 Mon Sep 17 00:00:00 2001 From: Margaret Ma Date: Mon, 3 Feb 2025 14:29:53 -0500 Subject: [PATCH] do not initilize empty labelset --- deployment/address_book.go | 81 +++++-- deployment/address_book_labels.go | 24 +- deployment/address_book_labels_test.go | 70 ++++-- deployment/address_book_test.go | 312 ++++++++++++++++++++++++- 4 files changed, 434 insertions(+), 53 deletions(-) diff --git a/deployment/address_book.go b/deployment/address_book.go index e123a2116bd..fa0e777c177 100644 --- a/deployment/address_book.go +++ b/deployment/address_book.go @@ -10,6 +10,7 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ethereum/go-ethereum/common" "github.com/pkg/errors" + chainsel "github.com/smartcontractkit/chain-selectors" ) @@ -39,16 +40,10 @@ type TypeAndVersion struct { func (tv TypeAndVersion) String() string { if len(tv.Labels) == 0 { - return fmt.Sprintf("%s %s", tv.Type, tv.Version.String()) + return fmt.Sprintf("%s %s", tv.Type, tv.Version) } - // Use the LabelSet's String method for sorted labels - sortedLabels := tv.Labels.String() - return fmt.Sprintf("%s %s %s", - tv.Type, - tv.Version.String(), - sortedLabels, - ) + return fmt.Sprintf("%s %s %s", tv.Type, tv.Version, tv.Labels) } func (tv TypeAndVersion) Equal(other TypeAndVersion) bool { @@ -83,7 +78,7 @@ func TypeAndVersionFromString(s string) (TypeAndVersion, error) { if err != nil { return TypeAndVersion{}, err } - labels := make(LabelSet) + var labels LabelSet if len(parts) > 2 { labels = NewLabelSet(parts[2:]...) } @@ -98,8 +93,21 @@ func NewTypeAndVersion(t ContractType, v semver.Version) TypeAndVersion { return TypeAndVersion{ Type: t, Version: v, - Labels: make(LabelSet), // empty set, + Labels: nil, + } +} + +// DeepClone returns a copy of the TypeAndVersion struct with its Labels cloned. +func (tv TypeAndVersion) DeepClone() TypeAndVersion { + // Make a shallow copy first + out := tv + + // Now deep-copy the Labels map + if tv.Labels != nil { + out.Labels = tv.Labels.DeepClone() } + + return out } // AddressBook is a simple interface for storing and retrieving contract addresses across @@ -172,7 +180,7 @@ func (m *AddressBookMap) Addresses() (map[uint64]map[string]TypeAndVersion, erro // maps are mutable and pass via a pointer // creating a copy of the map to prevent concurrency // read and changes outside object-bound - return m.cloneAddresses(m.addressesByChain), nil + return m.deepCloneAddresses(m.addressesByChain), nil } func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]TypeAndVersion, error) { @@ -245,7 +253,7 @@ func (m *AddressBookMap) Remove(ab AddressBook) error { return nil } -// cloneAddresses creates a deep copy of map[uint64]map[string]TypeAndVersion object +// cloneAddresses creates a shallow copy of map[uint64]map[string]TypeAndVersion object func (m *AddressBookMap) cloneAddresses(input map[uint64]map[string]TypeAndVersion) map[uint64]map[string]TypeAndVersion { result := make(map[uint64]map[string]TypeAndVersion) for chainSelector, chainAddresses := range input { @@ -254,6 +262,23 @@ func (m *AddressBookMap) cloneAddresses(input map[uint64]map[string]TypeAndVersi return result } +// deepCloneAddresses creates a deep copy of map[uint64]map[string]TypeAndVersion object +func (m *AddressBookMap) deepCloneAddresses( + input map[uint64]map[string]TypeAndVersion, +) map[uint64]map[string]TypeAndVersion { + result := make(map[uint64]map[string]TypeAndVersion, len(input)) + for chainSelector, chainAddresses := range input { + // Make a new map for the nested addresses + newChainMap := make(map[string]TypeAndVersion, len(chainAddresses)) + for addr, tv := range chainAddresses { + // Use the DeepClone method on the TypeAndVersion + newChainMap[addr] = tv.DeepClone() + } + result[chainSelector] = newChainMap + } + return result +} + // TODO: Maybe could add an environment argument // which would ensure only mainnet/testnet chain selectors are used // for further safety? @@ -307,11 +332,15 @@ type typeVersionKey struct { } func tvKey(tv TypeAndVersion) typeVersionKey { - sortedLabels := tv.Labels.String() + var labels string + if tv.Labels != nil { + labels = tv.Labels.String() + } + return typeVersionKey{ Type: tv.Type, Version: tv.Version.String(), - Labels: sortedLabels, + Labels: labels, } } @@ -328,8 +357,12 @@ func AddressesContainBundle(addrs map[string]TypeAndVersion, wantTypes []TypeAnd // They match exactly (Type, Version, Labels) counts[wantKey]++ if counts[wantKey] > 1 { + var labels string + if wantTV.Labels != nil { + labels = wantTV.Labels.String() + } return false, fmt.Errorf("found more than one instance of contract %s %s (labels=%s)", - wantTV.Type, wantTV.Version.String(), wantTV.Labels.String()) + wantTV.Type, wantTV.Version, labels) } } } @@ -340,9 +373,23 @@ func AddressesContainBundle(addrs map[string]TypeAndVersion, wantTypes []TypeAnd } // AddLabel adds a string to the LabelSet in the TypeAndVersion. -func (tv *TypeAndVersion) AddLabel(label string) { +func (tv *TypeAndVersion) AddLabel(label ...string) { if tv.Labels == nil { tv.Labels = make(LabelSet) } - tv.Labels.Add(label) + tv.Labels.Add(label...) +} + +func (tv *TypeAndVersion) RemoveLabel(label ...string) { + if tv.Labels == nil { + return + } + tv.Labels.Remove(label...) +} + +func (tv *TypeAndVersion) LabelsString() string { + if tv.Labels == nil { + return "" + } + return tv.Labels.String() } diff --git a/deployment/address_book_labels.go b/deployment/address_book_labels.go index f559a39078a..27d0717e971 100644 --- a/deployment/address_book_labels.go +++ b/deployment/address_book_labels.go @@ -18,13 +18,17 @@ func NewLabelSet(labels ...string) LabelSet { } // Add inserts a labels into the set. -func (ls LabelSet) Add(labels string) { - ls[labels] = struct{}{} +func (ls LabelSet) Add(labels ...string) { + for _, label := range labels { + ls[label] = struct{}{} + } } // Remove deletes a labels from the set, if it exists. -func (ls LabelSet) Remove(labels string) { - delete(ls, labels) +func (ls LabelSet) Remove(labels ...string) { + for _, label := range labels { + delete(ls, label) + } } // Contains checks if the set contains the given labels. @@ -65,3 +69,15 @@ func (ls LabelSet) Equal(other LabelSet) bool { } return true } + +// DeepClone returns a copy of the LabelSet. +func (ls LabelSet) DeepClone() LabelSet { + if ls == nil { + return nil + } + out := make(LabelSet, len(ls)) + for label := range ls { + out[label] = struct{}{} + } + return out +} diff --git a/deployment/address_book_labels_test.go b/deployment/address_book_labels_test.go index f42e3568cba..743dd9048ec 100644 --- a/deployment/address_book_labels_test.go +++ b/deployment/address_book_labels_test.go @@ -8,51 +8,71 @@ import ( func TestNewLabelSet(t *testing.T) { t.Run("no labels", func(t *testing.T) { - ms := NewLabelSet() - assert.Empty(t, ms, "expected empty set") + ls := NewLabelSet() + assert.Empty(t, ls, "expected empty set") }) t.Run("some labels", func(t *testing.T) { - ms := NewLabelSet("foo", "bar") - assert.Len(t, ms, 2) - assert.True(t, ms.Contains("foo")) - assert.True(t, ms.Contains("bar")) - assert.False(t, ms.Contains("baz")) + ls := NewLabelSet("foo", "bar") + assert.Len(t, ls, 2) + assert.True(t, ls.Contains("foo")) + assert.True(t, ls.Contains("bar")) + assert.False(t, ls.Contains("baz")) }) } func TestLabelSet_Add(t *testing.T) { - ms := NewLabelSet("initial") - ms.Add("new") + ls := NewLabelSet("initial") + ls.Add("new") - assert.True(t, ms.Contains("initial"), "expected 'initial' in set") - assert.True(t, ms.Contains("new"), "expected 'new' in set") - assert.Len(t, ms, 2, "expected 2 distinct labels in set") + assert.True(t, ls.Contains("initial"), "expected 'initial' in set") + assert.True(t, ls.Contains("new"), "expected 'new' in set") + assert.Len(t, ls, 2, "expected 2 distinct labels in set") // Add duplicate "new" again; size should remain 2 - ms.Add("new") - assert.Len(t, ms, 2, "expected size to remain 2 after adding a duplicate") + ls.Add("new") + assert.Len(t, ls, 2, "expected size to remain 2 after adding a duplicate") + + // Add multiple labels at once + ls.Add("label1", "label2", "label3") + assert.Len(t, ls, 5, "expected 5 distinct labels in set") // 2 previous + 3 new + assert.True(t, ls.Contains("label1")) + assert.True(t, ls.Contains("label2")) + assert.True(t, ls.Contains("label3")) } func TestLabelSet_Remove(t *testing.T) { - ms := NewLabelSet("remove_me", "keep") - ms.Remove("remove_me") + ls := NewLabelSet("remove_me", "keep", "label1", "label2", "label3") + ls.Remove("remove_me") - assert.False(t, ms.Contains("remove_me"), "expected 'remove_me' to be removed") - assert.True(t, ms.Contains("keep"), "expected 'keep' to remain") - assert.Len(t, ms, 1, "expected set size to be 1 after removal") + assert.False(t, ls.Contains("remove_me"), "expected 'remove_me' to be removed") + assert.True(t, ls.Contains("keep"), "expected 'keep' to remain") + assert.True(t, ls.Contains("label1"), "expected 'label1' to remain") + assert.True(t, ls.Contains("label2"), "expected 'label2' to remain") + assert.True(t, ls.Contains("label3"), "expected 'label3' to remain") + assert.Len(t, ls, 4, "expected set size to be 4 after removal") // Removing a non-existent item shouldn't change the size - ms.Remove("non_existent") - assert.Len(t, ms, 1, "expected size to remain 1 after removing a non-existent item") + ls.Remove("non_existent") + assert.Len(t, ls, 4, "expected size to remain 4 after removing a non-existent item") + + // Remove multiple labels at once + ls.Remove("label2", "label4") + + assert.Len(t, ls, 3, "expected 3 distinct labels in set after removal") // keep, label1, label3 + assert.True(t, ls.Contains("keep")) + assert.True(t, ls.Contains("label1")) + assert.False(t, ls.Contains("label2")) + assert.True(t, ls.Contains("label3")) + assert.False(t, ls.Contains("label4")) } func TestLabelSet_Contains(t *testing.T) { - ms := NewLabelSet("foo", "bar") + ls := NewLabelSet("foo", "bar") - assert.True(t, ms.Contains("foo")) - assert.True(t, ms.Contains("bar")) - assert.False(t, ms.Contains("baz")) + assert.True(t, ls.Contains("foo")) + assert.True(t, ls.Contains("bar")) + assert.False(t, ls.Contains("baz")) } // TestLabelSet_String tests the String() method of the LabelSet type. diff --git a/deployment/address_book_test.go b/deployment/address_book_test.go index 0c6b228da2e..090aa2e7e4e 100644 --- a/deployment/address_book_test.go +++ b/deployment/address_book_test.go @@ -7,11 +7,203 @@ import ( "github.com/Masterminds/semver/v3" "github.com/ethereum/go-ethereum/common" - chainsel "github.com/smartcontractkit/chain-selectors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + chainsel "github.com/smartcontractkit/chain-selectors" ) +func TestTypeAndVersion_NewTypeAndVersion(t *testing.T) { + contractType := ContractType("TestContract") + version := semver.MustParse("1.0.0") + + tv1 := NewTypeAndVersion(contractType, *version) + tv2 := NewTypeAndVersion(contractType, *version) + tv3 := TypeAndVersion{ + Type: "TestContract", + Version: *version, + } + + assert.True(t, tv1.Equal(tv2), "expected tv1 to be equal to tv2") + assert.Equal(t, tv1, tv3, "expected tv1 to be equal to tv3") +} + +func TestTypeAndVersion_String(t *testing.T) { + contractType := ContractType("TestContract") + version := semver.MustParse("1.0.0") + + tests := []struct { + name string + tv TypeAndVersion + expected string + }{ + { + name: "Nil labels", + tv: TypeAndVersion{ + Type: contractType, + Version: *version, + Labels: nil, + }, + expected: "TestContract 1.0.0", + }, + { + name: "Empty labels", + tv: TypeAndVersion{ + Type: contractType, + Version: *version, + Labels: make(LabelSet), + }, + expected: "TestContract 1.0.0", + }, + { + name: "With labels", + tv: TypeAndVersion{ + Type: contractType, + Version: *version, + Labels: NewLabelSet("alpha", "beta"), + }, + expected: "TestContract 1.0.0 alpha beta", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, tt.tv.String(), "unexpected string representation") + }) + } +} + +func TestTypeAndVersion_DeepClone(t *testing.T) { + tests := []struct { + name string + input TypeAndVersion + mutate func(tv *TypeAndVersion) + wantEqual bool + }{ + { + name: "No labels", + input: TypeAndVersion{ + Type: "MyContract", + Version: *semver.MustParse("1.2.3"), + Labels: nil, + }, + mutate: func(tv *TypeAndVersion) { + tv.Type = "Mutated" + tv.Version = *semver.MustParse("9.9.9") + }, + wantEqual: false, + }, + { + name: "With labels", + input: TypeAndVersion{ + Type: "AnotherContract", + Version: *semver.MustParse("2.0.1"), + Labels: NewLabelSet("fast", "secure"), + }, + mutate: func(tv *TypeAndVersion) { + tv.Labels.Add("new-label") + }, + wantEqual: false, + }, + { + name: "Empty label set", + input: TypeAndVersion{ + Type: "EmptyLabelContract", + Version: *semver.MustParse("0.1.0"), + Labels: NewLabelSet(), // empty, but allocated + }, + mutate: func(tv *TypeAndVersion) { + tv.Labels.Add("test-label") + }, + wantEqual: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clone + clone := tt.input.DeepClone() + + // Before mutation, the clone should be Equal to the original + assert.True(t, tt.input.Equal(clone), + "DeepClone result should initially match the input") + + // Mutate the clone + tt.mutate(&clone) + + // If wantEqual is false, the original should differ from the mutated clone + if !tt.wantEqual { + assert.False(t, tt.input.Equal(clone), + "Mutating the clone should not affect the original if deep-cloned") + } else { + assert.True(t, tt.input.Equal(clone), + "Mutating the clone incorrectly affected the original") + } + }) + } +} + +func TestAddressBookMap_DeepCloneAddresses(t *testing.T) { + // Prepare some TypeAndVersion items + tvA := TypeAndVersion{ + Type: "ContractA", + Version: *semver.MustParse("1.0.0"), + Labels: NewLabelSet("labelA"), + } + tvB := TypeAndVersion{ + Type: "ContractB", + Version: *semver.MustParse("1.1.0"), + Labels: NewLabelSet("labelB1", "labelB2"), + } + + // Build our sample input + inputMap := map[uint64]map[string]TypeAndVersion{ + 111: { + "0x1234": tvA, + }, + 222: { + "0xABCD": tvB, + }, + } + + ab := NewMemoryAddressBookFromMap(inputMap) + + // Addresses() is supposed to return a deep clone + clonedAddrs, err := ab.Addresses() + require.NoError(t, err) + + // Now mutate something in the clone to see if the original is affected + clonedAddrs[111]["0x1234"] = TypeAndVersion{ + Type: "MutatedType", + Version: *semver.MustParse("9.9.9"), + Labels: NewLabelSet("mutated"), + } + + // Check original is not mutated + originalAddrs, err := ab.Addresses() + require.NoError(t, err) + + // The original 111 -> 0x1234 should remain tvA + assert.Equal(t, tvA, originalAddrs[111]["0x1234"], + "Mutating cloned addresses must not affect the original") + + // Also check that the `Labels` inside each TypeAndVersion are deeply cloned + // For example, add a label to the clone's tvB + cloneTvB := clonedAddrs[222]["0xABCD"] + cloneTvB.Labels.Add("extra-label") + clonedAddrs[222]["0xABCD"] = cloneTvB + + // Now see if the original's version is unchanged + originalTvB := originalAddrs[222]["0xABCD"] + assert.False(t, originalTvB.Labels.Contains("extra-label"), + "Original TypeAndVersion's Labels should not reflect changes to the clone") + + // Optionally, ensure the rest of the original is still correct + assert.Equal(t, tvB, originalTvB, + "Original TypeAndVersion for 222 -> 0xABCD should remain unchanged") +} + func TestAddressBook_Save(t *testing.T) { ab := NewMemoryAddressBook() onRamp100 := NewTypeAndVersion("OnRamp", Version1_0_0) @@ -254,8 +446,8 @@ func TestAddressesContainBundle(t *testing.T) { // Create one with labels onRamp100WithLabels := NewTypeAndVersion("OnRamp", Version1_0_0) - onRamp100WithLabels.Labels.Add("sa") - onRamp100WithLabels.Labels.Add("staging") + onRamp100WithLabels.AddLabel("sa") + onRamp100WithLabels.AddLabel("staging") addr1 := common.HexToAddress("0x1").String() addr2 := common.HexToAddress("0x2").String() @@ -264,7 +456,7 @@ func TestAddressesContainBundle(t *testing.T) { tests := []struct { name string addrs map[string]TypeAndVersion // input address map - wantTypes []TypeAndVersion // the “bundle” we want + wantTypes []TypeAndVersion // the "bundle" we want wantErr bool wantErrMsg string wantResult bool // expected boolean return when no error @@ -332,7 +524,7 @@ func TestAddressesContainBundle(t *testing.T) { } for _, tt := range tests { - tt := tt // capture range variable + tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() @@ -404,7 +596,7 @@ func TestTypeAndVersionFromString(t *testing.T) { } for _, tt := range tests { - tt := tt // capture range variable + tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() @@ -464,7 +656,7 @@ func TestTypeAndVersion_AddLabels(t *testing.T) { } for _, tt := range tests { - tt := tt // capture range variable + tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() @@ -491,3 +683,109 @@ func TestTypeAndVersion_AddLabels(t *testing.T) { }) } } + +func TestTypeAndVersion_RemoveLabel(t *testing.T) { + contractType := ContractType("TestContract") + version := semver.MustParse("1.0.0") + + tests := []struct { + name string + initialLabels []string + toRemove []string + wantLabels LabelSet + }{ + { + name: "Remove from nil labels", + initialLabels: nil, + toRemove: []string{"alpha"}, + wantLabels: NewLabelSet(), + }, + { + name: "Remove from empty labels", + initialLabels: []string{}, + toRemove: []string{"alpha"}, + wantLabels: NewLabelSet(), + }, + { + name: "Remove existing label", + initialLabels: []string{"alpha", "beta"}, + toRemove: []string{"alpha"}, + wantLabels: NewLabelSet("beta"), + }, + { + name: "Remove non-existing label", + initialLabels: []string{"alpha"}, + toRemove: []string{"beta"}, + wantLabels: NewLabelSet("alpha"), + }, + { + name: "Remove multiple labels", + initialLabels: []string{"alpha", "beta", "gamma"}, + toRemove: []string{"alpha", "gamma"}, + wantLabels: NewLabelSet("beta"), + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tv := TypeAndVersion{ + Type: contractType, + Version: *version, + } + + if tt.initialLabels != nil { + tv.Labels = NewLabelSet(tt.initialLabels...) + } + + tv.RemoveLabel(tt.toRemove...) + + assert.True(t, tt.wantLabels.Equal(tv.Labels), "unexpected labels after removal") + }) + } +} + +func TestTypeAndVersion_LabelsString(t *testing.T) { + contractType := ContractType("TestContract") + version := semver.MustParse("1.0.0") + + tests := []struct { + name string + labels []string + expected string + }{ + { + name: "Nil labels", + labels: nil, + expected: "", + }, + { + name: "Empty labels", + labels: []string{}, + expected: "", + }, + { + name: "Single label", + labels: []string{"alpha"}, + expected: "alpha", + }, + { + name: "Multiple labels", + labels: []string{"alpha", "beta"}, + expected: "alpha beta", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + tv := TypeAndVersion{ + Type: contractType, + Version: *version, + Labels: NewLabelSet(tt.labels...), + } + + assert.Equal(t, tt.expected, tv.LabelsString(), "unexpected labels string") + }) + } +}