Skip to content

Commit

Permalink
feat: version negotiation
Browse files Browse the repository at this point in the history
  • Loading branch information
woorui committed Oct 9, 2023
1 parent 1c34dab commit 7ff1b95
Show file tree
Hide file tree
Showing 10 changed files with 221 additions and 17 deletions.
1 change: 1 addition & 0 deletions core/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func (c *Client) connect(ctx context.Context, addr string) *connectResult {
ObserveDataTags: c.opts.observeDataTags,
AuthName: c.opts.credential.Name(),
AuthPayload: c.opts.credential.Payload(),
Version: Version,
}

if err := fs.WriteFrame(hf); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (c *Context) WithFrame(f frame.Frame) error {
return nil
}

// CloseWithError close dataStream with an error string.
// CloseWithError close connection with an error string.
func (c *Context) CloseWithError(errString string) {
c.Logger.Debug("connection closed", "err", errString)

Expand Down
11 changes: 8 additions & 3 deletions core/frame/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,23 @@ func (f *DataFrame) Type() Type { return TypeDataFrame }
// It includes essential details required for the creation of a fresh connection.
// The server then generates the connection utilizing this provided information.
type HandshakeFrame struct {
// Name is the name of the dataStream that will be created.
// Name is the name of the connection that will be created.
Name string
// ID is the ID of the dataStream that will be created.
// ID is the ID of the connection that will be created.
ID string
// ClientType is the type of client.
ClientType byte
// ObserveDataTags is the ObserveDataTags of the dataStream that will be created.
// ObserveDataTags is the ObserveDataTags of the connection that will be created.
ObserveDataTags []Tag
// AuthName is the authentication name.
AuthName string
// AuthPayload is the authentication payload.
AuthPayload string
// Version is used by the source/sfn to communicate their version to the server.
// The version format should follow `https://semver.org`. otherwise, the handshake
// will fail. The client‘s MAJOR and MINOR versions should equal to server's,
// otherwise, the zipper will be considered has break-change, the handshake will fail.
Version string
}

// Type returns the type of HandshakeFrame.
Expand Down
42 changes: 34 additions & 8 deletions core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/yomorun/yomo/pkg/id"
pkgtls "github.com/yomorun/yomo/pkg/tls"
"github.com/yomorun/yomo/pkg/trace"
"github.com/yomorun/yomo/pkg/version"
oteltrace "go.opentelemetry.io/otel/trace"
)

Expand Down Expand Up @@ -103,7 +104,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) error {
return s.Serve(ctx, conn)
}

