From 46f53434a160f94ca904d78e90f4153dde33c60d Mon Sep 17 00:00:00 2001 From: Simone Magnani Date: Mon, 19 Feb 2024 11:59:30 +0100 Subject: [PATCH] introducing Map.Drain API to traverse a map while also deleting entries This commit introduces the `Map.Drain` API to traverse the map while also removing its entries. It leverages the same `MapIterator` structure, with the introduction of a new unexported method to handle the map draining. The tests make sure that the behavior is as expected, and that this API returns an error while invoked on the wrong map, such as arrays, for which `Map.Iterate` should be used instead. The `LookupAndDelete` system call support has been introduced in: 1. 5.14 for BPF_MAP_TYPE_HASH, BPF_MAP_TYPE_PERCPU_HASH, BPF_MAP_TYPE_LRU_HASH and BPF_MAP_TYPE_LRU_PERCPU_HASH. 2. 4.20 for BPF_MAP_TYPE_QUEUE, BPF_MAP_TYPE_STACK Do not expect the `Map.Drain` API to work on prior versions, according to the target map type. From the user perspective, the usage should be similar to `Map.Iterate`, as shown as follows: ```go m, err := NewMap(&MapSpec{ Type: Hash, KeySize: 4, ValueSize: 8, MaxEntries: 10, }) // populate here the map and defer close it := m.Drain() for it.Next(keyPtr, &value) { // here the entry doesn't exist anymore in the underlying map. ... } ``` Signed-off-by: Simone Magnani --- map.go | 133 +++++++++++++++++++++++++++++++++++++++++++--------- map_test.go | 133 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+), 21 deletions(-) diff --git a/map.go b/map.go index c5010b419..8256289d4 100644 --- a/map.go +++ b/map.go @@ -1344,10 +1344,31 @@ func batchCount(keys, values any) (int, error) { // // It's not possible to guarantee that all keys in a map will be // returned if there are concurrent modifications to the map. +// +// Iterating a hash map from which keys are being deleted is not +// safe. You may see the same key multiple times. Iteration may +// also abort with an error, see IsIterationAborted. +// +// Iterating a queue/stack map returns an error (NextKey invalid +// argument): [Map.Drain] API should be used instead. func (m *Map) Iterate() *MapIterator { return newMapIterator(m) } +// Drain traverses a map while also removing entries. +// +// It's safe to create multiple drainers at the same time, +// but their respective outputs will differ. +// +// Draining a map that does not support entry removal such as +// an array return an error (LookupAndDelete not supported): +// [Map.Iterate] API should be used instead. +func (m *Map) Drain() *MapIterator { + it := newMapIterator(m) + it.drain = true + return it +} + // Close the Map's underlying file descriptor, which could unload the // Map from the kernel if it is not pinned or in use by a loaded Program. func (m *Map) Close() error { @@ -1602,6 +1623,12 @@ func marshalMap(m *Map, length int) ([]byte, error) { return buf, nil } +// isKeyValueMap returns true if map supports key-value pairs (ex. hash) +// and false in case of value-only maps (ex. queue). +func isKeyValueMap(m *Map) bool { + return m.keySize != 0 +} + // MapIterator iterates a Map. // // See Map.Iterate. @@ -1611,7 +1638,7 @@ type MapIterator struct { // of []byte to avoid allocations. cursor any count, maxEntries uint32 - done bool + done, drain bool err error } @@ -1622,12 +1649,56 @@ func newMapIterator(target *Map) *MapIterator { } } +// cursorToKeyOut copies the current value held in the cursor to the +// provided argument. In case of errors, returns false and sets a +// non-nil error in the MapIterator. +func (mi *MapIterator) cursorToKeyOut(keyOut interface{}) bool { + buf := mi.cursor.([]byte) + if ptr, ok := keyOut.(unsafe.Pointer); ok { + copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) + } else { + mi.err = sysenc.Unmarshal(keyOut, buf) + } + return mi.err == nil +} + +// fetchNextKey loads into the cursor the key following the provided one. +func (mi *MapIterator) fetchNextKey(key interface{}) bool { + mi.err = mi.target.NextKey(key, mi.cursor) + if mi.err == nil { + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.done = true + mi.err = nil + } else { + mi.err = fmt.Errorf("get next key: %w", mi.err) + } + + return false +} + +// drainMapEntry removes and returns the key held in the cursor +// from the underlying map. +func (mi *MapIterator) drainMapEntry(valueOut interface{}) bool { + mi.err = mi.target.LookupAndDelete(mi.cursor, valueOut) + if mi.err == nil { + mi.count++ + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.err = nil + } else { + mi.err = fmt.Errorf("lookup_and_delete key: %w", mi.err) + } + + return false +} + // Next decodes the next key and value. // -// Iterating a hash map from which keys are being deleted is not -// safe. You may see the same key multiple times. Iteration may -// also abort with an error, see IsIterationAborted. -// // Returns false if there are no more entries. You must check // the result of Err afterwards. // @@ -1636,6 +1707,38 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { if mi.err != nil || mi.done { return false } + if mi.drain { + return mi.nextDrain(keyOut, valueOut) + } + return mi.nextIterate(keyOut, valueOut) +} + +func (mi *MapIterator) nextDrain(keyOut, valueOut interface{}) bool { + // Handle value-only maps (ex. queue). + if !isKeyValueMap(mi.target) { + if keyOut != nil { + mi.err = fmt.Errorf("non-nil keyOut provided for map without a key, must be nil instead") + return false + } + return mi.drainMapEntry(valueOut) + } + + if mi.cursor == nil { + mi.cursor = make([]byte, mi.target.keySize) + } + + // Always retrieve first key in the map. This should ensure that the whole map + // is traversed, despite concurrent operations (ordering of items might differ). + for mi.err == nil && mi.fetchNextKey(nil) { + if mi.drainMapEntry(valueOut) { + return mi.cursorToKeyOut(keyOut) + } + } + return false +} + +func (mi *MapIterator) nextIterate(keyOut, valueOut interface{}) bool { + var key interface{} // For array-like maps NextKey returns nil only after maxEntries // iterations. @@ -1645,17 +1748,12 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { // is returned. If we pass an uninitialized []byte instead, it'll see a // non-nil interface and try to marshal it. mi.cursor = make([]byte, mi.target.keySize) - mi.err = mi.target.NextKey(nil, mi.cursor) + key = nil } else { - mi.err = mi.target.NextKey(mi.cursor, mi.cursor) + key = mi.cursor } - if errors.Is(mi.err, ErrKeyNotExist) { - mi.done = true - mi.err = nil - return false - } else if mi.err != nil { - mi.err = fmt.Errorf("get next key: %w", mi.err) + if !mi.fetchNextKey(key) { return false } @@ -1677,14 +1775,7 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { return false } - buf := mi.cursor.([]byte) - if ptr, ok := keyOut.(unsafe.Pointer); ok { - copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) - } else { - mi.err = sysenc.Unmarshal(keyOut, buf) - } - - return mi.err == nil + return mi.cursorToKeyOut(keyOut) } mi.err = fmt.Errorf("%w", ErrIterationAborted) diff --git a/map_test.go b/map_test.go index dc66a205e..659c4933b 100644 --- a/map_test.go +++ b/map_test.go @@ -1174,6 +1174,139 @@ func TestMapIteratorAllocations(t *testing.T) { qt.Assert(t, qt.Equals(allocs, float64(0))) } +func TestMapDrain(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + keySize, value uint32 + keyPtr interface{} + values = []uint32{} + data = []uint32{0, 1} + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keyPtr = nil + keySize = 0 + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + keyPtr = new(uint32) + keySize = 4 + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 2, + }) + qt.Assert(t, qt.IsNil(err)) + defer m.Close() + + // Assert drain empty map. + entries := m.Drain() + qt.Assert(t, qt.IsFalse(entries.Next(keyPtr, &value))) + qt.Assert(t, qt.IsNil(entries.Err())) + + for _, v := range data { + if keySize == 0 { + err = m.Put(nil, uint32(v)) + } else { + err = m.Put(uint32(v), uint32(v)) + } + qt.Assert(t, qt.IsNil(err)) + } + + entries = m.Drain() + for entries.Next(keyPtr, &value) { + values = append(values, value) + } + qt.Assert(t, qt.IsNil(entries.Err())) + + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + qt.Assert(t, qt.DeepEquals(values, data)) + }) + } +} + +func TestDrainWrongMap(t *testing.T) { + arr, err := NewMap(&MapSpec{ + Type: Array, + KeySize: 4, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.IsNil(err)) + defer arr.Close() + + var key, value uint32 + entries := arr.Drain() + + qt.Assert(t, qt.IsFalse(entries.Next(&key, &value))) + qt.Assert(t, qt.IsNotNil(entries.Err())) + fmt.Println(entries.Err()) +} + +func TestMapDrainerAllocations(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + keySize, value uint32 + keyPtr interface{} + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keyPtr = nil + keySize = 0 + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + keyPtr = new(uint32) + keySize = 4 + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.ErrorIs(err, nil)) + defer m.Close() + + for i := 0; i < int(m.MaxEntries()); i++ { + if keySize == 0 { + err = m.Put(nil, uint32(i)) + } else { + err = m.Put(uint32(i), uint32(i)) + } + if err != nil { + t.Fatal(err) + } + } + + iter := m.Drain() + allocs := testing.AllocsPerRun(int(m.MaxEntries()-1), func() { + if !iter.Next(keyPtr, &value) { + t.Fatal("Next failed while draining: %w", iter.Err()) + } + }) + + qt.Assert(t, qt.Equals(allocs, float64(0))) + }) + } +} + func TestMapBatchLookupAllocations(t *testing.T) { testutils.SkipIfNotSupported(t, haveBatchAPI())