Skip to content

Commit

Permalink
add support for TrafficPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
TheConcierge committed Jun 25, 2024
1 parent 4917562 commit 34c92f9
Show file tree
Hide file tree
Showing 9 changed files with 311 additions and 284 deletions.
5 changes: 4 additions & 1 deletion config/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ type commonOpts struct {
// Currently only relevant for HTTP/1.1 vs HTTP/2, since there's a potential
// change-of-protocol happening at our edge.
ForwardsProto string

// DEPRECATED: use TrafficPolicy instead.
Policy *policy
// Policy that define rules that should be applied to incoming or outgoing
// connections to the edge.
Policy *policy
TrafficPolicy string
}

type CommonOptionsFunc func(cfg *commonOpts)
Expand Down
2 changes: 1 addition & 1 deletion config/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint {
opts.WebhookVerification = cfg.WebhookVerification.toProtoConfig()
opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig()
opts.UserAgentFilter = cfg.UserAgentFilter.toProtoConfig()
opts.Policy = cfg.Policy.toProtoConfig()
opts.TrafficPolicy = cfg.TrafficPolicy

return opts
}
Expand Down
82 changes: 27 additions & 55 deletions config/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,28 @@ import (
"errors"
"fmt"

"golang.ngrok.com/ngrok/internal/pb"
"gopkg.in/yaml.v3"

po "golang.ngrok.com/ngrok/policy"
)

type policy po.Policy
type rule po.Rule
type action po.Action
type trafficPolicy string

