Skip to content

Commit

Permalink
map: move per-CPU validation into unmarshalPerCPUValue
Browse files Browse the repository at this point in the history
Move validation required for reflect based decoding into
unmarshalPerCPU value. This ensures that the checks are
as close as possible to where they are required.

Signed-off-by: Lorenz Bauer <[email protected]>
  • Loading branch information
lmb committed Nov 17, 2023
1 parent 2816d2f commit eea7482
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 28 deletions.
27 changes: 9 additions & 18 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,34 +684,25 @@ func (m *Map) lookupAndDeletePerCPU(key, valueOut any, flags MapLookupFlags) err
return unmarshalPerCPUValue(slice, int(m.valueSize), valueBytes)
}

// ensurePerCPUSlice allocates a slice for a per-CPU value if necessary.
func ensurePerCPUSlice(sliceOrPtr any, elemLength int) (any, error) {
possibleCPUs, err := PossibleCPU()
if err != nil {
return nil, err
}

sliceOrPtrType := reflect.TypeOf(sliceOrPtr)
if sliceOrPtrType.Kind() == reflect.Slice {
sliceValue := reflect.ValueOf(sliceOrPtr)
if sliceValue.Len() != possibleCPUs {
return nil, fmt.Errorf("per-cpu slice is incorrect length, expected %d, got %d",
possibleCPUs, sliceValue.Len())
}
if sliceOrPtrType.Elem().Kind() == reflect.Pointer {
for i := 0; i < sliceValue.Len(); i++ {
if !sliceValue.Index(i).Elem().CanAddr() {
return nil, fmt.Errorf("per-cpu slice elements cannot be nil")
}
}
}
return sliceValue.Interface(), nil
// The target is a slice, the caller is responsible for ensuring that
// size is correct.
return sliceOrPtr, nil
}

slicePtrType := sliceOrPtrType
if slicePtrType.Kind() != reflect.Ptr || slicePtrType.Elem().Kind() != reflect.Slice {
return nil, fmt.Errorf("per-cpu value requires a slice or a pointer to slice")
}

possibleCPUs, err := PossibleCPU()
if err != nil {
return nil, err
}

sliceType := slicePtrType.Elem()
slice := reflect.MakeSlice(sliceType, possibleCPUs, possibleCPUs)

Expand Down
14 changes: 8 additions & 6 deletions marshalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func marshalPerCPUValue(slice any, elemLength int) (sys.Pointer, error) {
func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error {
sliceType := reflect.TypeOf(slice)
if sliceType.Kind() != reflect.Slice {
return fmt.Errorf("per-cpu value requires a slice")
return fmt.Errorf("per-CPU value requires a slice")
}

possibleCPUs, err := PossibleCPU()
Expand All @@ -97,19 +97,21 @@ func unmarshalPerCPUValue(slice any, elemLength int, buf []byte) error {
}

sliceValue := reflect.ValueOf(slice)
if sliceValue.Len() < possibleCPUs {
// Should be impossible here from ensurePerCPUSlice(),
// but avoid a panic in the loop.
return fmt.Errorf("per-cpu value slice len %d is less than possibleCPU %d",
sliceValue.Len(), possibleCPUs)
if sliceValue.Len() != possibleCPUs {
return fmt.Errorf("per-CPU slice has incorrect length, expected %d, got %d",
possibleCPUs, sliceValue.Len())
}

sliceElemType := sliceType.Elem()
sliceElemIsPointer := sliceElemType.Kind() == reflect.Ptr
stride := internal.Align(elemLength, 8)
for i := 0; i < possibleCPUs; i++ {
var elem any
v := sliceValue.Index(i)
if sliceElemIsPointer {
if !v.Elem().CanAddr() {
return fmt.Errorf("per-CPU slice elements cannot be nil")
}
elem = v.Elem().Addr().Interface()
} else {
elem = v.Addr().Interface()
Expand Down
8 changes: 4 additions & 4 deletions marshalers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"testing"

"github.com/cilium/ebpf/internal"

qt "github.com/frankban/quicktest"
)

Expand All @@ -29,9 +30,8 @@ func TestUnmarshalPerCPUValue(t *testing.T) {
qt.Assert(t, slice, qt.DeepEquals, expected)

smallSlice := make([]uint32, possibleCPUs-1)
qt.Assert(t, unmarshalPerCPUValue(smallSlice, elemLength, buf), qt.IsNotNil)

err = unmarshalPerCPUValue(smallSlice, elemLength, buf)
if err == nil {
t.Fatal("expected error")
}
nilElemSlice := make([]*uint32, possibleCPUs)
qt.Assert(t, unmarshalPerCPUValue(nilElemSlice, elemLength, buf), qt.IsNotNil)
}

0 comments on commit eea7482

Please sign in to comment.