diff --git a/server.go b/server.go index 70fd11ac..d55e65b9 100644 --- a/server.go +++ b/server.go @@ -22,8 +22,17 @@ const ( ) // Register a service by given arguments. This call will take the system's hostname -// and lookup IP by that hostname. +// This call will take the system's hostname and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { + return register(instance, service, domain, port, text, ifaces, false) +} + +// RegisterWithLoopback registers a service by given arguments but also registers on local loopback interface. +// and lookup IP by that hostname. +func RegisterWithLoopback(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { + return register(instance, service, domain, port, text, ifaces, true) +} +func register(instance, service, domain string, port int, text []string, ifaces []net.Interface, loopback bool) (*Server, error) { entry := NewServiceEntry(instance, service, domain) entry.Port = port entry.Text = text @@ -58,7 +67,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces } for _, iface := range ifaces { - v4, v6 := addrsForInterface(&iface) + v4, v6 := addrsForInterface(&iface, loopback) entry.AddrIPv4 = append(entry.AddrIPv4, v4...) entry.AddrIPv6 = append(entry.AddrIPv6, v6...) } @@ -71,7 +80,7 @@ func Register(instance, service, domain string, port int, text []string, ifaces if err != nil { return nil, err } - + s.loopback = loopback s.service = entry go s.mainloop() go s.probe() @@ -146,6 +155,7 @@ type Server struct { ipv4conn *ipv4.PacketConn ipv6conn *ipv6.PacketConn ifaces []net.Interface + loopback bool shouldShutdown chan struct{} shutdownLock sync.Mutex @@ -621,7 +631,7 @@ func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache if len(v4) == 0 && len(v6) == 0 { iface, _ := net.InterfaceByIndex(ifIndex) if iface != nil { - a4, a6 := addrsForInterface(iface) + a4, a6 := addrsForInterface(iface, s.loopback) v4 = append(v4, a4...) v6 = append(v6, a6...) } @@ -663,20 +673,27 @@ func (s *Server) appendAddrs(list []dns.RR, ttl uint32, ifIndex int, flushCache return list } -func addrsForInterface(iface *net.Interface) ([]net.IP, []net.IP) { +func addrsForInterface(iface *net.Interface, loopback bool) ([]net.IP, []net.IP) { var v4, v6, v6local []net.IP addrs, _ := iface.Addrs() for _, address := range addrs { - if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { - if ipnet.IP.To4() != nil { - v4 = append(v4, ipnet.IP) - } else { - switch ip := ipnet.IP.To16(); ip != nil { - case ip.IsGlobalUnicast(): - v6 = append(v6, ipnet.IP) - case ip.IsLinkLocalUnicast(): - v6local = append(v6local, ipnet.IP) - } + ipnet, ok := address.(*net.IPNet) + if !ok { + continue + } + if ipnet.IP.IsLoopback() && !loopback { + // loopback is disabled - skip + continue + } + + if ipnet.IP.To4() != nil { + v4 = append(v4, ipnet.IP) + } else { + switch ip := ipnet.IP.To16(); ip != nil { + case ip.IsGlobalUnicast(): + v6 = append(v6, ipnet.IP) + case ip.IsLinkLocalUnicast(): + v6local = append(v6local, ipnet.IP) } } } diff --git a/service_test.go b/service_test.go index 2c5a23ed..4d04f2e1 100644 --- a/service_test.go +++ b/service_test.go @@ -17,9 +17,16 @@ var ( mdnsPort = 8888 ) -func startMDNS(ctx context.Context, port int, name, service, domain string) { +func startMDNS(ctx context.Context, port int, name, service, domain string, loopback bool) { // 5353 is default mdns port - server, err := Register(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + var server *Server + var err error + + if loopback { + server, err = RegisterWithLoopback(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + } else { + server, err = Register(name, service, domain, port, []string{"txtv=0", "lo=1", "la=2"}, nil) + } if err != nil { panic(errors.Wrap(err, "while registering mdns service")) } @@ -29,14 +36,42 @@ func startMDNS(ctx context.Context, port int, name, service, domain string) { <-ctx.Done() log.Printf("Shutting down.") +} +func verifyResult(expectedResult *ServiceEntry, expectLoopback bool, t *testing.T) { + if expectedResult.Domain != mdnsDomain { + t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, expectedResult.Domain) + } + if expectedResult.Service != mdnsService { + t.Fatalf("Expected service is %s, but got %s", mdnsService, expectedResult.Service) + } + if expectedResult.Instance != mdnsName { + t.Fatalf("Expected instance is %s, but got %s", mdnsName, expectedResult.Instance) + } + if expectedResult.Port != mdnsPort { + t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult.Port) + } + foundLoopback := false + for _, addr := range expectedResult.AddrIPv4 { + if addr.IsLoopback() { + foundLoopback = true + break + } + } + if expectLoopback && !foundLoopback { + t.Fatal("Expected AddrIPv4 to include loopback") + } + if !expectLoopback && foundLoopback { + t.Fatal("Unexpected AddrIPv4 loopback result") + } } func TestBasic(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain) + loopback := false + go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain, loopback) time.Sleep(time.Second) @@ -54,18 +89,33 @@ func TestBasic(t *testing.T) { t.Fatalf("Expected number of service entries is 1, but got %d", len(entries)) } result := <-entries - if result.Domain != mdnsDomain { - t.Fatalf("Expected domain is %s, but got %s", mdnsDomain, result.Domain) - } - if result.Service != mdnsService { - t.Fatalf("Expected service is %s, but got %s", mdnsService, result.Service) + verifyResult(result, loopback, t) +} + +func TestWithLoopback(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + loopback := true + go startMDNS(ctx, mdnsPort, mdnsName, mdnsService, mdnsDomain, loopback) + + time.Sleep(time.Second) + + resolver, err := NewResolver(nil) + if err != nil { + t.Fatalf("Expected create resolver success, but got %v", err) } - if result.Instance != mdnsName { - t.Fatalf("Expected instance is %s, but got %s", mdnsName, result.Instance) + entries := make(chan *ServiceEntry, 100) + if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { + t.Fatalf("Expected browse success, but got %v", err) } - if result.Port != mdnsPort { - t.Fatalf("Expected port is %d, but got %d", mdnsPort, result.Port) + <-ctx.Done() + + if len(entries) != 1 { + t.Fatalf("Expected number of service entries is 1, but got %d", len(entries)) } + result := <-entries + verifyResult(result, loopback, t) } func TestNoRegister(t *testing.T) { @@ -96,7 +146,7 @@ func TestSubtype(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain) + go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain, false) time.Sleep(time.Second) @@ -132,7 +182,7 @@ func TestSubtype(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain) + go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain, false) time.Sleep(time.Second)