-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathquic_client_initial.go
231 lines (191 loc) · 7.4 KB
/
quic_client_initial.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
package clienthellod
import (
"errors"
"fmt"
"runtime"
"sort"
"sync"
"sync/atomic"
"time"
)
// ClientInitial represents a QUIC Initial Packet sent by the Client.
type ClientInitial struct {
Header *QUICHeader `json:"header,omitempty"` // QUIC header
FrameTypes []uint64 `json:"frames,omitempty"` // frames ID in order
frames QUICFrames // frames in order
raw []byte
}
// UnmarshalQUICClientInitialPacket is similar to ParseQUICCIP, but on error
// such as ClientHello cannot be parsed, it returns a partially completed
// ClientInitialPacket instead of nil.
func UnmarshalQUICClientInitialPacket(p []byte) (ci *ClientInitial, err error) {
ci = &ClientInitial{
raw: p,
}
ci.Header, ci.frames, err = DecodeQUICHeaderAndFrames(p)
if err != nil {
return
}
ci.FrameTypes = ci.frames.FrameTypes()
// Make sure first GC completely releases all resources as possible
runtime.SetFinalizer(ci, func(c *ClientInitial) {
c.Header = nil
c.FrameTypes = nil
c.frames = nil
c.raw = nil
})
return ci, nil
}
// GatheredClientInitials represents a series of Initial Packets sent by the Client to initiate
// the QUIC handshake.
type GatheredClientInitials struct {
Packets []*ClientInitial `json:"packets,omitempty"` // sorted by ClientInitial.PacketNumber
maxPacketNumber uint64 // if incomingPacketNumber > maxPacketNumber, will reject the packet
maxPacketCount uint64 // if len(Packets) >= maxPacketCount, will reject any new packets
pktsMutex *sync.Mutex
clientHelloReconstructor *QUICClientHelloReconstructor
ClientHello *QUICClientHello `json:"client_hello,omitempty"` // TLS ClientHello
TransportParameters *QUICTransportParameters `json:"transport_parameters,omitempty"` // QUIC Transport Parameters extracted from the extension in ClientHello
HexID string `json:"hex_id,omitempty"`
NumID uint64 `json:"num_id,omitempty"`
deadline time.Time
completed atomic.Bool
completeChan chan struct{}
completeChanCloseOnce sync.Once
}
const (
DEFAULT_MAX_INITIAL_PACKET_NUMBER uint64 = 32
DEFAULT_MAX_INITIAL_PACKET_COUNT uint64 = 4
)
var (
ErrGatheringExpired = errors.New("ClientInitials gathering has expired")
ErrPacketRejected = errors.New("packet rejected based upon rules")
ErrGatheredClientInitialsChannelClosedBeforeCompletion = errors.New("completion notification channel closed before setting completion flag")
)
// GatherClientInitialPackets reads a series of Client Initial Packets from the input channel
// and returns the result of the gathered packets.
func GatherClientInitials() *GatheredClientInitials {
gci := &GatheredClientInitials{
Packets: make([]*ClientInitial, 0, 4), // expecting 4 packets at max
maxPacketNumber: DEFAULT_MAX_INITIAL_PACKET_NUMBER,
maxPacketCount: DEFAULT_MAX_INITIAL_PACKET_COUNT,
pktsMutex: &sync.Mutex{},
clientHelloReconstructor: NewQUICClientHelloReconstructor(),
completed: atomic.Bool{},
completeChan: make(chan struct{}),
completeChanCloseOnce: sync.Once{},
}
// Make sure first GC completely releases all resources as possible
runtime.SetFinalizer(gci, func(g *GatheredClientInitials) {
g.Packets = nil
g.clientHelloReconstructor = nil
g.ClientHello = nil
g.TransportParameters = nil
g.completeChanCloseOnce.Do(func() {
close(g.completeChan)
})
g.completeChan = nil
})
return gci
}
// GatherClientInitialsWithDeadline is a helper function to create a GatheredClientInitials with a deadline.
func GatherClientInitialsWithDeadline(deadline time.Time) *GatheredClientInitials {
gci := GatherClientInitials()
gci.SetDeadline(deadline)
return gci
}
func (gci *GatheredClientInitials) AddPacket(cip *ClientInitial) error {
gci.pktsMutex.Lock()
defer gci.pktsMutex.Unlock()
if gci.Expired() { // not allowing new packets after expiry
return ErrGatheringExpired
}
if gci.ClientHello != nil { // parse complete, new packet likely to be an ACK-only frame, ignore
return nil
}
// check if packet needs to be rejected based upon set maxPacketNumber and maxPacketCount
if cip.Header.initialPacketNumber > atomic.LoadUint64(&gci.maxPacketNumber) ||
uint64(len(gci.Packets)) >= atomic.LoadUint64(&gci.maxPacketCount) {
return ErrPacketRejected
}
// check if duplicate packet number was received, if so, discard
for _, p := range gci.Packets {
if p.Header.initialPacketNumber == cip.Header.initialPacketNumber {
return nil
}
}
gci.Packets = append(gci.Packets, cip)
// sort by initialPacketNumber
sort.Slice(gci.Packets, func(i, j int) bool {
return gci.Packets[i].Header.initialPacketNumber < gci.Packets[j].Header.initialPacketNumber
})
if err := gci.clientHelloReconstructor.FromFrames(cip.frames); err != nil {
if errors.Is(err, ErrNeedMoreFrames) {
return nil // abort early, need more frames before ClientHello can be reconstructed
} else {
return fmt.Errorf("failed to reassemble ClientHello: %w", err)
}
}
return gci.lockedGatherComplete()
}
// Completed returns true if the GatheredClientInitials is complete.
func (gci *GatheredClientInitials) Completed() bool {
return gci.completed.Load()
}
// Expired returns true if the GatheredClientInitials has expired.
func (gci *GatheredClientInitials) Expired() bool {
return time.Now().After(gci.deadline)
}
func (gci *GatheredClientInitials) lockedGatherComplete() error {
var err error
// First, reconstruct the ClientHello
gci.ClientHello, err = gci.clientHelloReconstructor.Reconstruct()
if err != nil {
return fmt.Errorf("failed to reconstruct ClientHello: %w", err)
}
// Next, point the TransportParameters to the ClientHello's qtp
gci.TransportParameters = gci.ClientHello.qtp
// Then calculate the NumericID
numericID := gci.calcNumericID()
atomic.StoreUint64(&gci.NumID, numericID)
gci.HexID = FingerprintID(numericID).AsHex()
// Finally, mark the completion
gci.completed.Store(true)
gci.completeChanCloseOnce.Do(func() {
close(gci.completeChan)
})
return nil
}
// SetDeadline sets the deadline for the GatheredClientInitials to complete.
func (gci *GatheredClientInitials) SetDeadline(deadline time.Time) {
gci.deadline = deadline
}
// SetMaxPacketNumber sets the maximum packet number to be gathered.
// If a Client Initial packet with a higher packet number is received, it will be rejected.
//
// This function can be used as a precaution against memory exhaustion attacks.
func (gci *GatheredClientInitials) SetMaxPacketNumber(maxPacketNumber uint64) {
atomic.StoreUint64(&gci.maxPacketNumber, maxPacketNumber)
}
// SetMaxPacketCount sets the maximum number of packets to be gathered.
// If more Client Initial packets are received, they will be rejected.
//
// This function can be used as a precaution against memory exhaustion attacks.
func (gci *GatheredClientInitials) SetMaxPacketCount(maxPacketCount uint64) {
atomic.StoreUint64(&gci.maxPacketCount, maxPacketCount)
}
// Wait blocks until the GatheredClientInitials is complete or expired.
func (gci *GatheredClientInitials) Wait() error {
if gci.completed.Load() {
return nil
}
select {
case <-time.After(time.Until(gci.deadline)):
return ErrGatheringExpired
case <-gci.completeChan:
if gci.completed.Load() {
return nil
}
return ErrGatheredClientInitialsChannelClosedBeforeCompletion // divergent state, only possible reason is GC
}
}