From 5877433908b0d97f5d4a9a33ecb3606d2f349896 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sat, 5 Jun 2021 21:25:01 -0700 Subject: [PATCH] delete entries from the cache when the TTL expires --- README.md | 2 +- client.go | 53 ++++++++++++++++++++++++++++++++++++++++--------- server.go | 4 +++- service.go | 13 ++++++------ service_test.go | 44 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 99 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index f639e8cd..ca75b4a2 100644 --- a/README.md +++ b/README.md @@ -94,7 +94,7 @@ See what needs to be done and submit a pull request :) * [x] Browse / Lookup / Register services * [x] Multiple IPv6 / IPv4 addresses support * [x] Send multiple probes (exp. back-off) if no service answers (*) -* [ ] Timestamp entries for TTL checks +* [x] Timestamp entries for TTL checks * [ ] Compare new multicasts with already received services _Notes:_ diff --git a/client.go b/client.go index 1ebed508..da685962 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strings" + "sync" "time" "github.com/cenkalti/backoff" @@ -143,6 +144,9 @@ type client struct { ipv4conn *ipv4.PacketConn ipv6conn *ipv6.PacketConn ifaces []net.Interface + + mutex sync.Mutex + sentEntries map[string]*ServiceEntry } // Client structure constructor @@ -177,6 +181,28 @@ func newClient(opts clientOpts) (*client, error) { }, nil } +var cleanupFreq = 10 * time.Second + +// clean up entries whose TTL expired +func (c *client) cleanupSentEntries(ctx context.Context) { + ticker := time.NewTicker(cleanupFreq) + defer ticker.Stop() + for { + select { + case t := <-ticker.C: + c.mutex.Lock() + for k, e := range c.sentEntries { + if t.After(e.Expiry) { + delete(c.sentEntries, k) + } + } + c.mutex.Unlock() + case <-ctx.Done(): + return + } + } +} + // Start listeners and waits for the shutdown signal from exit channel func (c *client) mainloop(ctx context.Context, params *lookupParams) { // start listening for responses @@ -189,9 +215,12 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } // Iterate through channels from listeners goroutines - var entries, sentEntries map[string]*ServiceEntry - sentEntries = make(map[string]*ServiceEntry) + var entries map[string]*ServiceEntry + c.sentEntries = make(map[string]*ServiceEntry) + go c.cleanupSentEntries(ctx) + for { + var now time.Time select { case <-ctx.Done(): // Context expired. Notify subscriber that we are done here. @@ -199,6 +228,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { c.shutdown() return case msg := <-msgCh: + now = time.Now() entries = make(map[string]*ServiceEntry) sections := append(msg.Answer, msg.Ns...) sections = append(sections, msg.Extra...) @@ -218,7 +248,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { params.Service, params.Domain) } - entries[rr.Ptr].TTL = rr.Hdr.Ttl + entries[rr.Ptr].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) case *dns.SRV: if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name { continue @@ -233,7 +263,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { } entries[rr.Hdr.Name].HostName = rr.Target entries[rr.Hdr.Name].Port = int(rr.Port) - entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl + entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) case *dns.TXT: if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name { continue @@ -247,7 +277,7 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { params.Domain) } entries[rr.Hdr.Name].Text = rr.Txt - entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl + entries[rr.Hdr.Name].Expiry = now.Add(time.Duration(rr.Hdr.Ttl) * time.Second) } } // Associate IPs in a second round as other fields should be filled by now. @@ -271,12 +301,15 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { if len(entries) > 0 { for k, e := range entries { - if e.TTL == 0 { + c.mutex.Lock() + if !e.Expiry.After(now) { delete(entries, k) - delete(sentEntries, k) + delete(c.sentEntries, k) + c.mutex.Unlock() continue } - if _, ok := sentEntries[k]; ok { + if _, ok := c.sentEntries[k]; ok { + c.mutex.Unlock() continue } @@ -286,14 +319,16 @@ func (c *client) mainloop(ctx context.Context, params *lookupParams) { // Require at least one resolved IP address for ServiceEntry // TODO: wait some more time as chances are high both will arrive. if len(e.AddrIPv4) == 0 && len(e.AddrIPv6) == 0 { + c.mutex.Unlock() continue } } // Submit entry to subscriber and cache it. // This is also a point to possibly stop probing actively for a // service entry. + c.sentEntries[k] = e + c.mutex.Unlock() params.Entries <- e - sentEntries[k] = e if !params.isBrowsing { params.disableProbing() } diff --git a/server.go b/server.go index fc6650cf..dd147362 100644 --- a/server.go +++ b/server.go @@ -21,6 +21,8 @@ const ( multicastRepetitions = 2 ) +var defaultTTL uint32 = 3200 + // Register a service by given arguments. This call will take the system's hostname // and lookup IP by that hostname. func Register(instance, service, domain string, port int, text []string, ifaces []net.Interface) (*Server, error) { @@ -173,7 +175,7 @@ func newServer(ifaces []net.Interface) (*Server, error) { ipv4conn: ipv4conn, ipv6conn: ipv6conn, ifaces: ifaces, - ttl: 3200, + ttl: defaultTTL, shouldShutdown: make(chan struct{}), } diff --git a/service.go b/service.go index 6253c543..43bbf8aa 100644 --- a/service.go +++ b/service.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "sync" + "time" ) // ServiceRecord contains the basic description of a service, which contains instance name, service type & domain @@ -103,12 +104,12 @@ func (l *lookupParams) disableProbing() { // used to answer multicast queries. type ServiceEntry struct { ServiceRecord - HostName string `json:"hostname"` // Host machine DNS name - Port int `json:"port"` // Service Port - Text []string `json:"text"` // Service info served as a TXT record - TTL uint32 `json:"ttl"` // TTL of the service record - AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address - AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address + HostName string `json:"hostname"` // Host machine DNS name + Port int `json:"port"` // Service Port + Text []string `json:"text"` // Service info served as a TXT record + Expiry time.Time `json:"expiry"` // Expiry of the service entry, will be converted to a TTL value + AddrIPv4 []net.IP `json:"-"` // Host machine IPv4 address + AddrIPv6 []net.IP `json:"-"` // Host machine IPv6 address } // NewServiceEntry constructs a ServiceEntry. diff --git a/service_test.go b/service_test.go index ae692adc..c12ed932 100644 --- a/service_test.go +++ b/service_test.go @@ -184,4 +184,48 @@ func TestSubtype(t *testing.T) { t.Fatalf("Expected port is %d, but got %d", mdnsPort, expectedResult[0].Port) } }) + + t.Run("ttl", func(t *testing.T) { + origTTL := defaultTTL + origCleanupFreq := cleanupFreq + defer func() { + defaultTTL = origTTL + cleanupFreq = origCleanupFreq + }() + defaultTTL = 2 // 2 seconds + cleanupFreq = 100 * time.Millisecond + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + go startMDNS(ctx, mdnsPort, mdnsName, mdnsSubtype, mdnsDomain) + + entries := make(chan *ServiceEntry) + var expectedResult []*ServiceEntry + go func() { + for { + select { + case s := <-entries: + expectedResult = append(expectedResult, s) + case <-ctx.Done(): + return + } + } + }() + + resolver, err := NewResolver(nil) + if err != nil { + t.Fatalf("Expected create resolver success, but got %v", err) + } + if err := resolver.Browse(ctx, mdnsService, mdnsDomain, entries); err != nil { + t.Fatalf("Expected browse success, but got %v", err) + } + + <-ctx.Done() + if len(expectedResult) != 2 { + t.Fatalf("Expected to have received 2 entries, but got %d", len(expectedResult)) + } + if expectedResult[0].ServiceInstanceName() != expectedResult[1].ServiceInstanceName() { + t.Fatalf("expected the two entries to be identical") + } + }) }