Skip to content

Commit

Permalink
feat(router): enable setting request header from context (#1371)
Browse files Browse the repository at this point in the history
  • Loading branch information
df-wg authored Nov 15, 2024
1 parent 3a55440 commit c96485d
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 69 deletions.
49 changes: 49 additions & 0 deletions router-tests/headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,55 @@ func TestForwardHeaders(t *testing.T) {
})
})

t.Run("SetHeadersFromContext", func(t *testing.T) {
setRequestDynamicAttribute := func(headerName, contextField string) []core.Option {
return []core.Option{
core.WithHeaderRules(config.HeaderRules{
All: &config.GlobalHeaderRule{
Request: []*config.RequestHeaderRule{
{
Operation: config.HeaderRuleOperationSet,
Name: headerName,
ValueFrom: &config.CustomDynamicAttribute{
ContextField: contextField,
}}}}})}
}
opNameHeader := "x-operation-info"

t.Run("successfully sets operation name header", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterOptions: setRequestDynamicAttribute(opNameHeader, core.ContextFieldOperationName),
},
func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query myQuery { headerValue(name:"` + opNameHeader + `") }`,
})
headerVal := "myQuery"
require.Equal(t, `{"data":{"headerValue":"`+headerVal+`"}}`, res.Body)
})
})

t.Run("set dynamic header overwrites explicit header", func(t *testing.T) {
t.Parallel()

testenv.Run(t, &testenv.Config{
RouterOptions: setRequestDynamicAttribute(opNameHeader, core.ContextFieldOperationName),
},
func(t *testing.T, xEnv *testenv.Environment) {
res, err := xEnv.MakeGraphQLRequestWithHeaders(testenv.GraphQLRequest{
Query: `query myQuery { headerValue(name:"` + opNameHeader + `") }`,
}, map[string]string{
opNameHeader: "not-myQuery",
})
require.NoError(t, err)
headerVal := "myQuery"
require.Equal(t, `{"data":{"headerValue":"`+headerVal+`"}}`, res.Body)
})
})
})

t.Run("HTTP with client extension", func(t *testing.T) {
t.Parallel()

Expand Down
65 changes: 5 additions & 60 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -868,68 +867,14 @@ func (s *graphServer) accessLogsFieldHandler(panicError any, request *http.Reque
if field.ValueFrom != nil && field.ValueFrom.RequestHeader != "" {
resFields = append(resFields, NewStringLogField(request.Header.Get(field.ValueFrom.RequestHeader), field))
} else if field.ValueFrom != nil && field.ValueFrom.ContextField != "" && reqContext.operation != nil {
switch field.ValueFrom.ContextField {
case ContextFieldOperationName:
if v := NewStringLogField(reqContext.operation.name, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationType:
if v := NewStringLogField(reqContext.operation.opType, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationPlanningTime:
if v := NewDurationLogField(reqContext.operation.planningTime, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationNormalizationTime:
if v := NewDurationLogField(reqContext.operation.normalizationTime, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationParsingTime:
if v := NewDurationLogField(reqContext.operation.parsingTime, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationValidationTime:
if v := NewDurationLogField(reqContext.operation.validationTime, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationSha256:
if v := NewStringLogField(reqContext.operation.sha256Hash, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldOperationHash:
if reqContext.operation.hash != 0 {
if v := NewStringLogField(strconv.FormatUint(reqContext.operation.hash, 10), field); v != zap.Skip() {
resFields = append(resFields, v)
}
}
case ContextFieldPersistedOperationSha256:
if v := NewStringLogField(reqContext.operation.persistedID, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldResponseErrorMessage:
var errMessage string
if panicError != nil {
errMessage = fmt.Sprintf("%v", panicError)
} else if reqContext.error != nil {
errMessage = reqContext.error.Error()
}
if field.ValueFrom.ContextField == ContextFieldResponseErrorMessage && panicError != nil {
errMessage := fmt.Sprintf("%v", panicError)
if v := NewStringLogField(errMessage, field); v != zap.Skip() {
resFields = append(resFields, v)
}

case ContextFieldOperationServices:
if v := NewStringSliceLogField(reqContext.dataSourceNames, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldGraphQLErrorServices:
if v := NewStringSliceLogField(reqContext.graphQLErrorServices, field); v != zap.Skip() {
resFields = append(resFields, v)
}
case ContextFieldGraphQLErrorCodes:
if v := NewStringSliceLogField(reqContext.graphQLErrorCodes, field); v != zap.Skip() {
resFields = append(resFields, v)
}
}
if v := GetLogFieldFromCustomAttribute(field, reqContext); v != zap.Skip() {
resFields = append(resFields, v)
}
} else if field.Default != "" {
resFields = append(resFields, NewStringLogField(field.Default, field))
Expand Down
11 changes: 10 additions & 1 deletion router/core/header_rule_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,15 @@ func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propag

func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.Request, rule *config.RequestHeaderRule) {
if rule.Operation == config.HeaderRuleOperationSet {
if rule.ValueFrom != nil && rule.ValueFrom.ContextField != "" {
val := getCustomDynamicAttributeValue(rule.ValueFrom, getRequestContext(request.Context()))
value := fmt.Sprintf("%v", val)
if value != "" {
request.Header.Set(rule.Name, value)
}
return
}

request.Header.Set(rule.Name, rule.Value)
return
}
Expand Down Expand Up @@ -623,7 +632,7 @@ func PropagatedHeaders(rules []*config.RequestHeaderRule) (headerNames []string,
for _, rule := range rules {
switch rule.Operation {
case config.HeaderRuleOperationSet:
if rule.Name == "" || rule.Value == "" {
if rule.Name == "" || (rule.Value == "" && rule.ValueFrom == nil) {
return nil, nil, fmt.Errorf("invalid header set rule %+v, no header name/value combination", rule)
}
headerNames = append(headerNames, rule.Name)
Expand Down
56 changes: 56 additions & 0 deletions router/core/request_context_fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,59 @@ func NewDurationLogField(val time.Duration, attribute config.CustomAttribute) za
}
return zap.Skip()
}

func GetLogFieldFromCustomAttribute(field config.CustomAttribute, req *requestContext) zap.Field {
val := getCustomDynamicAttributeValue(field.ValueFrom, req)
switch v := val.(type) {
case string:
return NewStringLogField(v, field)
case []string:
return NewStringSliceLogField(v, field)
case time.Duration:
return NewDurationLogField(v, field)
}

return zap.Skip()
}

func getCustomDynamicAttributeValue(attribute *config.CustomDynamicAttribute, reqContext *requestContext) interface{} {
if attribute.ContextField == "" {
return ""
}

switch attribute.ContextField {
case ContextFieldOperationName:
return reqContext.operation.Name()
case ContextFieldOperationType:
return reqContext.operation.Type()
case ContextFieldOperationPlanningTime:
return reqContext.operation.planningTime
case ContextFieldOperationNormalizationTime:
return reqContext.operation.normalizationTime
case ContextFieldOperationParsingTime:
return reqContext.operation.parsingTime
case ContextFieldOperationValidationTime:
return reqContext.operation.validationTime
case ContextFieldOperationSha256:
return reqContext.operation.sha256Hash
case ContextFieldOperationHash:
if reqContext.operation.hash != 0 {
return strconv.FormatUint(reqContext.operation.hash, 10)
}
return reqContext.operation.Hash()
case ContextFieldPersistedOperationSha256:
return reqContext.operation.persistedID
case ContextFieldResponseErrorMessage:
if reqContext.error != nil {
return reqContext.error.Error()
}
case ContextFieldOperationServices:
return reqContext.dataSourceNames
case ContextFieldGraphQLErrorServices:
return reqContext.graphQLErrorServices
case ContextFieldGraphQLErrorCodes:
return reqContext.graphQLErrorCodes
}

return ""
}
6 changes: 4 additions & 2 deletions router/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ type CustomStaticAttribute struct {
}

type CustomDynamicAttribute struct {
RequestHeader string `yaml:"request_header"`
ContextField string `yaml:"context_field"`
RequestHeader string `yaml:"request_header,omitempty"`
ContextField string `yaml:"context_field,omitempty"`
}

type CustomAttribute struct {
Expand Down Expand Up @@ -227,6 +227,8 @@ type RequestHeaderRule struct {
Name string `yaml:"name"`
// Value is the value of the header to set
Value string `yaml:"value"`
// ValueFrom is the context field to get the value from, in propagating to subgraphs
ValueFrom *CustomDynamicAttribute `yaml:"value_from,omitempty"`
}

func (r *RequestHeaderRule) GetOperation() HeaderRuleOperation {
Expand Down
17 changes: 16 additions & 1 deletion router/pkg/config/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2156,9 +2156,24 @@
"type": "string",
"examples": ["My-Secret-Value"],
"description": "The value to set for the header. This can include environment variables."
},
"value_from": {
"type": "object",
"description": "The configuration for the value from. The value from is used to extract a value from a request context and propagate it to subgraphs. This is currently only valid in requests",
"additionalProperties": false,
"required": ["context_field"],
"properties": {
"context_field": {
"type": "string",
"description": "The field name of the context from which to extract the value. The value is only extracted when a context is available otherwise the default value is used.",
"enum": [
"operation_name"
]
}
}
}
},
"required": ["op", "name", "value"]
"required": ["op", "name"]
}
}
}
4 changes: 4 additions & 0 deletions router/pkg/config/fixtures/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ headers:
- op: "set"
name: "X-API-Key"
value: "some-secret"
- op: "set"
name: "x-operation-name"
value_from:
context_field: "operation_name"
response:
- op: "propagate"
algorithm: "append"
Expand Down
28 changes: 23 additions & 5 deletions router/pkg/config/testdata/config_full.json
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@
"Rename": "",
"Default": "",
"Name": "",
"Value": ""
"Value": "",
"ValueFrom": null
},
{
"Operation": "propagate",
Expand All @@ -141,7 +142,8 @@
"Rename": "",
"Default": "",
"Name": "",
"Value": ""
"Value": "",
"ValueFrom": null
},
{
"Operation": "propagate",
Expand All @@ -150,7 +152,8 @@
"Rename": "",
"Default": "123",
"Name": "",
"Value": ""
"Value": "",
"ValueFrom": null
},
{
"Operation": "set",
Expand All @@ -159,7 +162,21 @@
"Rename": "",
"Default": "",
"Name": "X-API-Key",
"Value": "some-secret"
"Value": "some-secret",
"ValueFrom": null
},
{
"Operation": "set",
"Matching": "",
"Named": "",
"Rename": "",
"Default": "",
"Name": "x-operation-name",
"Value": "",
"ValueFrom": {
"RequestHeader": "",
"ContextField": "operation_name"
}
}
],
"Response": [
Expand All @@ -185,7 +202,8 @@
"Rename": "",
"Default": "some-secret",
"Name": "",
"Value": ""
"Value": "",
"ValueFrom": null
}
],
"Response": [
Expand Down

0 comments on commit c96485d

Please sign in to comment.