diff --git a/examples/socketio_client/socketio_client.go b/examples/socketio_client/socketio_client.go new file mode 100644 index 0000000..c32158c --- /dev/null +++ b/examples/socketio_client/socketio_client.go @@ -0,0 +1,20 @@ +package main + +import ( + "flag" + "fmt" + "time" + + "github.com/timpalpant/go-iex" + "github.com/timpalpant/go-iex/socketio" +) + +func main() { + flag.Parse() + client := socketio.NewClient() + ns := client.GetTOPSNamespace() + go ns.SubscribeTo(func(msg iex.TOPS) { + fmt.Printf("Received message: %+v\n", msg) + }, "fb", "goog") + time.Sleep(30 * time.Second) +} diff --git a/examples/socketio_protocol/iex_socketio_protocol.go b/examples/socketio_protocol/iex_socketio_protocol.go new file mode 100644 index 0000000..7ae6f77 --- /dev/null +++ b/examples/socketio_protocol/iex_socketio_protocol.go @@ -0,0 +1,93 @@ +package main + +import ( + "encoding/json" + "flag" + "io" + "io/ioutil" + "net/http" + "net/http/cookiejar" + "net/url" + "strings" + + "github.com/chilts/sid" + "github.com/golang/glog" + "github.com/gorilla/websocket" +) + +type handshake struct { + Sid string +} + +func makeRequest(client *http.Client, method string, + uri *url.URL, bodyData *string) []byte { + glog.Infof("Making %s request:> %v", method, uri) + + var reader io.Reader + if bodyData != nil { + data := *bodyData + glog.Infof("With data:> %s", data) + reader = strings.NewReader(data) + } + req, _ := http.NewRequest(method, uri.String(), reader) + resp, _ := client.Do(req) + + body, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() + glog.Infof("Response:> %v", string(body)) + return body +} + +func wsMessage(conn *websocket.Conn, msg []byte) { + glog.Infof("Writing WS message:> %s", string(msg)) + conn.WriteMessage(websocket.TextMessage, msg) +} + +func wsReadMessage(conn *websocket.Conn) { + _, message, err := conn.ReadMessage() + if err != nil { + glog.Fatal(err) + } + glog.Infof("WS Response: %s", string(message)) +} + +func main() { + flag.Parse() + glog.Info("Starting handshake sequence") + + jar, _ := cookiejar.New(nil) + client := &http.Client{Jar: jar} + + uri, _ := url.Parse("https://ws-api.iextrading.com/socket.io/") + values := uri.Query() + values.Set("t", sid.IdBase64()) + values.Set("EIO", "3") + values.Set("transport", "polling") + uri.RawQuery = values.Encode() + + resp := makeRequest(client, "GET", uri, nil) + + var hs handshake + json.Unmarshal(resp[4:], &hs) + values.Set("sid", hs.Sid) + uri.RawQuery = values.Encode() + + makeRequest(client, "GET", uri, nil) + + uri, _ = url.Parse("wss://ws-api.iextrading.com/socket.io/") + values.Set("transport", "websocket") + uri.RawQuery = values.Encode() + glog.Infof("Websocket connecting to:> %s", uri.String()) + conn, _, err := websocket.DefaultDialer.Dial(uri.String(), nil) + if err != nil { + glog.Fatal(err) + } + wsMessage(conn, []byte("5")) + wsMessage(conn, []byte("2")) + wsReadMessage(conn) + wsMessage(conn, []byte("40/1.0/last,")) + wsReadMessage(conn) + wsMessage(conn, []byte("42/1.0/last,[\"subscribe\",\"fb,goog\"]")) + wsReadMessage(conn) + wsReadMessage(conn) +} diff --git a/go.mod b/go.mod index cd4489a..fed96de 100644 --- a/go.mod +++ b/go.mod @@ -1,11 +1,16 @@ module github.com/timpalpant/go-iex require ( + github.com/cheekybits/genny v1.0.0 + github.com/chilts/sid v0.0.0-20190607042430-660e94789ec9 + github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b github.com/google/go-cmp v0.2.0 // indirect github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135 github.com/google/gopacket v1.1.16-0.20181023151400-a35e09f9f224 + github.com/gorilla/websocket v1.4.1 github.com/johnmccabe/go-bitbar v0.4.0 github.com/mdlayher/raw v0.0.0-20181016155347-fa5ef3332ca9 // indirect - golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519 // indirect - golang.org/x/sys v0.0.0-20181024145615-5cd93ef61a7c // indirect + github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 ) + +go 1.13 diff --git a/go.sum b/go.sum index cce28b4..07e7238 100644 --- a/go.sum +++ b/go.sum @@ -1,14 +1,33 @@ +github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= +github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= +github.com/chilts/sid v0.0.0-20190607042430-660e94789ec9 h1:z0uK8UQqjMVYzvk4tiiu3obv2B44+XBsvgEJREQfnO8= +github.com/chilts/sid v0.0.0-20190607042430-660e94789ec9/go.mod h1:Jl2neWsQaDanWORdqZ4emBl50J4/aRBBS4FyyG9/PFo= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135 h1:zLTLjkaOFEFIOxY5BWLFLwh+cL8vOBW4XJ2aqLE/Tf0= github.com/google/go-querystring v0.0.0-20170111101155-53e6ce116135/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gopacket v1.1.16-0.20181023151400-a35e09f9f224 h1:78xLKlzgK/iEGI5iyrSMXEZu+kRRT+s08QqpSXonq7o= github.com/google/gopacket v1.1.16-0.20181023151400-a35e09f9f224/go.mod h1:UCLx9mCmAwsVbn6qQl1WIEt2SO7Nd2fD0th1TBAsqBw= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/johnmccabe/go-bitbar v0.4.0 h1:n2vBc0btNbDkdyEfovT9YjZE/QJvNUKCSASevTperhg= github.com/johnmccabe/go-bitbar v0.4.0/go.mod h1:i67T2iQ7Ql/v6x4NbPLlW7eTs+3d/vZgVDl12pr03C8= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/mdlayher/raw v0.0.0-20181016155347-fa5ef3332ca9 h1:tOtO8DXiNGj9NshRKHWiZuGlSldPFzFCFYhNtsKTBCs= github.com/mdlayher/raw v0.0.0-20181016155347-fa5ef3332ca9/go.mod h1:rC/yE65s/DoHB6BzVOUBNYBGTg772JVytyAytffIZkY= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519 h1:x6rhz8Y9CjbgQkccRGmELH6K+LJj7tOoh3XWeC1yaQM= -golang.org/x/net v0.0.0-20181023162649-9b4f9f5ad519/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/sys v0.0.0-20181024145615-5cd93ef61a7c h1:8QwKN2PcBeeHEiYIX6348SzigNWH9uHHP1EOEs5ExSc= -golang.org/x/sys v0.0.0-20181024145615-5cd93ef61a7c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337 h1:WN9BUFbdyOsSH/XohnWpXOlq9NBD5sGAB2FciQMUEe8= +github.com/smartystreets/goconvey v0.0.0-20190731233626-505e41936337/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= diff --git a/socketio/.gitignore b/socketio/.gitignore new file mode 100644 index 0000000..e4e5f6c --- /dev/null +++ b/socketio/.gitignore @@ -0,0 +1 @@ +*~ \ No newline at end of file diff --git a/socketio/client.go b/socketio/client.go new file mode 100644 index 0000000..e444091 --- /dev/null +++ b/socketio/client.go @@ -0,0 +1,134 @@ +package socketio + +import ( + "io/ioutil" + "net/http" + "net/http/cookiejar" + "sync" + + "github.com/golang/glog" + "github.com/gorilla/websocket" +) + +const ( + deep string = "/1.0/deep" + last string = "/1.0/last" + tops string = "/1.0/tops" +) + +// Connects to IEX SocketIO endpoints and routes received messages back to the +// correct handlers. +type Client struct { + // Allows reference counting of open namespaces. + CountingSubscriber + // Protects access to namespaces. + sync.Mutex + + // The Transport object used to send and receive SocketIO messages. + transport Transport + // Points to a DEEP namespace. + deepNamespace *IexDEEPNamespace + // Points to a Last namespace. + lastNamespace *IexLastNamespace + // Points to a TOPS namespace. + topsNamespace *IexTOPSNamespace +} + +func (c *Client) closeNamespace(ns string) { + c.Lock() + defer c.Unlock() + c.Unsubscribe(ns) + if !c.Subscribed(ns) { + enc := NewWSEncoder(ns) + r, err := enc.EncodePacket(Message, Disconnect) + if err != nil { + glog.Errorf( + "Error disconnecting from %s: %s", + ns, err) + } + msg, err := ioutil.ReadAll(r) + if err != nil { + glog.Errorf( + "Error disconnecting from %s: %s", + ns, err) + } + if _, err = c.transport.Write(msg); err != nil { + glog.Errorf( + "Error disconnecting from %s: %s", + ns, err) + } + switch ns { + case deep: + c.deepNamespace = nil + case last: + c.lastNamespace = nil + case tops: + c.topsNamespace = nil + } + } +} + +func (c *Client) GetDEEPNamespace() *IexDEEPNamespace { + if c.deepNamespace != nil { + return c.deepNamespace + } + c.deepNamespace = NewIexDEEPNamespace( + c.transport, deepSubUnsubFactory, c.closeNamespace) + return c.deepNamespace +} + +func (c *Client) GetLastNamespace() *IexLastNamespace { + if c.lastNamespace != nil { + return c.lastNamespace + } + c.lastNamespace = NewIexLastNamespace( + c.transport, simpleSubUnsubFactory, c.closeNamespace) + return c.lastNamespace +} + +func (c *Client) GetTOPSNamespace() *IexTOPSNamespace { + if c.topsNamespace != nil { + return c.topsNamespace + } + c.topsNamespace = NewIexTOPSNamespace( + c.transport, simpleSubUnsubFactory, c.closeNamespace) + return c.topsNamespace +} + +type defaultDialerWrapper struct { + dialer *websocket.Dialer +} + +func (d *defaultDialerWrapper) Dial(uri string, hdr http.Header) ( + WSConn, *http.Response, error) { + return d.dialer.Dial(uri, hdr) +} + +// Returns a SocketIO client that will use the passed in transport for +// communication. If it is nil, a default Transport will be created using an +// http.Client and websocket.DefaultDialer. The ability to inject a Tranport +// is mainly meant for testing. +func NewClientWithTransport(conn Transport) *Client { + toReturn := &Client{ + transport: conn, + } + if conn == nil { + wrapper := &defaultDialerWrapper{websocket.DefaultDialer} + jar, err := cookiejar.New(nil) + if err != nil { + glog.Fatalf("Error creating cookie jar: %s", err) + } + transport, err := NewTransport(&http.Client{Jar: jar}, wrapper) + if err != nil { + glog.Fatalf( + "Failed to create default transport: %s", + err) + } + toReturn.transport = transport + } + return toReturn + +} +func NewClient() *Client { + return NewClientWithTransport(nil) +} diff --git a/socketio/client_test.go b/socketio/client_test.go new file mode 100644 index 0000000..cefce62 --- /dev/null +++ b/socketio/client_test.go @@ -0,0 +1,138 @@ +package socketio_test + +import ( + "sync" + "testing" + + "github.com/golang/glog" + . "github.com/smartystreets/goconvey/convey" + "github.com/timpalpant/go-iex" + . "github.com/timpalpant/go-iex/socketio" +) + +type fakeTransport struct { + sync.Mutex + + messages []string + callbacks map[string]map[int]func(PacketData) + nextId int + closed bool +} + +func newFakeTransport() *fakeTransport { + return &fakeTransport{ + messages: make([]string, 0), + callbacks: make(map[string]map[int]func(PacketData)), + nextId: 0, + closed: false, + } +} + +func (f *fakeTransport) Write(data []byte) (int, error) { + glog.Infof("Fake transport writing message: %s", string(data)) + f.messages = append(f.messages, string(data)) + return len(data), nil +} + +func (f *fakeTransport) AddPacketCallback( + namespace string, callback func(PacketData)) (int, error) { + f.Lock() + defer f.Unlock() + f.nextId++ + if _, ok := f.callbacks[namespace]; !ok { + f.callbacks[namespace] = make(map[int]func(PacketData)) + } + f.callbacks[namespace][f.nextId] = callback + return f.nextId, nil +} + +func (f *fakeTransport) RemovePacketCallback(namespace string, id int) error { + f.Lock() + defer f.Unlock() + if _, ok := f.callbacks[namespace]; ok { + delete(f.callbacks[namespace], id) + } + return nil +} + +func (f *fakeTransport) TriggerCallbacks(pkt PacketData) { + f.Lock() + defer f.Unlock() + if ns, ok := f.callbacks[pkt.Namespace]; ok { + for _, callback := range ns { + callback(pkt) + } + } +} + +func (f *fakeTransport) Close() { + f.Lock() + defer f.Unlock() + f.closed = true +} + +func TestClient(t *testing.T) { + Convey("The Client should", t, func() { + ft := newFakeTransport() + Convey("send DEEP connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetDEEPNamespace() + handler := func(msg iex.DEEP) {} + ns.SubscribeTo(handler, "fb") + So(ft.messages[0], ShouldEqual, "40/1.0/deep,") + }) + Convey("send Last connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetLastNamespace() + handler := func(msg iex.Last) {} + ns.SubscribeTo(handler, "fb") + So(ft.messages[0], ShouldEqual, "40/1.0/last,") + }) + Convey("send TOPS connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetTOPSNamespace() + handler := func(msg iex.TOPS) {} + ns.SubscribeTo(handler, "fb") + So(ft.messages[0], ShouldEqual, "40/1.0/tops,") + }) + Convey("close DEEP connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetDEEPNamespace() + handler := func(msg iex.DEEP) {} + closer1, err := ns.SubscribeTo(handler, "fb") + So(err, ShouldBeNil) + closer2, err := ns.SubscribeTo(handler, "goog") + So(err, ShouldBeNil) + closer1() + closer2() + So("40/1.0/deep,", ShouldBeIn, ft.messages) + So("41/1.0/deep,", ShouldBeIn, ft.messages) + }) + Convey("close Last connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetLastNamespace() + handler := func(msg iex.Last) {} + closer1, err := ns.SubscribeTo(handler, "fb") + So(err, ShouldBeNil) + closer2, err := ns.SubscribeTo(handler, "goog") + So(err, ShouldBeNil) + closer1() + closer2() + So("40/1.0/last,", ShouldBeIn, ft.messages) + So("41/1.0/last,", ShouldBeIn, ft.messages) + }) + Convey("close TOPS connect", func() { + client := NewClientWithTransport(ft) + ns := client.GetTOPSNamespace() + handler := func(msg iex.TOPS) {} + closer1, err := ns.SubscribeTo(handler, "fb") + So(err, ShouldBeNil) + closer2, err := ns.SubscribeTo(handler, "goog") + So(err, ShouldBeNil) + closer1() + closer2() + So("40/1.0/tops,", ShouldBeIn, ft.messages) + So("41/1.0/tops,", ShouldBeIn, ft.messages) + }) + }) +} diff --git a/socketio/decoder.go b/socketio/decoder.go new file mode 100644 index 0000000..5f5b350 --- /dev/null +++ b/socketio/decoder.go @@ -0,0 +1,253 @@ +package socketio + +import ( + "encoding/json" + "io" + "io/ioutil" + "reflect" + "strconv" + "strings" + + "github.com/golang/glog" +) + +// SocketIO packet types. +type PacketType int + +// Defined in: https://preview.tinyurl.com/yxcgen7t +const ( + Open PacketType = iota + Close + Ping + Pong + Message + Upgrade + Noop +) + +// SocketIO event types. +type MessageType int + +// Most are unused. Defined in: https://preview.tinyurl.com/y3s4eh2y +const ( + Connect MessageType = iota + Disconnect + Event + Ack + Error + BinaryEvent + BinaryAck +) + +// The general SocketIO packet metadata. +type PacketData struct { + PacketType PacketType + MessageType MessageType + Namespace string + // The JSON string data remaining after the metadata has been parsed + // out. + Data string +} + +// SocketIO data uses a format :. This function splits on the +// first occurrence of ":", attempts to parse as an int, and returns +// . If there is a problem, the original string is returned. The method +// returns a second string parameter containing the remainder of the string if +// any. +func splitOnLength(input string) (string, string) { + parts := strings.SplitN(input, ":", 2) + if len(parts) != 2 { + return input, "" + } + length, err := strconv.Atoi(parts[0]) + if err != nil { + if glog.V(5) { + glog.Warningf("%s is not a length", parts[0]) + } + return input, "" + } + if glog.V(5) { + glog.Infof("Found response of length %d", length) + glog.Infof("Length actual data is %d", len(parts[1])) + } + return parts[1][:length], parts[1][length:] +} + +// Returns true if the first character is a number and sets the field of +// the passed in interface to the retrieved value if it exists. Also, the first +// char is removed from the decoder. Returns false if the first char is not a +// number. +func maybeProcessFirstChar( + name string, data string, v interface{}) bool { + if len(data) == 0 { + return false + } + firstChar := data[0] + number, err := strconv.Atoi(string(firstChar)) + if err != nil { + if glog.V(3) { + glog.Warningf("No %s found", name) + } + return false + } + instance := reflect.ValueOf(v).Elem() + typeOfV := instance.Type() + for i := 0; i < instance.NumField(); i++ { + f := instance.Field(i) + if typeOfV.Field(i).Name == name && f.Kind() == reflect.Int { + if glog.V(3) { + glog.Infof( + "Setting %s to %d", + name, number) + } + f.SetInt(int64(number)) + } + } + return true +} + +// Given a string of data, this method will attempt to parse out a namespace +// prefix. If it finds one and the passed in interface has a Namespace field, +// this method will set the field to the parsed value. Returns the original +// string if no namespace was found. Otherwise, the remaining string data is +// returned. +func maybeProcessNamespace(data string, v interface{}) string { + firstComma := strings.Index(data, ",") + firstOpenBracket := strings.Index(data, "[") + if data[0] == '/' && firstComma > -1 && firstComma < firstOpenBracket { + parts := strings.SplitN(data, ",", 2) + if glog.V(3) { + glog.Infof("Found namespace: %s", parts[0]) + } + instance := reflect.ValueOf(v).Elem() + typeOfV := instance.Type() + for i := 0; i < instance.NumField(); i++ { + f := instance.Field(i) + if typeOfV.Field(i).Name == "Namespace" && + f.Kind() == reflect.String { + if glog.V(3) { + glog.Infof( + "Setting Namespace to %s", + parts[0]) + } + f.SetString(parts[0]) + return parts[1] + } + } + return parts[1] + } + return data +} + +// An error type used when a potential JSON string is invalid. +type NotJsonError struct { + data string +} + +func (n *NotJsonError) Error() string { + return n.data +} + +// Parses the PacketType, MessageType and Namespace out of the passed in data +// string and into the fields of the same name on the passed in type v. The +// remaining string data is returned. If the metadata cannot be found or the +// passed in type v does not have PacketType, MessageType or Namespace fields, +// then no changes are made and the original data is returned. +func ParseMetadata(data string, v interface{}) string { + if len(data) == 0 { + return "" + } + minusTypes := data + if maybeProcessFirstChar("PacketType", minusTypes, v) { + minusTypes = minusTypes[1:] + if maybeProcessFirstChar("MessageType", minusTypes, v) { + minusTypes = minusTypes[1:] + } + } + if len(minusTypes) == 0 { + return "" + } + return maybeProcessNamespace(minusTypes, v) +} + +// Parses the actual JSON message into the passed in message type. The SocketIO +// response seems to alternate between a JSON array and a JSON object. In the +// case of the former, this method attempts to parse the second element of the +// array into v. If an error occurs, it is returned and v may not contain all +// parsed data. +func ParseToJSON(data string, v interface{}) error { + // The resulting is either a JSON object or a JSON array starting with + // the SocketIO event type string and ending with the JSON object. In + // order to handle both of these scenarios, Unmarshal first tries to + // parse a JSON array with the first element being an instance of v. If + // that fails, it tries to parse all the data into v. + if glog.V(5) { + glog.Infof("Checking JSON validity of %s", string(data)) + } + if !json.Valid([]byte(data)) { + return &NotJsonError{"invalid JSON"} + } + // Sometimes, the JSON is an array containing a string event type + // followed by the JSON object. Othertimes, it is just the object. Use + // jsonArray to test for the first case. + var jsonArray []json.RawMessage + err := json.Unmarshal([]byte(data), &jsonArray) + if err != nil { + if glog.V(3) { + glog.Warningf( + "Could not parse response as JSON array: %s", + err) + } + return json.Unmarshal([]byte(data), v) + } + if glog.V(3) { + glog.Infof("Parsed as JSON array: %s", string(jsonArray[1])) + } + jsonPart, err := strconv.Unquote(string(jsonArray[1])) + err = json.Unmarshal([]byte(jsonPart), v) + if err != nil { + if glog.V(3) { + glog.Errorf("Could not unmarshal data: %s", err) + } + return err + } + return nil +} + +// Parses the JSON HTTP SocketIO response from the given Reader into the passed +// in structs. For each of the passed in structs, if they contain MessageType +// or PacketType fields of type int, those fields will be populated with the +// corresponding response values. +func HTTPToJSON(data io.Reader, v []interface{}) error { + bytes, err := ioutil.ReadAll(data) + if err != nil { + glog.Errorf("Could not read input data: %s", err) + } + response := string(bytes) + glog.Infof("Parsing HTTP Response: %s", response) + + fillingIn := 0 + for true { + data, leftover := splitOnLength(response) + if glog.V(3) { + glog.Infof("Subresponse: %s", data) + glog.Infof("Leftover: %s", leftover) + } + remaining := ParseMetadata(data, v[fillingIn]) + if len(remaining) > 0 { + err := ParseToJSON(remaining, v[fillingIn]) + if err != nil { + glog.Warningf( + "Unable to parse message: %s; %s", + data, err) + return err + } + } + if len(leftover) == 0 { + break + } + response = leftover + fillingIn++ + } + return nil +} diff --git a/socketio/decoder_test.go b/socketio/decoder_test.go new file mode 100644 index 0000000..106e247 --- /dev/null +++ b/socketio/decoder_test.go @@ -0,0 +1,151 @@ +package socketio_test + +import ( + "strings" + "testing" + + . "github.com/smartystreets/goconvey/convey" + "github.com/timpalpant/go-iex" + . "github.com/timpalpant/go-iex/socketio" +) + +type fakeData struct { + Foo string + Bar []int +} + +type fakeDataWithTypes struct { + Foo string + Bar []int + MessageType int + PacketType int + Namespace string +} + +func TestUnsuccessfulDecoding(t *testing.T) { + Convey("HTTPToJSON", t, func() { + Convey("should error when the response is not JSON", func() { + data := strings.NewReader("just some data") + parsed := &fakeData{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldNotBeNil) + }) + }) +} + +func TestSuccessfulDecoding(t *testing.T) { + Convey("For a single message, HTTPToJSON", t, func() { + Convey("should populate a single struct", func() { + data := strings.NewReader( + `{"foo": "baz", "bar": [4, 6]}`) + parsed := &fakeData{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeData{"baz", []int{4, 6}}) + }) + Convey("should populate a single struct without types", func() { + data := strings.NewReader( + `44{"foo": "baz", "bar": [4, 6]}`) + parsed := &fakeData{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeData{"baz", []int{4, 6}}) + }) + Convey("should populate message type", func() { + data := strings.NewReader( + `44{"foo": "baz", "bar": [4, 6]}`) + parsed := &fakeDataWithTypes{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeDataWithTypes{ + "baz", []int{4, 6}, 4, 4, ""}) + }) + Convey("should populate message and packet type", func() { + data := strings.NewReader( + `44{"foo": "baz", "bar": [4, 6]}`) + parsed := &fakeDataWithTypes{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeDataWithTypes{ + "baz", []int{4, 6}, 4, 4, ""}) + }) + Convey("should populate only types", func() { + data := strings.NewReader(`44`) + parsed := &fakeDataWithTypes{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeDataWithTypes{"", []int(nil), 4, 4, ""}) + }) + Convey("should handle length encoding", func() { + data := strings.NewReader( + `31:44{"foo": "baz", "bar": [4, 6]}`) + parsed := &fakeDataWithTypes{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeDataWithTypes{ + "baz", []int{4, 6}, 4, 4, ""}) + }) + Convey("should handle length and namespace encoding", func() { + data := strings.NewReader( + `37:42/1.0/tops,{"foo":"baz","bar":[4,6]}`) + parsed := &fakeDataWithTypes{} + err := HTTPToJSON(data, []interface{}{parsed}) + So(err, ShouldBeNil) + So(parsed, ShouldResemble, + &fakeDataWithTypes{"baz", []int{4, 6}, + 2, 4, "/1.0/tops"}) + }) + Convey("should parse json array messages", func() { + data := "[\"message\", \"{\\\"symbol\\\":\\\"fb\\\"}\"]" + parsed := struct { + Symbol string + }{} + err := ParseToJSON(data, &parsed) + So(err, ShouldBeNil) + So(parsed.Symbol, ShouldEqual, "fb") + }) + }) +} +func TestSuccessfulDecodingMultipleMessages(t *testing.T) { + Convey("For a multiple messages, HTTPToJSON", t, func() { + Convey("should populate many structs", func() { + data := strings.NewReader( + `31:44{"foo": "baz", "bar": [4, 6]}31:44{"foo": "baz", "bar": [4, 6]}`) + parsedOne := &fakeData{} + parsedTwo := &fakeDataWithTypes{} + err := HTTPToJSON(data, + []interface{}{parsedOne, parsedTwo}) + So(err, ShouldBeNil) + So(parsedOne, ShouldResemble, + &fakeData{"baz", []int{4, 6}}) + So(parsedTwo, ShouldResemble, + &fakeDataWithTypes{ + "baz", []int{4, 6}, 4, 4, ""}) + }) + + }) +} +func TestDecodeActualTops(t *testing.T) { + Convey("For an actual Tops response, HTTPToJSON", t, func() { + Convey("should populate a Tops message", func() { + data := strings.NewReader(`348:42/1.0/tops,["message","{\"symbol\":\"SNAP\",\"sector\":\"mediaentertainment\",\"securityType\":\"commonstock\",\"bidPrice\":0.0000,\"bidSize\":0,\"askPrice\":0.0000,\"askSize\":0,\"lastUpdated\":1569873716685,\"lastSalePrice\":15.8000,\"lastSaleSize\":100,\"lastSaleTime\":1569873590063,\"volume\":458065,\"marketPercent\":0.02262,\"seq\":26739}"]344:42/1.0/tops,["message","{\"symbol\":\"FB\",\"sector\":\"mediaentertainment\",\"securityType\":\"commonstock\",\"bidPrice\":0.0000,\"bidSize\":0,\"askPrice\":0.0000,\"askSize\":0,\"lastUpdated\":1569876755318,\"lastSalePrice\":178.0750,\"lastSaleSize\":1,\"lastSaleTime\":1569873595907,\"volume\":411341,\"marketPercent\":0.03700,\"seq\":5904}"]325:42/1.0/tops,["message","{\"symbol\":\"AIG+\",\"sector\":\"n/a\",\"securityType\":\"warrant\",\"bidPrice\":0.0000,\"bidSize\":0,\"askPrice\":0.0000,\"askSize\":0,\"lastUpdated\":1569873600001,\"lastSalePrice\":14.3700,\"lastSaleSize\":200,\"lastSaleTime\":1569859449771,\"volume\":211,\"marketPercent\":0.00632,\"seq\":7281}"]`) + parsedOne := &iex.TOPS{} + parsedTwo := &iex.TOPS{} + parsedThree := &iex.TOPS{} + err := HTTPToJSON(data, + []interface{}{ + parsedOne, parsedTwo, parsedThree}) + So(err, ShouldBeNil) + So(parsedOne.Symbol, ShouldEqual, "SNAP") + So(parsedTwo.Symbol, ShouldEqual, "FB") + So(parsedThree.Symbol, ShouldEqual, "AIG+") + }) + + }) +} diff --git a/socketio/encoder.go b/socketio/encoder.go new file mode 100644 index 0000000..62ef4d9 --- /dev/null +++ b/socketio/encoder.go @@ -0,0 +1,197 @@ +package socketio + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "strings" + + "github.com/golang/glog" +) + +var msgTypeToNamespace = map[string]string{ + "IexDEEP": "/1.0/deep", + "IexLast": "/1.0/last", + "IexTOPS": "/1.0/tops", +} + +// Signals a subscribe or unsubscribe event. +type SubOrUnsub string + +const ( + Subscribe SubOrUnsub = "subscribe" + Unsubscribe SubOrUnsub = "unsubscribe" +) + +// A subUnsubMsgFactory takes in a set of string symbols to subscribe or +// unsubscribe to and returns an IEXMsg suitable for passing to an Encoder. This +// is used by namespaces to encode subscriptions and unsubscriptions. +type subUnsubMsgFactory func(signal SubOrUnsub, symbols []string) *IEXMsg + +// Returns a subscribe/unsubscribe struct for use by all endpoints except DEEP. +var simpleSubUnsubFactory = func( + signal SubOrUnsub, symbols []string) *IEXMsg { + return &IEXMsg{ + EventType: signal, + Data: strings.Join(symbols, ","), + } +} + +// Returns a subscribe/unsubscribe struct for use with the DEEP endpoint. Only +// a single symbol at a time can be used. If more than one symbol is passed in +// only the first one is used. +var deepSubUnsubFactory = func( + signal SubOrUnsub, symbols []string) *IEXMsg { + if len(symbols) > 1 { + glog.Error("DEEP can only subscribe to one symbol at a time") + } + json, err := json.Marshal(struct { + Symbols []string `json:"symbols"` + Channels []string `json:"channels"` + }{ + Symbols: symbols, + Channels: []string{"deep"}, + }) + if err != nil { + glog.Errorf("Could not encode DEEP %s", signal) + return nil + } + return &IEXMsg{ + EventType: signal, + Data: string(json), + } +} + +type IEXMsg struct { + // Contains a string representing subscribe or unsubscribe events. + EventType SubOrUnsub + // A string containing data to send. This is specific to a given + // endpoint. + Data string +} + +// Encodes messages for use with IEX SocketIO. MessageType and PacketType are +// defined in decoder.go. If the MessageType or PacketType are less than 0, +// they are not set on the output. +type Encoder interface { + // Encodes only a namespace and packet and message types. + EncodePacket(p PacketType, m MessageType) (io.Reader, error) + // Encodes a namespace, packet and message type and data. + EncodeMessage(p PacketType, m MessageType, msg *IEXMsg) ( + io.Reader, error) +} + +// Wraps a strArrayEncoder and returns its contents prepended by :. +type httpEncoder struct { + content *strArrayEncoder +} + +func (enc *httpEncoder) EncodePacket( + p PacketType, m MessageType) (io.Reader, error) { + inner, err := enc.content.EncodePacket(p, m) + if err != nil { + return nil, err + } + val, err := ioutil.ReadAll(inner) + if err != nil { + if glog.V(3) { + glog.Warningf("Failed to read inner encoding: %q", err) + } + return nil, err + } + if glog.V(3) { + glog.Infof("Encoded packet: %s", val) + } + parts := []string{fmt.Sprintf("%d", len(val)), string(val)} + return strings.NewReader(strings.Join(parts, ":")), nil +} + +func (enc *httpEncoder) EncodeMessage( + p PacketType, m MessageType, msg *IEXMsg) (io.Reader, error) { + inner, err := enc.content.EncodeMessage(p, m, msg) + if err != nil { + return nil, err + } + val, err := ioutil.ReadAll(inner) + if err != nil { + if glog.V(3) { + glog.Warningf("Failed to read inner encoding: %q", err) + } + return nil, err + } + if glog.V(3) { + glog.Infof("Inner encoding: %s", val) + } + parts := []string{fmt.Sprintf("%d", len(val)), string(val)} + return strings.NewReader(strings.Join(parts, ":")), nil +} + +// The base encoder implementation that performs as described by the interface. +type strArrayEncoder struct { + namespace string +} + +// Used to indicate an encoding error. +type encodeError struct { + message string +} + +func (e *encodeError) Error() string { + return e.message +} + +func (enc *strArrayEncoder) EncodePacket( + p PacketType, m MessageType) (io.Reader, error) { + readers := make([]io.Reader, 0) + if p >= 0 { + readers = append(readers, + strings.NewReader(fmt.Sprintf("%d", p))) + } + if m >= 0 { + readers = append(readers, + strings.NewReader(fmt.Sprintf("%d", m))) + } + if len(enc.namespace) > 0 { + readers = append(readers, + strings.NewReader(enc.namespace+",")) + } + return io.MultiReader(readers...), nil +} + +// Encodes a message, msg, of the given PacketType and MessageType. The +// resulting format is: +// ,[msg.Event, msg.Data] +func (enc *strArrayEncoder) EncodeMessage( + p PacketType, m MessageType, msg *IEXMsg) (io.Reader, error) { + reader, err := enc.EncodePacket(p, m) + if err != nil { + return nil, err + } + readers := []io.Reader{reader} + parts := []string{string(msg.EventType), msg.Data} + if glog.V(3) { + glog.Infof("Encoding parts: %v", parts) + } + encoding, err := json.Marshal(parts) + if err != nil { + glog.Errorf("Failed to encode data as JSON: %s", err) + return nil, err + } + if len(parts) > 0 { + readers = append(readers, bytes.NewBuffer(encoding)) + } + return io.MultiReader(readers...), nil + +} + +// Returns an encoder for use with HTTP Post. +func NewHTTPEncoder(namespace string) Encoder { + return &httpEncoder{&strArrayEncoder{namespace}} +} + +// Returns an encoder for use with SocketIO. +func NewWSEncoder(namespace string) Encoder { + return &strArrayEncoder{namespace} +} diff --git a/socketio/encoder_internal_test.go b/socketio/encoder_internal_test.go new file mode 100644 index 0000000..9458cd1 --- /dev/null +++ b/socketio/encoder_internal_test.go @@ -0,0 +1,46 @@ +package socketio + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestSubUnsubMsgFactory(t *testing.T) { + Convey("The simpleSubUnsubFactory should", t, func() { + Convey("returns IEXMsg", func() { + msg := simpleSubUnsubFactory(Subscribe, []string{ + "fb", "snap", + }) + So(msg, ShouldResemble, &IEXMsg{ + EventType: Subscribe, + Data: "fb,snap", + }) + msg = simpleSubUnsubFactory(Unsubscribe, []string{ + "goog", "aig+", + }) + So(msg, ShouldResemble, &IEXMsg{ + EventType: Unsubscribe, + Data: "goog,aig+", + }) + }) + }) + Convey("The deepSubUnsubFactory should", t, func() { + Convey("returns IEXMsg", func() { + msg := deepSubUnsubFactory(Subscribe, []string{ + "fb", + }) + So(msg, ShouldResemble, &IEXMsg{ + EventType: Subscribe, + Data: "{\"symbols\":[\"fb\"],\"channels\":[\"deep\"]}", + }) + msg = deepSubUnsubFactory(Unsubscribe, []string{ + "goog", + }) + So(msg, ShouldResemble, &IEXMsg{ + EventType: Unsubscribe, + Data: "{\"symbols\":[\"goog\"],\"channels\":[\"deep\"]}", + }) + }) + }) +} diff --git a/socketio/encoder_test.go b/socketio/encoder_test.go new file mode 100644 index 0000000..a4f1656 --- /dev/null +++ b/socketio/encoder_test.go @@ -0,0 +1,204 @@ +package socketio_test + +import ( + "encoding/json" + "io/ioutil" + "testing" + + . "github.com/smartystreets/goconvey/convey" + . "github.com/timpalpant/go-iex/socketio" +) + +func TestWebsocketEncoding(t *testing.T) { + Convey("Websocket encoding should", t, func() { + Convey("correctly encode a nil type", func() { + encoder := NewWSEncoder("/") + encoded, err := encoder.EncodePacket(-1, -1) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `/,`) + }) + Convey("correctly encode an empty type", func() { + encoder := NewWSEncoder("/") + encoded, err := encoder.EncodePacket(-1, -1) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `/,`) + }) + Convey("correctly send an upgrade request", func() { + encoder := NewWSEncoder("") + encoded, err := encoder.EncodePacket(5, -1) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `5`) + }) + Convey("correctly encode a simple type", func() { + encoder := NewWSEncoder("") + encoding, err := json.Marshal(struct { + Name string + Ints []int + }{ + Name: "foo", + Ints: []int{1, 2, 3}, + }) + So(err, ShouldBeNil) + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: string(encoding), + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `["subscribe","{\"Name\":\"foo\",\"Ints\":[1,2,3]}"]`) + }) + Convey("correctly encode a namespace", func() { + encoder := NewWSEncoder("/") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `/,["subscribe","foo"]`) + }) + Convey("correctly encode a longer namespace", func() { + encoder := NewWSEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `/1.0/tops,["subscribe","foo"]`) + }) + Convey("correctly encode the packet type", func() { + encoder := NewWSEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(4, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `4/1.0/tops,["subscribe","foo"]`) + }) + Convey("correctly encode the packet and message type", func() { + encoder := NewWSEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(4, 2, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `42/1.0/tops,["subscribe","foo"]`) + }) + }) +} + +func TestHTTPEncoding(t *testing.T) { + Convey("HTTP encoding should", t, func() { + Convey("correctly encode a nil type", func() { + encoder := NewHTTPEncoder("/") + encoded, err := encoder.EncodePacket(-1, -1) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `2:/,`) + }) + Convey("correctly encode an empty type", func() { + encoder := NewHTTPEncoder("/") + encoded, err := encoder.EncodePacket(4, 0) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `4:40/,`) + }) + Convey("correctly encode a simple type", func() { + encoder := NewHTTPEncoder("") + encoding, err := json.Marshal(struct { + Name string + Ints []int + }{ + Name: "foo", + Ints: []int{1, 2, 3}, + }) + So(err, ShouldBeNil) + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: string(encoding), + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `51:["subscribe","{\"Name\":\"foo\",\"Ints\":[1,2,3]}"]`) + }) + Convey("correctly encode a namespace", func() { + encoder := NewHTTPEncoder("/") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, `21:/,["subscribe","foo"]`) + }) + Convey("correctly encode a longer namespace", func() { + encoder := NewHTTPEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(-1, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `29:/1.0/tops,["subscribe","foo"]`) + }) + Convey("correctly encode the packet type", func() { + encoder := NewHTTPEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(4, -1, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `30:4/1.0/tops,["subscribe","foo"]`) + }) + Convey("correctly encode the packet and message type", func() { + encoder := NewHTTPEncoder("/1.0/tops") + iexMsg := &IEXMsg{ + EventType: Subscribe, + Data: "foo", + } + encoded, err := encoder.EncodeMessage(4, 2, iexMsg) + So(err, ShouldBeNil) + val, err := ioutil.ReadAll(encoded) + So(err, ShouldBeNil) + So(string(val), ShouldEqual, + `31:42/1.0/tops,["subscribe","foo"]`) + }) + }) +} diff --git a/socketio/endpoints.go b/socketio/endpoints.go new file mode 100644 index 0000000..0a6bee6 --- /dev/null +++ b/socketio/endpoints.go @@ -0,0 +1,73 @@ +package socketio + +import "net/url" + +// HTTP endpoint. +var httpEndpoint, _ = url.Parse("https://ws-api.iextrading.com/socket.io/") + +// Wbsocket endpoint. +var wsEndpoint, _ = url.Parse("wss://ws-api.iextrading.com/socket.io/") + +// An interface for hiding iexEndpoint. +type Endpoint interface { + SetSid(sid string) + GetHTTPUrl() string + GetWSUrl() string +} + +// Provides methods for manipulating the IEX websocket URL. +type iexEndpoint struct { + // URL for making HTTP requests. + httpUrl *url.URL + + // URL for making websocket requests. + wsUrl *url.URL + + // Method for generating unique timestamps. + idg func() string +} + +func (e *iexEndpoint) SetSid(sid string) { + httpValues := e.httpUrl.Query() + httpValues.Set("sid", sid) + e.httpUrl.RawQuery = httpValues.Encode() + + wsValues := e.wsUrl.Query() + wsValues.Set("sid", sid) + e.wsUrl.RawQuery = wsValues.Encode() +} + +func (e *iexEndpoint) GetHTTPUrl() string { + httpValues := e.httpUrl.Query() + httpValues.Set("t", e.idg()) + e.httpUrl.RawQuery = httpValues.Encode() + return e.httpUrl.String() +} + +func (e *iexEndpoint) GetWSUrl() string { + wsValues := e.wsUrl.Query() + wsValues.Set("t", e.idg()) + e.wsUrl.RawQuery = wsValues.Encode() + return e.wsUrl.String() +} + +func (e *iexEndpoint) Initialize() { + // Initialize the HTTP enpoint query params. + httpValues := e.httpUrl.Query() + httpValues.Set("EIO", "3") + httpValues.Set("transport", "polling") + httpValues.Set("b64", "1") + e.httpUrl.RawQuery = httpValues.Encode() + + // Initialize the Websocket enpoint query params. + wsValues := e.wsUrl.Query() + wsValues.Set("EIO", "3") + wsValues.Set("transport", "websocket") + e.wsUrl.RawQuery = wsValues.Encode() +} + +func NewIEXEndpoint(idg func() string) Endpoint { + endpoint := &iexEndpoint{httpEndpoint, wsEndpoint, idg} + endpoint.Initialize() + return endpoint +} diff --git a/socketio/endpoints_test.go b/socketio/endpoints_test.go new file mode 100644 index 0000000..2310198 --- /dev/null +++ b/socketio/endpoints_test.go @@ -0,0 +1,103 @@ +package socketio_test + +import ( + "net/url" + "testing" + + . "github.com/smartystreets/goconvey/convey" + . "github.com/timpalpant/go-iex/socketio" +) + +func TestIEXEndpoint(t *testing.T) { + Convey("The Endpoint", t, func() { + Convey("should have the correct base http URL", func() { + endpoint := NewIEXEndpoint(func() string { + return "123" + }) + to, err := url.Parse(endpoint.GetHTTPUrl()) + So(err, ShouldBeNil) + So(to.Scheme, ShouldEqual, "https") + So(to.Host, ShouldEqual, "ws-api.iextrading.com") + So(to.Path, ShouldEqual, "/socket.io/") + values := to.Query() + So(values.Get("EIO"), ShouldEqual, "3") + So(values.Get("transport"), ShouldEqual, "polling") + So(values.Get("b64"), ShouldEqual, "1") + }) + Convey("should have the correct base Websocket URL", func() { + endpoint := NewIEXEndpoint(func() string { + return "123" + }) + to, err := url.Parse(endpoint.GetWSUrl()) + So(err, ShouldBeNil) + So(to.Scheme, ShouldEqual, "wss") + So(to.Host, ShouldEqual, "ws-api.iextrading.com") + So(to.Path, ShouldEqual, "/socket.io/") + values := to.Query() + So(values.Get("EIO"), ShouldEqual, "3") + So(values.Get("transport"), ShouldEqual, "websocket") + }) + Convey("should set the SID on the HTTP URL", func() { + sid := "4567" + endpoint := NewIEXEndpoint(func() string { + return "123" + }) + endpoint.SetSid(sid) + to, err := url.Parse(endpoint.GetHTTPUrl()) + So(err, ShouldBeNil) + values := to.Query() + So(values.Get("sid"), ShouldEqual, sid) + }) + Convey("should set the SID on the Websocket URL", func() { + sid := "4567" + endpoint := NewIEXEndpoint(func() string { + return "123" + }) + endpoint.SetSid(sid) + to, err := url.Parse(endpoint.GetWSUrl()) + So(err, ShouldBeNil) + values := to.Query() + So(values.Get("sid"), ShouldEqual, sid) + }) + Convey("should change 't' for HTTP URLs", func() { + timestamps := []string{"123", "456"} + index := 0 + endpoint := NewIEXEndpoint(func() string { + timestamp := timestamps[index] + index++ + return timestamp + }) + to, err := url.Parse(endpoint.GetHTTPUrl()) + So(err, ShouldBeNil) + values := to.Query() + first := values.Get("t") + + to, err = url.Parse(endpoint.GetHTTPUrl()) + So(err, ShouldBeNil) + values = to.Query() + second := values.Get("t") + + So(first, ShouldNotEqual, second) + }) + Convey("should change 't' for WS URLs", func() { + timestamps := []string{"123", "456"} + index := 0 + endpoint := NewIEXEndpoint(func() string { + timestamp := timestamps[index] + index++ + return timestamp + }) + to, err := url.Parse(endpoint.GetWSUrl()) + So(err, ShouldBeNil) + values := to.Query() + first := values.Get("t") + + to, err = url.Parse(endpoint.GetWSUrl()) + So(err, ShouldBeNil) + values = to.Query() + second := values.Get("t") + + So(first, ShouldNotEqual, second) + }) + }) +} diff --git a/socketio/gen-namespace.go b/socketio/gen-namespace.go new file mode 100644 index 0000000..3543e2e --- /dev/null +++ b/socketio/gen-namespace.go @@ -0,0 +1,599 @@ +// This file was automatically generated by genny. +// Any changes will be lost if this file is regenerated. +// see https://github.com/cheekybits/genny + +package socketio + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + "sync" + + "github.com/golang/glog" + "github.com/timpalpant/go-iex" +) + +// The iexMsgTypeNamespace is a generic class built using Genny. +// https://github.com/cheekybits/genny +// Run "go generate" to re-generate the specific namespace types. + +// Contains callbacks and the symbols they correspond to. +type subIexTOPS struct { + Callback func(iex.TOPS) + Symbols map[string]struct{} +} + +// Receives messages for a given namespace and forwards them to endpoints. +type IexTOPSNamespace struct { + // Used to guard access to the fanout channels. + sync.RWMutex + + // A set of symbols that this namespace is currently subscribed to. + // This spans across subcriptions so that unsubscribing from a symbol + // only occurs if there are no subscriptions listening for that symbol. + symbols Subscriber + // The ID to use for the next connection created. + nextId int + // Active subscriptions by ID. + subscriptions map[int]*subIexTOPS + // For encoding outgoing messages in this namespace. + encoder Encoder + // Used for sending messages to the Transport. + writer io.Writer + // The factory function used to generate subscribe/unsubscribe messages. + // Subscribe and unsubscribe messages can differe by IEX namespace. + subUnsubMsgFactory subUnsubMsgFactory + // A function to be called when the namespace has no more endpoints. + closeFunc func(string) +} + +// Sends a subscribe message. This is performed when the number of subscriptions +// goes from 0 to 1. +func (i *IexTOPSNamespace) sendPacket(msgType MessageType) error { + r, err := i.encoder.EncodePacket(Message, msgType) + if err != nil { + return err + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Encodes and sends a subscribe or unsubscribe message on the transport layer. +func (i *IexTOPSNamespace) sendSubUnsub(subUnsubMsg *IEXMsg) error { + r, err := i.encoder.EncodeMessage(Message, Event, subUnsubMsg) + if err != nil { + return fmt.Errorf("Error encoding %+v: %s", subUnsubMsg, err) + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Given a string representing a JSON IEX message type, parse out the symbol and +// message and pass the message to each connection subscribed to the symbol. +func (i *IexTOPSNamespace) fanout(pkt PacketData) { + // This "symbol only" struct is necessary because this class + // is a genny generic. Therefore, even though all IEX messages + // have a "symbol" field, iexMsgType.symbol is not type safe. + var symbol struct { + Symbol string + } + if err := ParseToJSON(pkt.Data, &symbol); err != nil { + glog.Errorf("No symbol found for iexMsgType: %s - %v", + err, pkt) + } + // Now that the symbol has been extraced, the specific message + // can be extracted from the data. + var decoded iex.TOPS + if err := ParseToJSON(pkt.Data, &decoded); err != nil { + glog.Errorf("Could not decode iexMsgType: %s - %v", + err, pkt) + } + if glog.V(5) { + glog.Infof("Extracted symbol: %v", symbol) + glog.Infof("Extracted message: %v", decoded) + } + i.RLock() + defer i.RUnlock() + for _, sub := range i.subscriptions { + if glog.V(5) { + glog.Infof("Checking for subscription to %s", + symbol.Symbol) + } + if _, ok := sub.Symbols[symbol.Symbol]; ok { + if glog.V(5) { + glog.Infof("Calling subscription to %s", + symbol.Symbol) + } + sub.Callback(decoded) + } + } +} + +// Returns a method that is passed to new Connections, to be called when the +// connection is being closed. +func (i *IexTOPSNamespace) getCloseSubscriptionFunc(id int) func() { + return func() { + i.Lock() + unsub := make([]string, 0) + sub := i.subscriptions[id] + // Unsubscribe from the subscription symbols. For any that are + // no longer being listened to by any subscription, send an + // unsubscribe event to IEX. If there are no more subscriptions + // in the namespace, disconnect from the namespace. + for key, _ := range sub.Symbols { + i.symbols.Unsubscribe(key) + if !i.symbols.Subscribed(key) { + unsub = append(unsub, key) + } + } + delete(i.subscriptions, id) + i.Unlock() + for _, symbol := range unsub { + err := i.sendSubUnsub(i.subUnsubMsgFactory( + Unsubscribe, []string{symbol})) + if err != nil { + glog.Errorf("Error unsubscrubing from %v: %s", + unsub, err) + } + } + if len(i.subscriptions) == 0 { + i.closeFunc(msgTypeToNamespace["IexTOPS"]) + } + } +} + +// Receive messages for the passed in symbols using the passed in callback. +// Returns a close function that should be called when the client does not wish +// to receive any further messages. If symbols is empty, an error is returned. +func (i *IexTOPSNamespace) SubscribeTo( + msgReceived func(msg iex.TOPS), symbols ...string) (func(), error) { + if len(symbols) == 0 { + return nil, errors.New( + "Cannot call SubscribeTo with no symbols") + } + i.Lock() + defer i.Unlock() + // Connect to the namespace when adding the first subscription. + if len(i.subscriptions) == 0 { + i.sendPacket(Connect) + } + i.nextId++ + newSub := &subIexTOPS{ + Callback: msgReceived, + Symbols: make(map[string]struct{}), + } + if len(symbols) > 0 { + var err error + for _, symbol := range symbols { + symbol = strings.ToUpper(symbol) + newSub.Symbols[symbol] = struct{}{} + i.symbols.Subscribe(symbol) + // Subscribe to each symbol individually. This allows + // DEEP, which only allows subscribing to a single + // symbol at a time, to use the same path. + err = i.sendSubUnsub(i.subUnsubMsgFactory( + Subscribe, []string{symbol})) + } + if err != nil { + return nil, err + } + } + i.subscriptions[i.nextId] = newSub + return i.getCloseSubscriptionFunc(i.nextId), nil +} + +// Create a new namespace for a specific IEX endpoint. Because the IEX +// namespaces use different message types for representing the received data, +// these classes are represented as generics using Genny. +func NewIexTOPSNamespace( + transport Transport, subUnsubMsgFactory subUnsubMsgFactory, + closeFunc func(string)) *IexTOPSNamespace { + namespace := msgTypeToNamespace["IexTOPS"] + encoder := NewWSEncoder(namespace) + newNs := &IexTOPSNamespace{ + symbols: NewCountingSubscriber(), + nextId: 0, + subscriptions: make(map[int]*subIexTOPS), + encoder: encoder, + writer: transport, + subUnsubMsgFactory: subUnsubMsgFactory, + closeFunc: closeFunc, + } + transport.AddPacketCallback(namespace, newNs.fanout) + return newNs +} + +// The iexMsgTypeNamespace is a generic class built using Genny. +// https://github.com/cheekybits/genny +// Run "go generate" to re-generate the specific namespace types. + +// Contains callbacks and the symbols they correspond to. +type subIexLast struct { + Callback func(iex.Last) + Symbols map[string]struct{} +} + +// Receives messages for a given namespace and forwards them to endpoints. +type IexLastNamespace struct { + // Used to guard access to the fanout channels. + sync.RWMutex + + // A set of symbols that this namespace is currently subscribed to. + // This spans across subcriptions so that unsubscribing from a symbol + // only occurs if there are no subscriptions listening for that symbol. + symbols Subscriber + // The ID to use for the next connection created. + nextId int + // Active subscriptions by ID. + subscriptions map[int]*subIexLast + // For encoding outgoing messages in this namespace. + encoder Encoder + // Used for sending messages to the Transport. + writer io.Writer + // The factory function used to generate subscribe/unsubscribe messages. + // Subscribe and unsubscribe messages can differe by IEX namespace. + subUnsubMsgFactory subUnsubMsgFactory + // A function to be called when the namespace has no more endpoints. + closeFunc func(string) +} + +// Sends a subscribe message. This is performed when the number of subscriptions +// goes from 0 to 1. +func (i *IexLastNamespace) sendPacket(msgType MessageType) error { + r, err := i.encoder.EncodePacket(Message, msgType) + if err != nil { + return err + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Encodes and sends a subscribe or unsubscribe message on the transport layer. +func (i *IexLastNamespace) sendSubUnsub(subUnsubMsg *IEXMsg) error { + r, err := i.encoder.EncodeMessage(Message, Event, subUnsubMsg) + if err != nil { + return fmt.Errorf("Error encoding %+v: %s", subUnsubMsg, err) + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Given a string representing a JSON IEX message type, parse out the symbol and +// message and pass the message to each connection subscribed to the symbol. +func (i *IexLastNamespace) fanout(pkt PacketData) { + // This "symbol only" struct is necessary because this class + // is a genny generic. Therefore, even though all IEX messages + // have a "symbol" field, iexMsgType.symbol is not type safe. + var symbol struct { + Symbol string + } + if err := ParseToJSON(pkt.Data, &symbol); err != nil { + glog.Errorf("No symbol found for iexMsgType: %s - %v", + err, pkt) + } + // Now that the symbol has been extraced, the specific message + // can be extracted from the data. + var decoded iex.Last + if err := ParseToJSON(pkt.Data, &decoded); err != nil { + glog.Errorf("Could not decode iexMsgType: %s - %v", + err, pkt) + } + if glog.V(5) { + glog.Infof("Extracted symbol: %v", symbol) + glog.Infof("Extracted message: %v", decoded) + } + i.RLock() + defer i.RUnlock() + for _, sub := range i.subscriptions { + if glog.V(5) { + glog.Infof("Checking for subscription to %s", + symbol.Symbol) + } + if _, ok := sub.Symbols[symbol.Symbol]; ok { + if glog.V(5) { + glog.Infof("Calling subscription to %s", + symbol.Symbol) + } + sub.Callback(decoded) + } + } +} + +// Returns a method that is passed to new Connections, to be called when the +// connection is being closed. +func (i *IexLastNamespace) getCloseSubscriptionFunc(id int) func() { + return func() { + i.Lock() + unsub := make([]string, 0) + sub := i.subscriptions[id] + // Unsubscribe from the subscription symbols. For any that are + // no longer being listened to by any subscription, send an + // unsubscribe event to IEX. If there are no more subscriptions + // in the namespace, disconnect from the namespace. + for key, _ := range sub.Symbols { + i.symbols.Unsubscribe(key) + if !i.symbols.Subscribed(key) { + unsub = append(unsub, key) + } + } + delete(i.subscriptions, id) + i.Unlock() + for _, symbol := range unsub { + err := i.sendSubUnsub(i.subUnsubMsgFactory( + Unsubscribe, []string{symbol})) + if err != nil { + glog.Errorf("Error unsubscrubing from %v: %s", + unsub, err) + } + } + if len(i.subscriptions) == 0 { + i.closeFunc(msgTypeToNamespace["IexLast"]) + } + } +} + +// Receive messages for the passed in symbols using the passed in callback. +// Returns a close function that should be called when the client does not wish +// to receive any further messages. If symbols is empty, an error is returned. +func (i *IexLastNamespace) SubscribeTo( + msgReceived func(msg iex.Last), symbols ...string) (func(), error) { + if len(symbols) == 0 { + return nil, errors.New( + "Cannot call SubscribeTo with no symbols") + } + i.Lock() + defer i.Unlock() + // Connect to the namespace when adding the first subscription. + if len(i.subscriptions) == 0 { + i.sendPacket(Connect) + } + i.nextId++ + newSub := &subIexLast{ + Callback: msgReceived, + Symbols: make(map[string]struct{}), + } + if len(symbols) > 0 { + var err error + for _, symbol := range symbols { + symbol = strings.ToUpper(symbol) + newSub.Symbols[symbol] = struct{}{} + i.symbols.Subscribe(symbol) + // Subscribe to each symbol individually. This allows + // DEEP, which only allows subscribing to a single + // symbol at a time, to use the same path. + err = i.sendSubUnsub(i.subUnsubMsgFactory( + Subscribe, []string{symbol})) + } + if err != nil { + return nil, err + } + } + i.subscriptions[i.nextId] = newSub + return i.getCloseSubscriptionFunc(i.nextId), nil +} + +// Create a new namespace for a specific IEX endpoint. Because the IEX +// namespaces use different message types for representing the received data, +// these classes are represented as generics using Genny. +func NewIexLastNamespace( + transport Transport, subUnsubMsgFactory subUnsubMsgFactory, + closeFunc func(string)) *IexLastNamespace { + namespace := msgTypeToNamespace["IexLast"] + encoder := NewWSEncoder(namespace) + newNs := &IexLastNamespace{ + symbols: NewCountingSubscriber(), + nextId: 0, + subscriptions: make(map[int]*subIexLast), + encoder: encoder, + writer: transport, + subUnsubMsgFactory: subUnsubMsgFactory, + closeFunc: closeFunc, + } + transport.AddPacketCallback(namespace, newNs.fanout) + return newNs +} + +// The iexMsgTypeNamespace is a generic class built using Genny. +// https://github.com/cheekybits/genny +// Run "go generate" to re-generate the specific namespace types. + +// Contains callbacks and the symbols they correspond to. +type subIexDEEP struct { + Callback func(iex.DEEP) + Symbols map[string]struct{} +} + +// Receives messages for a given namespace and forwards them to endpoints. +type IexDEEPNamespace struct { + // Used to guard access to the fanout channels. + sync.RWMutex + + // A set of symbols that this namespace is currently subscribed to. + // This spans across subcriptions so that unsubscribing from a symbol + // only occurs if there are no subscriptions listening for that symbol. + symbols Subscriber + // The ID to use for the next connection created. + nextId int + // Active subscriptions by ID. + subscriptions map[int]*subIexDEEP + // For encoding outgoing messages in this namespace. + encoder Encoder + // Used for sending messages to the Transport. + writer io.Writer + // The factory function used to generate subscribe/unsubscribe messages. + // Subscribe and unsubscribe messages can differe by IEX namespace. + subUnsubMsgFactory subUnsubMsgFactory + // A function to be called when the namespace has no more endpoints. + closeFunc func(string) +} + +// Sends a subscribe message. This is performed when the number of subscriptions +// goes from 0 to 1. +func (i *IexDEEPNamespace) sendPacket(msgType MessageType) error { + r, err := i.encoder.EncodePacket(Message, msgType) + if err != nil { + return err + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Encodes and sends a subscribe or unsubscribe message on the transport layer. +func (i *IexDEEPNamespace) sendSubUnsub(subUnsubMsg *IEXMsg) error { + r, err := i.encoder.EncodeMessage(Message, Event, subUnsubMsg) + if err != nil { + return fmt.Errorf("Error encoding %+v: %s", subUnsubMsg, err) + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Given a string representing a JSON IEX message type, parse out the symbol and +// message and pass the message to each connection subscribed to the symbol. +func (i *IexDEEPNamespace) fanout(pkt PacketData) { + // This "symbol only" struct is necessary because this class + // is a genny generic. Therefore, even though all IEX messages + // have a "symbol" field, iexMsgType.symbol is not type safe. + var symbol struct { + Symbol string + } + if err := ParseToJSON(pkt.Data, &symbol); err != nil { + glog.Errorf("No symbol found for iexMsgType: %s - %v", + err, pkt) + } + // Now that the symbol has been extraced, the specific message + // can be extracted from the data. + var decoded iex.DEEP + if err := ParseToJSON(pkt.Data, &decoded); err != nil { + glog.Errorf("Could not decode iexMsgType: %s - %v", + err, pkt) + } + if glog.V(5) { + glog.Infof("Extracted symbol: %v", symbol) + glog.Infof("Extracted message: %v", decoded) + } + i.RLock() + defer i.RUnlock() + for _, sub := range i.subscriptions { + if glog.V(5) { + glog.Infof("Checking for subscription to %s", + symbol.Symbol) + } + if _, ok := sub.Symbols[symbol.Symbol]; ok { + if glog.V(5) { + glog.Infof("Calling subscription to %s", + symbol.Symbol) + } + sub.Callback(decoded) + } + } +} + +// Returns a method that is passed to new Connections, to be called when the +// connection is being closed. +func (i *IexDEEPNamespace) getCloseSubscriptionFunc(id int) func() { + return func() { + i.Lock() + unsub := make([]string, 0) + sub := i.subscriptions[id] + // Unsubscribe from the subscription symbols. For any that are + // no longer being listened to by any subscription, send an + // unsubscribe event to IEX. If there are no more subscriptions + // in the namespace, disconnect from the namespace. + for key, _ := range sub.Symbols { + i.symbols.Unsubscribe(key) + if !i.symbols.Subscribed(key) { + unsub = append(unsub, key) + } + } + delete(i.subscriptions, id) + i.Unlock() + for _, symbol := range unsub { + err := i.sendSubUnsub(i.subUnsubMsgFactory( + Unsubscribe, []string{symbol})) + if err != nil { + glog.Errorf("Error unsubscrubing from %v: %s", + unsub, err) + } + } + if len(i.subscriptions) == 0 { + i.closeFunc(msgTypeToNamespace["IexDEEP"]) + } + } +} + +// Receive messages for the passed in symbols using the passed in callback. +// Returns a close function that should be called when the client does not wish +// to receive any further messages. If symbols is empty, an error is returned. +func (i *IexDEEPNamespace) SubscribeTo( + msgReceived func(msg iex.DEEP), symbols ...string) (func(), error) { + if len(symbols) == 0 { + return nil, errors.New( + "Cannot call SubscribeTo with no symbols") + } + i.Lock() + defer i.Unlock() + // Connect to the namespace when adding the first subscription. + if len(i.subscriptions) == 0 { + i.sendPacket(Connect) + } + i.nextId++ + newSub := &subIexDEEP{ + Callback: msgReceived, + Symbols: make(map[string]struct{}), + } + if len(symbols) > 0 { + var err error + for _, symbol := range symbols { + symbol = strings.ToUpper(symbol) + newSub.Symbols[symbol] = struct{}{} + i.symbols.Subscribe(symbol) + // Subscribe to each symbol individually. This allows + // DEEP, which only allows subscribing to a single + // symbol at a time, to use the same path. + err = i.sendSubUnsub(i.subUnsubMsgFactory( + Subscribe, []string{symbol})) + } + if err != nil { + return nil, err + } + } + i.subscriptions[i.nextId] = newSub + return i.getCloseSubscriptionFunc(i.nextId), nil +} + +// Create a new namespace for a specific IEX endpoint. Because the IEX +// namespaces use different message types for representing the received data, +// these classes are represented as generics using Genny. +func NewIexDEEPNamespace( + transport Transport, subUnsubMsgFactory subUnsubMsgFactory, + closeFunc func(string)) *IexDEEPNamespace { + namespace := msgTypeToNamespace["IexDEEP"] + encoder := NewWSEncoder(namespace) + newNs := &IexDEEPNamespace{ + symbols: NewCountingSubscriber(), + nextId: 0, + subscriptions: make(map[int]*subIexDEEP), + encoder: encoder, + writer: transport, + subUnsubMsgFactory: subUnsubMsgFactory, + closeFunc: closeFunc, + } + transport.AddPacketCallback(namespace, newNs.fanout) + return newNs +} diff --git a/socketio/namespace.go b/socketio/namespace.go new file mode 100644 index 0000000..f0b3796 --- /dev/null +++ b/socketio/namespace.go @@ -0,0 +1,212 @@ +package socketio + +// The iexMsgTypeNamespace is a generic class built using Genny. +// https://github.com/cheekybits/genny +// Run "go generate" to re-generate the specific namespace types. + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + "sync" + + "github.com/cheekybits/genny/generic" + "github.com/golang/glog" +) + +// The generic type representing the IEX message parsed by the namespace. +type IEXMsgType generic.Type + +// Contains callbacks and the symbols they correspond to. +type subIEXMsgType struct { + Callback func(IEXMsgType) + Symbols map[string]struct{} +} + +// Receives messages for a given namespace and forwards them to endpoints. +type IEXMsgTypeNamespace struct { + // Used to guard access to the fanout channels. + sync.RWMutex + + // A set of symbols that this namespace is currently subscribed to. + // This spans across subcriptions so that unsubscribing from a symbol + // only occurs if there are no subscriptions listening for that symbol. + symbols Subscriber + // The ID to use for the next connection created. + nextId int + // Active subscriptions by ID. + subscriptions map[int]*subIEXMsgType + // For encoding outgoing messages in this namespace. + encoder Encoder + // Used for sending messages to the Transport. + writer io.Writer + // The factory function used to generate subscribe/unsubscribe messages. + // Subscribe and unsubscribe messages can differe by IEX namespace. + subUnsubMsgFactory subUnsubMsgFactory + // A function to be called when the namespace has no more endpoints. + closeFunc func(string) +} + +// Sends a subscribe message. This is performed when the number of subscriptions +// goes from 0 to 1. +func (i *IEXMsgTypeNamespace) sendPacket(msgType MessageType) error { + r, err := i.encoder.EncodePacket(Message, msgType) + if err != nil { + return err + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Encodes and sends a subscribe or unsubscribe message on the transport layer. +func (i *IEXMsgTypeNamespace) sendSubUnsub(subUnsubMsg *IEXMsg) error { + r, err := i.encoder.EncodeMessage(Message, Event, subUnsubMsg) + if err != nil { + return fmt.Errorf("Error encoding %+v: %s", subUnsubMsg, err) + } + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(r) + _, err = buffer.WriteTo(i.writer) + return err +} + +// Given a string representing a JSON IEX message type, parse out the symbol and +// message and pass the message to each connection subscribed to the symbol. +func (i *IEXMsgTypeNamespace) fanout(pkt PacketData) { + // This "symbol only" struct is necessary because this class + // is a genny generic. Therefore, even though all IEX messages + // have a "symbol" field, iexMsgType.symbol is not type safe. + var symbol struct { + Symbol string + } + if err := ParseToJSON(pkt.Data, &symbol); err != nil { + glog.Errorf("No symbol found for iexMsgType: %s - %v", + err, pkt) + } + // Now that the symbol has been extraced, the specific message + // can be extracted from the data. + var decoded IEXMsgType + if err := ParseToJSON(pkt.Data, &decoded); err != nil { + glog.Errorf("Could not decode iexMsgType: %s - %v", + err, pkt) + } + if glog.V(5) { + glog.Infof("Extracted symbol: %v", symbol) + glog.Infof("Extracted message: %v", decoded) + } + i.RLock() + defer i.RUnlock() + for _, sub := range i.subscriptions { + if glog.V(5) { + glog.Infof("Checking for subscription to %s", + symbol.Symbol) + } + if _, ok := sub.Symbols[symbol.Symbol]; ok { + if glog.V(5) { + glog.Infof("Calling subscription to %s", + symbol.Symbol) + } + sub.Callback(decoded) + } + } +} + +// Returns a method that is passed to new Connections, to be called when the +// connection is being closed. +func (i *IEXMsgTypeNamespace) getCloseSubscriptionFunc(id int) func() { + return func() { + i.Lock() + unsub := make([]string, 0) + sub := i.subscriptions[id] + // Unsubscribe from the subscription symbols. For any that are + // no longer being listened to by any subscription, send an + // unsubscribe event to IEX. If there are no more subscriptions + // in the namespace, disconnect from the namespace. + for key, _ := range sub.Symbols { + i.symbols.Unsubscribe(key) + if !i.symbols.Subscribed(key) { + unsub = append(unsub, key) + } + } + delete(i.subscriptions, id) + i.Unlock() + for _, symbol := range unsub { + err := i.sendSubUnsub(i.subUnsubMsgFactory( + Unsubscribe, []string{symbol})) + if err != nil { + glog.Errorf("Error unsubscrubing from %v: %s", + unsub, err) + } + } + if len(i.subscriptions) == 0 { + i.closeFunc(msgTypeToNamespace["IEXMsgType"]) + } + } +} + +// Receive messages for the passed in symbols using the passed in callback. +// Returns a close function that should be called when the client does not wish +// to receive any further messages. If symbols is empty, an error is returned. +func (i *IEXMsgTypeNamespace) SubscribeTo( + msgReceived func(msg IEXMsgType), symbols ...string) (func(), error) { + if len(symbols) == 0 { + return nil, errors.New( + "Cannot call SubscribeTo with no symbols") + } + i.Lock() + defer i.Unlock() + // Connect to the namespace when adding the first subscription. + if len(i.subscriptions) == 0 { + i.sendPacket(Connect) + } + i.nextId++ + newSub := &subIEXMsgType{ + Callback: msgReceived, + Symbols: make(map[string]struct{}), + } + if len(symbols) > 0 { + var err error + for _, symbol := range symbols { + symbol = strings.ToUpper(symbol) + newSub.Symbols[symbol] = struct{}{} + i.symbols.Subscribe(symbol) + // Subscribe to each symbol individually. This allows + // DEEP, which only allows subscribing to a single + // symbol at a time, to use the same path. + err = i.sendSubUnsub(i.subUnsubMsgFactory( + Subscribe, []string{symbol})) + } + if err != nil { + return nil, err + } + } + i.subscriptions[i.nextId] = newSub + return i.getCloseSubscriptionFunc(i.nextId), nil +} + +// Create a new namespace for a specific IEX endpoint. Because the IEX +// namespaces use different message types for representing the received data, +// these classes are represented as generics using Genny. +func NewIEXMsgTypeNamespace( + transport Transport, subUnsubMsgFactory subUnsubMsgFactory, + closeFunc func(string)) *IEXMsgTypeNamespace { + namespace := msgTypeToNamespace["IEXMsgType"] + encoder := NewWSEncoder(namespace) + newNs := &IEXMsgTypeNamespace{ + symbols: NewCountingSubscriber(), + nextId: 0, + subscriptions: make(map[int]*subIEXMsgType), + encoder: encoder, + writer: transport, + subUnsubMsgFactory: subUnsubMsgFactory, + closeFunc: closeFunc, + } + transport.AddPacketCallback(namespace, newNs.fanout) + return newNs +} + +//go:generate genny -in=$GOFILE -out=gen-$GOFILE gen "IEXMsgType=iex.TOPS,iex.Last,iex.DEEP" diff --git a/socketio/namespace_test.go b/socketio/namespace_test.go new file mode 100644 index 0000000..144af81 --- /dev/null +++ b/socketio/namespace_test.go @@ -0,0 +1,164 @@ +package socketio_test + +import ( + "strings" + "testing" + + . "github.com/smartystreets/goconvey/convey" + "github.com/timpalpant/go-iex" + . "github.com/timpalpant/go-iex/socketio" +) + +func TestNamespace(t *testing.T) { + Convey("The IexTOPSNamespace should", t, func() { + ft := newFakeTransport() + subFactory := func( + signal SubOrUnsub, symbols []string) *IEXMsg { + return &IEXMsg{ + EventType: signal, + Data: strings.Join(symbols, ","), + } + } + closed := false + closedNamespace := "" + closeFunc := func(namespace string) { + closedNamespace = namespace + closed = true + } + Convey("not send a connect on creation", func() { + NewIexTOPSNamespace(ft, subFactory, closeFunc) + So(ft.messages, ShouldHaveLength, 0) + }) + Convey("error on SubscribeTo with no symbols", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + handler := func(msg iex.TOPS) {} + _, err := ns.SubscribeTo(handler) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, "no symbols") + }) + Convey("send a connect message on first subscription", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + handler := func(msg iex.TOPS) {} + _, err := ns.SubscribeTo(handler, "fb", "snap") + So(err, ShouldBeNil) + So("40/1.0/tops,", ShouldBeIn, ft.messages) + So(`42/1.0/tops,["subscribe","FB"]`, ShouldBeIn, + ft.messages) + So(`42/1.0/tops,["subscribe","SNAP"]`, ShouldBeIn, + ft.messages) + }) + Convey("send unsubscribe messages", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + handler := func(msg iex.TOPS) {} + closer, err := ns.SubscribeTo(handler, "fb", "snap") + So(err, ShouldBeNil) + closer() + So(`42/1.0/tops,["subscribe","FB"]`, + ShouldBeIn, ft.messages) + So(`42/1.0/tops,["subscribe","SNAP"]`, + ShouldBeIn, ft.messages) + So(`42/1.0/tops,["unsubscribe","FB"]`, + ShouldBeIn, ft.messages) + So(`42/1.0/tops,["unsubscribe","SNAP"]`, + ShouldBeIn, ft.messages) + }) + Convey("unsubscribe when all references removed", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + handler := func(msg iex.TOPS) {} + closer1, err := ns.SubscribeTo(handler, "fb", "snap") + So(err, ShouldBeNil) + closer2, err := ns.SubscribeTo(handler, "fb", "goog") + So(err, ShouldBeNil) + So(`42/1.0/tops,["subscribe","FB"]`, + ShouldBeIn, ft.messages) + So(`42/1.0/tops,["subscribe","SNAP"]`, + ShouldBeIn, ft.messages) + closer1() + So(`42/1.0/tops,["unsubscribe","SNAP"]`, + ShouldBeIn, ft.messages) + closer2() + So(`42/1.0/tops,["unsubscribe","FB"]`, + ShouldBeIn, ft.messages) + So(`42/1.0/tops,["unsubscribe","GOOG"]`, + ShouldBeIn, ft.messages) + }) + Convey("call closeFunc when all connections closed", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + handler := func(msg iex.TOPS) {} + closer1, err := ns.SubscribeTo(handler, "fb") + So(err, ShouldBeNil) + closer2, err := ns.SubscribeTo(handler, "fb") + So(err, ShouldBeNil) + closer1() + closer2() + So(closedNamespace, ShouldEqual, "/1.0/tops") + So(closed, ShouldBeTrue) + }) + Convey("fan out messages", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + var msg1 iex.TOPS + handler1 := func(msg iex.TOPS) { + msg1 = msg + } + _, err := ns.SubscribeTo(handler1, "fb") + So(err, ShouldBeNil) + var msg2 iex.TOPS + handler2 := func(msg iex.TOPS) { + msg2 = msg + } + _, err = ns.SubscribeTo(handler2, "fb") + So(err, ShouldBeNil) + ft.callbacks["/1.0/tops"][1](PacketData{ + Data: "{\"symbol\":\"FB\",\"bidsize\":12}", + }) + expected := iex.TOPS{ + Symbol: "FB", + BidSize: 12, + } + So(msg1, ShouldResemble, expected) + So(msg2, ShouldResemble, expected) + }) + Convey("filter based on subscriptions", func() { + ns := NewIexTOPSNamespace(ft, subFactory, closeFunc) + var msg1 iex.TOPS + handler1 := func(msg iex.TOPS) { + msg1 = msg + } + _, err := ns.SubscribeTo(handler1, "fb") + So(err, ShouldBeNil) + var msg2 iex.TOPS + handler2 := func(msg iex.TOPS) { + msg2 = msg + } + _, err = ns.SubscribeTo(handler2, "goog") + So(err, ShouldBeNil) + ft.TriggerCallbacks(PacketData{ + Namespace: "/1.0/tops", + Data: "{\"symbol\":\"FB\",\"bidsize\":12}", + }) + fbExpected := iex.TOPS{ + Symbol: "FB", + BidSize: 12, + } + So(msg1, ShouldResemble, fbExpected) + So(msg2, ShouldResemble, iex.TOPS{}) + ft.TriggerCallbacks(PacketData{ + Namespace: "/1.0/tops", + Data: "{\"symbol\":\"GOOG\",\"bidsize\":11}", + }) + googExpected := iex.TOPS{ + Symbol: "GOOG", + BidSize: 11, + } + So(msg2, ShouldResemble, googExpected) + msg1 = iex.TOPS{} + msg2 = iex.TOPS{} + ft.TriggerCallbacks(PacketData{ + Namespace: "/1.0/tops", + Data: "{\"symbol\":\"AIG+\",\"bidsize\":11}", + }) + So(msg1, ShouldResemble, iex.TOPS{}) + So(msg2, ShouldResemble, iex.TOPS{}) + }) + }) +} diff --git a/socketio/subscribers.go b/socketio/subscribers.go new file mode 100644 index 0000000..d03271c --- /dev/null +++ b/socketio/subscribers.go @@ -0,0 +1,82 @@ +package socketio + +import "sync" + +// Enables subscribing based on string symbols. +type Subscriber interface { + // Subscribes to the given symbol. + Subscribe(symbol string) + // Returns true if the given symbol is currently subscribed to. + Subscribed(symbol string) bool + // Unsubscribed from events for the given symbol. + Unsubscribe(symbol string) +} + +// A Subscriber implementation using simple map presence. +type PresenceSubscriber struct { + // Guards the symbols map. + sync.RWMutex + + // Stores subscribed sympols. + symbols map[string]bool +} + +func (p *PresenceSubscriber) Subscribe(symbol string) { + p.Lock() + defer p.Unlock() + p.symbols[symbol] = true +} + +func (p *PresenceSubscriber) Subscribed(symbol string) bool { + p.RLock() + defer p.RUnlock() + _, ok := p.symbols[symbol] + return ok +} + +func (p *PresenceSubscriber) Unsubscribe(symbol string) { + p.Lock() + defer p.Unlock() + delete(p.symbols, symbol) +} + +func NewPresenceSubscriber() Subscriber { + return &PresenceSubscriber{symbols: make(map[string]bool)} +} + +// A subscriber implementation using a counter. A certain number of Subscribe +// calls for a given symbol must be followed by the same number of Unsubscribe +// calls for the same symbol for Unsubscribed to return false. +type CountingSubscriber struct { + // Guards the symbols map. + sync.RWMutex + + // Stores subscribed sympols. + symbols map[string]int +} + +func (c *CountingSubscriber) Subscribe(symbol string) { + c.Lock() + defer c.Unlock() + c.symbols[symbol]++ +} + +func (c *CountingSubscriber) Subscribed(symbol string) bool { + c.RLock() + defer c.RUnlock() + return c.symbols[symbol] > 0 +} + +func (c *CountingSubscriber) Unsubscribe(symbol string) { + c.Lock() + defer c.Unlock() + if c.symbols[symbol] > 0 { + c.symbols[symbol]-- + } else { + delete(c.symbols, symbol) + } +} + +func NewCountingSubscriber() Subscriber { + return &CountingSubscriber{symbols: make(map[string]int)} +} diff --git a/socketio/subscribers_test.go b/socketio/subscribers_test.go new file mode 100644 index 0000000..969a2d8 --- /dev/null +++ b/socketio/subscribers_test.go @@ -0,0 +1,69 @@ +package socketio_test + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + . "github.com/timpalpant/go-iex/socketio" +) + +func TestPresenceSubscriber(t *testing.T) { + Convey("The PresenceSubscriber should", t, func() { + subscriber := NewPresenceSubscriber() + Convey("return false by default", func() { + So(subscriber.Subscribed("FB"), ShouldBeFalse) + }) + Convey("returns true after subscription", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + }) + Convey("returns true after multiple subscriptions", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + }) + Convey("returns false after unsubscription", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Unsubscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeFalse) + subscriber.Unsubscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeFalse) + }) + }) +} + +func TestCountingSubscriber(t *testing.T) { + Convey("The CountingSubscriber should", t, func() { + subscriber := NewCountingSubscriber() + Convey("return false by default", func() { + So(subscriber.Subscribed("FB"), ShouldBeFalse) + }) + Convey("returns true after subscription", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + }) + Convey("returns true after multiple subscriptions", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + }) + Convey("requires corresponding unsubscriptions", func() { + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Subscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + + subscriber.Unsubscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeTrue) + subscriber.Unsubscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeFalse) + subscriber.Unsubscribe("FB") + So(subscriber.Subscribed("FB"), ShouldBeFalse) + }) + }) +} diff --git a/socketio/transport.go b/socketio/transport.go new file mode 100644 index 0000000..c925730 --- /dev/null +++ b/socketio/transport.go @@ -0,0 +1,418 @@ +package socketio + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "net/http" + "strconv" + "sync" + "time" + + "github.com/chilts/sid" + "github.com/golang/glog" + "github.com/gorilla/websocket" +) + +type handshakeResponse struct { + Sid string + PingInterval int + PingTimeout int + Upgrades []string +} + +// Fulfilled by http.Client#Do. +type doClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// An interface that is fulfilled by websocket.Conn and allows for injecting a +// test connection. +type WSConn interface { + ReadMessage() (int, []byte, error) + WriteMessage(messageType int, data []byte) error + Close() error +} + +// Fulfilled by websocket.DefaultDialer#Dial. +type WSDialer interface { + Dial(urlStr string, reqHeader http.Header) ( + WSConn, *http.Response, error) +} + +// Indicates an error during initialization of the Transport layer. +type transportError struct { + message string +} + +func (t *transportError) Error() string { + return t.message +} + +// A wrapper that provides thread-safe methods for interacting with the +// underlying Websocket layer. +type Transport interface { + // Provides a thread-safe io.Writer Write method. + io.Writer + + // Adds a callback to be triggered when packets from the given + // namespace are received. Returns a unique ID that can be used + // to remove the callback later. If an error occurs, returns -1 + // and the error. + AddPacketCallback( + namespace string, callback func(PacketData)) (int, error) + + // Removes a callback using the ID that was returned when the + // callback was added. If either the namespace or the ID is does + // not exist, this method is a no-op. + RemovePacketCallback(namespace string, id int) error + + // Closes the underlying websocket connection. + Close() +} + +// A set of callbacks used to convey incoming messages to listeners. +type outgoing struct { + sync.RWMutex + + // The next ID to use when adding a callback. + nextId int + // A collection of channels for transmitting messages to consumers. + callbacks map[int]func(PacketData) +} + +func newOutgoing() *outgoing { + return &outgoing{ + nextId: 0, + callbacks: make(map[int]func(PacketData)), + } +} + +// Adds a PacketData callback and returns an identifier to be used when later +// removing the callback. +func (o *outgoing) AddCallback(callback func(PacketData)) int { + o.Lock() + defer o.Unlock() + o.nextId++ + o.callbacks[o.nextId] = callback + return o.nextId +} + +// Deletes the callback associated with the given ID. This function is a no-op +// if the ID is non-existent. +func (o *outgoing) RemoveCallback(id int) { + o.Lock() + defer o.Unlock() + delete(o.callbacks, id) +} + +func (o *outgoing) Callbacks() map[int]func(PacketData) { + o.Lock() + defer o.Unlock() + // Make a copy for thread safety. + copy := make(map[int]func(PacketData)) + for key, val := range o.callbacks { + copy[key] = val + } + return copy +} + +type transport struct { + sync.RWMutex + sync.Once + + // The wrapped Gorilla websocket.Conn. + conn WSConn + // A collection of callbacks keyed by namespace names. + outgoing map[string]*outgoing + // True when this transport has been closed. + closed bool +} + +func (t *transport) Write(message []byte) (int, error) { + t.RLock() + closed := t.closed + t.RUnlock() + if closed { + return 0, &transportError{"Cannot write to a closed transport"} + } + if glog.V(3) { + glog.Infof("Writing message: %s", string(message)) + } + err := t.conn.WriteMessage( + websocket.TextMessage, message) + if err != nil { + glog.Errorf( + "Failed to write message %q: %s", + string(message), err) + } + return len(message), nil +} + +func (t *transport) AddPacketCallback( + namespace string, callback func(PacketData)) (int, error) { + t.Lock() + closed := t.closed + t.Unlock() + if closed { + return -1, &transportError{ + "Cannot add a callback to a closed transport"} + } + t.Lock() + defer t.Unlock() + if _, ok := t.outgoing[namespace]; !ok { + t.outgoing[namespace] = newOutgoing() + } + return t.outgoing[namespace].AddCallback(callback), nil +} + +func (t *transport) RemovePacketCallback(namespace string, id int) error { + t.Lock() + closed := t.closed + t.Unlock() + if closed { + return &transportError{ + "Cannot remove a callback from a closed transport"} + } + t.Lock() + defer t.Unlock() + if val, ok := t.outgoing[namespace]; ok { + val.RemoveCallback(id) + if len(val.Callbacks()) == 0 { + delete(t.outgoing, namespace) + } + } + return nil +} + +func (t *transport) Close() { + t.Do(func() { + // Send the close signal before marking the transport as closed. + sendPacket(t, Close) + t.conn.Close() + t.Lock() + t.closed = true + t.Unlock() + }) +} + +func (t *transport) startReadLoop() { + for { + _, message, err := t.conn.ReadMessage() + if err != nil { + glog.Errorf( + "Error reading from websocket: %s", + err) + return + } + if len(message) == 0 { + continue + } + if glog.V(3) { + glog.Infof( + "Received websocket message: %s", + message) + } + t.RLock() + closed := t.closed + t.RUnlock() + if closed { + if glog.V(3) { + errTxt := "Dropping message %s;" + + "Transport closed" + glog.Warningf(errTxt, message) + } + break + } + var metadata PacketData + remaining := ParseMetadata(string(message), &metadata) + metadata.Data = remaining + if val, ok := t.outgoing[metadata.Namespace]; ok { + callbacks := val.Callbacks() + for _, callback := range callbacks { + go callback(metadata) + } + } + } + +} + +// Starts a go routine that sends a ping message on the given Transport every +// "ping" milliseconds. +func (t *transport) startHeartbeat(pingMillis int) { + duration, err := time.ParseDuration(strconv.Itoa(pingMillis) + "ms") + if err != nil { + glog.Fatalf("Could not start heartbeat: %s", err) + } + heartbeat := time.NewTicker(duration) + go func() { + for { + select { + case time := <-heartbeat.C: + t.RLock() + closed := t.closed + t.RUnlock() + if closed { + if glog.V(5) { + glog.Info("Stop heart beat") + } + return + } + if glog.V(3) { + glog.Infof("Heartbeating at %v", time) + } + sendPacket(t, Ping) + } + } + }() +} + +// Performs an HTTP request and returns the body. If there is an error the +// io.Reader will be nil. +func makeHTTPRequest(client doClient, to string) (io.Reader, error) { + if glog.V(3) { + glog.Infof("Making GET request to: %v", to) + } + req, err := http.NewRequest("GET", to, nil) + + if err != nil { + if glog.V(3) { + glog.Warningf( + "Failed to construct request: %s", err) + } + return nil, err + } + resp, err := client.Do(req) + if err != nil { + if glog.V(3) { + glog.Warningf( + "Failed to make request: %s", err) + } + return nil, err + } + if resp == nil { + return nil, &transportError{fmt.Sprintf( + "No response body from %s", to)} + } + if glog.V(5) { + glog.Infof("Response: %v", resp) + glog.Infof("Status: %v", resp.Status) + glog.Infof("Headers: %v", resp.Header) + } + defer resp.Body.Close() + respBytes, _ := ioutil.ReadAll(resp.Body) + respBuffer := bytes.NewBuffer(respBytes) + return respBuffer, nil +} + +// Performs the initial GET connection to the SocketIO endpoint. If it it +// successful, it will set the session id (sid) parameter on the endpoint. +func connect(endpoint Endpoint, client doClient) (*handshakeResponse, error) { + handshakeUrl := endpoint.GetHTTPUrl() + resp, err := makeHTTPRequest(client, handshakeUrl) + if err != nil { + glog.Errorf("Error connecting to IEX: %s", err) + return nil, err + } + var handshake handshakeResponse + err = HTTPToJSON(resp, []interface{}{&handshake}) + if err != nil { + glog.Errorf("Error parsing handshake response: %s", err) + return nil, err + } + canUpgradeToWs := false + for _, val := range handshake.Upgrades { + if val == "websocket" { + canUpgradeToWs = true + } + } + if !canUpgradeToWs { + return nil, &transportError{ + "Websocket upgrade not found"} + } + endpoint.SetSid(handshake.Sid) + // Making a get request with the SID automatically joins the default + // namespace. + resp, err = makeHTTPRequest(client, endpoint.GetHTTPUrl()) + if err != nil { + glog.Errorf("Error making status GET: %s", err) + return nil, err + } + var packetData PacketData + err = HTTPToJSON(resp, []interface{}{&packetData}) + if err != nil { + glog.Errorf("Error parsing handshake response: %s", err) + return nil, err + } + if packetData.PacketType != Message && + packetData.MessageType != Connect { + return nil, fmt.Errorf("Unexpected namespace response: %v", + packetData) + } + return &handshake, nil +} + +func sendPacket(transport Transport, packetType PacketType) { + encoder := NewWSEncoder("") + reader, err := encoder.EncodePacket(packetType, -1) + if err != nil { + glog.Warningf( + "Could not encode probe message: %s", err) + } + data, err := ioutil.ReadAll(reader) + if err != nil { + glog.Warningf( + "Could not read encoded message: %s", err) + } + _, err = transport.Write(data) + if err != nil { + glog.Warningf( + "Error writing probe message: %s", err) + return + } + if glog.V(3) { + glog.Infof("Sent packet %q", data) + } +} + +// Upgrades from an HTTPS to a Websocket connection. This method starts +// regular probe polling and sends an upgrade message before returning +// the websocket.Conn object. If an error occurs, the returned Transport +// is nil. The ping interval is used to start a hearbeat polling mechanism. +func upgrade(endpoint Endpoint, dialer WSDialer, ping int) (Transport, error) { + to := endpoint.GetWSUrl() + if glog.V(3) { + glog.Infof("Opening websocket connection to: %s", to) + } + conn, _, err := dialer.Dial(to, nil) + if err != nil { + glog.Errorf("Error opening websocket connection: %s", err) + return nil, err + } + if glog.V(3) { + glog.Info("Websocket connection established; sending upgrade") + } + trans := &transport{ + conn: conn, + outgoing: make(map[string]*outgoing), + } + go trans.startReadLoop() + trans.startHeartbeat(ping) + + // Upgrade the websocket connection. + sendPacket(trans, Upgrade) + + return trans, nil +} + +// Returns a new Transport object backed by an open Websocket connection +// or an error if one occurs. +func NewTransport(client doClient, dialer WSDialer) (Transport, error) { + endpoint := NewIEXEndpoint(sid.IdBase64) + handshake, err := connect(endpoint, client) + if err != nil { + return nil, err + } + return upgrade(endpoint, dialer, handshake.PingInterval) +} diff --git a/socketio/transport_test.go b/socketio/transport_test.go new file mode 100644 index 0000000..c0a35a8 --- /dev/null +++ b/socketio/transport_test.go @@ -0,0 +1,544 @@ +package socketio_test + +import ( + "io" + "io/ioutil" + "net/http" + "strconv" + "strings" + "sync" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + . "github.com/timpalpant/go-iex/socketio" +) + +type response struct { + resp *http.Response + err error +} + +type fakeDoClient struct { + Requests []*http.Request + responses []*response + currentResponse int +} + +func (m *fakeDoClient) Do(req *http.Request) (*http.Response, error) { + m.Requests = append(m.Requests, req) + if len(m.responses) > 0 && m.currentResponse < len(m.responses) { + defer func() { + m.currentResponse++ + }() + if m.responses[m.currentResponse].err != nil { + return nil, m.responses[m.currentResponse].err + } + return m.responses[m.currentResponse].resp, nil + } + return nil, nil +} + +type message struct { + messageType int + message []byte + err error +} + +type fakeConn struct { + sync.Mutex + cond *sync.Cond + incomingMessage []byte + messagesWritten [][]byte + closed bool +} + +func newFakeConn() *fakeConn { + conn := &fakeConn{ + messagesWritten: make([][]byte, 0), + closed: false, + } + conn.cond = sync.NewCond(conn) + return conn +} + +// Calling with nil will set incomingMessage to nil, which will cause +// ReadMessage to return io.EOF. +func (f *fakeConn) SetIncomingMessage(msg []byte) { + f.Lock() + if msg == nil { + f.incomingMessage = nil + } else { + f.incomingMessage = make([]byte, len(msg)) + copy(f.incomingMessage, msg) + } + f.Unlock() + f.cond.Signal() +} + +func (f *fakeConn) ReadMessage() (int, []byte, error) { + f.Lock() + f.cond.Wait() + defer f.Unlock() + if f.incomingMessage == nil { + return 0, nil, io.EOF + } + toReturn := make([]byte, len(f.incomingMessage)) + copy(toReturn, f.incomingMessage) + return len(toReturn), toReturn, nil +} + +func (f *fakeConn) WriteMessage(messageType int, data []byte) error { + f.Lock() + defer f.Unlock() + f.messagesWritten = append(f.messagesWritten, data) + return nil +} + +func (f *fakeConn) Close() error { + f.Lock() + defer f.Unlock() + f.closed = true + return nil +} + +type fakeWsDialer struct { + WsUrl string + resp *http.Response + err error + conn WSConn +} + +func (w *fakeWsDialer) Dial(urlStr string, reqHeader http.Header) ( + WSConn, *http.Response, error) { + w.WsUrl = urlStr + if w.err != nil { + return nil, w.resp, w.err + } + return w.conn, w.resp, nil +} + +type fakeError struct { + message string +} + +func (f *fakeError) Error() string { + return f.message +} + +var hsResponseString = `95:0{"sid":"N1pkgEHs-wEXi4DtAA4m","upgrades":["websocket"],"pingInterval":500,"pingTimeout":60000}` +var hsLongPingResponseString = `98:0{"sid":"N1pkgEHs-wEXi4DtAA4m","upgrades":["websocket"],"pingInterval":100000,"pingTimeout":60000}` +var hsNoUpgradesString = `86:0{"sid":"N1pkgEHs-wEXi4DtAA4m","upgrades":[],"pingInterval":25000,"pingTimeout":60000}` + +var goodJoinResponse = `2:40` +var badJoinResponse = `2:22` + +func TestTransport(t *testing.T) { + Convey("The Transport layer should", t, func() { + Convey("return an error on no response body", func() { + requests := make([]*http.Request, 0) + responses := make([]*response, 0) + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + _, err := NewTransport(fdc, fw) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldStartWith, "No response body") + So(len(fdc.Requests), ShouldEqual, 1) + to := fdc.Requests[0].URL + So(fdc.Requests[0].Method, ShouldEqual, "GET") + So(to.Scheme, ShouldEqual, "https") + So(to.Host, ShouldStartWith, "ws-api.iextrading.com") + So(to.Path, ShouldEqual, "/socket.io/") + }) + Convey("return an error on no handshake response", func() { + requests := make([]*http.Request, 0) + hsResponse := &response{nil, &fakeError{"No connection"}} + responses := []*response{hsResponse} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + _, err := NewTransport(fdc, fw) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "No connection") + }) + Convey("return an error no websocket upgrade", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsNoUpgradesString)), + } + responses := []*response{&response{ + resp: hsResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + _, err := NewTransport(fdc, fw) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "Websocket upgrade not found") + }) + Convey("return an error on wrong message type", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(badJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + _, err := NewTransport(fdc, fw) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldStartWith, + "Unexpected namespace response") + }) + Convey("return an error on failure to open websocket", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fw := &fakeWsDialer{ + err: &fakeError{"could not open"}, + } + _, err := NewTransport(fdc, fw) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, + "could not open") + }) + Convey("successfully handshake and upgrade", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + So(err, ShouldBeNil) + So(len(fdc.Requests), ShouldEqual, 2) + to := fdc.Requests[1].URL + So(fdc.Requests[1].Method, ShouldEqual, "GET") + So(to.Scheme, ShouldEqual, "https") + So(to.Host, ShouldStartWith, "ws-api.iextrading.com") + So(to.Path, ShouldEqual, "/socket.io/") + So(to.Query().Get("sid"), ShouldEqual, + "N1pkgEHs-wEXi4DtAA4m") + // This should allow at least 2 heartbeats at 500ms. + dur, _ := time.ParseDuration("1.2s") + time.Sleep(dur) + fc.Lock() + So(len(fc.messagesWritten), ShouldEqual, 3) + msgs := fc.messagesWritten + So(string(msgs[0]), ShouldEqual, "5") + So(string(msgs[1]), ShouldEqual, "2") + So(string(msgs[2]), ShouldEqual, "2") + fc.Unlock() + + trans.Close() + + fc.Lock() + msgs = fc.messagesWritten + So(len(fc.messagesWritten), ShouldEqual, 4) + So(string(msgs[3]), ShouldEqual, "1") + So(fc.closed, ShouldEqual, true) + fc.Unlock() + }) + Convey("prevent writing to a closed transport", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + trans.Close() + _, err = trans.Write([]byte("String")) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, + "Cannot write to a closed transport") + }) + Convey("prevent adding callbacks to closed transports", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + trans.Close() + handler := func(pkt PacketData) {} + _, err = trans.AddPacketCallback("/1.0/tops", handler) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, + "Cannot add a callback") + }) + Convey("prevent removing callbacks to closed transports", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + trans.Close() + err = trans.RemovePacketCallback("/1.0/tops", 1) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldContainSubstring, + "Cannot remove a callback") + }) + Convey("successfully write from multiple threads", func() { + requests := make([]*http.Request, 0) + // For the sake of this test, make the heartbeat long to + // prevent from interferring. + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader( + hsLongPingResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + So(err, ShouldBeNil) + var wg sync.WaitGroup + for i := 10; i < 20; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + trans.Write([]byte(strconv.Itoa(i))) + }(i) + } + for i := 20; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + trans.Write([]byte(strconv.Itoa(i))) + + }(i) + } + wg.Wait() + trans.Close() + + fc.Lock() + So(fc.messagesWritten, ShouldHaveLength, 22) + for i := 10; i < 30; i++ { + So(fc.messagesWritten, ShouldContain, + []byte(strconv.Itoa(i))) + } + fc.Unlock() + }) + Convey("successfully read from multiple threads", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + + received := make([]PacketData, 0) + receivedLock := &sync.Mutex{} + receivedCond := sync.NewCond(receivedLock) + handler := func(pkt PacketData) { + receivedLock.Lock() + received = append(received, pkt) + receivedLock.Unlock() + receivedCond.Signal() + } + So(err, ShouldBeNil) + _, err = trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + _, err = trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + _, err = trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + message := []byte("42/1.0/last,[\"some\":\"data\"]") + fc.SetIncomingMessage(message) + expected := PacketData{ + PacketType: Message, + MessageType: Event, + Namespace: "/1.0/last", + Data: "[\"some\":\"data\"]", + } + for { + receivedLock.Lock() + if len(received) < 3 { + receivedCond.Wait() + } else { + receivedLock.Unlock() + break + } + receivedLock.Unlock() + + } + So(received[0], ShouldResemble, expected) + So(received[1], ShouldResemble, expected) + So(received[2], ShouldResemble, expected) + }) + Convey("successfully remove callbacks", func() { + requests := make([]*http.Request, 0) + hsResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(hsResponseString)), + } + nspResponse := &http.Response{ + Body: ioutil.NopCloser( + strings.NewReader(goodJoinResponse)), + } + responses := []*response{&response{ + resp: hsResponse, + }, &response{ + resp: nspResponse, + }} + fdc := &fakeDoClient{requests, responses, 0} + fc := newFakeConn() + fw := &fakeWsDialer{ + conn: fc, + } + trans, err := NewTransport(fdc, fw) + + received := make([]PacketData, 0) + receivedLock := &sync.Mutex{} + receivedCond := sync.NewCond(receivedLock) + handler := func(pkt PacketData) { + receivedLock.Lock() + received = append(received, pkt) + receivedLock.Unlock() + receivedCond.Signal() + } + So(err, ShouldBeNil) + _, err = trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + _, err = trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + id3, err := trans.AddPacketCallback("/1.0/last", handler) + So(err, ShouldBeNil) + err = trans.RemovePacketCallback("/1.0/last", id3) + So(err, ShouldBeNil) + message := []byte("42/1.0/last,[\"some\":\"data\"]") + fc.SetIncomingMessage(message) + expected := PacketData{ + PacketType: Message, + MessageType: Event, + Namespace: "/1.0/last", + Data: "[\"some\":\"data\"]", + } + for { + receivedLock.Lock() + if len(received) < 2 { + receivedCond.Wait() + } else { + receivedLock.Unlock() + break + } + receivedLock.Unlock() + + } + So(received[0], ShouldResemble, expected) + So(received[1], ShouldResemble, expected) + }) + }) +}