Skip to content

Commit

Permalink
fix potential deadlock, add test scenario for disabling and deleting …
Browse files Browse the repository at this point in the history
…route

Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Nov 8, 2023
1 parent fbb23db commit e045601
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 13 deletions.
6 changes: 6 additions & 0 deletions hscontrol/db/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ func (hsdb *HSDatabase) GetNodeByMachineKey(
hsdb.mu.RLock()
defer hsdb.mu.RUnlock()

return hsdb.getNodeByMachineKey(machineKey)
}

func (hsdb *HSDatabase) getNodeByMachineKey(
machineKey key.MachinePublic,
) (*types.Node, error) {
mach := types.Node{}
if result := hsdb.db.
Preload("AuthKey").
Expand Down
39 changes: 30 additions & 9 deletions hscontrol/db/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error {
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() {
route.Enabled = false
route.IsPrimary = false
err = hsdb.db.Save(route).Error
err = hsdb.failoverRouteWithNotify(route)
if err != nil {
return err
}

err = hsdb.failoverRouteWithNotify(route)
route.Enabled = false
route.IsPrimary = false
err = hsdb.db.Save(route).Error
if err != nil {
return err
}
Expand Down Expand Up @@ -229,14 +229,15 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error {
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.IsExitRoute() {
if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
}

err := hsdb.failoverRouteWithNotify(route)
if err != nil {
return nil
}

if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil {
return err
}

} else {

routes, err := hsdb.getNodeRoutes(&node)
Expand Down Expand Up @@ -489,20 +490,32 @@ func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error {

var nodes types.Nodes

log.Trace().
Str("hostname", r.Node.Hostname).
Msg("loading machines with new primary routes from db")

for _, key := range changedKeys {
node, err := hsdb.GetNodeByMachineKey(key)
node, err := hsdb.getNodeByMachineKey(key)
if err != nil {
return err
}

nodes = append(nodes, node)
}

log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notifying peers about primary route change")

hsdb.notifier.NotifyAll(types.StateUpdate{
Type: types.StatePeerChanged,
Changed: nodes,
})

log.Trace().
Str("hostname", r.Node.Hostname).
Msg("notified peers about primary route change")

return nil
}

Expand Down Expand Up @@ -571,6 +584,10 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return nil, err
}

log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("removed primary from old route")

// Set primary for the new primary
newPrimary.IsPrimary = true
err = hsdb.db.Save(&newPrimary).Error
Expand All @@ -580,6 +597,10 @@ func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, erro
return nil, err
}

log.Trace().
Str("hostname", newPrimary.Node.Hostname).
Msg("set primary to new route")

rKey, err := r.Node.MachinePublicKey()
if err != nil {
return nil, err
Expand Down
189 changes: 185 additions & 4 deletions integration/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,17 @@ func TestHASubnetRouterFailover(t *testing.T) {

// Verify that the client has routes from the primary machine
srs1, err := subRouter1.Status()
srs2, err := subRouter2.Status()

clientStatus, err := client.Status()
assertNoErr(t, err)

srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]

assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)

assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
Expand Down Expand Up @@ -431,13 +436,15 @@ func TestHASubnetRouterFailover(t *testing.T) {
// TODO(kradalby): Check client status
// Route is expected to be on SR2

srs2, err := subRouter2.Status()
srs2, err = subRouter2.Status()

clientStatus, err = client.Status()
assertNoErr(t, err)

srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey]
srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)

if srs2PeerStatus.PrimaryRoutes != nil {
Expand Down Expand Up @@ -489,8 +496,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
clientStatus, err = client.Status()
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assertNotNil(t, srs2PeerStatus.PrimaryRoutes)

if srs2PeerStatus.PrimaryRoutes != nil {
Expand Down Expand Up @@ -523,12 +532,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)
assert.Len(t, routesAfter1Up, 2)

// Node 1 is not primary
// Node 1 is primary
assert.Equal(t, true, routesAfter1Up[0].Advertised)
assert.Equal(t, true, routesAfter1Up[0].Enabled)
assert.Equal(t, true, routesAfter1Up[0].IsPrimary)

// Node 2 is primary
// Node 2 is not primary
assert.Equal(t, true, routesAfter1Up[1].Advertised)
assert.Equal(t, true, routesAfter1Up[1].Enabled)
assert.Equal(t, false, routesAfter1Up[1].IsPrimary)
Expand All @@ -538,8 +547,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)

if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
Expand Down Expand Up @@ -586,8 +597,178 @@ func TestHASubnetRouterFailover(t *testing.T) {
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.NotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)

if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs1PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]),
)
}

// Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"disable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].Id),
})
assertNoErr(t, err)

time.Sleep(5 * time.Second)

var routesAfterDisabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDisabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDisabling1, 2)

// Node 1 is not primary
assert.Equal(t, true, routesAfterDisabling1[0].Advertised)
assert.Equal(t, false, routesAfterDisabling1[0].Enabled)
assert.Equal(t, false, routesAfterDisabling1[0].IsPrimary)

// Node 2 is primary
assert.Equal(t, true, routesAfterDisabling1[1].Advertised)
assert.Equal(t, true, routesAfterDisabling1[1].Enabled)
assert.Equal(t, true, routesAfterDisabling1[1].IsPrimary)

// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)

if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}

// enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"enable",
"--route",
fmt.Sprintf("%d", routesAfter2Up[0].Id),
})
assertNoErr(t, err)

time.Sleep(5 * time.Second)

var routesAfterEnabling1 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterEnabling1,
)
assertNoErr(t, err)
assert.Len(t, routesAfterEnabling1, 2)

// Node 1 is not primary
assert.Equal(t, true, routesAfterEnabling1[0].Advertised)
assert.Equal(t, true, routesAfterEnabling1[0].Enabled)
assert.Equal(t, false, routesAfterEnabling1[0].IsPrimary)

// Node 2 is primary
assert.Equal(t, true, routesAfterEnabling1[1].Advertised)
assert.Equal(t, true, routesAfterEnabling1[1].Enabled)
assert.Equal(t, true, routesAfterEnabling1[1].IsPrimary)

// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assert.Nil(t, srs1PeerStatus.PrimaryRoutes)
assert.NotNil(t, srs2PeerStatus.PrimaryRoutes)

if srs2PeerStatus.PrimaryRoutes != nil {
assert.Contains(
t,
srs2PeerStatus.PrimaryRoutes.AsSlice(),
netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]),
)
}

// delete the route of subnet router 2, failover to one expected
t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname())
_, err = headscale.Execute(
[]string{
"headscale",
"routes",
"delete",
"--route",
fmt.Sprintf("%d", routesAfterEnabling1[1].Id),
})
assertNoErr(t, err)

time.Sleep(5 * time.Second)

var routesAfterDeleting2 []*v1.Route
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"routes",
"list",
"--output",
"json",
},
&routesAfterDeleting2,
)
assertNoErr(t, err)
assert.Len(t, routesAfterDeleting2, 1)

t.Logf("routes after deleting2 %#v", routesAfterDeleting2)

// Node 1 is primary
assert.Equal(t, true, routesAfterDeleting2[0].Advertised)
assert.Equal(t, true, routesAfterDeleting2[0].Enabled)
assert.Equal(t, true, routesAfterDeleting2[0].IsPrimary)

// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
assertNoErr(t, err)

srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey]
srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey]

assertNotNil(t, srs1PeerStatus.PrimaryRoutes)
assert.Nil(t, srs2PeerStatus.PrimaryRoutes)

if srs1PeerStatus.PrimaryRoutes != nil {
assert.Contains(
Expand Down

0 comments on commit e045601

Please sign in to comment.