Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support iter.Seqs in [Not]Contains and [Not]ElementsMatch #1685

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions assert/assertion_format.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions assert/assertion_forward.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 14 additions & 5 deletions assert/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}

same, ok := samePointers(expected, actual)
if !ok {
//fails when the arguments are not pointers
// fails when the arguments are not pointers
return !(Fail(t, "Both arguments must be pointers", msgAndArgs...))
}

Expand All @@ -549,7 +549,7 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
func samePointers(first, second interface{}) (same bool, ok bool) {
firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second)
if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr {
return false, false //not both are pointers
return false, false // not both are pointers
}

firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second)
Expand Down Expand Up @@ -918,7 +918,7 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {

}

// Contains asserts that the specified string, list(array, slice...) or map contains the
// Contains asserts that the specified string, list(array, slice, sequence...) or map contains the
// specified substring or element.
//
// assert.Contains(t, "Hello World", "World")
Expand All @@ -929,6 +929,7 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
h.Helper()
}

s = seqToSlice(s)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
Expand All @@ -952,6 +953,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
h.Helper()
}

s = seqToSlice(s)
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
Expand Down Expand Up @@ -1088,6 +1090,10 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
if h, ok := t.(tHelper); ok {
h.Helper()
}
// Convert sequences to lists, if applicable
listA = seqToSlice(listA)
listB = seqToSlice(listB)

if isEmpty(listA) && isEmpty(listB) {
return true
}
Expand Down Expand Up @@ -1175,8 +1181,8 @@ func formatListDiff(listA, listB interface{}, extraA, extraB []interface{}) stri
return msg.String()
}

// NotElementsMatch asserts that the specified listA(array, slice...) is NOT equal to specified
// listB(array, slice...) ignoring the order of the elements. If there are duplicate elements,
// NotElementsMatch asserts that the specified listA(array, slice, sequence...) is NOT equal to specified
// listB(array, slice, sequence...) ignoring the order of the elements. If there are duplicate elements,
// the number of appearances of each of them in both lists should not match.
// This is an inverse of ElementsMatch.
//
Expand All @@ -1189,6 +1195,9 @@ func NotElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interf
if h, ok := t.(tHelper); ok {
h.Helper()
}
// Convert sequences to lists, if applicable
listA = seqToSlice(listA)
listB = seqToSlice(listB)
if isEmpty(listA) && isEmpty(listB) {
return Fail(t, "listA and listB contain the same elements", msgAndArgs)
}
Expand Down
179 changes: 179 additions & 0 deletions assert/assertions_seq_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
//go:build go1.23 || goexperiment.rangefunc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a reason for this guard (see my other comment in seq_supported.go).


package assert

import (
"fmt"
"testing"
)

// go.mod version is set to 1.17, which precludes the use of generics (even though this file wouldn't be taken into
// account per the build tags).

