From eea7482ab581845bc2f74a6e0b753e186d3ff5c0 Mon Sep 17 00:00:00 2001 From: Lorenz Bauer Date: Fri, 17 Nov 2023 12:51:17 +0000 Subject: [PATCH] map: move per-CPU validation into unmarshalPerCPUValue Move validation required for reflect based decoding into unmarshalPerCPU value. This ensures that the checks are as close as possible to where they are required. Signed-off-by: Lorenz Bauer --- map.go | 27 +++++++++------------------ marshalers.go | 14 ++++++++------ marshalers_test.go | 8 ++++---- 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/map.go b/map.go index 4a487f14d..dc2524590 100644 --- a/map.go +++ b/map.go @@ -684,27 +684,13 @@ func (m *Map) lookupAndDeletePerCPU(key, valueOut any, flags MapLookupFlags) err return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes) } +// ensurePerCPUSlice allocates a slice for a per-CPU value if necessary. func ensurePerCPUSlice(sliceOrPtr any, elemLength int) (any, error) { - possibleCPUs, err := PossibleCPU() - if err != nil { - return nil, err - } - sliceOrPtrType := reflect.TypeOf(sliceOrPtr) if sliceOrPtrType.Kind() == reflect.Slice { - sliceValue := reflect.ValueOf(sliceOrPtr) - if sliceValue.Len() != possibleCPUs { - return nil, fmt.Errorf("per-cpu slice is incorrect length, expected %d, got %d", - possibleCPUs, sliceValue.Len()) - } - if sliceOrPtrType.Elem().Kind() == reflect.Pointer { - for i := 0; i < sliceValue.Len(); i++ { - if !sliceValue.Index(i).Elem().CanAddr() { - return nil, fmt.Errorf("per-cpu slice elements cannot be nil") - } - } - } - return sliceValue.Interface(), nil + // The target is a slice, the caller is responsible for ensuring that + // size is correct. + return sliceOrPtr, nil } slicePtrType := sliceOrPtrType @@ -712,6 +698,11 @@ func ensurePerCPUSlice(sliceOrPtr any, elemLength int) (any, error) { return nil, fmt.Errorf("per-cpu value requires a slice or a pointer to slice") } + possibleCPUs, err := PossibleCPU() + if err != nil { + return nil, err + } + sliceType := slicePtrType.Elem() slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs) diff --git a/marshalers.go b/marshalers.go index 9a3865ec6..d77a5fb81 100644 --- a/marshalers.go +++ b/marshalers.go @@ -88,7 +88,7 @@ func marshalPerCPUValue(slice any, elemLength int) (sys.Pointer, error) { func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error { sliceType := reflect.TypeOf(slice) if sliceType.Kind() != reflect.Slice { - return fmt.Errorf("per-cpu value requires a slice") + return fmt.Errorf("per-CPU value requires a slice") } possibleCPUs, err := PossibleCPU() @@ -97,12 +97,11 @@ func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error { } sliceValue := reflect.ValueOf(slice) - if sliceValue.Len() < possibleCPUs { - // Should be impossible here from ensurePerCPUSlice(), - // but avoid a panic in the loop. - return fmt.Errorf("per-cpu value slice len %d is less than possibleCPU %d", - sliceValue.Len(), possibleCPUs) + if sliceValue.Len() != possibleCPUs { + return fmt.Errorf("per-CPU slice has incorrect length, expected %d, got %d", + possibleCPUs, sliceValue.Len()) } + sliceElemType := sliceType.Elem() sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr stride := internal.Align(elemLength, 8) @@ -110,6 +109,9 @@ func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error { var elem any v := sliceValue.Index(i) if sliceElemIsPointer { + if !v.Elem().CanAddr() { + return fmt.Errorf("per-CPU slice elements cannot be nil") + } elem = v.Elem().Addr().Interface() } else { elem = v.Addr().Interface() diff --git a/marshalers_test.go b/marshalers_test.go index cb3ca7e0a..6695791b3 100644 --- a/marshalers_test.go +++ b/marshalers_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/cilium/ebpf/internal" + qt "github.com/frankban/quicktest" ) @@ -29,9 +30,8 @@ func TestUnmarshalPerCPUValue(t *testing.T) { qt.Assert(t, slice, qt.DeepEquals, expected) smallSlice := make([]uint32, possibleCPUs-1) + qt.Assert(t, unmarshalPerCPUValue(smallSlice, elemLength, buf), qt.IsNotNil) - err = unmarshalPerCPUValue(smallSlice, elemLength, buf) - if err == nil { - t.Fatal("expected error") - } + nilElemSlice := make([]*uint32, possibleCPUs) + qt.Assert(t, unmarshalPerCPUValue(nilElemSlice, elemLength, buf), qt.IsNotNil) }