Skip to content

Commit

Permalink
move validation of pre auth key out of db
Browse files Browse the repository at this point in the history
This move separates the logic a bit and allow us to
write specific errors for the caller, in this case the web
layer so we can present the user with the correct error
codes without bleeding web stuff into a generic validate.

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Feb 1, 2025
1 parent e76aaa9 commit 57214cb
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 180 deletions.
35 changes: 31 additions & 4 deletions hscontrol/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,42 @@ func (h *Headscale) waitForFollowup(
}
}

// canUsePreAuthKey checks if a pre auth key can be used.
func canUsePreAuthKey(pak *types.PreAuthKey) error {
if pak == nil {
return NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil)
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return NewHTTPError(http.StatusUnauthorized, "authkey expired", nil)
}

// we don't need to check if has been used before
if pak.Reusable {
return nil
}

if pak.Used {
return NewHTTPError(http.StatusUnauthorized, "authkey already used", nil)
}

return nil
}

func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
// TODO(kradalby) Refactor and get the validate away from the database
// so we can return nice http errors.
pak, err := h.db.ValidatePreAuthKey(regReq.Auth.AuthKey)
pak, err := h.db.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, fmt.Errorf("invalid pre auth key: %w", err)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
return nil, err
}

err = canUsePreAuthKey(pak)
if err != nil {
return nil, err
}

nodeToRegister := types.Node{
Expand Down
130 changes: 130 additions & 0 deletions hscontrol/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package hscontrol

import (
"net/http"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
)

func TestCanUsePreAuthKey(t *testing.T) {
now := time.Now()
past := now.Add(-time.Hour)
future := now.Add(time.Hour)

tests := []struct {
name string
pak *types.PreAuthKey
wantErr bool
err HTTPError
}{
{
name: "valid reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "valid non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &future,
},
wantErr: false,
},
{
name: "expired key",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "used non-reusable key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &future,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
{
name: "used reusable key",
pak: &types.PreAuthKey{
Reusable: true,
Used: true,
Expiration: &future,
},
wantErr: false,
},
{
name: "no expiration date",
pak: &types.PreAuthKey{
Reusable: false,
Used: false,
Expiration: nil,
},
wantErr: false,
},
{
name: "nil preauth key",
pak: nil,
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "invalid authkey", nil),
},
{
name: "expired and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: &past,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey expired", nil),
},
{
name: "no expiration and used key",
pak: &types.PreAuthKey{
Reusable: false,
Used: true,
Expiration: nil,
},
wantErr: true,
err: NewHTTPError(http.StatusUnauthorized, "authkey already used", nil),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := canUsePreAuthKey(tt.pak)
if tt.wantErr {
if err == nil {
t.Errorf("expected error but got none")
} else {
httpErr, ok := err.(HTTPError)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {
if diff := cmp.Diff(tt.err, httpErr); diff != "" {
t.Errorf("unexpected error (-want +got):\n%s", diff)
}
}
}
} else {
if err != nil {
t.Errorf("expected no error but got %v", err)
}
}
})
}
}
73 changes: 18 additions & 55 deletions hscontrol/db/preauth_keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (

"github.com/juanfont/headscale/hscontrol/types"
"gorm.io/gorm"
"tailscale.com/types/ptr"
"tailscale.com/util/set"
)

Expand Down Expand Up @@ -64,6 +63,7 @@ func CreatePreAuthKey(
}

now := time.Now().UTC()
// TODO(kradalby): unify the key generations spread all over the code.
kstr, err := generateKey()
if err != nil {
return nil, err
Expand Down Expand Up @@ -108,18 +108,21 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
return keys, nil
}

// GetPreAuthKey returns a PreAuthKey for a given key.
func GetPreAuthKey(tx *gorm.DB, user string, key string) (*types.PreAuthKey, error) {
pak, err := ValidatePreAuthKey(tx, key)
if err != nil {
return nil, err
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
}

if pak.User.Name != user {
return nil, ErrUserMismatch
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if err := tx.Preload("User").First(&pak, "key = ?", key).Error; err != nil {
return nil, ErrPreAuthKeyNotFound
}

return pak, nil
return &pak, nil
}

// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
Expand All @@ -140,15 +143,6 @@ func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
})
}

// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}

return nil
}

// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
k.Used = true
Expand All @@ -159,44 +153,13 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
return nil
}

func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return ValidatePreAuthKey(rx, k)
})
}

// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func ValidatePreAuthKey(tx *gorm.DB, k string) (*types.PreAuthKey, error) {
pak := types.PreAuthKey{}
if result := tx.Preload("User").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrPreAuthKeyNotFound
}

if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, ErrPreAuthKeyExpired
}

if pak.Reusable { // we don't need to check if has been used before
return &pak, nil
}

nodes := types.Nodes{}
if err := tx.
Preload("AuthKey").
Where(&types.Node{AuthKeyID: ptr.To(pak.ID)}).
Find(&nodes).Error; err != nil {
return nil, err
}

if len(nodes) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}

return &pak, nil
return nil
}

func generateKey() (string, error) {
Expand Down
Loading

0 comments on commit 57214cb

Please sign in to comment.