// WithPolicyString configures this edge with the provided policy configuration
// passed as a json string and overwrites any previously-set traffic policy
// passed as a json or yaml string and overwrites any previously-set traffic policy
// https://ngrok.com/docs/http/traffic-policy
func WithPolicyString(jsonStr string) interface {
func WithPolicyString(str string) interface {
HTTPEndpointOption
TLSEndpointOption
TCPEndpointOption
} {
p := &policy{}
if err := json.Unmarshal([]byte(jsonStr), p); err != nil {
panic("invalid json for policy configuration")
if !isJsonString(str) && !isYamlStr(str) {
panic(errors.New("provided string is neither valid JSON nor valid YAML"))
}

return p
return trafficPolicy(str)
}

// WithPolicy configures this edge with the given traffic policy and overwrites any
Expand All @@ -37,63 +37,35 @@ func WithPolicy(p po.Policy) interface {
TLSEndpointOption
TCPEndpointOption
} {
ret := policy(p)
fmt.Println("WithPolicy has been deprecated. Please use WithPolicyString instead, as WithPolicy will stop working soon.")

return &ret
}
val, err := json.Marshal(p)
if err != nil {
panic(errors.New(fmt.Sprintf("failed to parse action configuration due to error: %s", err.Error())))
}
fmt.Printf("%s\n", string(val))

func (p *policy) ApplyTLS(opts *tlsOptions) {
opts.Policy = p
return trafficPolicy(string(val))
}

func (p *policy) ApplyHTTP(opts *httpOptions) {
opts.Policy = p
func (p trafficPolicy) ApplyTLS(opts *tlsOptions) {
opts.TrafficPolicy = string(p)
}

func (p *policy) ApplyTCP(opts *tcpOptions) {
opts.Policy = p
func (p trafficPolicy) ApplyHTTP(opts *httpOptions) {
opts.TrafficPolicy = string(p)
}

func (p *policy) toProtoConfig() *pb.MiddlewareConfiguration_Policy {
if p == nil {
return nil
}
inbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Inbound))
for i, inP := range p.Inbound {
inbound[i] = rule(inP).toProtoConfig()
}

outbound := make([]*pb.MiddlewareConfiguration_PolicyRule, len(p.Outbound))
for i, outP := range p.Outbound {
outbound[i] = rule(outP).toProtoConfig()
}
return &pb.MiddlewareConfiguration_Policy{
Inbound: inbound,
Outbound: outbound,
}
func (p trafficPolicy) ApplyTCP(opts *tcpOptions) {
opts.TrafficPolicy = string(p)
}

func (pr rule) toProtoConfig() *pb.MiddlewareConfiguration_PolicyRule {
actions := make([]*pb.MiddlewareConfiguration_PolicyAction, len(pr.Actions))
for i, act := range pr.Actions {
actions[i] = action(act).toProtoConfig()
}

return &pb.MiddlewareConfiguration_PolicyRule{Name: pr.Name, Expressions: pr.Expressions, Actions: actions}
func isJsonString(jsonStr string) bool {
var js json.RawMessage
return json.Unmarshal([]byte(jsonStr), &js) == nil
}

func (a action) toProtoConfig() *pb.MiddlewareConfiguration_PolicyAction {
var cfgBytes []byte = nil
if len(a.Config) > 0 {
var err error
cfgBytes, err = json.Marshal(a.Config)

if err != nil {
panic(errors.New(fmt.Sprintf("failed to parse action configuration due to error: %s", err.Error())))
}
}
return &pb.MiddlewareConfiguration_PolicyAction{
Type: a.Type,
Config: cfgBytes,
}
func isYamlStr(yamlStr string) bool {
var yml map[string]any
return yaml.Unmarshal([]byte(yamlStr), &yml) == nil
}
82 changes: 59 additions & 23 deletions config/policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,37 @@ import (

"github.com/stretchr/testify/require"

"golang.ngrok.com/ngrok/internal/pb"
"golang.ngrok.com/ngrok/internal/tunnel/proto"
po "golang.ngrok.com/ngrok/policy"
)

func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T,
makeOpts func(...OT) Tunnel,
getPolicies func(*O) *pb.MiddlewareConfiguration_Policy,
getPolicies func(*O) string,
) {

// putting yaml string up here as the formatting makes the test
// cases messy
yamlPolicy := `---
inbound:
- name: DenyAll
actions:
- type: deny
config:
status_code: 446
`

optsFunc := func(opts ...any) Tunnel {
return makeOpts(assertSlice[OT](opts)...)
}

cases := testCases[T, O]{
{
name: "absent",
opts: optsFunc(),
expectOpts: func(t *testing.T, opts *O) {
actual := getPolicies(opts)
require.Nil(t, actual)
require.Empty(t, actual)
},
},
{
Expand Down Expand Up @@ -70,16 +82,12 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T,
),
expectOpts: func(t *testing.T, opts *O) {
actual := getPolicies(opts)
require.NotNil(t, actual)
require.Len(t, actual.Inbound, 2)
require.Equal(t, "denyPUT", actual.Inbound[0].Name)
require.Equal(t, actual.Inbound[0].Actions, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}})
require.Len(t, actual.Outbound, 1)
require.Len(t, actual.Outbound[0].Expressions, 2)
require.NotEmpty(t, actual)
require.Equal(t, actual, "{\"inbound\":[{\"name\":\"denyPUT\",\"expressions\":[\"req.Method == 'PUT'\"],\"actions\":[{\"type\":\"deny\"}]},{\"name\":\"logFooHeader\",\"expressions\":[\"'foo' in req.Headers\"],\"actions\":[{\"type\":\"log\",\"config\":{\"metadata\":{\"key\":\"val\"}}}]}],\"outbound\":[{\"name\":\"InternalErrorWhenFailed\",\"expressions\":[\"res.StatusCode \\u003c= '0'\",\"res.StatusCode \\u003e= '300'\"],\"actions\":[{\"type\":\"custom-response\",\"config\":{\"status_code\":500}}]}]}")
},
},
{
name: "with policy string",
name: "with valid JSON policy string",
opts: optsFunc(
WithPolicyString(`
{
Expand Down Expand Up @@ -107,13 +115,41 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T,
}`)),
expectOpts: func(t *testing.T, opts *O) {
actual := getPolicies(opts)
require.NotNil(t, actual)
require.Len(t, actual.Inbound, 2)
require.Equal(t, "denyPut", actual.Inbound[0].Name)
require.Equal(t, []*pb.MiddlewareConfiguration_PolicyAction{{Type: "deny"}}, actual.Inbound[0].Actions)
require.Len(t, actual.Outbound, 1)
require.Len(t, actual.Outbound[0].Expressions, 2)
require.Equal(t, []byte(`{"status_code":500}`), actual.Outbound[0].Actions[0].Config)
require.NotEmpty(t, actual)
require.Equal(t, actual, `
{
"inbound":[
{
"name":"denyPut",
"expressions":["req.Method == 'PUT'"],
"actions":[{"type":"deny"}]
},
{
"name":"logFooHeader",
"expressions":["'foo' in req.Headers"],
"actions":[
{"type":"log","config":{"metadata":{"key":"val"}}}
]
}
],
"outbound":[
{
"name":"500ForFailures",
"expressions":["res.StatusCode <= 0", "res.StatusCode >= 300"],
"actions":[{"type":"custom-response", "config":{"status_code":500}}]
}
]
}`)
},
},
{
name: "with valid YAML policy string",
opts: optsFunc(
WithPolicyString(yamlPolicy)),
expectOpts: func(t *testing.T, opts *O) {
actual := getPolicies(opts)
require.NotEmpty(t, actual)
require.Equal(t, actual, yamlPolicy)
},
},
}
Expand All @@ -123,15 +159,15 @@ func testPolicy[T tunnelConfigPrivate, O any, OT any](t *testing.T,

func TestPolicy(t *testing.T) {
testPolicy[*httpOptions](t, HTTPEndpoint,
func(h *proto.HTTPEndpoint) *pb.MiddlewareConfiguration_Policy {
return h.Policy
func(h *proto.HTTPEndpoint) string {
return h.TrafficPolicy
})
testPolicy[*tcpOptions](t, TCPEndpoint,
func(h *proto.TCPEndpoint) *pb.MiddlewareConfiguration_Policy {
return h.Policy
func(h *proto.TCPEndpoint) string {
return h.TrafficPolicy
})
testPolicy[*tlsOptions](t, TLSEndpoint,
func(h *proto.TLSEndpoint) *pb.MiddlewareConfiguration_Policy {
return h.Policy
func(h *proto.TLSEndpoint) string {
return h.TrafficPolicy
})
}
2 changes: 1 addition & 1 deletion config/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (cfg *tcpOptions) toProtoConfig() *proto.TCPEndpoint {
return &proto.TCPEndpoint{
Addr: cfg.RemoteAddr,
IPRestriction: cfg.commonOpts.CIDRRestrictions.toProtoConfig(),
Policy: cfg.commonOpts.Policy.toProtoConfig(),
ProxyProto: proto.ProxyProto(cfg.commonOpts.ProxyProto),
TrafficPolicy: cfg.commonOpts.TrafficPolicy,
}
}

Expand Down
2 changes: 1 addition & 1 deletion config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint {
}

opts.IPRestriction = cfg.commonOpts.CIDRRestrictions.toProtoConfig()
opts.Policy = cfg.commonOpts.Policy.toProtoConfig()
opts.TrafficPolicy = cfg.commonOpts.TrafficPolicy

opts.MutualTLSAtEdge = mutualTLSEndpointOption(cfg.MutualTLSCA).toProtoConfig()

Expand Down
Loading

0 comments on commit 34c92f9

Please sign in to comment.