-
Notifications
You must be signed in to change notification settings - Fork 44
/
compose_test.go
96 lines (92 loc) · 2.57 KB
/
compose_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package sqlhooks
import (
"context"
"errors"
"reflect"
"testing"
)
var (
oops = errors.New("oops")
oopsHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, oops
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return oops
},
}
okHook = &testHooks{
before: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
after: func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
},
onError: func(ctx context.Context, err error, query string, args ...interface{}) error {
return nil
},
}
)
func TestCompose(t *testing.T) {
for _, it := range []struct {
name string
hooks Hooks
want error
}{
{"happy case", Compose(okHook, okHook), nil},
{"no hooks", Compose(), nil},
{"multiple errors", Compose(oopsHook, okHook, oopsHook), MultipleErrors([]error{oops, oops})},
{"single error", Compose(okHook, oopsHook, okHook), oops},
} {
t.Run(it.name, func(t *testing.T) {
t.Run("Before", func(t *testing.T) {
_, got := it.hooks.Before(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("After", func(t *testing.T) {
_, got := it.hooks.After(context.Background(), "query")
if !reflect.DeepEqual(it.want, got) {
t.Errorf("unexpected error. want: %q, got: %q", it.want, got)
}
})
t.Run("OnError", func(t *testing.T) {
cause := errors.New("crikey")
want := it.want
if want == nil {
want = cause
}
got := it.hooks.(OnErrorer).OnError(context.Background(), cause, "query")
if !reflect.DeepEqual(want, got) {
t.Errorf("unexpected error. want: %q, got: %q", want, got)
}
})
})
}
}
func TestWrapErrors(t *testing.T) {
var (
err1 = errors.New("oops")
err2 = errors.New("oops2")
)
for _, it := range []struct {
name string
def error
errors []error
want error
}{
{"no errors", err1, nil, err1},
{"single error", nil, []error{err1}, err1},
{"multiple errors", nil, []error{err1, err2}, MultipleErrors([]error{err1, err2})},
} {
t.Run(it.name, func(t *testing.T) {
if want, got := it.want, wrapErrors(it.def, it.errors); !reflect.DeepEqual(want, got) {
t.Errorf("unexpected wrapping. want: %q, got %q", want, got)
}
})
}
}