func intSeq(s ...int) func(yield func(int) bool) {
return func(yield func(int) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

func strSeq(s ...string) func(yield func(string) bool) {
return func(yield func(string) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

func TestElementsMatch_Seq(t *testing.T) {
mockT := new(testing.T)

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{intSeq(), intSeq(), true},
{intSeq(1), intSeq(1), true},
{intSeq(1, 1), intSeq(1, 1), true},
{intSeq(1, 2), intSeq(1, 2), true},
{intSeq(1, 2), intSeq(2, 1), true},
{strSeq("hello", "world"), strSeq("world", "hello"), true},
{strSeq("hello", "hello"), strSeq("hello", "hello"), true},
{strSeq("hello", "hello", "world"), strSeq("hello", "world", "hello"), true},
{intSeq(), nil, true},

// not matching
{intSeq(1), intSeq(1, 1), false},
{intSeq(1, 2), intSeq(2, 2), false},
{strSeq("hello", "hello"), strSeq("hello"), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("ElementsMatch(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
res := ElementsMatch(mockT, c.actual, c.expected)

if res != c.result {
t.Errorf("ElementsMatch(%#v, %#v) should return %v", seqToSlice(c.actual), seqToSlice(c.expected), c.result)
}
})
}
}

func TestNotElementsMatch_Seq(t *testing.T) {
mockT := new(testing.T)

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
// not matching
{intSeq(1), intSeq(), true},
{intSeq(), intSeq(2), true},
{intSeq(1), intSeq(2), true},
{intSeq(1), intSeq(1, 1), true},
{intSeq(1, 2), intSeq(3, 4), true},
{intSeq(3, 4), intSeq(1, 2), true},
{intSeq(1, 1, 2, 3), intSeq(1, 2, 3), true},
{strSeq("hello"), strSeq("world"), true},
{strSeq("hello", "hello"), strSeq("world", "world"), true},

// matching
{intSeq(), nil, false},
{intSeq(), intSeq(), false},
{intSeq(1), intSeq(1), false},
{intSeq(1, 1), intSeq(1, 1), false},
{intSeq(1, 2), intSeq(2, 1), false},
{intSeq(1, 1, 2), intSeq(1, 2, 1), false},
{strSeq("hello", "world"), strSeq("world", "hello"), false},
{strSeq("hello", "hello"), strSeq("hello", "hello"), false},
{strSeq("hello", "hello", "world"), strSeq("hello", "world", "hello"), false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("NotElementsMatch(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
res := NotElementsMatch(mockT, c.actual, c.expected)

if res != c.result {
t.Errorf("NotElementsMatch(%#v, %#v) should return %v", seqToSlice(c.actual), seqToSlice(c.expected), c.result)
}
})
}
}

func TestContainsNotContains_Seq(t *testing.T) {

type A struct {
Name, Value string
}
complexSeq := func(s ...*A) func(yield func(*A) bool) {
return func(yield func(*A) bool) {
for _, elem := range s {
if !yield(elem) {
break
}
}
}
}

list := []string{"Foo", "Bar"}

complexList := []*A{
{"b", "c"},
{"d", "e"},
{"g", "h"},
{"j", "k"},
}

cases := []struct {
expected interface{}
actual interface{}
result bool
}{
{strSeq(list...), "Bar", true},
{strSeq(list...), "Salut", false},
{complexSeq(complexList...), &A{"g", "h"}, true},
{complexSeq(complexList...), &A{"g", "e"}, false},
}

for _, c := range cases {
t.Run(fmt.Sprintf("Contains(%#v, %#v)", seqToSlice(c.expected), seqToSlice(c.actual)), func(t *testing.T) {
mockT := new(testing.T)
res := Contains(mockT, c.expected, c.actual)

if res != c.result {
if res {
t.Errorf(
"Contains(%#v, %#v) should return true:\n\t%#v contains %#v",
seqToSlice(c.expected), seqToSlice(c.actual), seqToSlice(c.expected), seqToSlice(c.actual))
} else {
t.Errorf(
"Contains(%#v, %#v) should return false:\n\t%#v does not contain %#v",
seqToSlice(c.expected), seqToSlice(c.actual), seqToSlice(c.expected), seqToSlice(c.actual))
}
}
})
}

for _, c := range cases {
t.Run(fmt.Sprintf("NotContains(%#v, %#v)", c.expected, c.actual), func(t *testing.T) {
mockT := new(testing.T)
res := NotContains(mockT, c.expected, c.actual)

// NotContains should be inverse of Contains. If it's not, something is wrong
if res == Contains(mockT, c.expected, c.actual) {
if res {
t.Errorf("NotContains(%#v, %#v) should return true:\n\t%#v does not contains %#v", c.expected, c.actual, c.expected, c.actual)
} else {
t.Errorf("NotContains(%#v, %#v) should return false:\n\t%#v contains %#v", c.expected, c.actual, c.expected, c.actual)
}
}
})
}
}
44 changes: 44 additions & 0 deletions assert/seq_supported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
//go:build go1.23 || goexperiment.rangefunc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This compile guard doesn't seem necessary. There is nothing in seqToSlice implementation that requires Go 1.23. Sequences functions can be implemented in Go below 1.23. Go 1.23 only adds syntactic sugar in for loops.

Copy link
Author

@misberner misberner Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair. My thinking was that w/o range over func, there is no such thing as a "sequence" in Go and thus no basis for semantically treating a function of that form as a sequence (I am not aware of any use of the yield func pattern in Go outside of/before range-over-func). But I agree this is pedantry, so fine to change it in the name of simplification.


package assert

import "reflect"

var (
boolType = reflect.TypeOf(true)
)

// seqToSlice checks if x is a sequence, and converts it to a slice of the
// same element type. Otherwise, x is returned as-is.
func seqToSlice(x interface{}) interface{} {
misberner marked this conversation as resolved.
Show resolved Hide resolved
if x == nil {
return nil
}

xv := reflect.ValueOf(x)
xt := xv.Type()
// We're looking for a function with exactly one input parameter and no return values.
if xt.Kind() != reflect.Func || xt.NumIn() != 1 || xt.NumOut() != 0 {
return x
}

// The input parameter should be of type func(T) bool
paramType := xt.In(0)
if paramType.Kind() != reflect.Func || paramType.NumIn() != 1 || paramType.NumOut() != 1 || paramType.Out(0) != boolType {
return x
}

elemType := paramType.In(0)
resultType := reflect.SliceOf(elemType)
result := reflect.MakeSlice(resultType, 0, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to allocate a zero element slice, as it will be discarded on the first call to append.
Instead a nil slice is just enough:

Suggested change
result := reflect.MakeSlice(resultType, 0, 0)
result := reflect.New(resultType).Elem()

See https://go.dev/play/p/_9VZP__CIS8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed not necessary, thanks for pointing this out. Though I'll go with this to save a new:

Suggested change
result := reflect.MakeSlice(resultType, 0, 0)
result := reflect.Zero(resultType)


yieldFunc := reflect.MakeFunc(paramType, func(args []reflect.Value) []reflect.Value {
result = reflect.Append(result, args[0])
return []reflect.Value{reflect.ValueOf(true)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocate []reflect.Value{reflect.ValueOf(true)} out of the yieldFunc to allow reuse.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do.

})

// Call the function with the yield function as the argument
xv.Call([]reflect.Value{yieldFunc})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extracts only the first element of the sequence. From the description of the function I expect the whole sequence to be serialized into the slice using a loop.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, as otherwise obviously tests wouldn't be passing. The iteration itself happens in the function wrapped in xv. A for loop ranging over this function is syntactic sugar/magic for treating this single call as a coroutine, switching control between the function body and the loop head/body. If we call this function w/o the for/coroutine magic, it will just do a full iteration, calling yieldFunc for every element.


return result.Interface()
}
9 changes: 9 additions & 0 deletions assert/seq_unsupported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
//go:build !go1.23 && !goexperiment.rangefunc

package assert

// seqToSlice would convert a sequence of elements to a slice of the respective type.
// However, since sequences are not supported given the build tags, it just returns x as-is.
func seqToSlice(x interface{}) interface{} {
return x
}
Loading
Loading