-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdata.go
157 lines (137 loc) · 4.46 KB
/
data.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
// Copyright 2021-2024 Nokia
// Licensed under the BSD 3-Clause License.
// SPDX-License-Identifier: BSD-3-Clause
package restful
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"github.com/gorilla/schema"
"github.com/nokia/restful/messagepack"
log "github.com/sirupsen/logrus"
)
var (
formDecoder = schema.NewDecoder()
)
func init() {
formDecoder.IgnoreUnknownKeys(true)
}
// GetDataBytes returns []byte received.
// If maxBytes > 0 then larger body is dropped.
func GetDataBytes(headers http.Header, ioBody io.ReadCloser, maxBytes int) (body []byte, err error) {
if ioBody == nil { // On using httptest req.Body may be missing.
return
}
if maxBytes > 0 {
var cl int
cl, err = strconv.Atoi(headers.Get("Content-length"))
if err == nil && cl > maxBytes {
_, _ = io.ReadAll(ioBody)
_ = ioBody.Close()
err = fmt.Errorf("too big Content-Length: %d > %d", cl, maxBytes)
return
}
}
body, err = io.ReadAll(ioBody)
_ = ioBody.Close()
if err != nil {
return body, fmt.Errorf("body read error: %s", err.Error())
}
if maxBytes > 0 && len(body) > maxBytes {
err = fmt.Errorf("too long content: %d > %d", len(body), maxBytes)
}
return
}
// GetDataBytesForContentType returns []byte received, if Content-Type is matching or empty string.
// If no content then Content-Type is not checked.
// If maxBytes > 0 then larger body is dropped.
func GetDataBytesForContentType(headers http.Header, ioBody io.ReadCloser, maxBytes int, expectedContentType string) (body []byte, err error) {
body, err = GetDataBytes(headers, ioBody, maxBytes)
if err != nil {
return
}
if len(body) == 0 || expectedContentType == "" { // No need to check Content-Type
return
}
recvdContentType := GetBaseContentType(headers)
if recvdContentType != expectedContentType {
err = errors.Join(ErrUnexpectedContentType, fmt.Errorf("received: '%s'; expected: %s", recvdContentType, expectedContentType))
return
}
return
}
func getData(headers http.Header, ioBody io.ReadCloser, maxBytes int, data any, request bool) error {
if data == nil {
_ = ioBody.Close()
return nil
}
body, err := GetDataBytes(headers, ioBody, maxBytes)
if err != nil {
if request {
return NewError(err, http.StatusInternalServerError, "Failed to read request")
}
return err
}
if len(body) == 0 {
if request {
return NewError(nil, http.StatusBadRequest, "body expected")
}
return nil
}
recvdContentType := GetBaseContentType(headers)
if isMsgPackContentType(recvdContentType) {
err = messagepack.Unmarshal(body, data)
if err != nil && request {
return NewError(err, http.StatusBadRequest, "Invalid msgpack content")
}
return err
}
if !isJSONContentType(recvdContentType) {
err := fmt.Errorf("unexpected Content-Type: '%s'; not JSON", recvdContentType)
if request {
return NewError(err, http.StatusBadRequest)
}
return err
}
if recvdContentType == ContentTypeProblemJSON {
log.Debug("Problem: ", string(body))
}
err = json.Unmarshal(body, data)
if err != nil && request {
return NewError(err, http.StatusBadRequest, "Invalid JSON content")
}
return err
}
// GetRequestData returns request data from HTTP request.
// Data source depends on Content-Type (CT). JSON, form data or in case of GET w/o CT query parameters are used.
// If maxBytes > 0 it blocks parsing exceedingly huge data, which could be used for DoS or memory overflow attacks.
// If error is returned then suggested HTTP status may be encapsulated in it, available via GetErrStatusCode.
func GetRequestData(req *http.Request, maxBytes int, data any) error {
ct := GetBaseContentType(req.Header)
switch ct {
case "":
if req.Method == http.MethodGet {
return formDecoder.Decode(data, req.URL.Query())
}
return nil
case ContentTypeForm:
if err := req.ParseForm(); err != nil {
return NewError(err, http.StatusNotAcceptable, "Bad form")
}
return formDecoder.Decode(data, req.PostForm)
case ContentTypeMultipartForm:
if err := req.ParseMultipartForm(int64(maxBytes)); err != nil {
return NewError(err, http.StatusNotAcceptable, "Bad form")
}
return formDecoder.Decode(data, req.PostForm)
}
return getData(req.Header, req.Body, maxBytes, data, true)
}
// GetResponseData returns response data from JSON body of HTTP response.
// If maxBytes > 0 it blocks parsing exceedingly huge JSON data, which could be used for DoS or memory overflow attacks.
func GetResponseData(resp *http.Response, maxBytes int, data any) error {
return getData(resp.Header, resp.Body, maxBytes, data, false)
}