From b02f2529160efc93077dc1d92ae922b5ef806cc5 Mon Sep 17 00:00:00 2001 From: dyhkwong <50692134+dyhkwong@users.noreply.github.com> Date: Sat, 20 May 2023 12:11:00 +0800 Subject: [PATCH] Use api to create windows firewall rules --- go.mod | 2 + go.sum | 5 + internal/winfw/winfw.go | 274 ++++++++++++++++++++++++++++++++++++++++ stack.go | 27 ++-- system.go | 34 +++-- system_windows.go | 46 ++----- 6 files changed, 321 insertions(+), 67 deletions(-) create mode 100644 internal/winfw/winfw.go diff --git a/go.mod b/go.mod index 1952530..aeb2ebb 100644 --- a/go.mod +++ b/go.mod @@ -4,9 +4,11 @@ go 1.18 require ( github.com/fsnotify/fsnotify v1.6.0 + github.com/go-ole/go-ole v1.2.6 github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 github.com/sagernet/sing v0.2.4 + github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 golang.org/x/net v0.9.0 golang.org/x/sys v0.7.0 gvisor.dev/gvisor v0.0.0-20230415003630-3981d5d5e523 diff --git a/go.sum b/go.sum index c8e4e87..d5f1d03 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA= @@ -9,10 +11,13 @@ github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJ github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.2.4 h1:gC8BR5sglbJZX23RtMyFa8EETP9YEUADhfbEzU1yVbo= github.com/sagernet/sing v0.2.4/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w= +github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg= +github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg= github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/winfw/winfw.go b/internal/winfw/winfw.go new file mode 100644 index 0000000..7798fcb --- /dev/null +++ b/internal/winfw/winfw.go @@ -0,0 +1,274 @@ +// Copyright (c) 2018 Samuel Melrose +// SPDX-License-Identifier: MIT +// https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go + +//go:build windows + +package winfw + +import ( + "fmt" + "runtime" + + "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole/oleutil" + "github.com/scjalliance/comshim" +) + +// Firewall related API constants. +const ( + NET_FW_IP_PROTOCOL_TCP = 6 + NET_FW_IP_PROTOCOL_UDP = 17 + NET_FW_IP_PROTOCOL_ICMPv4 = 1 + NET_FW_IP_PROTOCOL_ICMPv6 = 58 + NET_FW_IP_PROTOCOL_ANY = 256 + + NET_FW_RULE_DIR_IN = 1 + NET_FW_RULE_DIR_OUT = 2 + + NET_FW_ACTION_BLOCK = 0 + NET_FW_ACTION_ALLOW = 1 + + // NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions. + // It can mean one profile or multiple (even all) profiles. It depends on which profiles + // are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi, + // Domain for VPN, and Private for LAN. All at the same time. + NET_FW_PROFILE2_CURRENT = 0 + NET_FW_PROFILE2_DOMAIN = 1 + NET_FW_PROFILE2_PRIVATE = 2 + NET_FW_PROFILE2_PUBLIC = 4 + NET_FW_PROFILE2_ALL = 2147483647 +) + +// Firewall Rule Groups +// Use this magical strings instead of group names. It will work on all language Windows versions. +// You can find more string locations here: +// https://windows10dll.nirsoft.net/firewallapi_dll.html +const ( + NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502" + NET_FW_REMOTE_DESKTOP = "@FirewallAPI.dll,-28752" +) + +// FWRule represents Firewall Rule. +type FWRule struct { + Name, Description, ApplicationName, ServiceName string + LocalPorts, RemotePorts string + // LocalAddresses, RemoteAddresses are always returned with netmask, f.e.: + // `10.10.1.1/255.255.255.0` + LocalAddresses, RemoteAddresses string + // ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon). + // Types are listed here: + // https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml + // So to allow ping set it to: + // "0" + ICMPTypesAndCodes string + Grouping string + // InterfaceTypes can be: + // "LAN", "Wireless", "RemoteAccess", "All" + // You can add multiple deviding with comma: + // "LAN, Wireless" + InterfaceTypes string + Protocol, Direction, Action, Profiles int32 + Enabled, EdgeTraversal bool +} + +// FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters. +// You probably do not want to use this, as function allows to create any rule, even opening all ports +// in given profile. So use with caution. +func FirewallRuleAddAdvanced(rule FWRule) (bool, error) { + return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName, + rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes, + rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal) +} + +// firewallRuleAdd is universal function to add all kinds of rules. +func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) { + if name == "" { + return false, fmt.Errorf("empty FW Rule name, name is mandatory") + } + + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + u, fwPolicy, err := firewallAPIInit() + if err != nil { + return false, err + } + defer firewallAPIRelease(u, fwPolicy) + + if profile == NET_FW_PROFILE2_CURRENT { + currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes") + if err != nil { + return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err) + } + profile = currentProfiles.Value().(int32) + } + unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules") + if err != nil { + return false, fmt.Errorf("Failed to get Rules: %s", err) + } + rules := unknownRules.ToIDispatch() + + if ok, err := FirewallRuleExistsByName(rules, name); err != nil { + return false, fmt.Errorf("Error while checking rules for duplicate: %s", err) + } else if ok { + return false, nil + } + + unknown2, err := oleutil.CreateObject("HNetCfg.FWRule") + if err != nil { + return false, fmt.Errorf("Error creating Rule object: %s", err) + } + defer unknown2.Release() + + fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch) + if err != nil { + return false, fmt.Errorf("Error creating Rule object (2): %s", err) + } + defer fwRule.Release() + + if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil { + return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err) + } + if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil { + return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err) + } + if appPath != "" { + if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil { + return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err) + } + } + if serviceName != "" { + if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil { + return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err) + } + } + if protocol != 0 { + if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil { + return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err) + } + } + if icmpTypes != "" { + if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil { + return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err) + } + } + if ports != "" { + if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil { + return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err) + } + } + if remotePorts != "" { + if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil { + return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err) + } + } + if localAddresses != "" { + if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil { + return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err) + } + } + if remoteAddresses != "" { + if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil { + return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err) + } + } + if direction != 0 { + if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil { + return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err) + } + } + if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil { + return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err) + } + if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil { + return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err) + } + if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil { + return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err) + } + if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil { + return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err) + } + if edgeTraversal { + if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil { + return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err) + } + } + + if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil { + return false, fmt.Errorf("Error adding Rule: %s", err) + } + + return true, nil +} + +func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) { + enumProperty, err := rules.GetProperty("_NewEnum") + if err != nil { + return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err) + } + defer enumProperty.Clear() + + enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant) + if err != nil { + return false, fmt.Errorf("Failed to cast enum to correct type: %s", err) + } + if enum == nil { + return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil") + } + + for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) { + if err != nil { + return false, fmt.Errorf("Failed to seek next Rule item: %s", err) + } + + t, err := func() (bool, error) { + item := itemRaw.ToIDispatch() + defer item.Release() + + if item, err := oleutil.GetProperty(item, "Name"); err != nil { + return false, fmt.Errorf("Failed to get Property (Name) of Rule") + } else if item.ToString() == name { + return true, nil + } + + return false, nil + }() + + if err != nil { + return false, err + } else if t { + return true, nil + } + } + + return false, nil +} + +// firewallAPIInit initialize common fw api. +// then: +// dispatch firewallAPIRelease(u, fwp) +func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) { + comshim.Add(1) + + unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2") + if err != nil { + return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err) + } + + fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch) + if err != nil { + unknown.Release() + return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err) + } + + return unknown, fwPolicy, nil +} + +// firewallAPIRelease cleans memory. +func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) { + fwp.Release() + u.Release() + comshim.Done() +} diff --git a/stack.go b/stack.go index bdb2818..71a183d 100644 --- a/stack.go +++ b/stack.go @@ -15,20 +15,19 @@ type Stack interface { } type StackOptions struct { - Context context.Context - Tun Tun - Name string - MTU uint32 - Inet4Address []netip.Prefix - Inet6Address []netip.Prefix - EndpointIndependentNat bool - UDPTimeout int64 - Router Router - Handler Handler - Logger logger.Logger - ForwarderBindInterface bool - InterfaceFinder control.InterfaceFinder - ExperimentalFixWindowsFirewall bool + Context context.Context + Tun Tun + Name string + MTU uint32 + Inet4Address []netip.Prefix + Inet6Address []netip.Prefix + EndpointIndependentNat bool + UDPTimeout int64 + Router Router + Handler Handler + Logger logger.Logger + ForwarderBindInterface bool + InterfaceFinder control.InterfaceFinder } func NewStack( diff --git a/system.go b/system.go index 123e1b4..8ab05f9 100644 --- a/system.go +++ b/system.go @@ -42,7 +42,6 @@ type System struct { routeMapping *RouteMapping bindInterface bool interfaceFinder control.InterfaceFinder - fixWindowsFirewall bool } type Session struct { @@ -54,19 +53,18 @@ type Session struct { func NewSystem(options StackOptions) (Stack, error) { stack := &System{ - ctx: options.Context, - tun: options.Tun, - tunName: options.Name, - mtu: options.MTU, - udpTimeout: options.UDPTimeout, - router: options.Router, - handler: options.Handler, - logger: options.Logger, - inet4Prefixes: options.Inet4Address, - inet6Prefixes: options.Inet6Address, - bindInterface: options.ForwarderBindInterface, - interfaceFinder: options.InterfaceFinder, - fixWindowsFirewall: options.ExperimentalFixWindowsFirewall, + ctx: options.Context, + tun: options.Tun, + tunName: options.Name, + mtu: options.MTU, + udpTimeout: options.UDPTimeout, + router: options.Router, + handler: options.Handler, + logger: options.Logger, + inet4Prefixes: options.Inet4Address, + inet6Prefixes: options.Inet6Address, + bindInterface: options.ForwarderBindInterface, + interfaceFinder: options.InterfaceFinder, } if stack.router != nil { stack.routeMapping = NewRouteMapping(options.UDPTimeout) @@ -99,11 +97,9 @@ func (s *System) Close() error { } func (s *System) Start() error { - if s.fixWindowsFirewall { - err := fixWindowsFirewall() - if err != nil { - return E.Cause(err, "fix windows firewall for system stack") - } + err := fixWindowsFirewall() + if err != nil { + return E.Cause(err, "fix windows firewall for system stack") } var listener net.ListenConfig if s.bindInterface { diff --git a/system_windows.go b/system_windows.go index 970f438..39f1877 100644 --- a/system_windows.go +++ b/system_windows.go @@ -2,46 +2,24 @@ package tun import ( "os" - "os/exec" "path/filepath" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - "github.com/sagernet/sing/common/shell" + "github.com/sagernet/sing-tun/internal/winfw" ) func fixWindowsFirewall() error { - const shellStringSplit = "\"" - isPWSH := true - powershell, err := exec.LookPath("pwsh.exe") + absPath, err := filepath.Abs(os.Args[0]) if err != nil { - powershell, err = exec.LookPath("powershell.exe") - isPWSH = false + return err } - if err != nil { - return nil - } - ruleName := "sing-tun rule for " + os.Args[0] - commandPrefix := []string{"-NoProfile", "-NonInteractive"} - if isPWSH { - commandPrefix = append(commandPrefix, "-Command") - } - err = shell.Exec(powershell, append(commandPrefix, - F.ToString("Get-NetFirewallRule -Name ", shellStringSplit, ruleName, shellStringSplit))...).Run() - if err == nil { - return nil - } - fileName := filepath.Base(os.Args[0]) - output, err := shell.Exec(powershell, append(commandPrefix, - F.ToString("New-NetFirewallRule", - " -Name ", shellStringSplit, ruleName, shellStringSplit, - " -DisplayName ", shellStringSplit, "sing-tun (", fileName, ")", shellStringSplit, - " -Program ", shellStringSplit, os.Args[0], shellStringSplit, - " -Direction Inbound", - " -Protocol TCP", - " -Action Allow"))...).Read() - if err != nil { - return E.Extend(err, output) + rule := winfw.FWRule{ + Name: "sing-tun (" + absPath + ")", + ApplicationName: absPath, + Enabled: true, + Protocol: winfw.NET_FW_IP_PROTOCOL_TCP, + Direction: winfw.NET_FW_RULE_DIR_IN, + Action: winfw.NET_FW_ACTION_ALLOW, } - return nil + _, err = winfw.FirewallRuleAddAdvanced(rule) + return err }