diff --git a/map.go b/map.go index c5010b419..8256289d4 100644 --- a/map.go +++ b/map.go @@ -1344,10 +1344,31 @@ func batchCount(keys, values any) (int, error) { // // It's not possible to guarantee that all keys in a map will be // returned if there are concurrent modifications to the map. +// +// Iterating a hash map from which keys are being deleted is not +// safe. You may see the same key multiple times. Iteration may +// also abort with an error, see IsIterationAborted. +// +// Iterating a queue/stack map returns an error (NextKey invalid +// argument): [Map.Drain] API should be used instead. func (m *Map) Iterate() *MapIterator { return newMapIterator(m) } +// Drain traverses a map while also removing entries. +// +// It's safe to create multiple drainers at the same time, +// but their respective outputs will differ. +// +// Draining a map that does not support entry removal such as +// an array return an error (LookupAndDelete not supported): +// [Map.Iterate] API should be used instead. +func (m *Map) Drain() *MapIterator { + it := newMapIterator(m) + it.drain = true + return it +} + // Close the Map's underlying file descriptor, which could unload the // Map from the kernel if it is not pinned or in use by a loaded Program. func (m *Map) Close() error { @@ -1602,6 +1623,12 @@ func marshalMap(m *Map, length int) ([]byte, error) { return buf, nil } +// isKeyValueMap returns true if map supports key-value pairs (ex. hash) +// and false in case of value-only maps (ex. queue). +func isKeyValueMap(m *Map) bool { + return m.keySize != 0 +} + // MapIterator iterates a Map. // // See Map.Iterate. @@ -1611,7 +1638,7 @@ type MapIterator struct { // of []byte to avoid allocations. cursor any count, maxEntries uint32 - done bool + done, drain bool err error } @@ -1622,12 +1649,56 @@ func newMapIterator(target *Map) *MapIterator { } } +// cursorToKeyOut copies the current value held in the cursor to the +// provided argument. In case of errors, returns false and sets a +// non-nil error in the MapIterator. +func (mi *MapIterator) cursorToKeyOut(keyOut interface{}) bool { + buf := mi.cursor.([]byte) + if ptr, ok := keyOut.(unsafe.Pointer); ok { + copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) + } else { + mi.err = sysenc.Unmarshal(keyOut, buf) + } + return mi.err == nil +} + +// fetchNextKey loads into the cursor the key following the provided one. +func (mi *MapIterator) fetchNextKey(key interface{}) bool { + mi.err = mi.target.NextKey(key, mi.cursor) + if mi.err == nil { + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.done = true + mi.err = nil + } else { + mi.err = fmt.Errorf("get next key: %w", mi.err) + } + + return false +} + +// drainMapEntry removes and returns the key held in the cursor +// from the underlying map. +func (mi *MapIterator) drainMapEntry(valueOut interface{}) bool { + mi.err = mi.target.LookupAndDelete(mi.cursor, valueOut) + if mi.err == nil { + mi.count++ + return true + } + + if errors.Is(mi.err, ErrKeyNotExist) { + mi.err = nil + } else { + mi.err = fmt.Errorf("lookup_and_delete key: %w", mi.err) + } + + return false +} + // Next decodes the next key and value. // -// Iterating a hash map from which keys are being deleted is not -// safe. You may see the same key multiple times. Iteration may -// also abort with an error, see IsIterationAborted. -// // Returns false if there are no more entries. You must check // the result of Err afterwards. // @@ -1636,6 +1707,38 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { if mi.err != nil || mi.done { return false } + if mi.drain { + return mi.nextDrain(keyOut, valueOut) + } + return mi.nextIterate(keyOut, valueOut) +} + +func (mi *MapIterator) nextDrain(keyOut, valueOut interface{}) bool { + // Handle value-only maps (ex. queue). + if !isKeyValueMap(mi.target) { + if keyOut != nil { + mi.err = fmt.Errorf("non-nil keyOut provided for map without a key, must be nil instead") + return false + } + return mi.drainMapEntry(valueOut) + } + + if mi.cursor == nil { + mi.cursor = make([]byte, mi.target.keySize) + } + + // Always retrieve first key in the map. This should ensure that the whole map + // is traversed, despite concurrent operations (ordering of items might differ). + for mi.err == nil && mi.fetchNextKey(nil) { + if mi.drainMapEntry(valueOut) { + return mi.cursorToKeyOut(keyOut) + } + } + return false +} + +func (mi *MapIterator) nextIterate(keyOut, valueOut interface{}) bool { + var key interface{} // For array-like maps NextKey returns nil only after maxEntries // iterations. @@ -1645,17 +1748,12 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { // is returned. If we pass an uninitialized []byte instead, it'll see a // non-nil interface and try to marshal it. mi.cursor = make([]byte, mi.target.keySize) - mi.err = mi.target.NextKey(nil, mi.cursor) + key = nil } else { - mi.err = mi.target.NextKey(mi.cursor, mi.cursor) + key = mi.cursor } - if errors.Is(mi.err, ErrKeyNotExist) { - mi.done = true - mi.err = nil - return false - } else if mi.err != nil { - mi.err = fmt.Errorf("get next key: %w", mi.err) + if !mi.fetchNextKey(key) { return false } @@ -1677,14 +1775,7 @@ func (mi *MapIterator) Next(keyOut, valueOut interface{}) bool { return false } - buf := mi.cursor.([]byte) - if ptr, ok := keyOut.(unsafe.Pointer); ok { - copy(unsafe.Slice((*byte)(ptr), len(buf)), buf) - } else { - mi.err = sysenc.Unmarshal(keyOut, buf) - } - - return mi.err == nil + return mi.cursorToKeyOut(keyOut) } mi.err = fmt.Errorf("%w", ErrIterationAborted) diff --git a/map_test.go b/map_test.go index dc66a205e..659c4933b 100644 --- a/map_test.go +++ b/map_test.go @@ -1174,6 +1174,139 @@ func TestMapIteratorAllocations(t *testing.T) { qt.Assert(t, qt.Equals(allocs, float64(0))) } +func TestMapDrain(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + keySize, value uint32 + keyPtr interface{} + values = []uint32{} + data = []uint32{0, 1} + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keyPtr = nil + keySize = 0 + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + keyPtr = new(uint32) + keySize = 4 + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 2, + }) + qt.Assert(t, qt.IsNil(err)) + defer m.Close() + + // Assert drain empty map. + entries := m.Drain() + qt.Assert(t, qt.IsFalse(entries.Next(keyPtr, &value))) + qt.Assert(t, qt.IsNil(entries.Err())) + + for _, v := range data { + if keySize == 0 { + err = m.Put(nil, uint32(v)) + } else { + err = m.Put(uint32(v), uint32(v)) + } + qt.Assert(t, qt.IsNil(err)) + } + + entries = m.Drain() + for entries.Next(keyPtr, &value) { + values = append(values, value) + } + qt.Assert(t, qt.IsNil(entries.Err())) + + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + qt.Assert(t, qt.DeepEquals(values, data)) + }) + } +} + +func TestDrainWrongMap(t *testing.T) { + arr, err := NewMap(&MapSpec{ + Type: Array, + KeySize: 4, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.IsNil(err)) + defer arr.Close() + + var key, value uint32 + entries := arr.Drain() + + qt.Assert(t, qt.IsFalse(entries.Next(&key, &value))) + qt.Assert(t, qt.IsNotNil(entries.Err())) + fmt.Println(entries.Err()) +} + +func TestMapDrainerAllocations(t *testing.T) { + for _, mapType := range []MapType{ + Hash, + Queue, + } { + t.Run(mapType.String(), func(t *testing.T) { + var ( + keySize, value uint32 + keyPtr interface{} + ) + + if mapType == Queue { + testutils.SkipOnOldKernel(t, "4.20", "map type queue") + keyPtr = nil + keySize = 0 + } + + if mapType == Hash { + testutils.SkipOnOldKernel(t, "5.14", "map type hash") + keyPtr = new(uint32) + keySize = 4 + } + + m, err := NewMap(&MapSpec{ + Type: mapType, + KeySize: keySize, + ValueSize: 4, + MaxEntries: 10, + }) + qt.Assert(t, qt.ErrorIs(err, nil)) + defer m.Close() + + for i := 0; i < int(m.MaxEntries()); i++ { + if keySize == 0 { + err = m.Put(nil, uint32(i)) + } else { + err = m.Put(uint32(i), uint32(i)) + } + if err != nil { + t.Fatal(err) + } + } + + iter := m.Drain() + allocs := testing.AllocsPerRun(int(m.MaxEntries()-1), func() { + if !iter.Next(keyPtr, &value) { + t.Fatal("Next failed while draining: %w", iter.Err()) + } + }) + + qt.Assert(t, qt.Equals(allocs, float64(0))) + }) + } +} + func TestMapBatchLookupAllocations(t *testing.T) { testutils.SkipIfNotSupported(t, haveBatchAPI())