From 2d3c852d29059371a19f775672d0e6ec07b87df8 Mon Sep 17 00:00:00 2001 From: Graham Miln Date: Tue, 21 Nov 2023 14:33:04 +0100 Subject: [PATCH] Apply default domain before caching derived values If the domain parameter to Register or RegisterProxy is left empty, a default value of `local.` is used. The default was applied after derived values are cached and thus the default was effectively being ignored. --- server.go | 67 +++++++++++++++++------------------ service_test.go | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 124 insertions(+), 36 deletions(-) diff --git a/server.go b/server.go index 4d907f93..4f3263b0 100644 --- a/server.go +++ b/server.go @@ -25,24 +25,11 @@ const ( // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { - entry := NewServiceEntry(instance, service, domain) - entry.Port = port - entry.Text = text - - if entry.Instance == "" { - return nil, fmt.Errorf("missing service instance name") - } - if entry.Service == "" { - return nil, fmt.Errorf("missing service name") - } - if entry.Domain == "" { - entry.Domain = "local." - } - if entry.Port == 0 { - return nil, fmt.Errorf("missing port") + entry, err := newRegisterServiceEntry(instance, service, domain, port, text) + if err != nil { + return nil, err } - var err error if entry.HostName == "" { entry.HostName, err = os.Hostname() if err != nil { @@ -83,25 +70,9 @@ func Register(instance, service, domain string, port int, text []string, ifaces // RegisterProxy registers a service proxy. This call will skip the hostname/IP lookup and // will use the provided values. func RegisterProxy(instance, service, domain string, port int, host string, ips []string, text []string, ifaces []net.Interface) (*Server, error) { - entry := NewServiceEntry(instance, service, domain) - entry.Port = port - entry.Text = text - entry.HostName = host - - if entry.Instance == "" { - return nil, fmt.Errorf("missing service instance name") - } - if entry.Service == "" { - return nil, fmt.Errorf("missing service name") - } - if entry.HostName == "" { - return nil, fmt.Errorf("missing host name") - } - if entry.Domain == "" { - entry.Domain = "local" - } - if entry.Port == 0 { - return nil, fmt.Errorf("missing port") + entry, err := newRegisterServiceEntry(instance, service, domain, port, text) + if err != nil { + return nil, err } if !strings.HasSuffix(trimDot(entry.HostName), entry.Domain) { @@ -137,6 +108,30 @@ func RegisterProxy(instance, service, domain string, port int, host string, ips return s, nil } +// newRegisterServiceEntry returns a ServiceEntry with defaults substituted as required. +func newRegisterServiceEntry(instance, service, domain string, port int, text []string) (*ServiceEntry, error) { + // Required parameters + if instance == "" { + return nil, fmt.Errorf("missing service instance name") + } + if service == "" { + return nil, fmt.Errorf("missing service name") + } + if port == 0 { + return nil, fmt.Errorf("missing port") + } + // Defaulted parameters + if domain == "" { + domain = "local." + } + + entry := NewServiceEntry(instance, service, domain) + entry.Port = port + entry.Text = text + + return entry, nil +} + const ( qClassCacheFlush uint16 = 1 << 15 ) @@ -525,7 +520,7 @@ func (s *Server) serviceTypeName(resp *dns.Msg, ttl uint32) { } // Perform probing & announcement -//TODO: implement a proper probing & conflict resolution +// TODO: implement a proper probing & conflict resolution func (s *Server) probe() { q := new(dns.Msg) q.SetQuestion(s.service.ServiceInstanceName(), dns.TypePTR) diff --git a/service_test.go b/service_test.go index 2c5a23ed..62aef4f3 100644 --- a/service_test.go +++ b/service_test.go @@ -3,6 +3,7 @@ package zeroconf import ( "context" "log" + "strings" "testing" "time" @@ -164,3 +165,95 @@ func TestSubtype(t *testing.T) { } }) } + +// Test the default domain is applied. +func TestDefaultDomain(t *testing.T) { + t.Run("register", func(t *testing.T) { + server, err := Register(mdnsName, mdnsService, "", mdnsPort, []string{"txtv=0", "lo=2", "la=3"}, nil) + if err != nil { + t.Fatal(err) + } + if server == nil { + t.Fatal("expect non-nil") + } + // Check the service record's cached fields + sr := server.service.ServiceRecord + if strings.Contains(sr.serviceName, "..") { + t.Errorf("malformed service name: %s", sr.serviceName) + } + if strings.Contains(sr.serviceInstanceName, "..") { + t.Errorf("malformed service instance name: %s", sr.serviceInstanceName) + } + + t.Logf("Published service: %+v", server.service) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Wait for context to time out + <-ctx.Done() + + t.Log("Shutting down.") + server.Shutdown() + }) + + t.Run("registerproxy", func(t *testing.T) { + server, err := RegisterProxy(mdnsName, mdnsService, "", mdnsPort, "localhost", []string{"::1"}, []string{"txtv=0", "lo=2", "la=3"}, nil) + if err != nil { + t.Fatal(err) + } + if server == nil { + t.Fatal("expect non-nil") + } + // Check the service record's cached fields + sr := server.service.ServiceRecord + if strings.Contains(sr.serviceName, "..") { + t.Errorf("malformed service name: %s", sr.serviceName) + } + if strings.Contains(sr.serviceInstanceName, "..") { + t.Errorf("malformed service instance name: %s", sr.serviceInstanceName) + } + + t.Logf("Published service: %+v", server.service) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Wait for context to time out + <-ctx.Done() + + t.Log("Shutting down.") + server.Shutdown() + }) +} + +func TestNewRegisterServiceEntry(t *testing.T) { + tests := []struct { + name string + instance, service, domain string + port int + text []string + err bool + }{ + {"minimal", mdnsName, mdnsService, mdnsDomain, mdnsPort, []string{}, false}, + // Required parameters + {"require-instance", "", mdnsService, mdnsDomain, mdnsPort, []string{}, true}, + {"require-service", mdnsName, "", mdnsDomain, mdnsPort, []string{}, true}, + {"require-port", mdnsName, mdnsService, mdnsDomain, 0, []string{}, true}, + // Default domain + {"default-domain", mdnsName, mdnsService, "", mdnsPort, []string{}, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + se, err := newRegisterServiceEntry(test.instance, test.service, test.domain, test.port, test.text) + if test.err && err == nil { + t.Error("expect error") + } else if !test.err && err != nil { + t.Error(err) + } + if err == nil && se == nil { + t.Error("expect non-nil") + } + }) + } +}