diff --git a/config/common.go b/config/common.go index 03abef9..7440176 100644 --- a/config/common.go +++ b/config/common.go @@ -19,7 +19,7 @@ type commonOpts struct { ForwardsProto string // Policy that define rules that should be applied to incoming or outgoing // connections to the edge. - Policy *policy + Policy Policy } type CommonOptionsFunc func(cfg *commonOpts) diff --git a/config/http.go b/config/http.go index 4d68251..cce9340 100644 --- a/config/http.go +++ b/config/http.go @@ -124,7 +124,7 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint { opts.WebhookVerification = cfg.WebhookVerification.toProtoConfig() opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() opts.UserAgentFilter = cfg.UserAgentFilter.toProtoConfig() - opts.Policy = cfg.Policy.toProtoConfig() + opts.Policy = map[string]any(cfg.Policy) return opts } diff --git a/config/policy.go b/config/policy.go index bc96418..f404e46 100644 --- a/config/policy.go +++ b/config/policy.go @@ -2,16 +2,10 @@ package config import ( "encoding/json" - "errors" - "fmt" - - "golang.ngrok.com/ngrok/internal/pb" - po "golang.ngrok.com/ngrok/policy" ) -type policy po.Policy -type rule po.Rule -type action po.Action +type Policy map[string]any +type withPolicy Policy // WithPolicyString configures this edge with the provided policy configuration // passed as a json string and overwrites any previously-set traffic policy @@ -21,79 +15,33 @@ func WithPolicyString(jsonStr string) interface { TLSEndpointOption TCPEndpointOption } { - p := &policy{} - if err := json.Unmarshal([]byte(jsonStr), p); err != nil { + p := map[string]any{} + if err := json.Unmarshal([]byte(jsonStr), &p); err != nil { panic("invalid json for policy configuration") } - return p + return WithPolicy(Policy(p)) } // WithPolicy configures this edge with the given traffic policy and overwrites any // previously-set traffic policy // https://ngrok.com/docs/http/traffic-policy/ -func WithPolicy(p po.Policy) interface { +func WithPolicy(p Policy) interface { HTTPEndpointOption TLSEndpointOption TCPEndpointOption } { - ret := policy(p) - - return &ret + return withPolicy(p) } -func (p *policy) ApplyTLS(opts *tlsOptions) { - opts.Policy = p -} - -func (p *policy) ApplyHTTP(opts *httpOptions) { - opts.Policy = p -} - -func (p *policy) ApplyTCP(opts *tcpOptions) { - opts.Policy = p -} - -func (p *policy) toProtoConfig() *pb.MiddlewareConfiguration_Policy { - if p == nil { - return nil - } - inbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Inbound)) - for i, inP := range p.Inbound { - inbound[i] = rule(inP).toProtoConfig() - } - - outbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Outbound)) - for i, outP := range p.Outbound { - outbound[i] = rule(outP).toProtoConfig() - } - return &pb.MiddlewareConfiguration_Policy{ - Inbound: inbound, - Outbound: outbound, - } +func (p withPolicy) ApplyTLS(opts *tlsOptions) { + opts.Policy = Policy(p) } -func (pr rule) toProtoConfig() *pb.MiddlewareConfiguration_PolicyRule { - actions := make([]*pb.MiddlewareConfiguration_PolicyAction, len(pr.Actions)) - for i, act := range pr.Actions { - actions[i] = action(act).toProtoConfig() - } - - return &pb.MiddlewareConfiguration_PolicyRule{Name: pr.Name, Expressions: pr.Expressions, Actions: actions} +func (p withPolicy) ApplyHTTP(opts *httpOptions) { + opts.Policy = Policy(p) } -func (a action) toProtoConfig() *pb.MiddlewareConfiguration_PolicyAction { - var cfgBytes []byte = nil - if len(a.Config) > 0 { - var err error - cfgBytes, err = json.Marshal(a.Config) - - if err != nil { - panic(errors.New(fmt.Sprintf("failed to parse action configuration due to error: %s", err.Error()))) - } - } - return &pb.MiddlewareConfiguration_PolicyAction{ - Type: a.Type, - Config: cfgBytes, - } +func (p withPolicy) ApplyTCP(opts *tcpOptions) { + opts.Policy = Policy(p) } diff --git a/config/policy_test.go b/config/policy_test.go index bb0b27a..f1707e4 100644 --- a/config/policy_test.go +++ b/config/policy_test.go @@ -5,14 +5,12 @@ import ( "github.com/stretchr/testify/require" - "golang.ngrok.com/ngrok/internal/pb" "golang.ngrok.com/ngrok/internal/tunnel/proto" - po "golang.ngrok.com/ngrok/policy" ) func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, makeOpts func(...OT) Tunnel, - getPolicies func(*O) *pb.MiddlewareConfiguration_Policy, + getPolicies func(*O) map[string]any, ) { optsFunc := func(opts ...any) Tunnel { return makeOpts(assertSlice[OT](opts)...) @@ -30,37 +28,37 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, name: "with policy", opts: optsFunc( WithPolicy( - po.Policy{ - Inbound: []po.Rule{ + Policy{ + "inbound": []map[string]any{ { - Name: "denyPUT", - Expressions: []string{"req.Method == 'PUT'"}, - Actions: []po.Action{ - {Type: "deny"}, + "name": "denyPUT", + "expressions": []string{"req.Method == 'PUT'"}, + "actions": []map[string]any{ + {"type": "deny"}, }, }, { - Name: "logFooHeader", - Expressions: []string{"'foo' in req.Headers"}, - Actions: []po.Action{ + "name": "logFooHeader", + "expressions": []string{"'foo' in req.Headers"}, + "actions": []map[string]any{ { - Type: "log", - Config: map[string]any{"metadata": map[string]any{"key": "val"}}, + "type": "log", + "config": map[string]any{"metadata": map[string]any{"key": "val"}}, }, }, }, }, - Outbound: []po.Rule{ + "outbound": []map[string]any{ { - Name: "InternalErrorWhenFailed", - Expressions: []string{ + "name": "InternalErrorWhenFailed", + "expressions": []string{ "res.StatusCode <= '0'", "res.StatusCode >= '300'", }, - Actions: []po.Action{ + "actions": []map[string]any{ { - Type: "custom-response", - Config: map[string]any{"status_code": 500}, + "type": "custom-response", + "config": map[string]any{"status_code": 500}, }, }, }, @@ -71,11 +69,11 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, expectOpts: func(t *testing.T, opts *O) { actual := getPolicies(opts) require.NotNil(t, actual) - require.Len(t, actual.Inbound, 2) - require.Equal(t, "denyPUT", actual.Inbound[0].Name) - require.Equal(t, actual.Inbound[0].Actions, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}}) - require.Len(t, actual.Outbound, 1) - require.Len(t, actual.Outbound[0].Expressions, 2) + require.Len(t, actual["inbound"], 2) + require.Equal(t, "denyPUT", actual["inbound"].([]map[string]any)[0]["name"]) + require.Equal(t, actual["inbound"].([]map[string]any)[0]["actions"], []map[string]any{{"type": "deny"}}) + require.Len(t, actual["outbound"], 1) + require.Len(t, actual["outbound"].([]map[string]any)[0]["expressions"], 2) }, }, { @@ -108,12 +106,12 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, expectOpts: func(t *testing.T, opts *O) { actual := getPolicies(opts) require.NotNil(t, actual) - require.Len(t, actual.Inbound, 2) - require.Equal(t, "denyPut", actual.Inbound[0].Name) - require.Equal(t, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}}, actual.Inbound[0].Actions) - require.Len(t, actual.Outbound, 1) - require.Len(t, actual.Outbound[0].Expressions, 2) - require.Equal(t, []byte(`{"status_code":500}`), actual.Outbound[0].Actions[0].Config) + require.Len(t, actual["inbound"], 2) + require.Equal(t, "denyPut", actual["inbound"].([]map[string]any)[0]["name"]) + require.Equal(t, []map[string]any{{"type": "deny"}}, actual["inbound"].([]map[string]any)[0]["actions"]) + require.Len(t, actual["outbound"], 1) + require.Len(t, actual["outbound"].([]map[string]any)[0]["expressions"], 2) + require.Equal(t, []byte(`{"status_code":500}`), actual["outbound"].([]map[string]any)[0]["actions"].([]map[string]any)[0]["config"]) }, }, } @@ -123,15 +121,15 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T, func TestPolicy(t *testing.T) { testPolicy[*httpOptions](t, HTTPEndpoint, - func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.HTTPEndpoint) map[string]any { return h.Policy }) testPolicy[*tcpOptions](t, TCPEndpoint, - func(h *proto.TCPEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.TCPEndpoint) map[string]any { return h.Policy }) testPolicy[*tlsOptions](t, TLSEndpoint, - func(h *proto.TLSEndpoint) *pb.MiddlewareConfiguration_Policy { + func(h *proto.TLSEndpoint) map[string]any { return h.Policy }) } diff --git a/config/tcp.go b/config/tcp.go index 8507eab..1b64337 100644 --- a/config/tcp.go +++ b/config/tcp.go @@ -50,7 +50,7 @@ func (cfg *tcpOptions) toProtoConfig() *proto.TCPEndpoint { return &proto.TCPEndpoint{ Addr: cfg.RemoteAddr, IPRestriction: cfg.commonOpts.CIDRRestrictions.toProtoConfig(), - Policy: cfg.commonOpts.Policy.toProtoConfig(), + Policy: cfg.commonOpts.Policy, ProxyProto: proto.ProxyProto(cfg.commonOpts.ProxyProto), } } diff --git a/config/tls.go b/config/tls.go index 45e4903..8e69077 100644 --- a/config/tls.go +++ b/config/tls.go @@ -69,7 +69,7 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint { } opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig() - opts.Policy = cfg.commonOpts.Policy.toProtoConfig() + opts.Policy = map[string]any(cfg.commonOpts.Policy) opts.MutualTLSAtEdge = mutualTLSEndpointOption(cfg.MutualTLSCA).toProtoConfig() diff --git a/internal/tunnel/proto/msg.go b/internal/tunnel/proto/msg.go index 849f27c..888c261 100644 --- a/internal/tunnel/proto/msg.go +++ b/internal/tunnel/proto/msg.go @@ -302,7 +302,7 @@ type HTTPEndpoint struct { ResponseHeaders *pb.MiddlewareConfiguration_Headers WebsocketTCPConverter *pb.MiddlewareConfiguration_WebsocketTCPConverter UserAgentFilter *pb.MiddlewareConfiguration_UserAgentFilter - Policy *pb.MiddlewareConfiguration_Policy + Policy map[string]any } type TCPEndpoint struct { @@ -311,7 +311,7 @@ type TCPEndpoint struct { // middleware IPRestriction *pb.MiddlewareConfiguration_IPRestriction - Policy *pb.MiddlewareConfiguration_Policy + Policy map[string]any } type TLSEndpoint struct { @@ -325,7 +325,7 @@ type TLSEndpoint struct { MutualTLSAtEdge *pb.MiddlewareConfiguration_MutualTLS TLSTermination *pb.MiddlewareConfiguration_TLSTermination IPRestriction *pb.MiddlewareConfiguration_IPRestriction - Policy *pb.MiddlewareConfiguration_Policy + Policy map[string]any } type SSHOptions struct {