diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b82bc6b..14a289c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,40 +1,35 @@ on: push: - branches: [ main ] + branches: [main] pull_request: name: Test jobs: test: strategy: matrix: - go-version: [1.18.x] + go-version: [1.21.x, 1.22.x] platform: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.platform }} steps: - - name: Install Go - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go-version }} - - name: Install staticcheck - run: go install honnef.co/go/tools/cmd/staticcheck@latest - shell: bash - - name: Install golint - run: go install golang.org/x/lint/golint@latest - shell: bash - - name: Update PATH - run: echo "$(go env GOPATH)/bin" >> $GITHUB_PATH - shell: bash - - name: Checkout code - uses: actions/checkout@v1 - - name: Fmt - if: matrix.platform != 'windows-latest' # :( - run: "diff <(gofmt -d .) <(printf '')" - shell: bash - - name: Vet - run: go vet ./... - - name: Staticcheck - run: staticcheck ./... - - name: Lint - run: golint ./... - - name: Test - run: go test -race ./... + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: ${{ matrix.go-version }} + - name: Install staticcheck + run: go install honnef.co/go/tools/cmd/staticcheck@latest + shell: bash + - name: Update PATH + run: echo "$(go env GOPATH)/bin" >> $GITHUB_PATH + shell: bash + - name: Checkout code + uses: actions/checkout@v4 + - name: Fmt + if: matrix.platform != 'windows-latest' # :( + run: "diff <(gofmt -d .) <(printf '')" + shell: bash + - name: Vet + run: go vet ./... + - name: Staticcheck + run: staticcheck ./... + - name: Test + run: go test -race ./... diff --git a/README.md b/README.md index 58a4cf8..04d9032 100644 --- a/README.md +++ b/README.md @@ -7,25 +7,55 @@ This library implements a simple, custom [RPC protocol](https://en.wikipedia.org A strongly typed client may look like this: ```go +// Define the request, message and receipt types for the RPC call. client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleResponse]{ + execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ ClientRawOptions: execrpc.ClientRawOptions{ Version: 1, Cmd: "go", - Dir: "./examples/servers/typed" + Dir: "./examples/servers/typed", Args: []string{"run", "."}, + Env: env, + Timeout: 30 * time.Second, }, - Codec: codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, + Codec: codec, }, ) -result, _ := client.Execute(model.ExampleRequest{Text: "world"}) +if err != nil { + logg.Fatal(err) +} + + +// Consume standalone messages (e.g. log messages) in its own goroutine. +go func() { + for msg := range client.MessagesRaw() { + fmt.Println("got message", string(msg.Body)) + } +}() + +// Execute the request. +result := client.Execute(model.ExampleRequest{Text: "world"}) -fmt.Println(result.Hello) +// Check for errors. +if err; result.Err(); err != nil { + logg.Fatal(err) +} + +// Consume the messages. +for m := range result.Messages() { + fmt.Println(m) +} -//... +// Wait for the receipt. +receipt := result.Receipt() + +// Check again for errors. +if err; result.Err(); err != nil { + logg.Fatal(err) +} -client.Close() +fmt.Println(receipt.Text) ``` @@ -35,41 +65,75 @@ And the server side of the above: ```go func main() { - server, _ := execrpc.NewServer( - execrpc.ServerOptions[model.ExampleRequest, model.ExampleResponse]{ - Call: func(d execrpc.Dispatcher, req model.ExampleRequest) model.ExampleResponse { - return model.ExampleResponse{ - Hello: "Hello " + req.Text + "!", - } + getHasher := func() hash.Hash { + return fnv.New64a() + } + + server, err := execrpc.NewServer( + execrpc.ServerOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + // Optional function to get a hasher for the ETag. + GetHasher: getHasher, + + // Allows you to delay message delivery, and drop + // them after reading the receipt (e.g. the ETag matches the ETag seen by client). + DelayDelivery: false, + + // Handle the incoming call. + Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { + // Raw messages are passed directly to the client, + // typically used for log messages. + c.SendRaw( + execrpc.Message{ + Header: execrpc.Header{ + Version: 32, + Status: 150, + }, + Body: []byte("a log message"), + }, + ) + + // Enqueue one or more messages. + c.Enqueue( + model.ExampleMessage{ + Hello: "Hello 1!", + }, + model.ExampleMessage{ + Hello: "Hello 2!", + }, + ) + + c.Enqueue( + model.ExampleMessage{ + Hello: "Hello 3!", + }, + ) + + // Wait for the framework generated receipt. + receipt := <-c.Receipt() + + // ETag provided by the framework. + // A hash of all message bodies. + fmt.Println("Receipt:", receipt.ETag) + + // Modify if needed. + receipt.Size = uint32(123) + + // Close the message stream. + c.Close(false, receipt) }, }, ) + if err != nil { + log.Fatal(err) + } + + // Start the server. This will block. if err := server.Start(); err != nil { - // ... handle error + log.Fatal(err) } - _ = server.Wait() } ``` -Of the included codecs, JSON seems to win by a small margin (but only tested with small requests/responses): - -```bsh -name time/op -Client/JSON-10 4.89µs ± 0% -Client/TOML-10 5.51µs ± 0% -Client/Gob-10 17.0µs ± 0% - -name alloc/op -Client/JSON-10 922B ± 0% -Client/TOML-10 1.67kB ± 0% -Client/Gob-10 9.22kB ± 0% - -name allocs/op -Client/JSON-10 19.0 ± 0% -Client/TOML-10 28.0 ± 0% -Client/Gob-10 227 ± 0% -``` - ## Status Codes The status codes in the header between 1 and 99 are reserved for the system. This will typically be used to catch decoding/encoding errors on the server. \ No newline at end of file diff --git a/client.go b/client.go index 800150e..635a0f7 100644 --- a/client.go +++ b/client.go @@ -24,7 +24,7 @@ const ( ) // StartClient starts a client for the given options. -func StartClient[Q, R any](opts ClientOptions[Q, R]) (*Client[Q, R], error) { +func StartClient[Q, M, R any](opts ClientOptions[Q, M, R]) (*Client[Q, M, R], error) { if opts.Codec == nil { return nil, errors.New("opts: Codec is required") } @@ -37,52 +37,129 @@ func StartClient[Q, R any](opts ClientOptions[Q, R]) (*Client[Q, R], error) { return nil, err } - return &Client[Q, R]{ + return &Client[Q, M, R]{ rawClient: rawClient, - codec: opts.Codec, + opts: opts, }, nil } // Client is a strongly typed RPC client. -type Client[Q, R any] struct { +type Client[Q, M, R any] struct { rawClient *ClientRaw - codec codecs.Codec[Q, R] + opts ClientOptions[Q, M, R] } -// Execute encodes and sends r the server and returns the response object. -// It's safe to call Execute from multiple goroutines. -func (c *Client[Q, R]) Execute(r Q) (R, error) { - body, err := c.codec.Encode(r) - var resp R - if err != nil { - return resp, err - } - message, err := c.rawClient.Execute(body) - if err != nil { - return resp, err +// Result is the result of a request +// with zero or more messages and the receipt. +type Result[M, R any] struct { + messages chan M + receipt chan R + errc chan error +} + +// Messages returns the messages from the server. +func (r Result[M, R]) Messages() <-chan M { + return r.messages +} + +// Receipt returns the receipt from the server. +func (r Result[M, R]) Receipt() <-chan R { + return r.receipt +} + +// Err returns any error. +func (r Result[M, R]) Err() error { + select { + case err := <-r.errc: + return err + default: + return nil } +} - if message.Header.Status > MessageStatusOK && message.Header.Status <= MessageStatusSystemReservedMax { - // All of these are currently error situations produced by the server. - return resp, fmt.Errorf("%s (error code %d)", message.Body, message.Header.Status) +func (r Result[M, R]) close() { + close(r.messages) + close(r.receipt) +} + +// MessagesRaw returns the raw messages from the server. +// These are not connected to the request-response flow, +// typically used for log messages etc. +func (c *Client[Q, M, R]) MessagesRaw() <-chan Message { + return c.rawClient.Messages +} + +// Execute sends the request to the server and returns the result. +// You should check Err() both before and after reading from the messages and receipt channels. +func (c *Client[Q, M, R]) Execute(r Q) Result[M, R] { + result := Result[M, R]{ + messages: make(chan M, 10), + receipt: make(chan R, 1), + errc: make(chan error, 1), } - err = c.codec.Decode(message.Body, &resp) + body, err := c.opts.Codec.Encode(r) if err != nil { - return resp, err + result.errc <- fmt.Errorf("failed to encode request: %w", err) + result.close() + return result } - return resp, nil + + go func() { + defer func() { + result.close() + }() + + messagesRaw := make(chan Message, 10) + go func() { + err := c.rawClient.Execute(body, messagesRaw) + if err != nil { + result.errc <- fmt.Errorf("failed to execute: %w", err) + } + }() + + for message := range messagesRaw { + if message.Header.Status > MessageStatusContinue && message.Header.Status <= MessageStatusSystemReservedMax { + // All of these are currently error situations produced by the server. + result.errc <- fmt.Errorf("%s (error code %d)", message.Body, message.Header.Status) + return + } + + if message.Header.Status == MessageStatusContinue { + var resp M + err = c.opts.Codec.Decode(message.Body, &resp) + if err != nil { + result.errc <- err + return + } + result.messages <- resp + } else { + // Receipt. + var rec R + err = c.opts.Codec.Decode(message.Body, &rec) + if err != nil { + result.errc <- err + return + } + result.receipt <- rec + return + } + + } + }() + + return result } // Close closes the client. -func (c *Client[Q, R]) Close() error { +func (c *Client[Q, M, R]) Close() error { return c.rawClient.Close() } // StartClientRaw starts a untyped client client for the given options. func StartClientRaw(opts ClientRawOptions) (*ClientRaw, error) { if opts.Timeout == 0 { - opts.Timeout = time.Second * 10 + opts.Timeout = time.Second * 30 } cmd := exec.Command(opts.Cmd, opts.Args...) @@ -109,18 +186,12 @@ func StartClientRaw(opts ClientRawOptions) (*ClientRaw, error) { return nil, fmt.Errorf("failed to start server: %s: %s", err, conn.stdErr.String()) } - if opts.OnMessage == nil { - opts.OnMessage = func(Message) { - - } - } - client := &ClientRaw{ - version: opts.Version, - timeout: opts.Timeout, - onMessage: opts.OnMessage, - conn: conn, - pending: make(map[uint32]*call), + version: opts.Version, + timeout: opts.Timeout, + conn: conn, + pending: make(map[uint32]*call), + Messages: make(chan Message, 10), } go client.input() @@ -138,7 +209,8 @@ type ClientRaw struct { closing bool shutdown bool - onMessage func(Message) + // Messages from the server that are not part of the request-response flow. + Messages chan Message timeout time.Duration @@ -155,6 +227,8 @@ func (c *ClientRaw) Close() error { if c == nil { return nil } + defer close(c.Messages) + c.sendMu.Lock() defer c.sendMu.Unlock() c.mu.Lock() @@ -170,32 +244,38 @@ func (c *ClientRaw) Close() error { return err } -// Execute sends body to the server and returns the Message it receives. +// Execute sends body to the server and sends any messages to the messages channel. // It's safe to call Execute from multiple goroutines. -func (c *ClientRaw) Execute(body []byte) (Message, error) { - call, err := c.newCall(body) +// The messages channel wil be closed when the call is done. +func (c *ClientRaw) Execute(body []byte, messages chan<- Message) error { + defer close(messages) + + call, err := c.newCall(body, messages) if err != nil { - return Message{}, err + return err } + timer := time.NewTimer(c.timeout) + defer timer.Stop() + select { case call = <-call.Done: - case <-time.After(c.timeout): - return Message{}, ErrTimeoutWaitingForServer + case <-timer.C: + return ErrTimeoutWaitingForCall } if call.Error != nil { - return call.Response, c.addErrContext("execute", call.Error) + return c.addErrContext("execute", call.Error) } - return call.Response, nil + return nil } func (c *ClientRaw) addErrContext(op string, err error) error { return fmt.Errorf("%s: %s %s", op, err, c.conn.stdErr.String()) } -func (c *ClientRaw) newCall(body []byte) (*call, error) { +func (c *ClientRaw) newCall(body []byte, messages chan<- Message) (*call, error) { c.mu.Lock() c.seq++ id := c.seq @@ -209,6 +289,7 @@ func (c *ClientRaw) newCall(body []byte) (*call, error) { }, Body: body, }, + Messages: messages, } if c.shutdown || c.closing { @@ -234,23 +315,34 @@ func (c *ClientRaw) input() { if err != nil { break } + id := message.Header.ID if id == 0 { // A message with ID 0 is a standalone message (e.g. log message) - c.onMessage(message) + // and not part of the request-response flow. + c.Messages <- message continue } // Attach it to the correct pending call. c.mu.Lock() - call := c.pending[id] + call, found := c.pending[id] + if !found { + panic(fmt.Sprintf("call with ID %d not found", id)) + } + if message.Header.Status == MessageStatusContinue { + call.Messages <- message + c.mu.Unlock() + continue + } + delete(c.pending, id) c.mu.Unlock() if call == nil { err = fmt.Errorf("call with ID %d not found", id) break } - call.Response = message + call.Messages <- message call.done() } @@ -274,7 +366,6 @@ func (c *ClientRaw) input() { call.Error = err call.done() } - } func (c *ClientRaw) send(call *call) error { @@ -290,9 +381,9 @@ func (c *ClientRaw) send(call *call) error { } // ClientOptions are options for the client. -type ClientOptions[Q, R any] struct { +type ClientOptions[Q, M, R any] struct { ClientRawOptions - Codec codecs.Codec[Q, R] + Codec codecs.Codec } // ClientRawOptions are options for the raw part of the client. @@ -317,16 +408,74 @@ type ClientRawOptions struct { // calling process's current directory. Dir string - // Callback for messages received from server without an ID (e.g. log message). - OnMessage func(Message) - // The timeout for the client. Timeout time.Duration } +var ( + _ TagProvider = &Identity{} + _ LastModifiedProvider = &Identity{} + _ SizeProvider = &Identity{} +) + +// Identity holds the modified time (Unix seconds) and a 64-bit checksum. +type Identity struct { + LastModified int64 `json:"lastModified"` + ETag string `json:"eTag"` + Size uint32 `json:"size"` +} + +// GetETag returns the checksum. +func (i Identity) GetETag() string { + return i.ETag +} + +// SetETag sets the checksum. +func (i *Identity) SetETag(s string) { + i.ETag = s +} + +// GetELastModified returns the last modified time. +func (i Identity) GetELastModified() int64 { + return i.LastModified +} + +// SetELastModified sets the last modified time. +func (i *Identity) SetELastModified(t int64) { + i.LastModified = t +} + +// GetESize returns the size. +func (i Identity) GetESize() uint32 { + return i.Size +} + +// SetESize sets the size. +func (i *Identity) SetESize(s uint32) { + i.Size = s +} + +// TagProvider is the interface for a type that can provide a eTag. +type TagProvider interface { + GetETag() string + SetETag(string) +} + +// LastModifiedProvider is the interface for a type that can provide a last modified time. +type LastModifiedProvider interface { + GetELastModified() int64 + SetELastModified(int64) +} + +// SizeProvider is the interface for a type that can provide a size. +type SizeProvider interface { + GetESize() uint32 + SetESize(uint32) +} + type call struct { Request Message - Response Message + Messages chan<- Message Error error Done chan *call } diff --git a/client_test.go b/client_test.go index 814170f..260abe5 100644 --- a/client_test.go +++ b/client_test.go @@ -14,24 +14,38 @@ import ( func TestExecRaw(t *testing.T) { c := qt.New(t) - client, err := execrpc.StartClientRaw( - execrpc.ClientRawOptions{ - Version: 1, - Cmd: "go", - Dir: "./examples/servers/raw", - Args: []string{"run", "."}, - }) - c.Assert(err, qt.IsNil) + newClient := func(c *qt.C) *execrpc.ClientRaw { + client, err := execrpc.StartClientRaw( + execrpc.ClientRawOptions{ + Version: 1, + Cmd: "go", + Dir: "./examples/servers/raw", + Args: []string{"run", "."}, + }) - defer func() { - c.Assert(client.Close(), qt.IsNil) - }() + c.Assert(err, qt.IsNil) + + return client + } c.Run("OK", func(c *qt.C) { - result, err := client.Execute([]byte("hello")) - c.Assert(err, qt.IsNil) - c.Assert(string(result.Body), qt.Equals, "echo: hello") + client := newClient(c) + defer client.Close() + messages := make(chan execrpc.Message) + var g errgroup.Group + g.Go(func() error { + return client.Execute([]byte("hello"), messages) + }) + var i int + for msg := range messages { + if i == 0 { + c.Assert(string(msg.Body), qt.Equals, "echo: hello") + } + i++ + } + c.Assert(i, qt.Equals, 1) + c.Assert(g.Wait(), qt.IsNil) }) } @@ -48,18 +62,18 @@ func TestExecStartFailed(t *testing.T) { c.Assert(err, qt.IsNotNil) c.Assert(err.Error(), qt.Contains, "failed to start server: chdir ./examples/servers/doesnotexist") c.Assert(client.Close(), qt.IsNil) - } -func newTestClient(t testing.TB, codec codecs.Codec[model.ExampleRequest, model.ExampleResponse], env ...string) *execrpc.Client[model.ExampleRequest, model.ExampleResponse] { +func newTestClient(t testing.TB, codec codecs.Codec, env ...string) *execrpc.Client[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt] { client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleResponse]{ + execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ ClientRawOptions: execrpc.ClientRawOptions{ Version: 1, Cmd: "go", Dir: "./examples/servers/typed", Args: []string{"run", "."}, Env: env, + Timeout: 30 * time.Second, }, Codec: codec, }, @@ -67,122 +81,206 @@ func newTestClient(t testing.TB, codec codecs.Codec[model.ExampleRequest, model. if err != nil { t.Fatal(err) } + + t.Cleanup(func() { + if err := client.Close(); err != nil { + t.Fatal(err) + } + }) + return client } func TestExecTyped(t *testing.T) { c := qt.New(t) - newClient := func(t testing.TB, codec codecs.Codec[model.ExampleRequest, model.ExampleResponse], env ...string) *execrpc.Client[model.ExampleRequest, model.ExampleResponse] { - client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleResponse]{ - ClientRawOptions: execrpc.ClientRawOptions{ - Version: 1, - Cmd: "go", - Dir: "./examples/servers/typed", - Args: []string{"run", "."}, - Env: env, - Timeout: 4 * time.Second, - }, - Codec: codec, - }, - ) - if err != nil { - t.Fatal(err) + runBasicTestForClient := func(c *qt.C, client *execrpc.Client[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) execrpc.Result[model.ExampleMessage, model.ExampleReceipt] { + result := client.Execute(model.ExampleRequest{Text: "world"}) + c.Assert(result.Err(), qt.IsNil) + return result + } + + assertMessages := func(c *qt.C, result execrpc.Result[model.ExampleMessage, model.ExampleReceipt], expected int) { + var i int + for m := range result.Messages() { + expect := fmt.Sprintf("%d: Hello world!", i) + c.Assert(string(m.Hello), qt.Equals, expect) + i++ } - return client + c.Assert(i, qt.Equals, expected) + c.Assert(result.Err(), qt.IsNil) } - runBasicTestForClient := func(c *qt.C, client *execrpc.Client[model.ExampleRequest, model.ExampleResponse]) model.ExampleResponse { - result, err := client.Execute(model.ExampleRequest{Text: "world"}) - c.Assert(err, qt.IsNil) + c.Run("One message", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}) + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + c.Assert(receipt.GetESize(), qt.Equals, uint32(123)) + }) + + c.Run("100 messages", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=100") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 100) + receipt := <-result.Receipt() + c.Assert(receipt.LastModified, qt.Not(qt.Equals), int64(0)) + c.Assert(receipt.ETag, qt.Equals, "15b8164b761923b7") + }) + + c.Run("1234 messages", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=1234") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1234) + receipt := <-result.Receipt() c.Assert(result.Err(), qt.IsNil) - c.Assert(string(result.Hello), qt.Equals, "Hello world!") - c.Assert(client.Close(), qt.IsNil) - return result + c.Assert(receipt.LastModified, qt.Not(qt.Equals), int64(0)) + c.Assert(receipt.ETag, qt.Equals, "43940b97841cc686") + }) - } + c.Run("Delay delivery", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_DELAY_DELIVERY=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + c.Assert(receipt.GetESize(), qt.Equals, uint32(123)) + }) - c.Run("JSON", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}) - runBasicTestForClient(c, client) + c.Run("Delay delivery, drop messages", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_DELAY_DELIVERY=true", "EXECRPC_DROP_MESSAGES=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 0) + receipt := <-result.Receipt() + // This is a little confusing. We always get a receipt even if the messages are dropped, + // and the server can create whatever protocol it wants. + c.Assert(receipt.GetESize(), qt.Equals, uint32(123)) }) - c.Run("TOML", func(c *qt.C) { - client := newClient(c, codecs.TOMLCodec[model.ExampleRequest, model.ExampleResponse]{}) - runBasicTestForClient(c, client) + c.Run("No Close", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_CLOSE=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + // Empty receipt. + c.Assert(receipt.LastModified, qt.Equals, int64(0)) }) - c.Run("Gob", func(c *qt.C) { - client := newClient(c, codecs.GobCodec[model.ExampleRequest, model.ExampleResponse]{}) - runBasicTestForClient(c, client) + c.Run("Receipt", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}) + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + c.Assert(receipt.LastModified, qt.Not(qt.Equals), int64(0)) + c.Assert(receipt.ETag, qt.Equals, "2d5537627636b58a") + + // Set by the server. + c.Assert(receipt.Text, qt.Equals, "echoed: world") + c.Assert(receipt.Size, qt.Equals, uint32(123)) + }) + + c.Run("No hasher", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_HASHER=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + c.Assert(receipt.ETag, qt.Equals, "") + }) + + c.Run("No reading Receipt", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_READING_RECEIPT=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + // Empty receipt. + c.Assert(receipt.LastModified, qt.Equals, int64(0)) + }) + + c.Run("No reading Receipt, no Close", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_NO_READING_RECEIPT=true", "EXECRPC_NO_CLOSE=true") + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + receipt := <-result.Receipt() + // Empty receipt. + c.Assert(receipt.LastModified, qt.Equals, int64(0)) }) c.Run("Send log message from server", func(c *qt.C) { var logMessages []execrpc.Message client, err := execrpc.StartClient( - execrpc.ClientOptions[model.ExampleRequest, model.ExampleResponse]{ + execrpc.ClientOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ ClientRawOptions: execrpc.ClientRawOptions{ Version: 1, Cmd: "go", Dir: "./examples/servers/typed", Args: []string{"run", "."}, Env: []string{"EXECRPC_SEND_TWO_LOG_MESSAGES=true"}, - Timeout: 4 * time.Second, - OnMessage: func(msg execrpc.Message) { - logMessages = append(logMessages, msg) - }, + Timeout: 30 * time.Second, }, - Codec: codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, + Codec: codecs.JSONCodec{}, }, ) if err != nil { c.Fatal(err) } - _, err = client.Execute(model.ExampleRequest{Text: "world"}) - c.Assert(err, qt.IsNil) - c.Assert(len(logMessages), qt.Equals, 2) + var wg errgroup.Group + wg.Go(func() error { + for msg := range client.MessagesRaw() { + fmt.Println("got message", string(msg.Body)) + logMessages = append(logMessages, msg) + } + return nil + }) + result := client.Execute(model.ExampleRequest{Text: "world"}) + c.Assert(result.Err(), qt.IsNil) + assertMessages(c, result, 1) + c.Assert(client.Close(), qt.IsNil) + c.Assert(wg.Wait(), qt.IsNil) c.Assert(string(logMessages[0].Body), qt.Equals, "first log message") c.Assert(logMessages[0].Header.Status, qt.Equals, uint16(150)) c.Assert(logMessages[0].Header.Version, qt.Equals, uint16(32)) - c.Assert(client.Close(), qt.IsNil) }) - c.Run("Error", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, "EXECRPC_CALL_SHOULD_FAIL=true") - result, err := client.Execute(model.ExampleRequest{Text: "hello"}) - c.Assert(err, qt.IsNil) - c.Assert(result.Err(), qt.IsNotNil) - c.Assert(client.Close(), qt.IsNil) + c.Run("TOML", func(c *qt.C) { + client := newTestClient(c, codecs.TOMLCodec{}) + result := runBasicTestForClient(c, client) + assertMessages(c, result, 1) + }) + + c.Run("Error in receipt", func(c *qt.C) { + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_CALL_SHOULD_FAIL=true") + result := client.Execute(model.ExampleRequest{Text: "hello"}) + c.Assert(result.Err(), qt.IsNil) + receipt := <-result.Receipt() + c.Assert(receipt.Error, qt.Not(qt.IsNil)) + assertMessages(c, result, 0) }) // The "stdout print tests" are just to make sure that the server behaves and does not hang. c.Run("Print to stdout outside server before", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout inside server", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, "EXECRPC_PRINT_INSIDE_SERVER=true") + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_INSIDE_SERVER=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout outside server before", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_BEFORE=true") runBasicTestForClient(c, client) }) c.Run("Print to stdout inside after", func(c *qt.C) { - client := newClient(c, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}, "EXECRPC_PRINT_OUTSIDE_SERVER_AFTER=true") + client := newTestClient(c, codecs.JSONCodec{}, "EXECRPC_PRINT_OUTSIDE_SERVER_AFTER=true") runBasicTestForClient(c, client) }) - } func TestExecTypedConcurrent(t *testing.T) { - client := newTestClient(t, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}) + client := newTestClient(t, codecs.JSONCodec{}) var g errgroup.Group for i := 0; i < 100; i++ { @@ -190,16 +288,21 @@ func TestExecTypedConcurrent(t *testing.T) { g.Go(func() error { for j := 0; j < 10; j++ { text := fmt.Sprintf("%d-%d", i, j) - result, err := client.Execute(model.ExampleRequest{Text: text}) - if err != nil { + result := client.Execute(model.ExampleRequest{Text: text}) + if err := result.Err(); err != nil { return err } - if result.Err() != nil { - return result.Err() + var k int + for response := range result.Messages() { + expect := fmt.Sprintf("%d: Hello %s!", k, text) + if string(response.Hello) != expect { + return fmt.Errorf("unexpected result: %s", response.Hello) + } + k++ } - expect := fmt.Sprintf("Hello %s!", text) - if string(result.Hello) != expect { - return fmt.Errorf("unexpected result: %s", result.Hello) + receipt := <-result.Receipt() + if receipt.Text != "echoed: "+text { + return fmt.Errorf("unexpected receipt: %s", receipt.Text) } } return nil @@ -209,47 +312,36 @@ func TestExecTypedConcurrent(t *testing.T) { if err := g.Wait(); err != nil { t.Fatal(err) } - } func BenchmarkClient(b *testing.B) { - const word = "World" - b.Run("JSON", func(b *testing.B) { - client := newTestClient(b, codecs.JSONCodec[model.ExampleRequest, model.ExampleResponse]{}) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := client.Execute(model.ExampleRequest{Text: word}) - if err != nil { - b.Fatal(err) + runBenchmark := func(name string, codec codecs.Codec, env ...string) { + b.Run(name, func(b *testing.B) { + client := newTestClient(b, codec, env...) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + result := client.Execute(model.ExampleRequest{Text: word}) + if err := result.Err(); err != nil { + b.Fatal(err) + } + for range result.Messages() { + } + <-result.Receipt() + if err := result.Err(); err != nil { + b.Fatal(err) + } } - } + }) }) - }) - - b.Run("TOML", func(b *testing.B) { - client := newTestClient(b, codecs.TOMLCodec[model.ExampleRequest, model.ExampleResponse]{}) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := client.Execute(model.ExampleRequest{Text: word}) - if err != nil { - b.Fatal(err) - } - } - }) - }) - - b.Run("Gob", func(b *testing.B) { - client := newTestClient(b, codecs.GobCodec[model.ExampleRequest, model.ExampleResponse]{}) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - _, err := client.Execute(model.ExampleRequest{Text: word}) - if err != nil { - b.Fatal(err) - } - } - }) - }) + } + runBenchmarksForCodec := func(codec codecs.Codec) { + runBenchmark("1 message "+codec.Name(), codec) + runBenchmark("100 messages "+codec.Name(), codec, "EXECRPC_NUM_MESSAGES=100") + } + runBenchmarksForCodec(codecs.JSONCodec{}) + runBenchmark("100 messages JSON, no hasher ", codecs.JSONCodec{}, "EXECRPC_NUM_MESSAGES=100", "EXECRPC_NO_HASHER=true") + runBenchmarksForCodec(codecs.TOMLCodec{}) } diff --git a/codecs/codecs.go b/codecs/codecs.go index 73c2a80..3d9c6b8 100644 --- a/codecs/codecs.go +++ b/codecs/codecs.go @@ -2,7 +2,6 @@ package codecs import ( "bytes" - "encoding/gob" "encoding/json" "errors" "strings" @@ -11,9 +10,9 @@ import ( ) // Codec defines the interface for a two way conversion between Q and R. -type Codec[Q, R any] interface { - Encode(Q) ([]byte, error) - Decode([]byte, *R) error +type Codec interface { + Encode(any) ([]byte, error) + Decode([]byte, any) error Name() string } @@ -21,27 +20,25 @@ type Codec[Q, R any] interface { var ErrUnknownCodec = errors.New("unknown codec") // ForName returns the codec for the given name or ErrUnknownCodec if no codec is found. -func ForName[Q, R any](name string) (Codec[Q, R], error) { +func ForName(name string) (Codec, error) { switch strings.ToLower(name) { case "toml": - return TOMLCodec[Q, R]{}, nil + return TOMLCodec{}, nil case "json": - return JSONCodec[Q, R]{}, nil - case "gob": - return GobCodec[Q, R]{}, nil + return JSONCodec{}, nil default: return nil, ErrUnknownCodec } } // TOMLCodec is a Codec that uses TOML as the underlying format. -type TOMLCodec[Q, R any] struct{} +type TOMLCodec struct{} -func (c TOMLCodec[Q, R]) Decode(b []byte, r *R) error { +func (c TOMLCodec) Decode(b []byte, r any) error { return toml.Unmarshal(b, r) } -func (c TOMLCodec[Q, R]) Encode(q Q) ([]byte, error) { +func (c TOMLCodec) Encode(q any) ([]byte, error) { var b bytes.Buffer enc := toml.NewEncoder(&b) if err := enc.Encode(q); err != nil { @@ -50,43 +47,21 @@ func (c TOMLCodec[Q, R]) Encode(q Q) ([]byte, error) { return b.Bytes(), nil } -func (c TOMLCodec[Q, R]) Name() string { +func (c TOMLCodec) Name() string { return "TOML" } // JSONCodec is a Codec that uses JSON as the underlying format. -type JSONCodec[Q, R any] struct{} +type JSONCodec struct{} -func (c JSONCodec[Q, R]) Decode(b []byte, r *R) error { +func (c JSONCodec) Decode(b []byte, r any) error { return json.Unmarshal(b, r) } -func (c JSONCodec[Q, R]) Encode(q Q) ([]byte, error) { +func (c JSONCodec) Encode(q any) ([]byte, error) { return json.Marshal(q) } -func (c JSONCodec[Q, R]) Name() string { +func (c JSONCodec) Name() string { return "JSON" } - -// GobCodec is a Codec that uses gob as the underlying format. -type GobCodec[Q, R any] struct{} - -func (c GobCodec[Q, R]) Decode(b []byte, r *R) error { - dec := gob.NewDecoder(bytes.NewReader(b)) - return dec.Decode(r) -} - -func (c GobCodec[Q, R]) Encode(q Q) ([]byte, error) { - var b bytes.Buffer - enc := gob.NewEncoder(&b) - err := enc.Encode(q) - if err != nil { - return nil, err - } - return b.Bytes(), nil -} - -func (c GobCodec[Q, R]) Name() string { - return "Gob" -} diff --git a/conn.go b/conn.go index b5fe625..266a1f3 100644 --- a/conn.go +++ b/conn.go @@ -15,8 +15,12 @@ import ( "golang.org/x/sync/errgroup" ) -// ErrTimeoutWaitingForServer is returned on timeouts starting the server. -var ErrTimeoutWaitingForServer = errors.New("timed out waiting for server to start") +var ( + // ErrTimeoutWaitingForServer is returned on timeouts starting the server. + ErrTimeoutWaitingForServer = errors.New("timed out waiting for server to start") + // ErrTimeoutWaitingForCall is returned on timeouts waiting for a call to complete. + ErrTimeoutWaitingForCall = errors.New("timed out waiting for call to complete") +) var brokenPipeRe = regexp.MustCompile("Broken pipe|pipe is being closed") @@ -83,7 +87,7 @@ func (c conn) Start() error { g, ctx := errgroup.WithContext(ctx) g.Go(func() error { - // THe server will announce when it's ready to read from stdin + // The server will announce when it's ready to read from stdin // by writing a special string to stdout. for { select { @@ -137,6 +141,8 @@ func (c conn) Start() error { // time to do so. func (c conn) waitWithTimeout() error { result := make(chan error, 1) + timer := time.NewTimer(c.timeout) + defer timer.Stop() go func() { result <- c.cmd.Wait() }() select { case err := <-result: @@ -146,7 +152,7 @@ func (c conn) waitWithTimeout() error { } } return err - case <-time.After(time.Second): + case <-timer.C: return errors.New("timed out waiting for server to finish") } } diff --git a/examples/model/model.go b/examples/model/model.go index a582ef0..7a7085e 100644 --- a/examples/model/model.go +++ b/examples/model/model.go @@ -1,18 +1,26 @@ package model +import "github.com/bep/execrpc" + // ExampleRequest is just a simple example request. type ExampleRequest struct { Text string `json:"text"` } -// ExampleResponse is just a simple example response. -type ExampleResponse struct { +// ExampleMessage is just a simple example message. +type ExampleMessage struct { Hello string `json:"hello"` +} + +// ExampleReceipt is just a simple receipt. +type ExampleReceipt struct { + execrpc.Identity Error *Error `json:"err"` + Text string `json:"text"` } // Err is just a simple example error. -func (r ExampleResponse) Err() error { +func (r ExampleReceipt) Err() error { if r.Error == nil { // Make sure that resp.Err() == nil. return nil diff --git a/examples/servers/raw/go.sum b/examples/servers/raw/go.sum index e259118..fc21b62 100644 --- a/examples/servers/raw/go.sum +++ b/examples/servers/raw/go.sum @@ -1,18 +1,16 @@ -github.com/bep/execrpc v0.3.0 h1:cZCobakmpgFSWRtJBmOmksSirSQZzz3m3+Ojc0YroGs= -github.com/bep/execrpc v0.3.0/go.mod h1:wDgrl/LwaUp6+EFCx0w1uvqP1cbwln/4OhB608TIREk= github.com/bep/helpers v0.1.0 h1:HFLG+W6axHackmKMk0houEnz9G2aiBrDMZyOvL9J0WM= github.com/bep/helpers v0.1.0/go.mod h1:/QpHdmcPagDw7+RjkLFCvnlUc8lQ5kg4KDrEkb2Yyco= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= diff --git a/examples/servers/raw/main.go b/examples/servers/raw/main.go index b586c81..5f4a0b1 100644 --- a/examples/servers/raw/main.go +++ b/examples/servers/raw/main.go @@ -9,15 +9,21 @@ import ( func main() { server, err := execrpc.NewServerRaw( execrpc.ServerRawOptions{ - Call: func(d execrpc.Dispatcher, message execrpc.Message) (execrpc.Message, error) { - return execrpc.Message{ - Header: message.Header, - Body: append([]byte("echo: "), message.Body...), - }, nil + Call: func(req execrpc.Message, d execrpc.Dispatcher) error { + header := req.Header + // execrpc.MessageStatusOK will complete the exchange. + // Setting it to execrpc.MessageStatusContinue will continue the conversation. + header.Status = execrpc.MessageStatusOK + d.SendMessage( + execrpc.Message{ + Header: header, + Body: append([]byte("echo: "), req.Body...), + }, + ) + return nil }, }, ) - if err != nil { handleErr(err) } @@ -25,7 +31,6 @@ func main() { if err := server.Start(); err != nil { handleErr(err) } - _ = server.Wait() } func handleErr(err error) { diff --git a/examples/servers/typed/go.mod b/examples/servers/typed/go.mod index 3211b69..7076eff 100644 --- a/examples/servers/typed/go.mod +++ b/examples/servers/typed/go.mod @@ -2,8 +2,9 @@ module github.com/bep/execrpc/examples/servers/typed go 1.19 +require github.com/bep/execrpc v0.3.0 + require ( - github.com/bep/execrpc v0.3.0 // indirect github.com/bep/helpers v0.1.0 // indirect github.com/pelletier/go-toml/v2 v2.0.2 // indirect golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect diff --git a/examples/servers/typed/go.sum b/examples/servers/typed/go.sum index 32fdbd8..fc21b62 100644 --- a/examples/servers/typed/go.sum +++ b/examples/servers/typed/go.sum @@ -1,14 +1,21 @@ -github.com/bep/execrpc v0.3.0 h1:cZCobakmpgFSWRtJBmOmksSirSQZzz3m3+Ojc0YroGs= -github.com/bep/execrpc v0.3.0/go.mod h1:wDgrl/LwaUp6+EFCx0w1uvqP1cbwln/4OhB608TIREk= github.com/bep/helpers v0.1.0 h1:HFLG+W6axHackmKMk0houEnz9G2aiBrDMZyOvL9J0WM= github.com/bep/helpers v0.1.0/go.mod h1:/QpHdmcPagDw7+RjkLFCvnlUc8lQ5kg4KDrEkb2Yyco= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuwRkXbG66OzopI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/servers/typed/main.go b/examples/servers/typed/main.go index 6c3cc64..594d2bc 100644 --- a/examples/servers/typed/main.go +++ b/examples/servers/typed/main.go @@ -2,8 +2,13 @@ package main import ( "fmt" + "hash" + "hash/fnv" "log" "os" + "strconv" + "sync/atomic" + "time" "github.com/bep/execrpc" "github.com/bep/execrpc/examples/model" @@ -20,26 +25,54 @@ func main() { printInsideServer = os.Getenv("EXECRPC_PRINT_INSIDE_SERVER") != "" callShouldFail = os.Getenv("EXECRPC_CALL_SHOULD_FAIL") != "" sendLogMessage = os.Getenv("EXECRPC_SEND_TWO_LOG_MESSAGES") != "" + noClose = os.Getenv("EXECRPC_NO_CLOSE") != "" + noReadingReceipt = os.Getenv("EXECRPC_NO_READING_RECEIPT") != "" + numMessagesStr = os.Getenv("EXECRPC_NUM_MESSAGES") + numMessages = 1 + delayDelivery = os.Getenv("EXECRPC_DELAY_DELIVERY") != "" + dropMessages = os.Getenv("EXECRPC_DROP_MESSAGES") != "" + noHasher = os.Getenv("EXECRPC_NO_HASHER") != "" ) + if numMessagesStr != "" { + numMessages, _ = strconv.Atoi(numMessagesStr) + if numMessages < 1 { + numMessages = 1 + } + } + if printOutsideServerBefore { fmt.Println("Printing outside server before") } + var getHasher func() hash.Hash + + if !noHasher { + getHasher = func() hash.Hash { + return fnv.New64a() + } + } + server, err := execrpc.NewServer( - execrpc.ServerOptions[model.ExampleRequest, model.ExampleResponse]{ - Call: func(d execrpc.Dispatcher, req model.ExampleRequest) model.ExampleResponse { + execrpc.ServerOptions[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]{ + GetHasher: getHasher, + DelayDelivery: delayDelivery, + Handle: func(c *execrpc.Call[model.ExampleRequest, model.ExampleMessage, model.ExampleReceipt]) { if printInsideServer { fmt.Println("Printing inside server") } if callShouldFail { - return model.ExampleResponse{ - Error: &model.Error{Msg: "failed to echo"}, - } + c.Close( + false, + model.ExampleReceipt{ + Error: &model.Error{Msg: "failed to echo"}, + }, + ) + return } if sendLogMessage { - d.Send( + c.SendRaw( execrpc.Message{ Header: execrpc.Header{ Version: 32, @@ -53,17 +86,45 @@ func main() { Status: 150, }, Body: []byte("second log message"), - }) + }, + ) } - return model.ExampleResponse{ - Hello: "Hello " + req.Text + "!", + for i := 0; i < numMessages; i++ { + c.Enqueue( + model.ExampleMessage{ + Hello: strconv.Itoa(i) + ": Hello " + c.Request.Text + "!", + }, + ) } + if !noClose { + var receipt model.ExampleReceipt + if !noReadingReceipt { + var receiptSeen atomic.Bool + go func() { + time.Sleep(1 * time.Second) + if !receiptSeen.Load() { + log.Fatalf("expected receipt to be seen") + } + }() + + receipt = <-c.Receipt() + receipt.Text = "echoed: " + c.Request.Text + receipt.Size = uint32(123) + + receiptSeen.Store(true) + + if getHasher != nil && receipt.ETag == "" { + log.Fatalf("expected receipt eTag to be set") + } + } + + c.Close(dropMessages, receipt) + } }, }, ) - if err != nil { handleErr(err) } @@ -74,13 +135,9 @@ func main() { if printOutsideServerAfter { fmt.Println("Printing outside server after") - } - _ = server.Wait() - } func handleErr(err error) { log.Fatalf("error: failed to start typed echo server: %s", err) - } diff --git a/go.mod b/go.mod index ad6d931..313d81d 100644 --- a/go.mod +++ b/go.mod @@ -4,14 +4,14 @@ go 1.18 require ( github.com/bep/helpers v0.1.0 - github.com/frankban/quicktest v1.14.3 + github.com/frankban/quicktest v1.14.6 github.com/pelletier/go-toml/v2 v2.0.2 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 ) require ( - github.com/google/go-cmp v0.5.8 // indirect - github.com/kr/pretty v0.3.0 // indirect + github.com/google/go-cmp v0.5.9 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect - github.com/rogpeppe/go-internal v1.8.1 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect ) diff --git a/go.sum b/go.sum index 7e8a47f..9558582 100644 --- a/go.sum +++ b/go.sum @@ -3,16 +3,12 @@ github.com/bep/helpers v0.1.0/go.mod h1:/QpHdmcPagDw7+RjkLFCvnlUc8lQ5kg4KDrEkb2Y github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pelletier/go-toml/v2 v2.0.2 h1:+jQXlF3scKIcSEKkdHzXhCTDLPFi5r1wnK6yPS+49Gw= @@ -20,17 +16,13 @@ github.com/pelletier/go-toml/v2 v2.0.2/go.mod h1:MovirKjgVRESsAvNZlAjtFwV867yGuw github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg= -github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message_test.go b/message_test.go index db67b42..8d312ac 100644 --- a/message_test.go +++ b/message_test.go @@ -30,5 +30,4 @@ func TestMessage(t *testing.T) { c.Assert(m2.Read(&b), qt.IsNil) c.Assert(m2, qt.DeepEquals, m1) - } diff --git a/server.go b/server.go index 9dc4fa1..2700115 100644 --- a/server.go +++ b/server.go @@ -1,18 +1,24 @@ package execrpc import ( + "encoding/hex" "fmt" + "hash" "io" "os" - "sync" + "time" "github.com/bep/execrpc/codecs" "golang.org/x/sync/errgroup" ) const ( - // MessageStatusOK is the status code for a successful message. + // MessageStatusOK is the status code for a successful and complete message exchange. MessageStatusOK = iota + + // MessageStatusContinue is the status code for a message that should continue the conversation. + MessageStatusContinue + // MessageStatusErrDecodeFailed is the status code for a message that failed to decode. MessageStatusErrDecodeFailed // MessageStatusErrEncodeFailed is the status code for a message that failed to encode. @@ -30,30 +36,33 @@ func NewServerRaw(opts ServerRawOptions) (*ServerRaw, error) { s := &ServerRaw{ call: opts.Call, } - s.dispatcher = &messageDispatcher{ + s.dispatcher = messageDispatcher{ s: s, } return s, nil } // NewServer creates a new Server. using the given options. -func NewServer[Q, R any](opts ServerOptions[Q, R]) (*Server[Q, R], error) { - if opts.Call == nil { - return nil, fmt.Errorf("opts: Call function is required") +func NewServer[Q, M, R comparable](opts ServerOptions[Q, M, R]) (*Server[Q, M, R], error) { + if opts.Handle == nil { + return nil, fmt.Errorf("opts: Handle function is required") } if opts.Codec == nil { codecName := os.Getenv(envClientCodec) var err error - opts.Codec, err = codecs.ForName[R, Q](codecName) + opts.Codec, err = codecs.ForName(codecName) if err != nil { return nil, fmt.Errorf("failed to resolve codec from env variable %s with value %q (set by client); it can optionally be set in ServerOptions", envClientCodec, codecName) } } - var rawServer *ServerRaw + var ( + rawServer *ServerRaw + messagesRaw = make(chan Message, 10) + ) - call := func(d Dispatcher, message Message) (Message, error) { + callRaw := func(message Message, d Dispatcher) error { var q Q err := opts.Codec.Decode(message.Body, &q) if err != nil { @@ -62,64 +71,198 @@ func NewServer[Q, R any](opts ServerOptions[Q, R]) (*Server[Q, R], error) { Body: []byte(fmt.Sprintf("failed to decode request: %s. Check that client and server uses the same codec.", err)), } m.Header.Status = MessageStatusErrDecodeFailed - return m, nil + d.SendMessage(m) + return nil } - r := opts.Call(rawServer.dispatcher, q) - b, err := opts.Codec.Encode(r) - if err != nil { - m := Message{ - Header: message.Header, - Body: []byte(fmt.Sprintf("failed to encode response: %s. Check that client and server uses the same codec.", err)), + + call := &Call[Q, M, R]{ + Request: q, + messagesRaw: messagesRaw, + messages: make(chan M, 10), + receiptToServer: make(chan R, 1), + receiptFromServer: make(chan R, 1), + } + + go func() { + opts.Handle(call) + if !call.closed1 { + // The server returned without fetching the Receipt. + call.closeMessages() + } + if !call.closed2 { + // The server did not call Close, + // just send an empty receipt. + var r R + call.Close(false, r) } - m.Header.Status = MessageStatusErrEncodeFailed - return m, nil + }() + + var size uint32 + var hasher hash.Hash + if opts.GetHasher != nil { + hasher = opts.GetHasher() } - return Message{ - Header: message.Header, - Body: b, - }, nil + + var shouldHash bool + if hasher != nil { + // Avoid hashing if the receipt does not implement Sum64Provider. + var r *R + _, shouldHash = any(r).(TagProvider) + } + + var ( + checksum string + messageBuff []Message + ) + + defer func() { + receipt := <-call.receiptFromServer + + // Send any buffered message before the receipt. + if opts.DelayDelivery && !call.drop { + for _, m := range messageBuff { + d.SendMessage(m) + } + } + + b, err := opts.Codec.Encode(receipt) + h := message.Header + h.Status = MessageStatusOK + d.SendMessage(createMessage(b, err, h, MessageStatusErrEncodeFailed)) + }() + + for m := range call.messages { + b, err := opts.Codec.Encode(m) + h := message.Header + h.Status = MessageStatusContinue + m := createMessage(b, err, h, MessageStatusErrEncodeFailed) + if opts.DelayDelivery { + messageBuff = append(messageBuff, m) + } else { + d.SendMessage(m) + } + if shouldHash { + hasher.Write(m.Body) + } + size += uint32(len(m.Body)) + } + if shouldHash { + checksum = hex.EncodeToString(hasher.Sum(nil)) + } + + var receipt R + setReceiptValuesIfNotSet(size, checksum, &receipt) + + call.receiptToServer <- receipt + + return nil } var err error rawServer, err = NewServerRaw( ServerRawOptions{ - Call: call, + Call: callRaw, }, ) - if err != nil { return nil, err } - return &Server[Q, R]{ - ServerRaw: rawServer, - }, nil + s := &Server[Q, M, R]{ + messagesRaw: messagesRaw, + ServerRaw: rawServer, + } + + // Handle standalone messages in its own goroutine. + go func() { + for message := range s.messagesRaw { + rawServer.dispatcher.SendMessage(message) + } + }() + + return s, nil +} + +func setReceiptValuesIfNotSet(size uint32, checksum string, r any) { + if m, ok := any(r).(LastModifiedProvider); ok && m.GetELastModified() == 0 { + m.SetELastModified(time.Now().Unix()) + } + if size != 0 { + if m, ok := any(r).(SizeProvider); ok && m.GetESize() == 0 { + m.SetESize(size) + } + } + if checksum != "" { + if m, ok := any(r).(TagProvider); ok && m.GetETag() == "" { + m.SetETag(checksum) + } + } +} + +func createMessage(b []byte, err error, h Header, failureStatus uint16) Message { + var m Message + if err != nil { + m = Message{ + Header: h, + Body: []byte(fmt.Sprintf("failed create message: %s. Check that client and server uses the same codec.", err)), + } + m.Header.Status = failureStatus + } else { + m = Message{ + Header: h, + Body: b, + } + } + return m } // ServerOptions is the options for a server. -type ServerOptions[Q, R any] struct { - // Call is the function that will be called when a request is received. - Call func(Dispatcher, Q) R +type ServerOptions[Q, M, R any] struct { + // Handle is the function that will be called when a request is received. + Handle func(*Call[Q, M, R]) - // Codec is the codec that will be used to encode and decode requests and responses. + // Codec is the codec that will be used to encode and decode requests, messages and receipts. // The client will tell the server what codec is in use, so in most cases you should just leave this unset. - Codec codecs.Codec[R, Q] + Codec codecs.Codec + + // GetHasher returns the hash instance to be used for the response body + // If it's not set or it returns nil, no hash will be calculated. + GetHasher func() hash.Hash + + // Delay delivery of messages to the client until Close is called. + // Close takes a drop parameter that will drop any buffered messages. + // This can be useful if you want to check the server generated ETag, + // maybe the client already has this data. + DelayDelivery bool } // Server is a stringly typed server for requests of type Q and responses of tye R. -type Server[Q, R any] struct { +type Server[Q, M, R any] struct { + messagesRaw chan Message *ServerRaw } +func (s *Server[Q, M, R]) Start() error { + err := s.ServerRaw.Start() + + // Close the standalone message channel. + close(s.messagesRaw) + + if err == io.EOF { + return nil + } + + return err +} + // ServerRaw is a RPC server handling raw messages with a header and []byte body. // See Server for a generic, typed version. type ServerRaw struct { - call func(Dispatcher, Message) (Message, error) - dispatcher *messageDispatcher + call func(Message, Dispatcher) error + dispatcher messageDispatcher - startInit sync.Once - started bool - onStop func() + started bool + onStop func() in io.Reader out io.Writer @@ -131,89 +274,67 @@ type ServerRaw struct { var serverStarted = []byte("_server_started") // Start sets upt the server communication and starts the server loop. -// It's safe to call Start multiple times, but only the first call will start the server. func (s *ServerRaw) Start() error { - var initErr error - s.startInit.Do(func() { - defer func() { - s.started = true - }() - - // os.Stdout is where the client will listen for a specific byte stream, - // and any writes to stdout outside of this protocol (e.g. fmt.Println("hello world!") will - // freeze the server. - // - // To prevent that, we preserve the original stdout for the server and redirect user output to stderr. - origStdout := os.Stdout - done := make(chan bool) - - r, w, err := os.Pipe() - if err != nil { - initErr = err - return - } - - os.Stdout = w - - go func() { - // Copy all output from the pipe to stderr. - _, _ = io.Copy(os.Stderr, r) - // Done when the pipe is closed. - done <- true - }() + if s.started { + panic("server already started") + } + s.started = true - s.in = os.Stdin - s.out = origStdout - s.onStop = func() { - // Close one side of the pipe. - _ = w.Close() - <-done - } + // os.Stdout is where the client will listen for a specific byte stream, + // and any writes to stdout outside of this protocol (e.g. fmt.Println("hello world!") will + // freeze the server. + // + // To prevent that, we preserve the original stdout for the server and redirect user output to stderr. + origStdout := os.Stdout + done := make(chan bool) - s.g = &errgroup.Group{} + r, w, err := os.Pipe() + if err != nil { + return err + } - // Signal to client that the server is ready. - fmt.Fprint(s.out, string(serverStarted)+"\n") + os.Stdout = w + + go func() { + // Copy all output from the pipe to stderr. + _, _ = io.Copy(os.Stderr, r) + // Done when the pipe is closed. + done <- true + }() + + s.in = os.Stdin + s.out = origStdout + s.onStop = func() { + // Close one side of the pipe. + _ = w.Close() + <-done + } - s.g.Go(func() error { - return s.inputOutput() - }) + s.g = &errgroup.Group{} - s.g.Go(func() error { + // Signal to client that the server is ready. + fmt.Fprint(s.out, string(serverStarted)+"\n") - return nil - }) + s.g.Go(func() error { + return s.inputOutput() }) - return initErr -} - -// Wait waits for the server to stop. -// This happens when it gets disconnected from the client. -func (s *ServerRaw) Wait() error { - s.checkStarted() - err := s.g.Wait() + err = s.g.Wait() if s.onStop != nil { s.onStop() } - return err -} -func (s *ServerRaw) checkStarted() { - if !s.started { - panic("server not started") - } + return err } // inputOutput reads messages from the stdin and calls the server's call function. // The response is written to stdout. func (s *ServerRaw) inputOutput() error { - // We currently treat all errors in here as stop signals. // This means that the server will stop taking requests and // needs to be restarted. // Server implementations should communicate client error situations - // via the response message. + // via the messages. var err error for err == nil { var header Header @@ -226,25 +347,17 @@ func (s *ServerRaw) inputOutput() error { break } - var response Message - response, err = s.call( - s.dispatcher, + err = s.call( Message{ Header: header, Body: body, }, + s.dispatcher, ) - if err != nil { break } - response.Header.Size = uint32(len(response.Body)) - - if err = response.Write(s.out); err != nil { - break - } - } return err @@ -255,32 +368,79 @@ type ServerRawOptions struct { // Call is the message exhcange between the client and server. // Note that any error returned by this function will be treated as a fatal error and the server is stopped. // Validation errors etc. should be returned in the response message. - // The Dispatcher can be used to send messages to the client outside of the request/response loop, e.g. log messages. - // Note that these messages must have ID 0. - Call func(Dispatcher, Message) (Message, error) + // Message passed to the Dispatcher as part of the request/response must + // use the same ID as the request. + // ID 0 is reserved for standalone messages (e.g. log messages). + Call func(Message, Dispatcher) error } type messageDispatcher struct { s *ServerRaw } -// Dispatcher is the interface for dispatching standalone messages to the client, e.g. log messages. -type Dispatcher interface { - // Send sends one or more message back to the client. - // This is normally used for log messages and similar, - // and these messages should have a zero (0) ID. - Send(...Message) error +// Call is the request/response exchange between the client and server. +// Note that the stream parameter S is optional, set it to any if not used. +type Call[Q, M, R any] struct { + Request Q + messagesRaw chan Message + messages chan M + receiptFromServer chan R + receiptToServer chan R + + closed1 bool // No more messages. + closed2 bool // Receipt set. + drop bool // Drop buffered messages. } -func (s *messageDispatcher) Send(messages ...Message) error { - for _, message := range messages { - if message.Header.ID != 0 { - return fmt.Errorf("message ID must be 0") +// SendRaw sends one or more messages back to the client +// that is not part of the request/response exchange. +// These messages must have ID 0. +func (c *Call[Q, M, R]) SendRaw(ms ...Message) { + for _, m := range ms { + if m.Header.ID != 0 { + panic("message ID must be 0 for standalone messages") } - message.Header.Size = uint32(len(message.Body)) - if err := message.Write(s.s.out); err != nil { - return err + c.messagesRaw <- m + } +} + +// Enqueue enqueues one or more messages to be sent back to the client. +func (c *Call[Q, M, R]) Enqueue(rr ...M) { + for _, r := range rr { + c.messages <- r + } +} + +func (c *Call[Q, M, R]) Receipt() <-chan R { + c.closeMessages() + return c.receiptToServer +} + +// Close closes the call and sends andy buffered messages and the receipt back to the client. +// If drop is true, the buffered messages are dropped. +// Note that drop is only relevant if the server is configured with DelayDelivery set to true. +func (c *Call[Q, M, R]) Close(drop bool, r R) { + c.drop = drop + c.closed2 = true + c.receiptFromServer <- r +} + +func (c *Call[Q, M, R]) closeMessages() { + c.closed1 = true + close(c.messages) +} + +// Dispatcher is the interface for dispatching messages to the client. +type Dispatcher interface { + // SendMessage sends one or more message back to the client. + SendMessage(...Message) +} + +func (s messageDispatcher) SendMessage(ms ...Message) { + for _, m := range ms { + m.Header.Size = uint32(len(m.Body)) + if err := m.Write(s.s.out); err != nil { + panic(err) } } - return nil }