From 39cf7968fa8710d578cdddcab6e5dfd93ab6ebba Mon Sep 17 00:00:00 2001 From: Alexej Kubarev Date: Thu, 9 Jan 2020 00:34:56 +0100 Subject: [PATCH] A bit of refactoring and more documentation --- client.go | 5 +++ client_test.go | 4 +- codec.go | 19 +++++++--- codec_response.go | 34 ----------------- decode.go | 93 +++++++++++++++++----------------------------- decode_response.go | 47 +++++++++++++++++++++++ decode_test.go | 10 +++-- doc.go | 19 ++++++++++ encode.go | 59 ++++++++++++++++------------- encode_test.go | 23 +++++++----- fault.go | 5 ++- 11 files changed, 179 insertions(+), 139 deletions(-) delete mode 100644 codec_response.go create mode 100644 decode_response.go create mode 100644 doc.go diff --git a/client.go b/client.go index 8758e78..c0527e3 100644 --- a/client.go +++ b/client.go @@ -7,15 +7,20 @@ import ( "net/url" ) +// Client is responsible for making calls to RPC services with help of underlying rpc.Client. type Client struct { *rpc.Client } +// NewClient creates a Client with http.DefaultClient. +// If provided endpoint is not valid, an error is returned. func NewClient(endpoint string) (*Client, error) { return NewClientWithHttpClient(endpoint, http.DefaultClient) } +// NewClientWithHttpClient allows customization of http.Client used to make RPC calls. +// If provided endpoint is not valid, an error is returned. func NewClientWithHttpClient(endpoint string, httpClient *http.Client) (*Client, error) { // Parse Endpoint URL diff --git a/client_test.go b/client_test.go index 5ca5783..4c24595 100644 --- a/client_test.go +++ b/client_test.go @@ -17,8 +17,8 @@ func TestClient_Call(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := &struct { - Name string `xml:"methodName"` - Params []*respParam `xml:"params>param"` + Name string `xml:"methodName"` + Params []*ResponseParam `xml:"params>param"` }{} body, err := ioutil.ReadAll(r.Body) assert.NoError(t, err, "test server: read body") diff --git a/codec.go b/codec.go index de8a772..2520713 100644 --- a/codec.go +++ b/codec.go @@ -11,6 +11,8 @@ import ( "sync" ) +// Codec implements methods required by rpc.ClientCodec +// In this implementation Codec is the one performing actual RPC requests with http.Client. type Codec struct { endpoint *url.URL httpClient *http.Client @@ -20,7 +22,9 @@ type Codec struct { pending map[uint64]*rpcCall // Current in-flight response - response *decodableResponse + response *Response + encoder Encoder + decoder Decoder // presents completed requests by sequence ID ready chan uint64 @@ -32,11 +36,16 @@ type rpcCall struct { httpResponse *http.Response } +// NewCodec creates a new Codec bound to provided endpoint. +// Provided client will be used to perform RPC requests. func NewCodec(endpoint *url.URL, httpClient *http.Client) *Codec { return &Codec{ endpoint: endpoint, httpClient: httpClient, + encoder: &StdEncoder{}, + decoder: &StdDecoder{}, + pending: make(map[uint64]*rpcCall), response: nil, ready: make(chan uint64), @@ -46,7 +55,7 @@ func NewCodec(endpoint *url.URL, httpClient *http.Client) *Codec { func (c *Codec) WriteRequest(req *rpc.Request, args interface{}) error { bodyBuffer := new(bytes.Buffer) - err := EncodeMethodCall(bodyBuffer, req.ServiceMethod, args) + err := c.encoder.Encode(bodyBuffer, req.ServiceMethod, args) if err != nil { return err } @@ -105,14 +114,14 @@ func (c *Codec) ReadResponseHeader(resp *rpc.Response) error { return nil } - decodableResponse, err := newDecodableResponse(body) + decodableResponse, err := NewResponse(body) if err != nil { resp.Error = err.Error() return nil } // Return response Fault already a this stage - if err := decodableResponse.Fault(); err != nil { + if err := c.decoder.DecodeFault(decodableResponse); err != nil { resp.Error = err.Error() return nil } @@ -131,7 +140,7 @@ func (c *Codec) ReadResponseBody(v interface{}) error { return errors.New("no in-flight response found") } - return c.response.Decode(v) + return c.decoder.Decode(c.response, v) } func (c *Codec) Close() error { diff --git a/codec_response.go b/codec_response.go deleted file mode 100644 index 65bf943..0000000 --- a/codec_response.go +++ /dev/null @@ -1,34 +0,0 @@ -package xmlrpc - -type decodableResponse struct { - body []byte - wrapper *respWrapper -} - -func newDecodableResponse(body []byte) (*decodableResponse, error) { - - wrapper, err := toRespWrapper(body) - if err != nil { - return nil, err - } - - r := &decodableResponse{ - wrapper: wrapper, - } - - return r, nil -} - -func (r *decodableResponse) Fault() *Fault { - - if r.wrapper.Fault == nil { - return nil - } - - return decodeFault(r.wrapper.Fault) -} - -func (r *decodableResponse) Decode(v interface{}) error { - - return decodeWrapper(r.wrapper, v) -} diff --git a/decode.go b/decode.go index 8790e92..3d47964 100644 --- a/decode.go +++ b/decode.go @@ -2,7 +2,6 @@ package xmlrpc import ( "encoding/base64" - "encoding/xml" "fmt" "reflect" "strconv" @@ -14,73 +13,42 @@ const ( errFormatInvalidFieldType = "invalid field type: expected '%s', got '%s'" ) -type respWrapper struct { - Params []respParam `xml:"params>param"` - Fault *respFault `xml:"fault,omitempty"` +// Decoder implementations provide mechanisms for parsing of XML-RPC responses to native data-types. +type Decoder interface { + DecodeRaw(body []byte, v interface{}) error + Decode(response *Response, v interface{}) error + DecodeFault(response *Response) *Fault } -type respParam struct { - Value respValue `xml:"value"` -} - -type respValue struct { - Array []*respValue `xml:"array>data>value"` - Struct []*respStructMember `xml:"struct>member"` - String string `xml:"string"` - Int string `xml:"int"` - Int4 string `xml:"i4"` - Double string `xml:"double"` - Boolean string `xml:"boolean"` - DateTime string `xml:"dateTime.iso8601"` - Base64 string `xml:"base64"` - - Raw string `xml:",innerxml"` // the value can be default string -} - -type respStructMember struct { - Name string `xml:"name"` - Value respValue `xml:"value"` -} - -type respFault struct { - Value respValue `xml:"value"` -} +// StdDecoder is the default implementation of the Decoder interface. +type StdDecoder struct{} -func DecodeResponse(body []byte, v interface{}) error { +func (d *StdDecoder) DecodeRaw(body []byte, v interface{}) error { - wrapper, err := toRespWrapper(body) + response, err := NewResponse(body) if err != nil { return err } - if wrapper.Fault != nil { - return decodeFault(wrapper.Fault) - } - - return decodeWrapper(wrapper, v) -} - -func toRespWrapper(body []byte) (*respWrapper, error) { - wrapper := &respWrapper{} - if err := xml.Unmarshal(body, wrapper); err != nil { - return nil, err + if response.Fault != nil { + return d.decodeFault(response.Fault) } - return wrapper, nil + return d.Decode(response, v) } -func decodeWrapper(wrapper *respWrapper, v interface{}) error { +func (d *StdDecoder) Decode(response *Response, v interface{}) error { // Validate that v has same number of public fields as response params - if err := fieldsMustEqual(v, len(wrapper.Params)); err != nil { + if err := fieldsMustEqual(v, len(response.Params)); err != nil { return err } vElem := reflect.Indirect(reflect.ValueOf(v)) - for i, param := range wrapper.Params { + for i, param := range response.Params { field := vElem.Field(i) - if err := decodeValue(¶m.Value, &field); err != nil { + if err := d.decodeValue(¶m.Value, &field); err != nil { return err } } @@ -88,7 +56,16 @@ func decodeWrapper(wrapper *respWrapper, v interface{}) error { return nil } -func decodeFault(fault *respFault) *Fault { +func (d *StdDecoder) DecodeFault(response *Response) *Fault { + + if response.Fault == nil { + return nil + } + + return d.decodeFault(response.Fault) +} + +func (d *StdDecoder) decodeFault(fault *ResponseFault) *Fault { f := &Fault{} for _, m := range fault.Value.Struct { @@ -107,7 +84,7 @@ func decodeFault(fault *respFault) *Fault { return f } -func decodeValue(value *respValue, field *reflect.Value) error { +func (d *StdDecoder) decodeValue(value *ResponseValue, field *reflect.Value) error { var val interface{} var err error @@ -124,16 +101,16 @@ func decodeValue(value *respValue, field *reflect.Value) error { val, err = strconv.ParseFloat(value.Double, 64) case value.Boolean != "": - val, err = decodeBoolean(value.Boolean) + val, err = d.decodeBoolean(value.Boolean) case value.String != "": val, err = value.String, nil case value.Base64 != "": - val, err = decodeBase64(value.Base64) + val, err = d.decodeBase64(value.Base64) case value.DateTime != "": - val, err = decodeDateTime(value.DateTime) + val, err = d.decodeDateTime(value.DateTime) // Array decoding case len(value.Array) > 0: @@ -145,7 +122,7 @@ func decodeValue(value *respValue, field *reflect.Value) error { slice := reflect.MakeSlice(reflect.TypeOf(field.Interface()), len(value.Array), len(value.Array)) for i, v := range value.Array { item := slice.Index(i) - if err := decodeValue(v, &item); err != nil { + if err := d.decodeValue(v, &item); err != nil { return fmt.Errorf("failed decoding array item at index %d: %w", i, err) } } @@ -168,7 +145,7 @@ func decodeValue(value *respValue, field *reflect.Value) error { return fmt.Errorf("cannot find field '%s' on struct", fName) } - if err := decodeValue(&m.Value, &f); err != nil { + if err := d.decodeValue(&m.Value, &f); err != nil { return fmt.Errorf("failed decoding struct member '%s': %w", m.Name, err) } } @@ -188,7 +165,7 @@ func decodeValue(value *respValue, field *reflect.Value) error { return nil } -func decodeBoolean(value string) (bool, error) { +func (d *StdDecoder) decodeBoolean(value string) (bool, error) { switch value { case "1", "true", "TRUE", "True": @@ -199,12 +176,12 @@ func decodeBoolean(value string) (bool, error) { return false, fmt.Errorf("unrecognized value '%s' for boolean", value) } -func decodeBase64(value string) ([]byte, error) { +func (d *StdDecoder) decodeBase64(value string) ([]byte, error) { return base64.StdEncoding.DecodeString(value) } -func decodeDateTime(value string) (time.Time, error) { +func (d *StdDecoder) decodeDateTime(value string) (time.Time, error) { return time.Parse(time.RFC3339, value) } diff --git a/decode_response.go b/decode_response.go new file mode 100644 index 0000000..daff1de --- /dev/null +++ b/decode_response.go @@ -0,0 +1,47 @@ +package xmlrpc + +import "encoding/xml" + +// Response is the basic parsed object of the XML-RPC response body. +// While it's not convenient to use this object directly - it contains all the information needed to unmarshal into other data-types. +type Response struct { + Params []ResponseParam `xml:"params>param"` + Fault *ResponseFault `xml:"fault,omitempty"` +} + +// NewResponse creates a Response object from XML body. +// It relies on XML Unmarshaler and if it fails - error is returned. +func NewResponse(body []byte) (*Response, error) { + + response := &Response{} + if err := xml.Unmarshal(body, response); err != nil { + return nil, err + } + + return response, nil +} + +type ResponseParam struct { + Value ResponseValue `xml:"value"` +} + +type ResponseValue struct { + Array []*ResponseValue `xml:"array>data>value"` + Struct []*ResponseStructMember `xml:"struct>member"` + String string `xml:"string"` + Int string `xml:"int"` + Int4 string `xml:"i4"` + Double string `xml:"double"` + Boolean string `xml:"boolean"` + DateTime string `xml:"dateTime.iso8601"` + Base64 string `xml:"base64"` +} + +type ResponseStructMember struct { + Name string `xml:"name"` + Value ResponseValue `xml:"value"` +} + +type ResponseFault struct { + Value ResponseValue `xml:"value"` +} diff --git a/decode_test.go b/decode_test.go index db34281..5a8e29a 100644 --- a/decode_test.go +++ b/decode_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDecodeResponse(t *testing.T) { +func TestStdDecoder_DecodeRaw(t *testing.T) { tests := []struct { name string testFile string @@ -116,7 +116,8 @@ func TestDecodeResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := DecodeResponse(loadTestFile(t, tt.testFile), tt.v) + dec := &StdDecoder{} + err := dec.DecodeRaw(loadTestFile(t, tt.testFile), tt.v) assert.Equal(t, tt.err, err) if tt.err == nil { assert.EqualValues(t, tt.expect, tt.v) @@ -125,11 +126,12 @@ func TestDecodeResponse(t *testing.T) { } } -func TestDecodeResponse_Fault(t *testing.T) { +func TestStdDecoder_DecodeRaw_Fault(t *testing.T) { decodeTarget := &struct { Ints []int }{} - err := DecodeResponse(loadTestFile(t, "response_fault.xml"), decodeTarget) + dec := &StdDecoder{} + err := dec.DecodeRaw(loadTestFile(t, "response_fault.xml"), decodeTarget) assert.Error(t, err) fT := &Fault{} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..2208ea5 --- /dev/null +++ b/doc.go @@ -0,0 +1,19 @@ +/* +This package includes everything that is required to perform XML-RPC requests by utilizing familiar rpc.Client interface. + +The simplest use-case is creating a client towards an endpoint and making calls: + + + c, _ := NewClient("https://bugzilla.mozilla.org/xmlrpc.cgi") + + resp := &struct { + BugzillaVersion struct { + Version string + } + }{} + + err = c.Call("Bugzilla.version", nil, resp) + fmt.Printf("Version: %s\n", resp.BugzillaVersion.Version) + +*/ +package xmlrpc diff --git a/encode.go b/encode.go index 696f1d6..eb5b6b6 100644 --- a/encode.go +++ b/encode.go @@ -9,12 +9,19 @@ import ( "time" ) -func EncodeMethodCall(w io.Writer, methodName string, args interface{}) error { +// Encoder implementations are responsible for handling encoding of XML-RPC requests to the proper wire format. +type Encoder interface { + Encode(w io.Writer, methodName string, args interface{}) error +} + +// StdEncoder is the default implementation of Encoder interface. +type StdEncoder struct{} +func (e *StdEncoder) Encode(w io.Writer, methodName string, args interface{}) error { _, _ = fmt.Fprintf(w, "%s", methodName) if args != nil { - if err := encodeArgs(w, args); err != nil { + if err := e.encodeArgs(w, args); err != nil { return fmt.Errorf("cannot encoded provided method arguments: %w", err) } } @@ -24,7 +31,7 @@ func EncodeMethodCall(w io.Writer, methodName string, args interface{}) error { return nil } -func encodeArgs(w io.Writer, args interface{}) error { +func (e *StdEncoder) encodeArgs(w io.Writer, args interface{}) error { // Allows reading both pointer and value-structs elem := reflect.Indirect(reflect.ValueOf(args)) @@ -46,7 +53,7 @@ func encodeArgs(w io.Writer, args interface{}) error { } _, _ = fmt.Fprint(w, "") - if err := encodeValue(w, field.Interface()); err != nil { + if err := e.encodeValue(w, field.Interface()); err != nil { return fmt.Errorf("cannot encode argument '%s': %w", elem.Type().Field(fN).Name, err) } _, _ = fmt.Fprint(w, "") @@ -66,7 +73,7 @@ func encodeArgs(w io.Writer, args interface{}) error { // In that case a value is returned. // // See more: https://en.wikipedia.org/wiki/XML-RPC#Data_types -func encodeValue(w io.Writer, value interface{}) error { +func (e *StdEncoder) encodeValue(w io.Writer, value interface{}) error { valueOf := reflect.ValueOf(value) kind := valueOf.Kind() @@ -77,51 +84,51 @@ func encodeValue(w io.Writer, value interface{}) error { _, _ = fmt.Fprint(w, "") return nil } - return encodeValue(w, valueOf.Elem().Interface()) + return e.encodeValue(w, valueOf.Elem().Interface()) } _, _ = fmt.Fprint(w, "") switch kind { case reflect.Bool: - if err := encodeBoolean(w, value.(bool)); err != nil { + if err := e.encodeBoolean(w, value.(bool)); err != nil { return fmt.Errorf("cannot encode boolean value: %w", err) } case reflect.Int: - if err := encodeInteger(w, value.(int)); err != nil { + if err := e.encodeInteger(w, value.(int)); err != nil { return fmt.Errorf("cannot encode integer value: %w", err) } case reflect.Float64: - if err := encodeDouble(w, value.(float64)); err != nil { + if err := e.encodeDouble(w, value.(float64)); err != nil { return fmt.Errorf("cannot encode double value: %w", err) } case reflect.String: - if err := encodeString(w, value.(string)); err != nil { + if err := e.encodeString(w, value.(string)); err != nil { return fmt.Errorf("cannot encode string value: %w", err) } case reflect.Array, reflect.Slice: - if isByteArray(value) { - if err := encodeBase64(w, value.([]byte)); err != nil { + if e.isByteArray(value) { + if err := e.encodeBase64(w, value.([]byte)); err != nil { return fmt.Errorf("cannot encode byte-array value: %w", err) } } else { - if err := encodeArray(w, value); err != nil { + if err := e.encodeArray(w, value); err != nil { return fmt.Errorf("cannot encode array value: %w", err) } } case reflect.Struct: if reflect.TypeOf(value).String() != "time.Time" { - if err := encodeStruct(w, value); err != nil { + if err := e.encodeStruct(w, value); err != nil { return fmt.Errorf("cannot encode struct value: %w", err) } } else { - if err := encodeTime(w, value.(time.Time)); err != nil { + if err := e.encodeTime(w, value.(time.Time)); err != nil { return fmt.Errorf("cannot encode time.Time value: %w", err) } } @@ -131,25 +138,25 @@ func encodeValue(w io.Writer, value interface{}) error { return nil } -func isByteArray(val interface{}) bool { +func (e *StdEncoder) isByteArray(val interface{}) bool { _, ok := val.([]byte) return ok } -func encodeInteger(w io.Writer, val int) error { +func (e *StdEncoder) encodeInteger(w io.Writer, val int) error { _, err := fmt.Fprintf(w, "%d", val) return err } -func encodeDouble(w io.Writer, val float64) error { +func (e *StdEncoder) encodeDouble(w io.Writer, val float64) error { _, err := fmt.Fprintf(w, "%f", val) return err } -func encodeBoolean(w io.Writer, val bool) error { +func (e *StdEncoder) encodeBoolean(w io.Writer, val bool) error { v := 0 if val { @@ -160,7 +167,7 @@ func encodeBoolean(w io.Writer, val bool) error { return err } -func encodeString(w io.Writer, val string) error { +func (e *StdEncoder) encodeString(w io.Writer, val string) error { _, _ = fmt.Fprint(w, "") if err := xml.EscapeText(w, []byte(val)); err != nil { @@ -171,11 +178,11 @@ func encodeString(w io.Writer, val string) error { return nil } -func encodeArray(w io.Writer, val interface{}) error { +func (e *StdEncoder) encodeArray(w io.Writer, val interface{}) error { _, _ = fmt.Fprint(w, "") for i := 0; i < reflect.ValueOf(val).Len(); i++ { - if err := encodeValue(w, reflect.ValueOf(val).Index(i).Interface()); err != nil { + if err := e.encodeValue(w, reflect.ValueOf(val).Index(i).Interface()); err != nil { return fmt.Errorf("cannot encode array element at index %d: %w", i, err) } } @@ -185,7 +192,7 @@ func encodeArray(w io.Writer, val interface{}) error { return nil } -func encodeStruct(w io.Writer, val interface{}) error { +func (e *StdEncoder) encodeStruct(w io.Writer, val interface{}) error { _, _ = fmt.Fprint(w, "") for i := 0; i < reflect.TypeOf(val).NumField(); i++ { @@ -203,7 +210,7 @@ func encodeStruct(w io.Writer, val interface{}) error { } _, _ = fmt.Fprintf(w, "%s", fieldName) - if err := encodeValue(w, field.Interface()); err != nil { + if err := e.encodeValue(w, field.Interface()); err != nil { return fmt.Errorf("cannot encode value of struct field '%s': %w", fieldName, err) } _, _ = fmt.Fprint(w, "") @@ -213,13 +220,13 @@ func encodeStruct(w io.Writer, val interface{}) error { return nil } -func encodeBase64(w io.Writer, val []byte) error { +func (e *StdEncoder) encodeBase64(w io.Writer, val []byte) error { _, err := fmt.Fprintf(w, "%s", base64.StdEncoding.EncodeToString(val)) return err } -func encodeTime(w io.Writer, val time.Time) error { +func (e *StdEncoder) encodeTime(w io.Writer, val time.Time) error { _, err := fmt.Fprintf(w, "%s", val.Format(time.RFC3339)) return err diff --git a/encode_test.go b/encode_test.go index ed14460..f82756a 100644 --- a/encode_test.go +++ b/encode_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestEncodeMethodCall(t *testing.T) { +func TestStdEncoder_Encode(t *testing.T) { tests := []struct { name string methodName string @@ -124,15 +124,15 @@ func TestEncodeMethodCall(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := new(strings.Builder) - - err := EncodeMethodCall(buf, tt.methodName, tt.args) + enc := &StdEncoder{} + err := enc.Encode(buf, tt.methodName, tt.args) assert.Equal(t, tt.expect, buf.String()) assert.Equal(t, tt.err, err) }) } } -func Test_isByteArray(t *testing.T) { +func TestStdEncoder_isByteArray(t *testing.T) { tests := []struct { name string input interface{} @@ -167,7 +167,8 @@ func Test_isByteArray(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := isByteArray(tt.input) + enc := &StdEncoder{} + resp := enc.isByteArray(tt.input) assert.Equal(t, tt.expect, resp) }) } @@ -205,7 +206,8 @@ func Test_encodeArray(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := new(strings.Builder) - err := encodeArray(buf, tt.input) + enc := &StdEncoder{} + err := enc.encodeArray(buf, tt.input) assert.Equal(t, tt.err, err) assert.Equal(t, tt.expect, buf.String()) }) @@ -240,7 +242,8 @@ func Test_encodeBase64(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := new(strings.Builder) - err := encodeBase64(buf, tt.input) + enc := &StdEncoder{} + err := enc.encodeBase64(buf, tt.input) assert.Equal(t, tt.err, err) assert.Equal(t, tt.expect, buf.String()) }) @@ -311,7 +314,8 @@ func Test_encodeStruct(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := new(strings.Builder) - err := encodeStruct(buf, tt.input) + enc := &StdEncoder{} + err := enc.encodeStruct(buf, tt.input) assert.Equal(t, tt.err, err) assert.Equal(t, tt.expect, buf.String()) }) @@ -361,7 +365,8 @@ func Test_encodeTime(t *testing.T) { t.Run(tt.name, func(t *testing.T) { buf := new(strings.Builder) - err := encodeTime(buf, tt.input) + enc := &StdEncoder{} + err := enc.encodeTime(buf, tt.input) assert.Equal(t, tt.err, err) assert.Equal(t, tt.expect, buf.String()) }) diff --git a/fault.go b/fault.go index 9edb5fa..c7ecfd1 100644 --- a/fault.go +++ b/fault.go @@ -2,8 +2,11 @@ package xmlrpc import "fmt" +// Fault is a wrapper for XML-RPC fault object type Fault struct { - Code int + // Code provides numerical failure code + Code int + // String includes more detailed information about the fault, such as error name and cause String string }