From 64df5b6fe4406f8ad63e3ddf9b097b1dbd202da3 Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Wed, 29 Nov 2023 16:54:20 +0100 Subject: [PATCH] Update CI and cleanups (#7) --- client_test.go | 386 +++++++++++++----------------------------------- oauth2_test.go | 247 +++++++++++++++++-------------- token_test.go | 151 ++++++++----------- wrapper_test.go | 22 +-- 4 files changed, 305 insertions(+), 501 deletions(-) diff --git a/client_test.go b/client_test.go index 83eb809..588de65 100644 --- a/client_test.go +++ b/client_test.go @@ -13,75 +13,43 @@ import ( func TestExchangeRequest(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() != "/token" { - t.Errorf("Unexpected exchange request URL %q", r.URL) - } + mustEqual(t, r.URL.String(), "/token") headerAuth := r.Header.Get("Authorization") - if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want { - t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want) - } + mustEqual(t, headerAuth, "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=") headerContentType := r.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header %q", headerContentType) - } + mustEqual(t, headerContentType, "application/x-www-form-urlencoded") body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("Failed reading request body: %s.", err) - } - - if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { - t.Errorf("Unexpected exchange payload; got %q", body) - } + mustOk(t, err) + mustEqual(t, string(body), "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL") w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, _ = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) + fmt.Fprint(w, "access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer") }) defer ts.Close() client := newClient(ts.URL) tok, err := client.Exchange(context.Background(), "exchange-code") - if err != nil { - t.Error(err) - } - if !tok.Valid() { - t.Fatalf("Token invalid. Got: %#v", tok) - } - if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) - } - if tok.TokenType != "bearer" { - t.Errorf("Unexpected token type, %#v.", tok.TokenType) - } - scope := tok.Extra("scope") - if scope != "user" { - t.Errorf("Unexpected value for scope: %v", scope) - } + mustOk(t, err) + mustEqual(t, tok.Valid(), true) + mustEqual(t, tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c") + mustEqual(t, tok.TokenType, "bearer") + mustEqual(t, tok.Extra("scope"), "user") } func TestClientExchangeWithParams(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - got := r.Header.Get("Authorization") - - want := "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y=" - if got != want { - t.Errorf("Authorization header = %q; want %q", got, want) - } + headerAuth := r.Header.Get("Authorization") + mustEqual(t, headerAuth, "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y=") body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("Failed reading request body: %s.", err) - } - - want = "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" - if string(body) != want { - t.Errorf("got %v want %v", string(body), want) - } + mustOk(t, err) + mustEqual(t, string(body), "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL") w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, _ = w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) + fmt.Fprint(w, "access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer") }) defer ts.Close() @@ -94,80 +62,64 @@ func TestClientExchangeWithParams(t *testing.T) { TokenURL: ts.URL + "/token", }) - _, err := client.ExchangeWithParams(context.Background(), "exchange-code", url.Values{"foo": {"bar"}}) - if err != nil { - t.Error(err) - } + _, err := client.ExchangeWithParams( + context.Background(), + "exchange-code", + url.Values{"foo": {"bar"}}, + ) + mustOk(t, err) } func TestExchangeRequest_BadResponse(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) + fmt.Fprint(w, `{"scope": "user", "token_type": "bearer"}`) }) defer ts.Close() client := newClient(ts.URL) _, err := client.Exchange(context.Background(), "code") - if err == nil { - t.Error("expected error from missing access_token") - } + mustFail(t, err) } func TestExchangeRequest_BadResponseType(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) + fmt.Fprint(w, `{"access_token":123, "scope": "user", "token_type": "bearer"}`) }) defer ts.Close() client := newClient(ts.URL) _, err := client.Exchange(context.Background(), "exchange-code") - if err == nil { - t.Error("expected error from non-string access_token") - } + mustFail(t, err) } func TestTokenRetrieveError(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() != "/token" { - t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) - } + mustEqual(t, r.URL.String(), "/token") w.Header().Set("Content-type", "application/json") w.WriteHeader(http.StatusBadRequest) - _, _ = w.Write([]byte(`{"error": "invalid_grant"}`)) + fmt.Fprint(w, `{"error": "invalid_grant"}`) }) defer ts.Close() conf := newClient(ts.URL) _, err := conf.Exchange(context.Background(), "exchange-code") - if err == nil { - t.Fatalf("got no error, expected one") - } + mustFail(t, err) expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`) - if errStr := err.Error(); errStr != expected { - t.Fatalf("got %#v, expected %#v", errStr, expected) - } + mustEqual(t, err.Error(), expected) } func TestRetrieveToken_InParams(t *testing.T) { const clientID = "client-id" ts := newServer(func(w http.ResponseWriter, r *http.Request) { - got := r.FormValue("client_id") - want := clientID - if got != want { - t.Errorf("client_id = %q; want %q", got, want) - } + mustEqual(t, r.FormValue("client_id"), clientID) + mustEqual(t, r.FormValue("client_secret"), "") - got = r.FormValue("client_secret") - want = "" - if got != want { - t.Errorf("client_secret = %q; want empty", got) - } w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + fmt.Fprint(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) }) defer ts.Close() @@ -179,9 +131,7 @@ func TestRetrieveToken_InParams(t *testing.T) { }) _, err := client.Exchange(context.Background(), "nil") - if err != nil { - t.Errorf("RetrieveToken = %v; want no error", err) - } + mustOk(t, err) } func TestRetrieveToken_InHeaderMode(t *testing.T) { @@ -190,19 +140,12 @@ func TestRetrieveToken_InHeaderMode(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { user, pass, ok := r.BasicAuth() - if !ok { - t.Error("expected with HTTP Basic Authentication") - } - - if user != clientID { - t.Errorf("client_id = %q; want %q", user, clientID) - } - if pass != clientSecret { - t.Errorf("client_secret = %q; want %q", pass, clientSecret) - } + mustEqual(t, ok, true) + mustEqual(t, user, clientID) + mustEqual(t, pass, clientSecret) w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + fmt.Fprint(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) }) defer ts.Close() @@ -214,9 +157,7 @@ func TestRetrieveToken_InHeaderMode(t *testing.T) { }) _, err := client.Exchange(context.Background(), "nil") - if err != nil { - t.Errorf("RetrieveToken = %v; want no error", err) - } + mustOk(t, err) } func TestRetrieveToken_AutoDetect(t *testing.T) { @@ -224,21 +165,16 @@ func TestRetrieveToken_AutoDetect(t *testing.T) { const clientSecret = "client-secret" ts := newServer(func(w http.ResponseWriter, r *http.Request) { - got := r.FormValue("client_id") - want := clientID - if got != want { + if r.FormValue("client_id") != clientID { w.WriteHeader(http.StatusInternalServerError) - _, _ = io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + fmt.Fprint(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) return } - got = r.FormValue("client_secret") - want = clientSecret - if got != want { - t.Errorf("client_secret = %q; want empty", got) - } + mustEqual(t, r.FormValue("client_secret"), clientSecret) + w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + fmt.Fprint(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) }) defer ts.Close() @@ -250,153 +186,85 @@ func TestRetrieveToken_AutoDetect(t *testing.T) { }) _, err := client.Exchange(context.Background(), "test") - if err != nil { - t.Errorf("RetrieveToken = %v; want no error", err) - } + mustOk(t, err) } func TestExchangeRequest_WithParams(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() != "/token" { - t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) - } + mustEqual(t, r.URL.String(), "/token") headerAuth := r.Header.Get("Authorization") - if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) - } + mustEqual(t, headerAuth, "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=") headerContentType := r.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) - } + mustEqual(t, headerContentType, "application/x-www-form-urlencoded") body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("Failed reading request body: %s.", err) - } - if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { - t.Errorf("Unexpected exchange payload, %v is found.", string(body)) - } + mustOk(t, err) + mustEqual(t, string(body), "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL") + w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, _ = w.Write([]byte("access_token=ProperToken&scope=user&token_type=bearer")) + fmt.Fprint(w, "access_token=ProperToken&scope=user&token_type=bearer") }) defer ts.Close() client := newClient(ts.URL) tok, err := client.ExchangeWithParams(context.Background(), "exchange-code", url.Values{"foo": {"bar"}}) - if err != nil { - t.Error(err) - } - if !tok.Valid() { - t.Fatalf("Token invalid. Got: %#v", tok) - } - if tok.AccessToken != "ProperToken" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) - } - if tok.TokenType != "bearer" { - t.Errorf("Unexpected token type, %#v.", tok.TokenType) - } - scope := tok.Extra("scope") - if scope != "user" { - t.Errorf("Unexpected value for scope: %v", scope) - } + mustOk(t, err) + mustEqual(t, tok.Valid(), true) + mustEqual(t, tok.AccessToken, "ProperToken") + mustEqual(t, tok.TokenType, "bearer") + mustEqual(t, tok.Extra("scope"), "user") } func TestExchangeRequest_JSONResponse(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - if r.URL.String() != "/token" { - t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) - } + mustEqual(t, r.URL.String(), "/token") headerAuth := r.Header.Get("Authorization") - if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { - t.Errorf("Unexpected authorization header, %v is found.", headerAuth) - } + mustEqual(t, headerAuth, "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=") headerContentType := r.Header.Get("Content-Type") - if headerContentType != "application/x-www-form-urlencoded" { - t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) - } + mustEqual(t, headerContentType, "application/x-www-form-urlencoded") body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("Failed reading request body: %s.", err) - } - - if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { - t.Errorf("Unexpected exchange payload, %v is found.", string(body)) - } + mustOk(t, err) + mustEqual(t, string(body), "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL") w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"access_token": "ProperToken", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) + fmt.Fprint(w, `{"access_token": "ProperToken", "scope": "user", "token_type": "bearer", "expires_in": 86400}`) }) defer ts.Close() client := newClient(ts.URL) tok, err := client.Exchange(context.Background(), "exchange-code") - if err != nil { - t.Error(err) - } - - if !tok.Valid() { - t.Fatalf("Token invalid. Got: %#v", tok) - } - - if tok.AccessToken != "ProperToken" { - t.Errorf("Unexpected access token, %#v.", tok.AccessToken) - } - - if tok.TokenType != "bearer" { - t.Errorf("Unexpected token type, %#v.", tok.TokenType) - } - - scope := tok.Extra("scope") - if scope != "user" { - t.Errorf("Unexpected value for scope: %v", scope) - } - - expiresIn := tok.Extra("expires_in") - if expiresIn != float64(86400) { - t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) - } + mustOk(t, err) + mustEqual(t, tok.Valid(), true) + mustEqual(t, tok.AccessToken, "ProperToken") + mustEqual(t, tok.TokenType, "bearer") + mustEqual(t, tok.Extra("scope"), "user") + mustEqual(t, tok.Extra("expires_in").(float64), float64(86400)) } func TestExchangeRequest_JSONResponse_Expiry(t *testing.T) { - seconds := int32((24 * time.Hour).Seconds()) - - f := func(expires string, want, nullExpires bool) { - t.Helper() - - testExchangeRequestJSONResponseExpiry(t, expires, want, nullExpires) + testCases := []struct { + expires string + want bool + nullExpires bool + }{ + {`"expires_in": 86400`, true, false}, + {`"expires_in": "86400"`, true, false}, + {`"expires_in": null`, true, true}, + {`"expires_in": false`, false, false}, + {`"expires_in": {}`, false, false}, + {`"expires_in": "zzz"`, false, false}, + } + + for _, tc := range testCases { + testExchangeRequestJSONResponseExpiry(t, tc.expires, tc.want, tc.nullExpires) } - - f( - fmt.Sprintf(`"expires_in": %d`, seconds), - true, false, - ) - f( - fmt.Sprintf(`"expires_in": "%d"`, seconds), - true, false, - ) - f( - `"expires_in": null`, - true, true, - ) - f( - `"expires_in": false`, - false, false, - ) - f( - `"expires_in": {}`, - false, false, - ) - f( - `"expires_in": "zzz"`, - false, false, - ) } func testExchangeRequestJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) { @@ -422,9 +290,7 @@ func testExchangeRequestJSONResponseExpiry(t *testing.T, exp string, want, nullE if !want { return } - if !tok.Valid() { - t.Fatalf("Token invalid. Got: %#v", tok) - } + mustEqual(t, tok.Valid(), true) expiry := tok.Expiry if nullExpires && expiry.IsZero() { @@ -437,57 +303,29 @@ func testExchangeRequestJSONResponseExpiry(t *testing.T, exp string, want, nullE func TestPasswordCredentialsTokenRequest(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - expected := "/token" - if r.URL.String() != expected { - t.Errorf("URL = %q; want %q", r.URL, expected) - } + mustEqual(t, r.URL.String(), "/token") headerAuth := r.Header.Get("Authorization") - expected = "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" - if headerAuth != expected { - t.Errorf("Authorization header = %q; want %q", headerAuth, expected) - } + mustEqual(t, headerAuth, "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=") headerContentType := r.Header.Get("Content-Type") - expected = "application/x-www-form-urlencoded" - if headerContentType != expected { - t.Errorf("Content-Type header = %q; want %q", headerContentType, expected) - } + mustEqual(t, headerContentType, "application/x-www-form-urlencoded") body, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("Failed reading request body: %s.", err) - } - - expected = "grant_type=password&password=password1&scope=scope1+scope2&username=user1" - if string(body) != expected { - t.Errorf("res.Body = %q; want %q", string(body), expected) - } + mustOk(t, err) + mustEqual(t, string(body), "grant_type=password&password=password1&scope=scope1+scope2&username=user1") w.Header().Set("Content-Type", "application/x-www-form-urlencoded") - _, _ = w.Write([]byte("access_token=ProperToken&scope=user&token_type=bearer")) + fmt.Fprint(w, "access_token=ProperToken&scope=user&token_type=bearer") }) defer ts.Close() client := newClient(ts.URL) tok, err := client.CredentialsToken(context.Background(), "user1", "password1") - if err != nil { - t.Error(err) - } - if !tok.Valid() { - t.Fatalf("Token invalid. Got: %#v", tok) - } - - expected := "ProperToken" - if tok.AccessToken != expected { - t.Errorf("AccessToken = %q; want %q", tok.AccessToken, expected) - } - - expected = "bearer" - if tok.TokenType != expected { - t.Errorf("TokenType = %q; want %q", tok.TokenType, expected) - } + mustOk(t, err) + mustEqual(t, tok.Valid(), true) + mustEqual(t, tok.AccessToken, "ProperToken") + mustEqual(t, tok.TokenType, "bearer") } // func TestTokenRefreshRequest(t *testing.T) { @@ -495,9 +333,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { // if r.URL.String() == "/somethingelse" { // return // } -// if r.URL.String() != "/token" { -// t.Errorf("Unexpected token refresh request URL %q", r.URL) -// } +// mustEqual(t, r.URL.String(), "/token") // headerContentType := r.Header.Get("Content-Type") // if headerContentType != "application/x-www-form-urlencoded" { // t.Errorf("Unexpected Content-Type header %q", headerContentType) @@ -520,9 +356,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { // if r.URL.String() == "/somethingelse" { // return // } -// if r.URL.String() != "/token" { -// t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL) -// } +// mustEqual(t, r.URL.String(), "/token") // headerContentType := r.Header.Get("Content-Type") // if headerContentType != "application/x-www-form-urlencoded" { // t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) @@ -552,9 +386,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { // conf := newConf(ts.URL) // tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: "OLD_REFRESH_TOKEN"}) // tk, err := tkr.Token() -// if err != nil { -// t.Errorf("got err = %v; want none", err) -// return +// mustOk(t, err) // } // if want := "NEW_REFRESH_TOKEN"; tk.RefreshToken != want { // t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, want) @@ -572,9 +404,7 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { // const oldRefreshToken = "OLD_REFRESH_TOKEN" // tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: oldRefreshToken}) // tk, err := tkr.Token() -// if err != nil { -// t.Fatalf("got err = %v; want none", err) -// } +// mustOk(t, err) // if tk.RefreshToken != oldRefreshToken { // t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, oldRefreshToken) // } @@ -595,13 +425,9 @@ func TestPasswordCredentialsTokenRequest(t *testing.T) { // c := conf.Client(context.Background(), tok) // req, err := http.NewRequest("GET", ts.URL, nil) -// if err != nil { -// t.Error(err) -// } +// mustOk(t, err) // _, err = c.Do(req) -// if err != nil { -// t.Error(err) -// } +// mustOk(t, err) // } func TestRetrieveTokenWithContexts(t *testing.T) { @@ -609,7 +435,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { ts := newServer(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - _, _ = io.WriteString(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) + fmt.Fprint(w, `{"access_token": "ACCESS_TOKEN", "token_type": "bearer"}`) }) defer ts.Close() @@ -620,9 +446,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { Mode: AutoDetectMode, }) _, err := client.retrieveToken(context.Background(), url.Values{}) - if err != nil { - t.Errorf("RetrieveToken (with background context) = %v; want no error", err) - } + mustOk(t, err) retrieved := make(chan struct{}) cancellingts := newServer(func(w http.ResponseWriter, r *http.Request) { @@ -641,9 +465,7 @@ func TestRetrieveTokenWithContexts(t *testing.T) { cancel() _, err = client.retrieveToken(ctx, url.Values{}) close(retrieved) - if err == nil { - t.Errorf("RetrieveToken (with cancelled context) = nil; want error") - } + mustFail(t, err) } func newClient(url string) *Client { diff --git a/oauth2_test.go b/oauth2_test.go index d5967e2..b7bb510 100644 --- a/oauth2_test.go +++ b/oauth2_test.go @@ -3,132 +3,155 @@ package oauth2 import ( "net/http" "net/url" + "reflect" "testing" ) func TestAuthCodeURL(t *testing.T) { - f := func(cfg Config, state string, want string) { - client := NewClient(http.DefaultClient, cfg) - - url := client.AuthCodeURL(state) - if url != want { - t.Errorf("got %q; want %q", url, want) - } - } - - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: nil, - AuthURL: "server:1234/auth", - TokenURL: "", + testCases := []struct { + cfg Config + state string + want string + }{ + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: nil, + AuthURL: "server:1234/auth", + TokenURL: "", + }, + "test-state", + `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, }, - "test-state", - `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, - ) - - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: nil, - AuthURL: "server:1234/auth?foo=bar", - TokenURL: "", + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: nil, + AuthURL: "server:1234/auth?foo=bar", + TokenURL: "", + }, + "test-state", + `server:1234/auth?foo=bar&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, }, - "test-state", - `server:1234/auth?foo=bar&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, - ) - - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: []string{"scope1", "scope2"}, - AuthURL: "server:1234/auth", - TokenURL: "", + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: []string{"scope1", "scope2"}, + AuthURL: "server:1234/auth", + TokenURL: "", + }, + "test-state", + `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=test-state`, }, - "test-state", - `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=test-state`, - ) - - f( - Config{ - ClientID: "CLIENT_ID", - AuthURL: "server:1234/auth-url", - TokenURL: "", + { + Config{ + ClientID: "CLIENT_ID", + AuthURL: "server:1234/auth-url", + TokenURL: "", + }, + "", + `server:1234/auth-url?client_id=CLIENT_ID&response_type=code`, }, - "", - `server:1234/auth-url?client_id=CLIENT_ID&response_type=code`, - ) -} - -func TestAuthCodeURLWithParams(t *testing.T) { - f := func(cfg Config, state string, params url.Values, want string) { - client := NewClient(http.DefaultClient, cfg) + } - url := client.AuthCodeURLWithParams(state, params) - if url != want { - t.Errorf("got %q; want %q", url, want) - } + for _, tc := range testCases { + client := NewClient(http.DefaultClient, tc.cfg) + url := client.AuthCodeURL(tc.state) + mustEqual(t, url, tc.want) } +} - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: nil, - AuthURL: "server:1234/auth", - TokenURL: "", +func TestAuthCodeURLWithParams(t *testing.T) { + testCases := []struct { + cfg Config + state string + params url.Values + want string + }{ + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: nil, + AuthURL: "server:1234/auth", + TokenURL: "", + }, + "test-state", + nil, + `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, }, - "test-state", - nil, - `server:1234/auth?client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, - ) - - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: nil, - AuthURL: "server:1234/auth?foo=bar", - TokenURL: "", + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: nil, + AuthURL: "server:1234/auth?foo=bar", + TokenURL: "", + }, + "test-state", + nil, + `server:1234/auth?foo=bar&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, }, - "test-state", - nil, - `server:1234/auth?foo=bar&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&state=test-state`, - ) - - f( - Config{ - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - RedirectURL: "REDIRECT_URL", - Scopes: []string{"scope1", "scope2"}, - AuthURL: "server:1234/auth", - TokenURL: "", + { + Config{ + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: "REDIRECT_URL", + Scopes: []string{"scope1", "scope2"}, + AuthURL: "server:1234/auth", + TokenURL: "", + }, + "test-state", + url.Values{ + "access_type": []string{"anything"}, + "param1": []string{"value1"}, + }, + `server:1234/auth?access_type=anything&client_id=CLIENT_ID¶m1=value1&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=test-state`, }, - "test-state", - url.Values{ - "access_type": []string{"anything"}, - "param1": []string{"value1"}, + { + Config{ + ClientID: "CLIENT_ID", + AuthURL: "server:1234/auth-url", + TokenURL: "", + }, + "", + nil, + `server:1234/auth-url?client_id=CLIENT_ID&response_type=code`, }, - `server:1234/auth?access_type=anything&client_id=CLIENT_ID¶m1=value1&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=test-state`, - ) + } - f( - Config{ - ClientID: "CLIENT_ID", - AuthURL: "server:1234/auth-url", - TokenURL: "", - }, - "", - nil, - `server:1234/auth-url?client_id=CLIENT_ID&response_type=code`, - ) + for _, tc := range testCases { + client := NewClient(http.DefaultClient, tc.cfg) + url := client.AuthCodeURLWithParams(tc.state, tc.params) + mustEqual(t, url, tc.want) + } +} + +func mustOk(tb testing.TB, err error) { + tb.Helper() + if err != nil { + tb.Fatal(err) + } +} + +func mustFail(tb testing.TB, err error) { + tb.Helper() + if err == nil { + tb.Fatal() + } +} + +func mustEqual[T any](tb testing.TB, have, want T) { + tb.Helper() + if !reflect.DeepEqual(have, want) { + tb.Fatalf("\nhave: %+v\nwant: %+v\n", have, want) + } } diff --git a/token_test.go b/token_test.go index 9b30c33..818be8a 100644 --- a/token_test.go +++ b/token_test.go @@ -7,99 +7,66 @@ import ( ) func TestTokenTypeMethod(t *testing.T) { - f := func(token *Token, want string) { - t.Helper() - - got := token.Type() - if got != want { - t.Errorf("got %v; want %v", got, want) - } + testCases := []struct { + token *Token + want string + }{ + {&Token{}, "Bearer"}, + {&Token{TokenType: "beAREr"}, "Bearer"}, + {&Token{TokenType: "beAREr"}, "Bearer"}, + {&Token{TokenType: "basic"}, "Basic"}, + {&Token{TokenType: "Basic"}, "Basic"}, + {&Token{TokenType: "mac"}, "MAC"}, + {&Token{TokenType: "MAC"}, "MAC"}, + {&Token{TokenType: "mAc"}, "MAC"}, + {&Token{TokenType: "unknown"}, "unknown"}, } - f( - &Token{}, "Bearer", - ) - f( - &Token{TokenType: "beAREr"}, "Bearer", - ) - f( - &Token{TokenType: "beAREr"}, "Bearer", - ) - f( - &Token{TokenType: "basic"}, "Basic", - ) - f( - &Token{TokenType: "Basic"}, "Basic", - ) - f( - &Token{TokenType: "mac"}, "MAC", - ) - f( - &Token{TokenType: "MAC"}, "MAC", - ) - f( - &Token{TokenType: "mAc"}, "MAC", - ) - f( - &Token{TokenType: "unknown"}, "unknown", - ) + for _, tc := range testCases { + mustEqual(t, tc.token.Type(), tc.want) + } } func TestTokenExtra(t *testing.T) { const wantKey = "extra-key" - f := func(key string, value, want interface{}) { - t.Helper() - - extra := map[string]interface{}{ - key: value, - } - token := &Token{ - Raw: extra, - } - - got := token.Extra(wantKey) - if got != want { - t.Errorf("Extra(%q) = %q; want %q", key, got, want) - } + testCases := []struct { + key string + value any + want any + }{ + {wantKey, "abc", "abc"}, + {wantKey, 123, 123}, + {wantKey, "", ""}, + {"other-key", "def", nil}, } - f("extra-key", "abc", "abc") - f("extra-key", 123, 123) - f("extra-key", "", "") - f("other-key", "def", nil) + for _, tc := range testCases { + token := &Token{Raw: map[string]any{ + tc.key: tc.value, + }} + mustEqual(t, token.Extra(wantKey), tc.want) + } } func TestTokenExpiry(t *testing.T) { now := time.Now() timeNow = func() time.Time { return now } - defer func() { timeNow = time.Now }() - - f := func(token *Token, want bool) { - t.Helper() - - got := token.IsExpired() - if got != want { - t.Errorf("got %v; want %v", got, want) - } + t.Cleanup(func() { timeNow = time.Now }) + + testCases := []struct { + token *Token + want bool + }{ + {&Token{Expiry: now.Add(12 * time.Second)}, false}, + {&Token{Expiry: now.Add(expiryDelta)}, false}, + {&Token{Expiry: now.Add(expiryDelta - 1*time.Nanosecond)}, true}, + {&Token{Expiry: now.Add(-1 * time.Hour)}, true}, } - f( - &Token{Expiry: now.Add(12 * time.Second)}, - false, - ) - f( - &Token{Expiry: now.Add(expiryDelta)}, - false, - ) - f( - &Token{Expiry: now.Add(expiryDelta - 1*time.Nanosecond)}, - true, - ) - f( - &Token{Expiry: now.Add(-1 * time.Hour)}, - true, - ) + for _, tc := range testCases { + mustEqual(t, tc.token.IsExpired(), tc.want) + } } func TestExtraValueRetrieval(t *testing.T) { @@ -109,30 +76,30 @@ func TestExtraValueRetrieval(t *testing.T) { "expires_in": "86400.92", "server_time": "1443571905.5606415", "referer_ip": "10.0.0.1", - "etag": "\"afZYj912P4alikMz_P11982\"", + "etag": `"afZYj912P4alikMz_P11982"`, "request_id": "86400", "untrimmed": " untrimmed ", } + values := url.Values{} for key, value := range kvmap { values.Set(key, value) } - tok := Token{Raw: values} - f := func(key string, want interface{}) { - t.Helper() - - value := tok.Extra(key) - if value != want { - t.Errorf("got %q; want %q", value, want) - } + testCases := []struct { + key string + value any + }{ + {"scope", "user"}, + {"server_time", 1443571905.5606415}, + {"referer_ip", "10.0.0.1"}, + {"expires_in", 86400.92}, + {"request_id", int64(86400)}, + {"untrimmed", " untrimmed "}, } - f("scope", "user") - f("server_time", 1443571905.5606415) - f("referer_ip", "10.0.0.1") - f("expires_in", 86400.92) - f("request_id", int64(86400)) - f("untrimmed", " untrimmed ") + for _, tc := range testCases { + mustEqual(t, tok.Extra(tc.key), tc.value) + } } diff --git a/wrapper_test.go b/wrapper_test.go index 7e514ac..ab529b4 100644 --- a/wrapper_test.go +++ b/wrapper_test.go @@ -7,28 +7,20 @@ import ( ) func TestWrap(t *testing.T) { - apikey := "Test-Api-Key-123" + const apikey = "Test-Api-Key-123" ts := newServer(func(w http.ResponseWriter, r *http.Request) { - got := r.Header.Get("Authorization") - want := apikey - if got != want { - t.Errorf("Authorization header = %q; want %q", got, want) - } + mustEqual(t, r.Header.Get("Authorization"), apikey) + w.WriteHeader(http.StatusOK) }) defer ts.Close() c := &http.Client{Timeout: 5 * time.Second} wc, err := Wrap("Authorization", apikey, c) - if err != nil { - t.Fatal(err) - } + mustOk(t, err) + resp, err := wc.Get(ts.URL) - if err != nil { - t.Fatal(err) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("status code %#v, want %#v", resp.StatusCode, http.StatusOK) - } + mustOk(t, err) + mustEqual(t, resp.StatusCode, http.StatusOK) }