This repository has been archived by the owner on Dec 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathwebsocket.go
383 lines (346 loc) · 10.6 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
package kucoin
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"math/rand"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// A WebSocketTokenModel contains a token and some servers for WebSocket feed.
type WebSocketTokenModel struct {
Token string `json:"token"`
Servers WebSocketServersModel `json:"instanceServers"`
AcceptUserMessage bool `json:"accept_user_message"`
}
// A WebSocketServerModel contains some servers for WebSocket feed.
type WebSocketServerModel struct {
PingInterval int64 `json:"pingInterval"`
Endpoint string `json:"endpoint"`
Protocol string `json:"protocol"`
Encrypt bool `json:"encrypt"`
PingTimeout int64 `json:"pingTimeout"`
}
// A WebSocketServersModel is the set of *WebSocketServerModel.
type WebSocketServersModel []*WebSocketServerModel
// RandomServer returns a server randomly.
func (s WebSocketServersModel) RandomServer() (*WebSocketServerModel, error) {
l := len(s)
if l == 0 {
return nil, errors.New("No available server ")
}
return s[rand.Intn(l)], nil
}
// WebSocketPublicToken returns the token for public channel.
func (as *ApiService) WebSocketPublicToken(ctx context.Context) (*ApiResponse, error) {
req := NewRequest(http.MethodPost, "/api/v1/bullet-public", map[string]string{})
return as.Call(ctx, req)
}
// WebSocketPrivateToken returns the token for private channel.
func (as *ApiService) WebSocketPrivateToken(ctx context.Context) (*ApiResponse, error) {
req := NewRequest(http.MethodPost, "/api/v1/bullet-private", map[string]string{})
return as.Call(ctx, req)
}
// All message types of WebSocket.
const (
WelcomeMessage = "welcome"
PingMessage = "ping"
PongMessage = "pong"
SubscribeMessage = "subscribe"
AckMessage = "ack"
UnsubscribeMessage = "unsubscribe"
ErrorMessage = "error"
Message = "message"
Notice = "notice"
Command = "command"
)
// A WebSocketMessage represents a message between the WebSocket client and server.
type WebSocketMessage struct {
Id string `json:"id"`
Type string `json:"type"`
}
// A WebSocketSubscribeMessage represents a message to subscribe the public/private channel.
type WebSocketSubscribeMessage struct {
*WebSocketMessage
Topic string `json:"topic"`
PrivateChannel bool `json:"privateChannel"`
Response bool `json:"response"`
}
// NewPingMessage creates a ping message instance.
func NewPingMessage() *WebSocketMessage {
return &WebSocketMessage{
Id: IntToString(time.Now().UnixNano()),
Type: PingMessage,
}
}
// NewSubscribeMessage creates a subscribe message instance.
func NewSubscribeMessage(topic string, privateChannel bool) *WebSocketSubscribeMessage {
return &WebSocketSubscribeMessage{
WebSocketMessage: &WebSocketMessage{
Id: IntToString(time.Now().UnixNano()),
Type: SubscribeMessage,
},
Topic: topic,
PrivateChannel: privateChannel,
Response: true,
}
}
// A WebSocketUnsubscribeMessage represents a message to unsubscribe the public/private channel.
type WebSocketUnsubscribeMessage WebSocketSubscribeMessage
// NewUnsubscribeMessage creates a unsubscribe message instance.
func NewUnsubscribeMessage(topic string, privateChannel bool) *WebSocketUnsubscribeMessage {
return &WebSocketUnsubscribeMessage{
WebSocketMessage: &WebSocketMessage{
Id: IntToString(time.Now().UnixNano()),
Type: UnsubscribeMessage,
},
Topic: topic,
PrivateChannel: privateChannel,
Response: true,
}
}
// A WebSocketDownstreamMessage represents a message from the WebSocket server to client.
type WebSocketDownstreamMessage struct {
*WebSocketMessage
Sn int64 `json:"sn"`
Topic string `json:"topic"`
Subject string `json:"subject"`
RawData json.RawMessage `json:"data"`
}
// ReadData read the data in channel.
func (m *WebSocketDownstreamMessage) ReadData(v interface{}) error {
return json.Unmarshal(m.RawData, v)
}
// A WebSocketClient represents a connection to WebSocket server.
type WebSocketClient struct {
// Wait all goroutines quit
wg *sync.WaitGroup
// Stop subscribing channel
done chan struct{}
// Pong channel to check pong message
pongs chan string
// ACK channel to check pong message
acks chan string
// Error channel
errors chan error
// Downstream message channel
messages chan *WebSocketDownstreamMessage
conn *websocket.Conn
token *WebSocketTokenModel
server *WebSocketServerModel
enableHeartbeat bool
skipVerifyTls bool
timeout time.Duration
}
var defaultTimeout = time.Second * 5
// WebSocketClientOpts defines the options for the client
// during the websocket connection.
type WebSocketClientOpts struct {
Token *WebSocketTokenModel
TLSSkipVerify bool
Timeout time.Duration
}
// NewWebSocketClient creates an instance of WebSocketClient.
func (as *ApiService) NewWebSocketClient(token *WebSocketTokenModel) *WebSocketClient {
return as.NewWebSocketClientOpts(WebSocketClientOpts{
Token: token,
TLSSkipVerify: as.apiSkipVerifyTls,
Timeout: defaultTimeout,
})
}
// NewWebSocketClientOpts creates an instance of WebSocketClient with the parsed options.
func (as *ApiService) NewWebSocketClientOpts(opts WebSocketClientOpts) *WebSocketClient {
wc := &WebSocketClient{
wg: &sync.WaitGroup{},
done: make(chan struct{}),
errors: make(chan error, 1),
pongs: make(chan string, 1),
acks: make(chan string, 1),
token: opts.Token,
messages: make(chan *WebSocketDownstreamMessage, 2048),
skipVerifyTls: opts.TLSSkipVerify,
timeout: opts.Timeout,
}
return wc
}
// Connect connects the WebSocket server.
func (wc *WebSocketClient) Connect() (<-chan *WebSocketDownstreamMessage, <-chan error, error) {
// Find out a server
s, err := wc.token.Servers.RandomServer()
if err != nil {
return wc.messages, wc.errors, err
}
wc.server = s
// Concat ws url
q := url.Values{}
q.Add("connectId", IntToString(time.Now().UnixNano()))
q.Add("token", wc.token.Token)
if wc.token.AcceptUserMessage == true {
q.Add("acceptUserMessage", "true")
}
u := fmt.Sprintf("%s?%s", s.Endpoint, q.Encode())
// Ignore verify tls
websocket.DefaultDialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: wc.skipVerifyTls}
// Connect ws server
websocket.DefaultDialer.ReadBufferSize = 2048000 //2000 kb
wc.conn, _, err = websocket.DefaultDialer.Dial(u, nil)
if err != nil {
return wc.messages, wc.errors, err
}
// Must read the first welcome message
for {
m := &WebSocketDownstreamMessage{}
if err := wc.conn.ReadJSON(m); err != nil {
return wc.messages, wc.errors, err
}
if DebugMode {
logrus.Debugf("Received a WebSocket message: %s", ToJsonString(m))
}
if m.Type == ErrorMessage {
return wc.messages, wc.errors, errors.Errorf("Error message: %s", ToJsonString(m))
}
if m.Type == WelcomeMessage {
break
}
}
wc.wg.Add(2)
go wc.read()
go wc.keepHeartbeat()
return wc.messages, wc.errors, nil
}
func (wc *WebSocketClient) read() {
defer func() {
close(wc.pongs)
close(wc.messages)
wc.wg.Done()
}()
for {
select {
case <-wc.done:
return
default:
m := &WebSocketDownstreamMessage{}
if err := wc.conn.ReadJSON(m); err != nil {
wc.errors <- err
return
}
if DebugMode {
logrus.Debugf("Received a WebSocket message: %s", ToJsonString(m))
}
// log.Printf("ReadJSON: %s", ToJsonString(m))
switch m.Type {
case WelcomeMessage:
case PongMessage:
if wc.enableHeartbeat {
wc.pongs <- m.Id
}
case AckMessage:
// log.Printf("Subscribed: %s==%s? %s", channel.Id, m.Id, channel.Topic)
wc.acks <- m.Id
case ErrorMessage:
wc.errors <- errors.Errorf("Error message: %s", ToJsonString(m))
return
case Message, Notice, Command:
wc.messages <- m
default:
wc.errors <- errors.Errorf("Unknown message type: %s", m.Type)
}
}
}
}
func (wc *WebSocketClient) keepHeartbeat() {
wc.enableHeartbeat = true
// New ticker to send ping message
pt := time.NewTicker(time.Duration(wc.server.PingInterval)*time.Millisecond - time.Millisecond*200)
defer wc.wg.Done()
defer pt.Stop()
for {
select {
case <-wc.done:
return
case <-pt.C:
p := NewPingMessage()
m := ToJsonString(p)
if DebugMode {
logrus.Debugf("Sent a WebSocket message: %s", m)
}
if err := wc.conn.WriteMessage(websocket.TextMessage, []byte(m)); err != nil {
wc.errors <- err
return
}
// log.Printf("Ping: %s", ToJsonString(p))
// Waiting (with timeout) for the server to response pong message
// If timeout, close this connection
select {
case pid := <-wc.pongs:
if pid != p.Id {
wc.errors <- errors.Errorf("Invalid pong id %s, expect %s", pid, p.Id)
return
}
case <-time.After(time.Duration(wc.server.PingTimeout) * time.Millisecond):
wc.errors <- errors.Errorf("Wait pong message timeout in %d ms", wc.server.PingTimeout)
return
}
}
}
}
// Subscribe subscribes the specified channel.
func (wc *WebSocketClient) Subscribe(channels ...*WebSocketSubscribeMessage) error {
for _, c := range channels {
m := ToJsonString(c)
if DebugMode {
logrus.Debugf("Sent a WebSocket message: %s", m)
}
if err := wc.conn.WriteMessage(websocket.TextMessage, []byte(m)); err != nil {
return err
}
//log.Printf("Subscribing: %s, %s", c.Id, c.Topic)
select {
case id := <-wc.acks:
//log.Printf("ack: %s=>%s", id, c.Id)
if id != c.Id {
return errors.Errorf("Invalid ack id %s, expect %s", id, c.Id)
}
case err := <-wc.errors:
return errors.Errorf("Subscribe failed, %s", err.Error())
case <-time.After(wc.timeout):
return errors.Errorf("Wait ack message timeout in %v", wc.timeout)
}
}
return nil
}
// Unsubscribe unsubscribes the specified channel.
func (wc *WebSocketClient) Unsubscribe(channels ...*WebSocketUnsubscribeMessage) error {
for _, c := range channels {
m := ToJsonString(c)
if DebugMode {
logrus.Debugf("Sent a WebSocket message: %s", m)
}
if err := wc.conn.WriteMessage(websocket.TextMessage, []byte(m)); err != nil {
return err
}
//log.Printf("Unsubscribing: %s, %s", c.Id, c.Topic)
select {
case id := <-wc.acks:
//log.Printf("ack: %s=>%s", id, c.Id)
if id != c.Id {
return errors.Errorf("Invalid ack id %s, expect %s", id, c.Id)
}
case <-time.After(wc.timeout):
return errors.Errorf("Wait ack message timeout in %v", wc.timeout)
}
}
return nil
}
// Stop stops subscribing the specified channel, all goroutines quit.
func (wc *WebSocketClient) Stop() {
close(wc.done)
_ = wc.conn.Close()
wc.wg.Wait()
}