diff --git a/core/client.go b/core/client.go index e3bc54dd5..93b32cc6d 100644 --- a/core/client.go +++ b/core/client.go @@ -157,6 +157,7 @@ func (c *Client) connect(ctx context.Context, addr string) (frame.Conn, error) { ObserveDataTags: c.opts.observeDataTags, AuthName: c.opts.credential.Name(), AuthPayload: c.opts.credential.Payload(), + Version: Version, } if err := conn.WriteFrame(hf); err != nil { diff --git a/core/context.go b/core/context.go index 37666f922..57425e02d 100644 --- a/core/context.go +++ b/core/context.go @@ -119,7 +119,7 @@ func newContext(conn *Connection, route router.Route, df *frame.DataFrame) (c *C return } -// CloseWithError close dataStream with an error string. +// CloseWithError close connection with an error string. func (c *Context) CloseWithError(errString string) { c.Logger.Debug("connection closed", "err", errString) diff --git a/core/frame/frame.go b/core/frame/frame.go index ef51f2700..766f77cca 100644 --- a/core/frame/frame.go +++ b/core/frame/frame.go @@ -45,18 +45,23 @@ func (f *DataFrame) Type() Type { return TypeDataFrame } // It includes essential details required for the creation of a fresh connection. // The server then generates the connection utilizing this provided information. type HandshakeFrame struct { - // Name is the name of the dataStream that will be created. + // Name is the name of the connection that will be created. Name string - // ID is the ID of the dataStream that will be created. + // ID is the ID of the connection that will be created. ID string // ClientType is the type of client. ClientType byte - // ObserveDataTags is the ObserveDataTags of the dataStream that will be created. + // ObserveDataTags is the ObserveDataTags of the connection that will be created. ObserveDataTags []Tag // AuthName is the authentication name. AuthName string // AuthPayload is the authentication payload. AuthPayload string + // Version is used by the source/sfn to communicate their version to the server. + // The Version format must follow the `Major.Minor.Patch`. otherwise, the handshake + // will fail. The client‘s MAJOR and MINOR versions should equal to server's, + // otherwise, the zipper will be considered has break-change, the handshake will fail. + Version string } // Type returns the type of HandshakeFrame. diff --git a/core/server.go b/core/server.go index fe8b6e190..f42cb275c 100644 --- a/core/server.go +++ b/core/server.go @@ -21,6 +21,7 @@ import ( "github.com/yomorun/yomo/pkg/frame-codec/y3codec" yquic "github.com/yomorun/yomo/pkg/listener/quic" pkgtls "github.com/yomorun/yomo/pkg/tls" + "github.com/yomorun/yomo/pkg/version" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -152,7 +153,22 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { } } -func (s *Server) handshake(fconn frame.Conn) (bool, router.Route, *Connection) { +func (s *Server) handleFrameConn(fconn frame.Conn, logger *slog.Logger) { + route, conn, err := s.handshake(fconn) + if err != nil { + logger.Error("handshake failed", "err", err) + return + } + + s.connHandler(conn, route) // s.handleConnRoute(conn, route) with middlewares + + if conn.ClientType() == ClientTypeStreamFunction { + _ = route.Remove(conn.ID()) + } + _ = s.connector.Remove(conn.ID()) +} + +func (s *Server) handshake(fconn frame.Conn) (router.Route, *Connection, error) { var gerr error defer func() { @@ -166,7 +182,7 @@ func (s *Server) handshake(fconn frame.Conn) (bool, router.Route, *Connection) { first, err := fconn.ReadFrame() if err != nil { gerr = err - return false, nil, nil + return nil, nil, gerr } switch first.Type() { case frame.TypeHandshakeFrame: @@ -175,17 +191,17 @@ func (s *Server) handshake(fconn frame.Conn) (bool, router.Route, *Connection) { conn, err := s.handleHandshakeFrame(fconn, hf) if err != nil { gerr = err - return false, nil, conn + return nil, conn, gerr } route, err := s.addSfnToRoute(hf, conn.Metadata()) if err != nil { gerr = err } - return true, route, conn + return route, conn, gerr default: gerr = fmt.Errorf("yomo: handshake read unexpected frame, read: %s", first.Type().String()) - return false, nil, nil + return nil, nil, gerr } } @@ -216,22 +232,8 @@ func (s *Server) handleConnRoute(conn *Connection, route router.Route) { } } -func (s *Server) handleFrameConn(fconn frame.Conn, logger *slog.Logger) { - ok, route, conn := s.handshake(fconn) - if !ok { - logger.Error("handshake failed") - return - } - - s.connHandler(conn, route) // s.handleConnRoute(conn, route) with middlewares - - if conn.ClientType() == ClientTypeStreamFunction { - _ = route.Remove(conn.ID()) - } - _ = s.connector.Remove(conn.ID()) -} - func (s *Server) handleHandshakeFrame(fconn frame.Conn, hf *frame.HandshakeFrame) (*Connection, error) { + // 1. authentication md, ok := auth.Authenticate(s.opts.auths, hf) if !ok { @@ -243,6 +245,11 @@ func (s *Server) handleHandshakeFrame(fconn frame.Conn, hf *frame.HandshakeFrame return nil, fmt.Errorf("authentication failed: client credential type is %s", hf.AuthName) } + // 2. version negotiation + if err := negotiateVersion(hf.Version, Version); err != nil { + return nil, err + } + conn := newConnection(hf.Name, hf.ID, ClientType(hf.ClientType), md, hf.ObserveDataTags, fconn, s.logger) return conn, s.connector.Store(hf.ID, conn) @@ -263,6 +270,25 @@ func (s *Server) addSfnToRoute(hf *frame.HandshakeFrame, md metadata.M) (router. return route, nil } +func negotiateVersion(cVersion, sVersion string) error { + cv, err := version.Parse(cVersion) + if err != nil { + return err + } + + sv, err := version.Parse(sVersion) + if err != nil { + return err + } + + // If the Major and Minor versions are equal, the server can serve the client. + if cv.Major == sv.Major && cv.Minor == sv.Minor { + return nil + } + + return fmt.Errorf("yomo: version negotiation failed, client=%s, server=%s", cVersion, sVersion) +} + func (s *Server) handleFrame(c *Context) { // routing data frame. if err := s.routingDataFrame(c); err != nil { diff --git a/core/server_test.go b/core/server_test.go index 85bd471e2..83027e671 100644 --- a/core/server_test.go +++ b/core/server_test.go @@ -1,6 +1,7 @@ package core import ( + "errors" "testing" "github.com/stretchr/testify/assert" @@ -44,3 +45,46 @@ func (s *mockConnectionInfo) Name() string { return s.name } func (s *mockConnectionInfo) Metadata() metadata.M { return s.metadata } func (s *mockConnectionInfo) ClientType() ClientType { return s.clientType } func (s *mockConnectionInfo) ObserveDataTags() []frame.Tag { return s.observed } + +func Test_negotiateVersion(t *testing.T) { + type args struct { + cVersion string + sVersion string + } + tests := []struct { + name string + args args + wantErr error + }{ + { + name: "ok", + args: args{ + cVersion: "1.16.3", + sVersion: "1.16.3", + }, + wantErr: nil, + }, + { + name: "client empty version", + args: args{ + cVersion: "", + sVersion: "1.16.3", + }, + wantErr: errors.New("invalid semantic version, params="), + }, + { + name: "not ok", + args: args{ + cVersion: "1.15.0", + sVersion: "1.16.3", + }, + wantErr: errors.New("yomo: version negotiation failed, client=1.15.0, server=1.16.3"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := negotiateVersion(tt.args.cVersion, tt.args.sVersion) + assert.Equal(t, tt.wantErr, err) + }) + } +} diff --git a/core/version.go b/core/version.go new file mode 100644 index 000000000..ee0e2dfca --- /dev/null +++ b/core/version.go @@ -0,0 +1,4 @@ +package core + +// Version is the current yomo version. +const Version = "1.17.0" diff --git a/pkg/frame-codec/y3codec/codec_test.go b/pkg/frame-codec/y3codec/codec_test.go index e5aa4febc..290df2d0d 100644 --- a/pkg/frame-codec/y3codec/codec_test.go +++ b/pkg/frame-codec/y3codec/codec_test.go @@ -76,13 +76,13 @@ func TestCodec(t *testing.T) { ObserveDataTags: []uint32{'a', 'b', 'c'}, AuthName: "ddddd", AuthPayload: "eeeee", + Version: "1.16.3", }, - data: []byte{0xb1, 0x31, 0x1, 0x8, 0x74, 0x68, 0x65, 0x2d, 0x6e, 0x61, + data: []byte{0xb1, 0x39, 0x1, 0x8, 0x74, 0x68, 0x65, 0x2d, 0x6e, 0x61, 0x6d, 0x65, 0x3, 0x6, 0x74, 0x68, 0x65, 0x2d, 0x69, 0x64, 0x2, 0x1, 0x68, 0x6, 0xc, 0x61, 0x0, 0x0, 0x0, 0x62, 0x0, 0x0, 0x0, 0x63, 0x0, - 0x0, 0x0, 0x4, 0x5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x5, 0x5, 0x65, - 0x65, 0x65, 0x65, 0x65, - }, + 0x0, 0x0, 0x4, 0x5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x5, 0x5, 0x65, 0x65, + 0x65, 0x65, 0x65, 0x7, 0x6, 0x31, 0x2e, 0x31, 0x36, 0x2e, 0x33}, }, }, { diff --git a/pkg/frame-codec/y3codec/handshake_frame.go b/pkg/frame-codec/y3codec/handshake_frame.go index 2527d0767..2cf9e7e71 100644 --- a/pkg/frame-codec/y3codec/handshake_frame.go +++ b/pkg/frame-codec/y3codec/handshake_frame.go @@ -31,6 +31,9 @@ func encodeHandshakeFrame(f *frame.HandshakeFrame) ([]byte, error) { // auth payload authPayloadBlock := y3.NewPrimitivePacketEncoder(tagAuthenticationPayload) authPayloadBlock.SetStringValue(f.AuthPayload) + // version + versionBlock := y3.NewPrimitivePacketEncoder(tagHandshakeVersion) + versionBlock.SetStringValue(f.Version) // handshake frame handshake := y3.NewNodePacketEncoder(byte(f.Type())) @@ -40,6 +43,7 @@ func encodeHandshakeFrame(f *frame.HandshakeFrame) ([]byte, error) { handshake.AddPrimitivePacket(observeDataTagsBlock) handshake.AddPrimitivePacket(authNameBlock) handshake.AddPrimitivePacket(authPayloadBlock) + handshake.AddPrimitivePacket(versionBlock) return handshake.Encode(), nil } @@ -98,15 +102,24 @@ func decodeHandshakeFrame(data []byte, f *frame.HandshakeFrame) error { } f.AuthPayload = authPayload } + // version + if versionBlock, ok := node.PrimitivePackets[tagHandshakeVersion]; ok { + version, err := versionBlock.ToUTF8String() + if err != nil { + return err + } + f.Version = version + } return nil } -var ( +const ( tagHandshakeName byte = 0x01 tagHandshakeClientType byte = 0x02 tagHandshakeID byte = 0x03 tagAuthenticationName byte = 0x04 tagAuthenticationPayload byte = 0x05 tagHandshakeObserveDataTags byte = 0x06 + tagHandshakeVersion byte = 0x07 ) diff --git a/pkg/version/version.go b/pkg/version/version.go new file mode 100644 index 000000000..64e99edf8 --- /dev/null +++ b/pkg/version/version.go @@ -0,0 +1,44 @@ +// Package version provides functionality for parsing versions.. +package version + +import ( + "fmt" + "strconv" + "strings" +) + +// Version is used by the source/sfn to communicate their version to the server. +type Version struct { + Major int + Minor int + Patch int +} + +// Parse parses a string into a Version. The string format must follow the `Major.Minor.Patch` +// formatting, and the Major, Minor, and Patch components must be numeric. If they are not, +// a parse error will be returned. +func Parse(str string) (*Version, error) { + vs := strings.Split(str, ".") + if len(vs) != 3 { + return nil, fmt.Errorf("invalid semantic version, params=%s", str) + } + + major, err := strconv.Atoi(vs[0]) + if err != nil { + return nil, fmt.Errorf("invalid version major, params=%s", str) + } + + minor, err := strconv.Atoi(vs[1]) + if err != nil { + return nil, fmt.Errorf("invalid version minor, params=%s", str) + } + + patch, err := strconv.Atoi(vs[2]) + if err != nil { + return nil, fmt.Errorf("invalid version patch, params=%s", str) + } + + ver := &Version{Major: major, Minor: minor, Patch: patch} + + return ver, nil +} diff --git a/pkg/version/version_test.go b/pkg/version/version_test.go new file mode 100644 index 000000000..b9e58e00b --- /dev/null +++ b/pkg/version/version_test.go @@ -0,0 +1,67 @@ +package version + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + type args struct { + str string + } + tests := []struct { + name string + args args + want *Version + wantErr error + }{ + { + name: "ok", + args: args{ + str: "1.16.3", + }, + want: &Version{Major: 1, Minor: 16, Patch: 3}, + }, + { + name: "invalid semantic version", + args: args{ + str: "1.16.3-beta.1", + }, + want: nil, + wantErr: errors.New("invalid semantic version, params=1.16.3-beta.1"), + }, + { + name: "invalid version major", + args: args{ + str: "xx.16.3", + }, + want: nil, + wantErr: errors.New("invalid version major, params=xx.16.3"), + }, + { + name: "invalid version minor", + args: args{ + str: "1.yy.3", + }, + want: nil, + wantErr: errors.New("invalid version minor, params=1.yy.3"), + }, + { + name: "invalid version patch", + args: args{ + str: "1.16.3-beta", + }, + want: nil, + wantErr: errors.New("invalid version patch, params=1.16.3-beta"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := Parse(tt.args.str) + assert.Equal(t, tt.wantErr, gotErr) + assert.Equal(t, tt.want, got) + }) + } +}