Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write deadlock fixes and overall reliability improvements #30

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@ const (
IPv4AndIPv6 = IPv4 | IPv6 // default option
)

var initialQueryInterval = 4 * time.Second
var (
initialQueryInterval = 4 * time.Second
defaultClientWriteTimeout = 10 * time.Second
)

// Client structure encapsulates both IPv4/IPv6 UDP connections.
type client struct {
ipv4conn *ipv4.PacketConn
ipv6conn *ipv6.PacketConn
ifaces []net.Interface
ipv4conn *ipv4.PacketConn
ipv6conn *ipv6.PacketConn
interfaces NetInterfaceList
writeTimeout time.Duration
}

type clientOpts struct {
listenOn IPType
ifaces []net.Interface
listenOn IPType
ifaces []net.Interface
writeTimeout time.Duration
}

// ClientOption fills the option struct to configure intefaces, etc.
Expand All @@ -63,6 +68,13 @@ func SelectIfaces(ifaces []net.Interface) ClientOption {
}
}

// ClientWriteTimeout sets timeout for writing to the socket
func ClientWriteTimeout(duration time.Duration) ClientOption {
return func(o *clientOpts) {
o.writeTimeout = duration
}
}

// Browse for all services of a given type in a given domain.
// Received entries are sent on the entries channel.
// It blocks until the context is canceled (or an error occurs).
Expand Down Expand Up @@ -100,7 +112,8 @@ func Lookup(ctx context.Context, instance, service, domain string, entries chan<
func applyOpts(options ...ClientOption) clientOpts {
// Apply default configuration and load supplied options.
var conf = clientOpts{
listenOn: IPv4AndIPv6,
listenOn: IPv4AndIPv6,
writeTimeout: defaultClientWriteTimeout,
}
for _, o := range options {
if o != nil {
Expand Down Expand Up @@ -137,11 +150,12 @@ func newClient(opts clientOpts) (*client, error) {
if len(ifaces) == 0 {
ifaces = listMulticastInterfaces()
}
ifaceList := NewInterfaceList(ifaces)
// IPv4 interfaces
var ipv4conn *ipv4.PacketConn
if (opts.listenOn & IPv4) > 0 {
var err error
ipv4conn, err = joinUdp4Multicast(ifaces)
ipv4conn, err = joinUdp4Multicast(ifaceList)
if err != nil {
return nil, err
}
Expand All @@ -150,16 +164,17 @@ func newClient(opts clientOpts) (*client, error) {
var ipv6conn *ipv6.PacketConn
if (opts.listenOn & IPv6) > 0 {
var err error
ipv6conn, err = joinUdp6Multicast(ifaces)
ipv6conn, err = joinUdp6Multicast(ifaceList)
if err != nil {
return nil, err
}
}

return &client{
ipv4conn: ipv4conn,
ipv6conn: ipv6conn,
ifaces: ifaces,
ipv4conn: ipv4conn,
ipv6conn: ipv6conn,
interfaces: ifaceList,
writeTimeout: opts.writeTimeout,
}, nil
}

Expand Down Expand Up @@ -428,47 +443,62 @@ func (c *client) query(params *lookupParams) error {
m.SetQuestion(serviceName, dns.TypePTR)
}
m.RecursionDesired = false
return c.sendQuery(m)
// only send multicast queries to interfaces that we have joined
return c.sendQuery(m, NetInterfaceStateFlagMulticastJoined)
}

// Pack the dns.Msg and write to available connections (multicast)
func (c *client) sendQuery(msg *dns.Msg) error {
func (c *client) sendQuery(msg *dns.Msg, requiredFlags ...NetInterfaceStateFlag) error {
buf, err := msg.Pack()
if err != nil {
return err
return fmt.Errorf("failed to pack msg %v: %w", msg, err)
}
if c.ipv4conn != nil {
// See https://pkg.go.dev/golang.org/x/net/ipv4#pkg-note-BUG
// As of Golang 1.18.4
// On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented.
var wcm ipv4.ControlMessage
for ifi := range c.ifaces {
for _, intf := range c.interfaces {
if !intf.HasFlags(NetInterfaceScopeIPv4, requiredFlags...) {
continue
}
switch runtime.GOOS {
case "darwin", "ios", "linux":
wcm.IfIndex = c.ifaces[ifi].Index
wcm.IfIndex = intf.Index
default:
if err := c.ipv4conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil {
if err := c.ipv4conn.SetMulticastInterface(&intf.Interface); err != nil {
log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err)
}
}
c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr)
setDeadline(c.writeTimeout, c.ipv4conn)
n, err := c.ipv4conn.WriteTo(buf, &wcm, ipv4Addr)
if err == nil && n > 0 {
intf.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMessageSent)
}
}
}
if c.ipv6conn != nil {
// See https://pkg.go.dev/golang.org/x/net/ipv6#pkg-note-BUG
// As of Golang 1.18.4
// On Windows, the ControlMessage for ReadFrom and WriteTo methods of PacketConn is not implemented.
var wcm ipv6.ControlMessage
for ifi := range c.ifaces {
for _, intf := range c.interfaces {
if !intf.HasFlags(NetInterfaceScopeIPv6, requiredFlags...) {
continue
}
switch runtime.GOOS {
case "darwin", "ios", "linux":
wcm.IfIndex = c.ifaces[ifi].Index
wcm.IfIndex = intf.Index
default:
if err := c.ipv6conn.SetMulticastInterface(&c.ifaces[ifi]); err != nil {
if err := c.ipv6conn.SetMulticastInterface(&intf.Interface); err != nil {
log.Printf("[WARN] mdns: Failed to set multicast interface: %v", err)
}
}
c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr)
setDeadline(c.writeTimeout, c.ipv6conn)
n, err := c.ipv6conn.WriteTo(buf, &wcm, ipv6Addr)
if err == nil && n > 0 {
intf.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMessageSent)
}
}
}
return nil
Expand Down
39 changes: 20 additions & 19 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ var (
}
)

func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) {
func joinUdp6Multicast(interfaces []*NetInterface) (*ipv6.PacketConn, error) {
if len(interfaces) == 0 {
return nil, fmt.Errorf("no interfaces to join multicast on")
}

udpConn, err := net.ListenUDP("udp6", mdnsWildcardAddrIPv6)
if err != nil {
return nil, err
Expand All @@ -45,19 +49,15 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) {
pkConn := ipv6.NewPacketConn(udpConn)
pkConn.SetControlMessage(ipv6.FlagInterface, true)

if len(interfaces) == 0 {
interfaces = listMulticastInterfaces()
}
// log.Println("Using multicast interfaces: ", interfaces)

var failedJoins int
var anySucceeded bool
for _, iface := range interfaces {
if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil {
// log.Println("Udp6 JoinGroup failed for iface ", iface)
failedJoins++
if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv6}); err == nil {
iface.SetFlag(NetInterfaceScopeIPv6, NetInterfaceStateFlagMulticastJoined)
anySucceeded = true
}
}
if failedJoins == len(interfaces) {
if !anySucceeded {
pkConn.Close()
return nil, fmt.Errorf("udp6: failed to join any of these interfaces: %v", interfaces)
}
Expand All @@ -67,7 +67,11 @@ func joinUdp6Multicast(interfaces []net.Interface) (*ipv6.PacketConn, error) {
return pkConn, nil
}

func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) {
func joinUdp4Multicast(interfaces []*NetInterface) (*ipv4.PacketConn, error) {
if len(interfaces) == 0 {
return nil, fmt.Errorf("no interfaces to join multicast on")
}

udpConn, err := net.ListenUDP("udp4", mdnsWildcardAddrIPv4)
if err != nil {
// log.Printf("[ERR] bonjour: Failed to bind to udp4 mutlicast: %v", err)
Expand All @@ -78,19 +82,16 @@ func joinUdp4Multicast(interfaces []net.Interface) (*ipv4.PacketConn, error) {
pkConn := ipv4.NewPacketConn(udpConn)
pkConn.SetControlMessage(ipv4.FlagInterface, true)

if len(interfaces) == 0 {
interfaces = listMulticastInterfaces()
}
// log.Println("Using multicast interfaces: ", interfaces)
var anySucceed bool

var failedJoins int
for _, iface := range interfaces {
if err := pkConn.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil {
// log.Println("Udp4 JoinGroup failed for iface ", iface)
failedJoins++
if err := pkConn.JoinGroup(&iface.Interface, &net.UDPAddr{IP: mdnsGroupIPv4}); err == nil {
anySucceed = true
iface.SetFlag(NetInterfaceScopeIPv4, NetInterfaceStateFlagMulticastJoined)
}
}
if failedJoins == len(interfaces) {
if !anySucceed {
pkConn.Close()
return nil, fmt.Errorf("udp4: failed to join any of these interfaces: %v", interfaces)
}
Expand Down
86 changes: 86 additions & 0 deletions netinterface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package zeroconf

import (
"net"
"sync/atomic"
)

type NetInterface struct {
net.Interface
stateIPv4 NetInterfaceStateFlag
stateIPv6 NetInterfaceStateFlag
}

type NetInterfaceScope int

const (
NetInterfaceScopeIPv4 NetInterfaceScope = iota
NetInterfaceScopeIPv6
)

type NetInterfaceList []*NetInterface

type NetInterfaceStateFlag uint32

const (
NetInterfaceStateFlagMulticastJoined NetInterfaceStateFlag = 1 << iota // we have joined the multicast group on this interface
NetInterfaceStateFlagMessageSent // we have successfully sent at least one message on this interface
)

func (i *NetInterface) HasFlags(scope NetInterfaceScope, flags ...NetInterfaceStateFlag) bool {
for _, flag := range flags {
if !i.HasFlag(scope, flag) {
return false
}
}
return true
}

func (i *NetInterface) loadFlag(address *NetInterfaceStateFlag) NetInterfaceStateFlag {
return NetInterfaceStateFlag(atomic.LoadUint32((*uint32)(address)))
}

func (i *NetInterface) HasFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) bool {
if scope == NetInterfaceScopeIPv4 {
return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv4)&flag) != 0
} else if scope == NetInterfaceScopeIPv6 {
return NetInterfaceStateFlag(i.loadFlag(&i.stateIPv6)&flag) != 0
}
return false
}

func (i *NetInterface) SetFlag(scope NetInterfaceScope, flag NetInterfaceStateFlag) {
if scope == NetInterfaceScopeIPv4 {
i.setFlag(&i.stateIPv4, flag)
} else if scope == NetInterfaceScopeIPv6 {
i.setFlag(&i.stateIPv6, flag)
}
}

func (i *NetInterface) setFlag(address *NetInterfaceStateFlag, flag NetInterfaceStateFlag) {
// If atomic value != previously loaded value, then repeat the operation
// If they are equal, then we can safely set the new value
// This is the way to ensure atomicity of the operation
for {
loadedValue := uint32(i.loadFlag(address))
if atomic.CompareAndSwapUint32((*uint32)(address), loadedValue, loadedValue|uint32(flag)) {
break
}
}
}

func (list NetInterfaceList) GetByIndex(index int) *NetInterface {
for _, iface := range list {
if iface.Index == index {
return iface
}
}
return nil
}

func NewInterfaceList(ifaces []net.Interface) (list NetInterfaceList) {
for i := range ifaces {
list = append(list, &NetInterface{Interface: ifaces[i]})
}
return
}
Loading