Skip to content

Commit

Permalink
Allow binding on address with arbitrary zone
Browse files Browse the repository at this point in the history
Signed-off-by: Xu Liu <[email protected]>
  • Loading branch information
xliuxu committed Sep 26, 2024
1 parent 6da6235 commit 8cf5475
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
41 changes: 27 additions & 14 deletions addr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"net/netip"
"strconv"
)

// An Addr is an IPv6 unicast address.
Expand All @@ -18,14 +19,16 @@ const (
)

// chooseAddr selects an Addr from the interface based on the specified Addr type.
func chooseAddr(addrs []net.Addr, zone string, addr Addr) (netip.Addr, error) {
func chooseAddr(addrs []net.Addr, zone string, zoneIndex int, addr Addr) (netip.Addr, error) {
// Does the caller want an unspecified address?
if addr == Unspecified {
return netip.IPv6Unspecified().WithZone(zone), nil
}

// Select an IPv6 address from the interface's addresses.
var match func(ip netip.Addr) bool
var preferred netip.Addr
var err error
switch addr {
case LinkLocal:
match = (netip.Addr).IsLinkLocalUnicast
Expand All @@ -38,27 +41,30 @@ func chooseAddr(addrs []net.Addr, zone string, addr Addr) (netip.Addr, error) {
}
default:
// Special case: try to match Addr as a literal IPv6 address.
ip, err := netip.ParseAddr(string(addr))
preferred, err = netip.ParseAddr(string(addr))
if err != nil {
return netip.Addr{}, fmt.Errorf("ndp: invalid IPv6 address: %q", addr)
}

if err := checkIPv6(ip); err != nil {
if err := checkIPv6(preferred); err != nil {
return netip.Addr{}, err
}

match = func(check netip.Addr) bool {
return ip == check
return preferred == check ||
preferred == check.WithZone(zone) ||
preferred == check.WithZone(strconv.Itoa(zoneIndex))
}
}

return findAddr(addrs, addr, zone, match)
found := findAddr(addrs, preferred, zone, match)
if !found.IsValid() {
return netip.Addr{}, fmt.Errorf("ndp: no valid IPv6 address found for %q", addr)
}
return found, nil
}

// findAddr searches for a valid IPv6 address in the slice of net.Addr that
// matches the input function. If none is found, the IPv6 unspecified address
// "::" is returned.
func findAddr(addrs []net.Addr, addr Addr, zone string, match func(ip netip.Addr) bool) (netip.Addr, error) {
// matches the input function. If none is found, it returns an invalid netip.Addr.
func findAddr(addrs []net.Addr, preferred netip.Addr, zone string, match func(ip netip.Addr) bool) netip.Addr {
for _, a := range addrs {
ipn, ok := a.(*net.IPNet)
if !ok {
Expand All @@ -75,11 +81,18 @@ func findAddr(addrs []net.Addr, addr Addr, zone string, match func(ip netip.Addr

// From here on, we can assume that only IPv6 addresses are
// being checked.
if match(ip) {
return ip.WithZone(zone), nil
if !match(ip) {
continue
}
// If a preferred address is set, use it directly.
if preferred.IsValid() {
if preferred.Zone() == "" {
return preferred.WithZone(zone)
}
return preferred
}
return ip.WithZone(zone)
}

// No matching address on this interface.
return netip.Addr{}, fmt.Errorf("ndp: address %q not found on interface %q", addr, zone)
return netip.Addr{}
}
44 changes: 34 additions & 10 deletions addr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package ndp
import (
"net"
"net/netip"
"strconv"
"testing"

"github.com/google/go-cmp/cmp"
)

func Test_chooseAddr(t *testing.T) {
// Assumed zone for all tests.
const zone = "eth0"
zone := "eth0"
zoneId := 1

var (
ip4 = net.IPv4(192, 168, 1, 1).To4()
Expand Down Expand Up @@ -60,43 +62,67 @@ func Test_chooseAddr(t *testing.T) {
},
{
name: "ok, unspecified",
ip: netip.IPv6Unspecified(),
ip: netip.IPv6Unspecified().WithZone(zone),
addr: Unspecified,
ok: true,
},
{
name: "ok, GUA",
addrs: addrs,
ip: netip.MustParseAddr("2001:db8::1"),
ip: netip.MustParseAddr("2001:db8::1").WithZone(zone),
addr: Global,
ok: true,
},
{
name: "ok, ULA",
addrs: addrs,
ip: netip.MustParseAddr("fc00::1"),
ip: netip.MustParseAddr("fc00::1").WithZone(zone),
addr: UniqueLocal,
ok: true,
},
{
name: "ok, LLA",
addrs: addrs,
ip: netip.MustParseAddr("fe80::1"),
ip: netip.MustParseAddr("fe80::1").WithZone(zone),
addr: LinkLocal,
ok: true,
},
{
name: "ok, arbitrary",
addrs: addrs,
ip: netip.MustParseAddr("2001:db8::1000"),
ip: netip.MustParseAddr("2001:db8::1000").WithZone(zone),
addr: Addr(ip6.String()),
ok: true,
},
{
name: "ok, arbitrary with zone",
addrs: addrs,
ip: netip.MustParseAddr("2001:db8::1000").WithZone(zone),
addr: Addr("2001:db8::1000%eth0"),
ok: true,
},
{
name: "arbitrary with mismatched zone",
addrs: addrs,
addr: Addr("2001:db8::1000%eth1"),
},
{
name: "ok, arbitrary with zone id",
addrs: addrs,
ip: netip.MustParseAddr("2001:db8::1000").WithZone(strconv.Itoa(zoneId)),
addr: Addr("2001:db8::1000%1"),
ok: true,
},
{
name: "arbitrary with mismatched zone id",
addrs: addrs,
addr: Addr("2001:db8::1000%2"),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ipa, err := chooseAddr(tt.addrs, zone, tt.addr)
ipa, err := chooseAddr(tt.addrs, zone, zoneId, tt.addr)

if err != nil && tt.ok {
t.Fatalf("unexpected error: %v", err)
Expand All @@ -108,9 +134,7 @@ func Test_chooseAddr(t *testing.T) {
t.Logf("OK error: %v", err)
return
}

ttipa := tt.ip.WithZone(zone)
if diff := cmp.Diff(ttipa, ipa, cmp.Comparer(addrEqual)); diff != "" {
if diff := cmp.Diff(tt.ip, ipa, cmp.Comparer(addrEqual)); diff != "" {
t.Fatalf("unexpected IPv6 address (-want +got):\n%s", diff)
}
})
Expand Down
2 changes: 1 addition & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func Listen(ifi *net.Interface, addr Addr) (*Conn, netip.Addr, error) {
return nil, netip.Addr{}, err
}

ip, err := chooseAddr(addrs, ifi.Name, addr)
ip, err := chooseAddr(addrs, ifi.Name, ifi.Index, addr)
if err != nil {
return nil, netip.Addr{}, err
}
Expand Down

0 comments on commit 8cf5475

Please sign in to comment.