diff --git a/config/common.go b/config/common.go index 03abef9..7440176 100644 --- a/config/common.go +++ b/config/common.go @@ -19,7 +19,7 @@ type commonOpts struct { ForwardsProto string // Policy that define rules that should be applied to incoming or outgoing // connections to the edge. - Policy *policy + Policy Policy } type CommonOptionsFunc func(cfg *commonOpts) diff --git a/config/http.go b/config/http.go index 4d68251..ab9733b 100644 --- a/config/http.go +++ b/config/http.go @@ -124,7 +124,7 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint { opts.WebhookVerification = cfg.WebhookVerification.toProtoConfig() opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() opts.UserAgentFilter = cfg.UserAgentFilter.toProtoConfig() - opts.Policy = cfg.Policy.toProtoConfig() + opts.Policy = cfg.Policy return opts } diff --git a/config/policy.go b/config/policy.go index bc96418..418c7bd 100644 --- a/config/policy.go +++ b/config/policy.go @@ -2,16 +2,12 @@ package config import ( "encoding/json" - "errors" - "fmt" - - "golang.ngrok.com/ngrok/internal/pb" - po "golang.ngrok.com/ngrok/policy" ) -type policy po.Policy -type rule po.Rule -type action po.Action +type Policy any +type withPolicy struct { + policy Policy +} // WithPolicyString configures this edge with the provided policy configuration // passed as a json string and overwrites any previously-set traffic policy @@ -21,79 +17,33 @@ func WithPolicyString(jsonStr string) interface { TLSEndpointOption TCPEndpointOption } { - p := &policy{} - if err := json.Unmarshal([]byte(jsonStr), p); err != nil { + p := map[string]any{} + if err := json.Unmarshal([]byte(jsonStr), &p); err != nil { panic("invalid json for policy configuration") } - return p + return WithPolicy(Policy(p)) } // WithPolicy configures this edge with the given traffic policy and overwrites any // previously-set traffic policy // https://ngrok.com/docs/http/traffic-policy/ -func WithPolicy(p po.Policy) interface { +func WithPolicy(p Policy) interface { HTTPEndpointOption TLSEndpointOption TCPEndpointOption } { - ret := policy(p) - - return &ret -} - -func (p *policy) ApplyTLS(opts *tlsOptions) { - opts.Policy = p + return withPolicy{p} } -func (p *policy) ApplyHTTP(opts *httpOptions) { - opts.Policy = p +func (p withPolicy) ApplyTLS(opts *tlsOptions) { + opts.Policy = p.policy } -func (p *policy) ApplyTCP(opts *tcpOptions) { - opts.Policy = p +func (p withPolicy) ApplyHTTP(opts *httpOptions) { + opts.Policy = p.policy } -func (p *policy) toProtoConfig() *pb.MiddlewareConfiguration_Policy { - if p == nil { - return nil - } - inbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Inbound)) - for i, inP := range p.Inbound { - inbound[i] = rule(inP).toProtoConfig() - } - - outbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Outbound)) - for i, outP := range p.Outbound { - outbound[i] = rule(outP).toProtoConfig() - } - return &pb.MiddlewareConfiguration_Policy{ - Inbound: inbound, - Outbound: outbound, - } -} - -func (pr rule) toProtoConfig() *pb.MiddlewareConfiguration_PolicyRule { - actions := make([]*pb.MiddlewareConfiguration_PolicyAction, len(pr.Actions)) - for i, act := range pr.Actions { - actions[i] = action(act).toProtoConfig() - } - - return &pb.MiddlewareConfiguration_PolicyRule{Name: pr.Name, Expressions: pr.Expressions, Actions: actions} -} - -func (a action) toProtoConfig() *pb.MiddlewareConfiguration_PolicyAction { - var cfgBytes []byte = nil - if len(a.Config) > 0 { - var err error - cfgBytes, err = json.Marshal(a.Config) - - if err != nil { - panic(errors.New(fmt.Sprintf("failed to parse action configuration due to error: %s", err.Error()))) - } - } - return &pb.MiddlewareConfiguration_PolicyAction{ - Type: a.Type, - Config: cfgBytes, - } +func (p withPolicy) ApplyTCP(opts *tcpOptions) { + opts.Policy = p.policy } diff --git a/config/policy_test.go b/config/policy_test.go index bb0b27a..1b6d7b2 100644 --- a/config/policy_test.go +++ b/config/policy_test.go @@ -1,18 +1,18 @@ package config import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/pb" "golang.ngrok.com/ngrok/internal/tunnel/proto" po "golang.ngrok.com/ngrok/policy" ) func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, makeOpts func(...OT) Tunnel, - getPolicies func(*O) *pb.MiddlewareConfiguration_Policy, + getPolicies func(*O) any, ) { optsFunc := func(opts ...any) Tunnel { return makeOpts(assertSlice[OT](opts)...) @@ -69,11 +69,14 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, ), ), expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) - require.NotNil(t, actual) + actualAny := getPolicies(opts) + require.NotNil(t, actualAny) + + actual := actualAny.(po.Policy) + require.Len(t, actual.Inbound, 2) require.Equal(t, "denyPUT", actual.Inbound[0].Name) - require.Equal(t, actual.Inbound[0].Actions, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}}) + require.Equal(t, actual.Inbound[0].Actions, []po.Action{{Type: "deny"}}) require.Len(t, actual.Outbound, 1) require.Len(t, actual.Outbound[0].Expressions, 2) }, @@ -106,14 +109,20 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, ] }`)), expectOpts: func(t *testing.T, opts *O) { - actual := getPolicies(opts) + actualAny := getPolicies(opts) + actualSer, err := json.Marshal(actualAny) + require.NoError(t, err) + + var actual po.Policy + require.NoError(t, json.Unmarshal(actualSer, &actual)) + require.NotNil(t, actual) require.Len(t, actual.Inbound, 2) require.Equal(t, "denyPut", actual.Inbound[0].Name) - require.Equal(t, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}}, actual.Inbound[0].Actions) + require.Equal(t, []po.Action{{Type: "deny"}}, actual.Inbound[0].Actions) require.Len(t, actual.Outbound, 1) require.Len(t, actual.Outbound[0].Expressions, 2) - require.Equal(t, []byte(`{"status_code":500}`), actual.Outbound[0].Actions[0].Config) + require.Equal(t, map[string]any{"status_code": 500.}, actual.Outbound[0].Actions[0].Config) }, }, } @@ -123,15 +132,15 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, func TestPolicy(t *testing.T) { testPolicy[*httpOptions](t, HTTPEndpoint, - func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.HTTPEndpoint) any { return h.Policy }) testPolicy[*tcpOptions](t, TCPEndpoint, - func(h *proto.TCPEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.TCPEndpoint) any { return h.Policy }) testPolicy[*tlsOptions](t, TLSEndpoint, - func(h *proto.TLSEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.TLSEndpoint) any { return h.Policy }) } diff --git a/config/tcp.go b/config/tcp.go index 8507eab..1b64337 100644 --- a/config/tcp.go +++ b/config/tcp.go @@ -50,7 +50,7 @@ func (cfg *tcpOptions) toProtoConfig() *proto.TCPEndpoint { return &proto.TCPEndpoint{ Addr: cfg.RemoteAddr, IPRestriction: cfg.commonOpts.CIDRRestrictions.toProtoConfig(), - Policy: cfg.commonOpts.Policy.toProtoConfig(), + Policy: cfg.commonOpts.Policy, ProxyProto: proto.ProxyProto(cfg.commonOpts.ProxyProto), } } diff --git a/config/tls.go b/config/tls.go index 45e4903..e3e6b29 100644 --- a/config/tls.go +++ b/config/tls.go @@ -69,7 +69,7 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint { } opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() - opts.Policy = cfg.commonOpts.Policy.toProtoConfig() + opts.Policy = cfg.commonOpts.Policy opts.MutualTLSAtEdge = mutualTLSEndpointOption(cfg.MutualTLSCA).toProtoConfig() diff --git a/internal/tunnel/proto/msg.go b/internal/tunnel/proto/msg.go index 849f27c..8dd8241 100644 --- a/internal/tunnel/proto/msg.go +++ b/internal/tunnel/proto/msg.go @@ -302,7 +302,7 @@ type HTTPEndpoint struct { ResponseHeaders *pb.MiddlewareConfiguration_Headers WebsocketTCPConverter *pb.MiddlewareConfiguration_WebsocketTCPConverter UserAgentFilter *pb.MiddlewareConfiguration_UserAgentFilter - Policy *pb.MiddlewareConfiguration_Policy + Policy any } type TCPEndpoint struct { @@ -311,7 +311,7 @@ type TCPEndpoint struct { // middleware IPRestriction *pb.MiddlewareConfiguration_IPRestriction - Policy *pb.MiddlewareConfiguration_Policy + Policy any } type TLSEndpoint struct { @@ -325,7 +325,7 @@ type TLSEndpoint struct { MutualTLSAtEdge *pb.MiddlewareConfiguration_MutualTLS TLSTermination *pb.MiddlewareConfiguration_TLSTermination IPRestriction *pb.MiddlewareConfiguration_IPRestriction - Policy *pb.MiddlewareConfiguration_Policy + Policy any } type SSHOptions struct {