From b8bb7c90f5f34346cdf3bbd1352d74c0fd7553d4 Mon Sep 17 00:00:00 2001 From: Urjit Singh Bhatia Date: Mon, 3 Feb 2020 06:16:36 +0000 Subject: [PATCH] Add loopback support --- server.go | 53 ++++++++++++++++++++------------ service_test.go | 81 +++++++++++++++++++++++++++++++++++++++---------- 2 files changed, 99 insertions(+), 35 deletions(-) diff --git a/server.go b/server.go index 86701514..15c3745a 100644 --- a/server.go +++ b/server.go @@ -1,6 +1,7 @@ package zeroconf import ( + "errors" "fmt" "log" "math/rand" @@ -10,12 +11,9 @@ import ( "sync" "time" + "github.com/miekg/dns" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" - - "errors" - - "github.com/miekg/dns" ) const ( @@ -24,8 +22,17 @@ const ( ) // Register a service by given arguments. This call will take the system's hostname -// and lookup IP by that hostname. +// This call will take the system's hostname and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { + return register(instance, service, domain, port, text, ifaces, false) +} + +// RegisterWithLoopback registers a service by given arguments but also registers on local loopback interface. +// and lookup IP by that hostname. +func RegisterWithLoopback(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { + return register(instance, service, domain, port, text, ifaces, true) +} +func register(instance, service, domain string, port int, text []string, ifaces []net.Interface, loopback bool) (*Server, error) { entry := NewServiceEntry(instance, service, domain) entry.Port = port entry.Text = text @@ -60,7 +67,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces } for _, iface := range ifaces { - v4, v6 := addrsForInterface(&iface) + v4, v6 := addrsForInterface(&iface, loopback) entry.AddrIPv4 = append(entry.AddrIPv4, v4...) entry.AddrIPv6 = append(entry.AddrIPv6, v6...) } @@ -73,7 +80,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces if err != nil { return nil, err } - + s.loopback = loopback s.service = entry go s.mainloop() go s.probe() @@ -148,6 +155,7 @@ type Server struct { ipv4conn *ipv4.PacketConn ipv6conn *ipv6.PacketConn ifaces []net.Interface + loopback bool shouldShutdown chan struct{} shutdownLock sync.Mutex @@ -605,7 +613,7 @@ func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache if len(v4) == 0 && len(v6) == 0 { iface, _ := net.InterfaceByIndex(ifIndex) if iface != nil { - a4, a6 := addrsForInterface(iface) + a4, a6 := addrsForInterface(iface, s.loopback) v4 = append(v4, a4...) v6 = append(v6, a6...) } @@ -647,20 +655,27 @@ func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache return list } -func addrsForInterface(iface *net.Interface) ([]net.IP, []net.IP) { +func addrsForInterface(iface *net.Interface, loopback bool) ([]net.IP, []net.IP) { var v4, v6, v6local []net.IP addrs, _ := iface.Addrs() for _, address := range addrs { - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - v4 = append(v4, ipnet.IP) - } else { - switch ip := ipnet.IP.To16(); ip != nil { - case ip.IsGlobalUnicast(): - v6 = append(v6, ipnet.IP) - case ip.IsLinkLocalUnicast(): - v6local = append(v6local, ipnet.IP) - } + ipnet, ok := address.(*net.IPNet) + if !ok { + continue + } + if ipnet.IP.IsLoopback() && !loopback { + // loopback is disabled - skip + continue + } + + if ipnet.IP.To4() != nil { + v4 = append(v4, ipnet.IP) + } else { + switch ip := ipnet.IP.To16(); ip != nil { + case ip.IsGlobalUnicast(): + v6 = append(v6, ipnet.IP) + case ip.IsLinkLocalUnicast(): + v6local = append(v6local, ipnet.IP) } } } diff --git a/service_test.go b/service_test.go index 61c5f1f1..208bd6ee 100644 --- a/service_test.go +++ b/service_test.go @@ -16,9 +16,16 @@ var ( mdnsPort = 8888 ) -func startMDNS(ctx context.Context, port int, name, service, domain string) { +func startMDNS(ctx context.Context, port int, name, service, domain string, loopback bool) { // 5353 is default mdns port - server, err := Register(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + var server *Server + var err error + + if loopback { + server, err = RegisterWithLoopback(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + } else { + server, err = Register(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + } if err != nil { panic(errors.Wrap(err, "while registering mdns service")) } @@ -28,14 +35,44 @@ func startMDNS(ctx context.Context, port int, name, service, domain string) { <-ctx.Done() log.Printf("Shutting down.") +} +func verifyResult(expectedResult []*ServiceEntry, expectLoopback bool, t *testing.T) { + if len(expectedResult) != 1 { + t.Fatalf("Expected number of service entries is 1, but got %d", len(expectedResult)) + } + if expectedResult[0].Domain != mdnsDomain { + t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, expectedResult[0].Domain) + } + if expectedResult[0].Service != mdnsService { + t.Fatalf("Expected service is %s, but got %s", mdnsService, expectedResult[0].Service) + } + if expectedResult[0].Instance != mdnsName { + t.Fatalf("Expected instance is %s, but got %s", mdnsName, expectedResult[0].Instance) + } + if expectedResult[0].Port != mdnsPort { + t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult[0].Port) + } + foundLoopback := false + for _, addr := range expectedResult[0].AddrIPv4 { + if addr.IsLoopback() { + foundLoopback = true + break + } + } + if expectLoopback && !foundLoopback { + t.Fatal("Expected AddrIPv4 to include loopback") + } + if !expectLoopback && foundLoopback { + t.Fatal("Unexpected AddrIPv4 loopback result") + } } func TestBasic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain) + go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain, false) time.Sleep(time.Second) @@ -55,21 +92,33 @@ func TestBasic(t *testing.T) { } <-ctx.Done() - if len(expectedResult) != 1 { - t.Fatalf("Expected number of service entries is 1, but got %d", len(expectedResult)) - } - if expectedResult[0].Domain != mdnsDomain { - t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, expectedResult[0].Domain) - } - if expectedResult[0].Service != mdnsService { - t.Fatalf("Expected service is %s, but got %s", mdnsService, expectedResult[0].Service) - } - if expectedResult[0].Instance != mdnsName { - t.Fatalf("Expected instance is %s, but got %s", mdnsName, expectedResult[0].Instance) + verifyResult(expectedResult, false, t) +} + +func TestWithLoopback(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain, true) + + time.Sleep(time.Second) + + resolver, err := NewResolver(nil) + if err != nil { + t.Fatalf("Expected create resolver success, but got %v", err) } - if expectedResult[0].Port != mdnsPort { - t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult[0].Port) + entries := make(chan *ServiceEntry) + expectedResult := []*ServiceEntry{} + go func(results <-chan *ServiceEntry) { + s := <-results + expectedResult = append(expectedResult, s) + }(entries) + + if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { + t.Fatalf("Expected browse success, but got %v", err) } + <-ctx.Done() + verifyResult(expectedResult, true, t) } func TestNoRegister(t *testing.T) {