Skip to content

Commit

Permalink
Add address limit support for DNS rules
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Feb 4, 2024
1 parent 5d1738b commit fd705d9
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 69 deletions.
12 changes: 7 additions & 5 deletions adapter/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ type InboundContext struct {

// rule cache

IPCIDRMatchSource bool
SourceAddressMatch bool
SourcePortMatch bool
DestinationAddressMatch bool
DestinationPortMatch bool
IPCIDRMatchSource bool
SourceAddressMatch bool
SourcePortMatch bool
DestinationAddressMatch bool
DestinationPortMatch bool
DidMatch bool
IgnoreDestinationIPCIDRMatch bool
}

func (c *InboundContext) ResetRuleCache() {
Expand Down
3 changes: 3 additions & 0 deletions adapter/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ type DNSRule interface {
Rule
DisableCache() bool
RewriteTTL() *uint32
WithAddressLimit() bool
MatchAddressLimit(metadata *InboundContext) bool
}

type RuleSet interface {
Expand All @@ -97,6 +99,7 @@ type RuleSet interface {
type RuleSetMetadata struct {
ContainsProcessRule bool
ContainsWIFIRule bool
ContainsIPCIDRRule bool
}

type RuleSetStartContext interface {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ require (
github.com/sagernet/quic-go v0.40.1
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691
github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573
github.com/sagernet/sing-dns v0.1.12
github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060
github.com/sagernet/sing-mux v0.2.0
github.com/sagernet/sing-quic v0.1.8
github.com/sagernet/sing-shadowsocks v0.2.6
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4Wk
github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo=
github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573 h1:1wGN3eNanp8r+Y3bNBys3ZnAVF5gdtDoDwtosMZEbgA=
github.com/sagernet/sing v0.3.1-0.20240105061852-782bc05c5573/go.mod h1:9pfuAH6mZfgnz/YjP6xu5sxx882rfyjpcrTdUpd6w3g=
github.com/sagernet/sing-dns v0.1.12 h1:1HqZ+ln+Rezx/aJMStaS0d7oPeX2EobSV1NT537kyj4=
github.com/sagernet/sing-dns v0.1.12/go.mod h1:rx/DTOisneQpCgNQ4jbFU/JNEtnz0lYcHXenlVzpjEU=
github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060 h1:ah78H3NjlBEov2MGAKC5Wtn71LhFfRatVrJ88PCQPjE=
github.com/sagernet/sing-dns v0.1.13-0.20240203102504-27e217be9060/go.mod h1:IxOqfSb6Zt6UVCy8fJpDxb2XxqzHUytNqeOuJfaiLu8=
github.com/sagernet/sing-mux v0.2.0 h1:4C+vd8HztJCWNYfufvgL49xaOoOHXty2+EAjnzN3IYo=
github.com/sagernet/sing-mux v0.2.0/go.mod h1:khzr9AOPocLa+g53dBplwNDz4gdsyx/YM3swtAhlkHQ=
github.com/sagernet/sing-quic v0.1.8 h1:G4iBXAKIII+uTzd55oZ/9cAQswGjlvHh/0yKMQioDS0=
Expand Down
3 changes: 3 additions & 0 deletions option/rule_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ type DefaultDNSRule struct {
DomainRegex Listable[string] `json:"domain_regex,omitempty"`
Geosite Listable[string] `json:"geosite,omitempty"`
SourceGeoIP Listable[string] `json:"source_geoip,omitempty"`
GeoIP Listable[string] `json:"geoip,omitempty"`
SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"`
IPCIDR Listable[string] `json:"ip_cidr,omitempty"`
IPIsPrivate bool `json:"ip_is_private,omitempty"`
SourceIPIsPrivate bool `json:"source_ip_is_private,omitempty"`
SourcePort Listable[uint16] `json:"source_port,omitempty"`
SourcePortRange Listable[string] `json:"source_port_range,omitempty"`
Expand Down
182 changes: 128 additions & 54 deletions route/router_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package route

import (
"context"
"errors"
"net/netip"
"strings"
"time"

"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/cache"
E "github.com/sagernet/sing/common/exceptions"
Expand Down Expand Up @@ -37,41 +37,47 @@ func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) {
return domain, loaded
}

func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool) (context.Context, dns.Transport, dns.DomainStrategy) {
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, index int) (context.Context, dns.Transport, dns.DomainStrategy, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
for i, rule := range r.dnsRules {
metadata.ResetRuleCache()
if rule.Match(metadata) {
detour := rule.Outbound()
transport, loaded := r.transportMap[detour]
if !loaded {
r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
continue
}
if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !allowFakeIP {
continue
}
r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
if rule.DisableCache() {
ctx = dns.ContextWithDisableCache(ctx, true)
}
if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil {
ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL)
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
return ctx, transport, domainStrategy
} else {
return ctx, transport, r.defaultDomainStrategy
if index < len(r.dnsRules) {
dnsRules := r.dnsRules
if index != -1 {
dnsRules = dnsRules[index+1:]
}
for ruleIndex, rule := range dnsRules {
metadata.ResetRuleCache()
if rule.Match(metadata) {
detour := rule.Outbound()
transport, loaded := r.transportMap[detour]
if !loaded {
r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
continue
}
if _, isFakeIP := transport.(adapter.FakeIPTransport); isFakeIP && !allowFakeIP {
continue
}
r.dnsLogger.DebugContext(ctx, "match[", ruleIndex, "] ", rule.String(), " => ", detour)
if rule.DisableCache() {
ctx = dns.ContextWithDisableCache(ctx, true)
}
if rewriteTTL := rule.RewriteTTL(); rewriteTTL != nil {
ctx = dns.ContextWithRewriteTTL(ctx, *rewriteTTL)
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
return ctx, transport, domainStrategy, rule, ruleIndex
} else {
return ctx, transport, r.defaultDomainStrategy, rule, ruleIndex
}
}
}
}
if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded {
return ctx, r.defaultTransport, domainStrategy
return ctx, r.defaultTransport, domainStrategy, nil, -1
} else {
return ctx, r.defaultTransport, r.defaultDomainStrategy
return ctx, r.defaultTransport, r.defaultDomainStrategy, nil, -1
}
}

Expand All @@ -86,7 +92,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
)
response, cached = r.dnsClient.ExchangeCache(ctx, message)
if !cached {
ctx, metadata := adapter.AppendContext(ctx)
var metadata *adapter.InboundContext
ctx, metadata = adapter.AppendContext(ctx)
if len(message.Question) > 0 {
metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType {
Expand All @@ -97,17 +104,47 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er
}
metadata.Domain = fqdnToDomain(message.Question[0].Name)
}
ctx, transport, strategy := r.matchDNS(ctx, true)
ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
defer cancel()
response, err = r.dnsClient.Exchange(ctx, transport, message, strategy)
if err != nil && len(message.Question) > 0 {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
var (
transport dns.Transport
strategy dns.DomainStrategy
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
var (
dnsCtx context.Context
cancel context.CancelFunc
addressLimit bool
)

dnsCtx, transport, strategy, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex)
dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout)
if rule != nil && rule.WithAddressLimit() && isAddressQuery(message) {
addressLimit = true
response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, strategy, func(response *mDNS.Msg) bool {
metadata.DestinationAddresses, _ = dns.MessageToAddresses(response)
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
response, err = r.dnsClient.Exchange(dnsCtx, transport, message, strategy)
}
cancel()
if err != nil {
if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
if !addressLimit || err == nil {
break
}
}
}
if len(message.Question) > 0 && response != nil {
LogDNSAnswers(r.dnsLogger, ctx, message.Question[0].Name, response.Answer)
}
if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
for _, answer := range response.Answer {
switch record := answer.(type) {
Expand All @@ -125,22 +162,56 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS
r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
ctx, metadata := adapter.AppendContext(ctx)
metadata.Domain = domain
ctx, transport, transportStrategy := r.matchDNS(ctx, false)
if strategy == dns.DomainStrategyAsIS {
strategy = transportStrategy
var (
transport dns.Transport
transportStrategy dns.DomainStrategy
rule adapter.DNSRule
ruleIndex int
resultAddrs []netip.Addr
err error
)
ruleIndex = -1
for {
var (
dnsCtx context.Context
cancel context.CancelFunc
addressLimit bool
)
metadata.ResetRuleCache()
metadata.DestinationAddresses = nil
dnsCtx, transport, transportStrategy, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex)
if strategy == dns.DomainStrategyAsIS {
strategy = transportStrategy
}
dnsCtx, cancel = context.WithTimeout(dnsCtx, C.DNSTimeout)
if rule != nil && rule.WithAddressLimit() {
addressLimit = true
resultAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, strategy, func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
resultAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, strategy)
}
cancel()
if err != nil {
if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(resultAddrs) == 0 {
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
}
if !addressLimit || err == nil {
break
}
}
ctx, cancel := context.WithTimeout(ctx, C.DNSTimeout)
defer cancel()
addrs, err := r.dnsClient.Lookup(ctx, transport, domain, strategy)
if len(addrs) > 0 {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(addrs), " "))
} else if err != nil {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
} else {
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = dns.RCodeNameError
if len(resultAddrs) > 0 {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(resultAddrs), " "))
}
return addrs, err
return resultAddrs, err
}

func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
Expand All @@ -154,10 +225,13 @@ func (r *Router) ClearDNSCache() {
}
}

func LogDNSAnswers(logger log.ContextLogger, ctx context.Context, domain string, answers []mDNS.RR) {
for _, answer := range answers {
logger.InfoContext(ctx, "exchanged ", domain, " ", mDNS.Type(answer.Header().Rrtype).String(), " ", formatQuestion(answer.String()))
func isAddressQuery(message *mDNS.Msg) bool {
for _, question := range message.Question {
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
return true
}
}
return false
}

func fqdnToDomain(fqdn string) string {
Expand Down
4 changes: 4 additions & 0 deletions route/router_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,7 @@ func isWIFIDNSRule(rule option.DefaultDNSRule) bool {
func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.WIFISSID) > 0 || len(rule.WIFIBSSID) > 0
}

func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool {
return len(rule.IPCIDR) > 0 || rule.IPSet != nil
}
24 changes: 23 additions & 1 deletion route/rule_abstract.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type abstractDefaultRule struct {
sourceAddressItems []RuleItem
sourcePortItems []RuleItem
destinationAddressItems []RuleItem
destinationIPCIDRItems []RuleItem
destinationPortItems []RuleItem
allItems []RuleItem
ruleSetItem RuleItem
Expand Down Expand Up @@ -64,6 +65,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
}

if len(r.sourceAddressItems) > 0 && !metadata.SourceAddressMatch {
metadata.DidMatch = true
for _, item := range r.sourceAddressItems {
if item.Match(metadata) {
metadata.SourceAddressMatch = true
Expand All @@ -73,6 +75,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
}

if len(r.sourcePortItems) > 0 && !metadata.SourcePortMatch {
metadata.DidMatch = true
for _, item := range r.sourcePortItems {
if item.Match(metadata) {
metadata.SourcePortMatch = true
Expand All @@ -82,6 +85,7 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
}

if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch {
metadata.DidMatch = true
for _, item := range r.destinationAddressItems {
if item.Match(metadata) {
metadata.DestinationAddressMatch = true
Expand All @@ -90,7 +94,18 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
}
}

if !metadata.IgnoreDestinationIPCIDRMatch && len(r.destinationIPCIDRItems) > 0 && !metadata.DestinationAddressMatch {
metadata.DidMatch = true
for _, item := range r.destinationIPCIDRItems {
if item.Match(metadata) {
metadata.DestinationAddressMatch = true
break
}
}
}

if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch {
metadata.DidMatch = true
for _, item := range r.destinationPortItems {
if item.Match(metadata) {
metadata.DestinationPortMatch = true
Expand All @@ -100,6 +115,9 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
}

for _, item := range r.items {
if _, isRuleSet := item.(*RuleSetItem); !isRuleSet {
metadata.DidMatch = true
}
if !item.Match(metadata) {
return r.invert
}
Expand All @@ -113,14 +131,18 @@ func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool {
return r.invert
}

if len(r.destinationAddressItems) > 0 && !metadata.DestinationAddressMatch {
if ((!metadata.IgnoreDestinationIPCIDRMatch || len(r.destinationIPCIDRItems) > 0) || len(r.destinationAddressItems) > 0) && !metadata.DestinationAddressMatch {
return r.invert
}

if len(r.destinationPortItems) > 0 && !metadata.DestinationPortMatch {
return r.invert
}

if !metadata.DidMatch {
return false
}

return !r.invert
}

Expand Down
Loading

0 comments on commit fd705d9

Please sign in to comment.