func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router.Route, Connection) {
func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (error, router.Route, Connection) {
var gerr error

defer func() {
Expand All @@ -117,7 +118,7 @@ func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router
first, err := fs.ReadFrame()
if err != nil {
gerr = err
return false, nil, nil
return gerr, nil, nil
}

switch first.Type() {
Expand All @@ -127,24 +128,24 @@ func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router
conn, err := s.handleHandshakeFrame(qconn, fs, hf)
if err != nil {
gerr = err
return false, nil, conn
return gerr, nil, conn
}

route, err := s.handleRoute(hf, conn.Metadata())
if err != nil {
gerr = err
}
return true, route, conn
return gerr, route, conn
default:
gerr = fmt.Errorf("yomo: handshake read unexpected frame, read: %s", first.Type().String())
return false, nil, nil
return gerr, nil, nil
}
}

func (s *Server) handleConnection(qconn quic.Connection, fs *FrameStream, logger *slog.Logger) {
ok, route, conn := s.handshake(qconn, fs)
if !ok {
logger.Error("handshake failed")
err, route, conn := s.handshake(qconn, fs)
if err != nil {
logger.Error("handshake failed", "err", err)
return
}

Expand Down Expand Up @@ -227,18 +228,43 @@ func (s *Server) handleRoute(hf *frame.HandshakeFrame, md metadata.M) (router.Ro
}

func (s *Server) handleHandshakeFrame(qconn quic.Connection, fs *FrameStream, hf *frame.HandshakeFrame) (Connection, error) {
// 1. authentication
md, ok := auth.Authenticate(s.opts.auths, hf)

if !ok {
s.logger.Warn("authentication failed", "credential", hf.AuthName)
return nil, fmt.Errorf("authentication failed: client credential name is %s", hf.AuthName)
}

// 2. version negotiation
if err := negotiateVersion(hf.Version, Version); err != nil {
return nil, err
}

conn := newConnection(hf.Name, hf.ID, ClientType(hf.ClientType), md, hf.ObserveDataTags, qconn, fs)

return conn, s.connector.Store(hf.ID, conn)
}

func negotiateVersion(cVersion, sVersion string) error {
cv, err := version.Parse(cVersion)
if err != nil {
return err
}

sv, err := version.Parse(sVersion)
if err != nil {
return err
}

// If the Major and Minor versions are equal, the server can serve the client.
if cv.Major == sv.Major && cv.Minor == sv.Minor {
return nil
}

return fmt.Errorf("yomo: version negotiation failed, client=%s, server=%s", cVersion, sVersion)
}

// Serve the server with a net.PacketConn.
func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error {
if err := s.validateRouter(); err != nil {
Expand Down
44 changes: 44 additions & 0 deletions core/server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -44,3 +45,46 @@ func (s *mockConnectionInfo) Name() string { return s.name }
func (s *mockConnectionInfo) Metadata() metadata.M { return s.metadata }
func (s *mockConnectionInfo) ClientType() ClientType { return s.clientType }
func (s *mockConnectionInfo) ObserveDataTags() []frame.Tag { return s.observed }

func Test_negotiateVersion(t *testing.T) {
type args struct {
cVersion string
sVersion string
}
tests := []struct {
name string
args args
wantErr error
}{
{
name: "ok",
args: args{
cVersion: "1.16.3",
sVersion: "1.16.3",
},
wantErr: nil,
},
{
name: "client empty version",
args: args{
cVersion: "",
sVersion: "1.16.3",
},
wantErr: errors.New("invalid semantic version, params="),
},
{
name: "not ok",
args: args{
cVersion: "1.15.0",
sVersion: "1.16.3",
},
wantErr: errors.New("yomo: version negotiation failed, client=1.15.0, server=1.16.3"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := negotiateVersion(tt.args.cVersion, tt.args.sVersion)
assert.Equal(t, tt.wantErr, err)
})
}
}
4 changes: 4 additions & 0 deletions core/version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package core

// Version is the current yomo version.
const Version = "1.16.3"
8 changes: 4 additions & 4 deletions pkg/frame-codec/y3codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@ func TestCodec(t *testing.T) {
ObserveDataTags: []uint32{'a', 'b', 'c'},
AuthName: "ddddd",
AuthPayload: "eeeee",
Version: "1.16.3",
},
data: []byte{0xb1, 0x31, 0x1, 0x8, 0x74, 0x68, 0x65, 0x2d, 0x6e, 0x61,
data: []byte{0xb1, 0x39, 0x1, 0x8, 0x74, 0x68, 0x65, 0x2d, 0x6e, 0x61,
0x6d, 0x65, 0x3, 0x6, 0x74, 0x68, 0x65, 0x2d, 0x69, 0x64, 0x2, 0x1,
0x68, 0x6, 0xc, 0x61, 0x0, 0x0, 0x0, 0x62, 0x0, 0x0, 0x0, 0x63, 0x0,
0x0, 0x0, 0x4, 0x5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x5, 0x5, 0x65,
0x65, 0x65, 0x65, 0x65,
},
0x0, 0x0, 0x4, 0x5, 0x64, 0x64, 0x64, 0x64, 0x64, 0x5, 0x5, 0x65, 0x65,
0x65, 0x65, 0x65, 0x7, 0x6, 0x31, 0x2e, 0x31, 0x36, 0x2e, 0x33},
},
},
{
Expand Down
15 changes: 14 additions & 1 deletion pkg/frame-codec/y3codec/handshake_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ func encodeHandshakeFrame(f *frame.HandshakeFrame) ([]byte, error) {
// auth payload
authPayloadBlock := y3.NewPrimitivePacketEncoder(tagAuthenticationPayload)
authPayloadBlock.SetStringValue(f.AuthPayload)
// version
versionBlock := y3.NewPrimitivePacketEncoder(tagHandshakeVersion)
versionBlock.SetStringValue(f.Version)

// handshake frame
handshake := y3.NewNodePacketEncoder(byte(f.Type()))
Expand All @@ -40,6 +43,7 @@ func encodeHandshakeFrame(f *frame.HandshakeFrame) ([]byte, error) {
handshake.AddPrimitivePacket(observeDataTagsBlock)
handshake.AddPrimitivePacket(authNameBlock)
handshake.AddPrimitivePacket(authPayloadBlock)
handshake.AddPrimitivePacket(versionBlock)

return handshake.Encode(), nil
}
Expand Down Expand Up @@ -98,15 +102,24 @@ func decodeHandshakeFrame(data []byte, f *frame.HandshakeFrame) error {
}
f.AuthPayload = authPayload
}
// version
if versionBlock, ok := node.PrimitivePackets[tagHandshakeVersion]; ok {
version, err := versionBlock.ToUTF8String()
if err != nil {
return err
}
f.Version = version
}

return nil
}

var (
const (
tagHandshakeName byte = 0x01
tagHandshakeClientType byte = 0x02
tagHandshakeID byte = 0x03
tagAuthenticationName byte = 0x04
tagAuthenticationPayload byte = 0x05
tagHandshakeObserveDataTags byte = 0x06
tagHandshakeVersion byte = 0x07
)
44 changes: 44 additions & 0 deletions pkg/version/version.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Package version provides functionality for parsing versions..
package version

import (
"fmt"
"strconv"
"strings"
)

// Version is used by the source/sfn to communicate their version to the server.
type Version struct {
Major int
Minor int
Patch int
}

// Parse parses a string into a Version.The string format must follow the `Major.Minor.Patch`
// formatting, and the Major, Minor, and Patch components must be numeric. If they are not,
// a parse error will be returned.
func Parse(str string) (*Version, error) {
vs := strings.Split(str, ".")
if len(vs) != 3 {
return nil, fmt.Errorf("invalid semantic version, params=%s", str)
}

major, err := strconv.Atoi(vs[0])
if err != nil {
return nil, fmt.Errorf("invalid version major, params=%s", str)
}

minor, err := strconv.Atoi(vs[1])
if err != nil {
return nil, fmt.Errorf("invalid version minor, params=%s", str)
}

patch, err := strconv.Atoi(vs[2])
if err != nil {
return nil, fmt.Errorf("invalid version patch, params=%s", str)
}

ver := &Version{Major: major, Minor: minor, Patch: patch}

return ver, nil
}
67 changes: 67 additions & 0 deletions pkg/version/version_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package version

import (
"errors"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParse(t *testing.T) {
type args struct {
str string
}
tests := []struct {
name string
args args
want *Version
wantErr error
}{
{
name: "ok",
args: args{
str: "1.16.3",
},
want: &Version{Major: 1, Minor: 16, Patch: 3},
},
{
name: "invalid semantic version",
args: args{
str: "1.16.3-beta.1",
},
want: nil,
wantErr: errors.New("invalid semantic version, params=1.16.3-beta.1"),
},
{
name: "invalid version major",
args: args{
str: "xx.16.3",
},
want: nil,
wantErr: errors.New("invalid version major, params=xx.16.3"),
},
{
name: "invalid version minor",
args: args{
str: "1.yy.3",
},
want: nil,
wantErr: errors.New("invalid version minor, params=1.yy.3"),
},
{
name: "invalid version patch",
args: args{
str: "1.16.3-beta",
},
want: nil,
wantErr: errors.New("invalid version patch, params=1.16.3-beta"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, gotErr := Parse(tt.args.str)
assert.Equal(t, tt.wantErr, gotErr)
assert.Equal(t, tt.want, got)
})
}
}

0 comments on commit 7ff1b95

Please sign in to comment.