diff --git a/core/client.go b/core/client.go index 8d18cef3f..f53293753 100644 --- a/core/client.go +++ b/core/client.go @@ -5,15 +5,14 @@ import ( "context" "errors" "fmt" - "io" "reflect" "runtime" "time" - "github.com/quic-go/quic-go" "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/pkg/frame-codec/y3codec" "github.com/yomorun/yomo/pkg/id" + yquic "github.com/yomorun/yomo/pkg/listener/quic" oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/exp/slog" ) @@ -69,35 +68,10 @@ func NewClient(appName, zipperAddr string, clientType ClientType, opts ...Client } } -type connectResult struct { - conn quic.Connection - fs *FrameStream - err error -} - -func newConnectResult(conn quic.Connection, fs *FrameStream, err error) *connectResult { - return &connectResult{ - conn: conn, - fs: fs, - err: err, - } -} - -func (c *Client) connect(ctx context.Context) *connectResult { - conn, err := quic.DialAddr(ctx, c.zipperAddr, c.opts.tlsConfig, c.opts.quicConfig) - if err != nil { - return newConnectResult(conn, nil, err) - } - - stream, err := conn.OpenStream() +func (c *Client) connect(ctx context.Context, addr string) (frame.Conn, error) { + conn, err := yquic.DialAddr(ctx, addr, y3codec.Codec(), y3codec.PacketReadWriter(), c.opts.tlsConfig, c.opts.quicConfig) if err != nil { - return newConnectResult(conn, nil, err) - } - - fs := NewFrameStream(stream, y3codec.Codec(), y3codec.PacketReadWriter()) - - if credential := c.opts.credential; credential != nil { - c.Logger.Info("use credential", "credential_name", credential.Name()) + return conn, err } hf := &frame.HandshakeFrame{ @@ -109,59 +83,56 @@ func (c *Client) connect(ctx context.Context) *connectResult { AuthPayload: c.opts.credential.Payload(), } - if err := fs.WriteFrame(hf); err != nil { - return newConnectResult(conn, nil, err) + if err := conn.WriteFrame(hf); err != nil { + return conn, err } - received, err := fs.ReadFrame() + received, err := conn.ReadFrame() if err != nil { - return newConnectResult(conn, nil, err) + return nil, err } - switch received.Type() { case frame.TypeRejectedFrame: - se := ErrAuthenticateFailed{received.(*frame.RejectedFrame).Message} - return newConnectResult(conn, fs, se) + return conn, ErrAuthenticateFailed{received.(*frame.RejectedFrame).Message} case frame.TypeHandshakeAckFrame: - return newConnectResult(conn, fs, nil) + return conn, nil default: - se := ErrAuthenticateFailed{ + return conn, ErrAuthenticateFailed{ fmt.Sprintf("authentication failed: read unexcepted frame, frame read: %s", received.Type().String()), } - return newConnectResult(conn, fs, se) } } -func (c *Client) runBackground(ctx context.Context, conn quic.Connection, fs *FrameStream) { +func (c *Client) runBackground(ctx context.Context, conn frame.Conn) { reconnection := make(chan struct{}) - go c.handleReadFrames(fs, reconnection) + go c.handleReadFrames(conn, reconnection) + var err error for { select { case <-c.ctx.Done(): - fs.Close() + conn.CloseWithError("yomo: client closed") return case <-ctx.Done(): - fs.Close() + conn.CloseWithError("yomo: parent context canceled") return case f := <-c.writeFrameChan: - if err := fs.WriteFrame(f); err != nil { + if err := conn.WriteFrame(f); err != nil { c.handleFrameError(err, reconnection) } case <-reconnection: reconnect: - cr := c.connect(ctx) - if err := cr.err; err != nil { + conn, err = c.connect(ctx, c.zipperAddr) + if err != nil { if errors.As(err, new(ErrAuthenticateFailed)) { return } - c.Logger.Error("reconnect to zipper error", "err", cr.err) + c.Logger.Error("reconnect to zipper error", "err", err) time.Sleep(time.Second) goto reconnect } - fs = cr.fs - go c.handleReadFrames(fs, reconnection) + go c.handleReadFrames(conn, reconnection) } } } @@ -169,19 +140,19 @@ func (c *Client) runBackground(ctx context.Context, conn quic.Connection, fs *Fr // Connect connect client to server. func (c *Client) Connect(ctx context.Context) error { connect: - result := c.connect(ctx) - if result.err != nil { + fconn, err := c.connect(ctx, c.zipperAddr) + if err != nil { if c.opts.connectUntilSucceed { - c.Logger.Error("failed to connect to zipper, trying to reconnect", "err", result.err) + c.Logger.Error("failed to connect to zipper, trying to reconnect", "err", err) time.Sleep(time.Second) goto connect } - c.Logger.Error("can not connect to zipper", "err", result.err) - return result.err + c.Logger.Error("can not connect to zipper", "err", err) + return err } c.Logger.Info("connected to zipper") - go c.runBackground(ctx, result.conn, result.fs) + go c.runBackground(ctx, fconn) return nil } @@ -238,8 +209,8 @@ func (c *Client) handleFrameError(err error, reconnection chan<- struct{}) { c.errorfn(err) // exit client program if stream has be closed. - if err == io.EOF { - c.ctxCancel(fmt.Errorf("%s: remote shutdown", c.clientType.String())) + if se := new(yquic.ErrConnClosed); errors.As(err, &se) { + c.ctxCancel(fmt.Errorf("%s: shutdown with error=%s", c.clientType.String(), se.Error())) return } @@ -256,9 +227,9 @@ func (c *Client) Wait() { <-c.ctx.Done() } -func (c *Client) handleReadFrames(fs *FrameStream, reconnection chan struct{}) { +func (c *Client) handleReadFrames(fconn frame.Conn, reconnection chan struct{}) { for { - f, err := fs.ReadFrame() + f, err := fconn.ReadFrame() if err != nil { c.handleFrameError(err, reconnection) return diff --git a/core/connection.go b/core/connection.go index 7fb638202..26bc52b91 100644 --- a/core/connection.go +++ b/core/connection.go @@ -1,9 +1,6 @@ package core import ( - "context" - - "github.com/quic-go/quic-go" "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/metadata" "golang.org/x/exp/slog" @@ -31,19 +28,16 @@ type Connection struct { clientType ClientType metadata metadata.M observeDataTags []uint32 - conn quic.Connection - fs *FrameStream + fconn frame.Conn Logger *slog.Logger } func newConnection( name string, id string, clientType ClientType, md metadata.M, tags []uint32, - conn quic.Connection, fs *FrameStream, logger *slog.Logger) *Connection { + fconn frame.Conn, logger *slog.Logger, +) *Connection { logger = logger.With("conn_id", id, "conn_name", name) - if conn != nil { - logger.Info("new client connected", "remote_addr", conn.RemoteAddr().String(), "client_type", clientType.String()) - } return &Connection{ name: name, @@ -51,22 +45,11 @@ func newConnection( clientType: clientType, metadata: md, observeDataTags: tags, - conn: conn, - fs: fs, + fconn: fconn, Logger: logger, } } -// Close closes the connection. -func (c *Connection) Close() error { - return c.fs.Close() -} - -// Context returns the context of the connection. -func (c *Connection) Context() context.Context { - return c.fs.Context() -} - // ID returns the connection ID. func (c *Connection) ID() string { return c.id @@ -87,26 +70,10 @@ func (c *Connection) ObserveDataTags() []uint32 { return c.observeDataTags } -// ReadFrame reads a frame from the connection. -func (c *Connection) ReadFrame() (frame.Frame, error) { - return c.fs.ReadFrame() -} - -// ClientType returns the client type of the connection. func (c *Connection) ClientType() ClientType { return c.clientType } -// WriteFrame writes a frame to the connection. -func (c *Connection) WriteFrame(f frame.Frame) error { - return c.fs.WriteFrame(f) +func (c *Connection) FrameConn() frame.Conn { + return c.fconn } - -// CloseWithError closes the connection with error. -func (c *Connection) CloseWithError(errString string) error { - return c.conn.CloseWithError(YomoCloseErrorCode, errString) -} - -// YomoCloseErrorCode is the error code for close quic Connection for yomo. -// If the Connection implemented by quic is closed, the quic ApplicationErrorCode is always 0x13. -const YomoCloseErrorCode = quic.ApplicationErrorCode(0x13) diff --git a/core/connection_test.go b/core/connection_test.go index ae72d8960..daf5de54f 100644 --- a/core/connection_test.go +++ b/core/connection_test.go @@ -1,37 +1,23 @@ package core import ( - "bytes" - "context" - "io" - "sync" "testing" - "time" - "github.com/quic-go/quic-go" "github.com/stretchr/testify/assert" - "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/metadata" - "golang.org/x/exp/slog" + "github.com/yomorun/yomo/core/ylog" ) func TestConnection(t *testing.T) { var ( - readBytes = []byte("aaabbbcccdddeeefff") - name = "test-data-connection" - id = "123456" - styp = ClientTypeStreamFunction - observed = []uint32{1, 2, 3} - md metadata.M + name = "test-data-connection" + id = "123456" + styp = ClientTypeStreamFunction + observed = []uint32{1, 2, 3} + md metadata.M ) - // Create a connection that initializes the read buffer with a string that has been split by spaces. - mockStream := newMemByteStream(readBytes) - - // create frame connection. - fs := NewFrameStream(mockStream, &byteCodec{}, &bytePacketReadWriter{}) - - connection := newConnection(name, id, styp, md, observed, nil, fs, slog.Default()) + connection := newConnection(name, id, styp, md, observed, nil, ylog.Default()) t.Run("ConnectionInfo", func(t *testing.T) { assert.Equal(t, id, connection.ID()) @@ -40,60 +26,6 @@ func TestConnection(t *testing.T) { assert.Equal(t, md, connection.Metadata()) assert.Equal(t, observed, connection.ObserveDataTags()) }) - - t.Run("connection read", func(t *testing.T) { - gots := []byte{} - for i := 0; i < len(readBytes)+1; i++ { - f, err := connection.ReadFrame() - if err != nil { - if i == len(readBytes) { - assert.Equal(t, io.EOF, err) - } else { - t.Fatal(err) - } - return - } - - b, err := fs.codec.Encode(f) - assert.NoError(t, err) - - gots = append(gots, b...) - } - assert.Equal(t, readBytes, gots) - }) - - t.Run("connection write", func(t *testing.T) { - dataWrited := []byte("ggghhhiiigggkkklll") - - for _, w := range dataWrited { - err := connection.WriteFrame(byteFrame(w)) - assert.NoError(t, err) - } - - assert.Equal(t, string(mockStream.GetReadBytes()), string(dataWrited)) - }) - - t.Run("connection close", func(t *testing.T) { - err := connection.Close() - assert.NoError(t, err) - - // close twice. - err = connection.Close() - assert.NoError(t, err) - - f, err := connection.ReadFrame() - assert.ErrorIs(t, err, io.EOF) - assert.Nil(t, f) - - err = connection.WriteFrame(byteFrame('a')) - assert.ErrorIs(t, err, io.EOF) - - select { - case <-connection.Context().Done(): - default: - assert.Fail(t, "stream.Context().Done() should be done") - } - }) } func TestClientTypeString(t *testing.T) { @@ -102,139 +34,3 @@ func TestClientTypeString(t *testing.T) { assert.Equal(t, ClientTypeUpstreamZipper.String(), "UpstreamZipper") assert.Equal(t, ClientType(0).String(), "Unknown") } - -// byteFrame implements frame.Frame interface for unittest. -func byteFrame(byt byte) *frame.DataFrame { - return &frame.DataFrame{ - Payload: []byte{byt}, - } -} - -type byteCodec struct{} - -var _ frame.Codec = &byteCodec{} - -// Decode implements frame.Codec -func (*byteCodec) Decode(data []byte, f frame.Frame) error { - df, ok := f.(*frame.DataFrame) - if !ok { - return nil - } - df.Payload = data - - return nil -} - -// Encode implements frame.Codec -func (*byteCodec) Encode(f frame.Frame) ([]byte, error) { - return f.(*frame.DataFrame).Payload, nil -} - -type bytePacketReadWriter struct{} - -// WritePacket implements frame.PacketReadWriter -func (*bytePacketReadWriter) WritePacket(stream io.Writer, ftyp frame.Type, data []byte) error { - _, err := stream.Write(data) - return err -} - -// ReadPacket implements frame.PacketReadWriter -func (*bytePacketReadWriter) ReadPacket(stream io.Reader) (frame.Type, []byte, error) { - var b [1]byte - _, err := stream.Read(b[:]) - if err != nil { - return frame.TypeDataFrame, nil, err - } - return frame.TypeDataFrame, []byte{b[0]}, nil -} - -var _ frame.PacketReadWriter = &bytePacketReadWriter{} - -type memByteStream struct { - ctx context.Context - cancel context.CancelFunc - readBuf *bytes.Buffer - writeBuf *bytes.Buffer - mutex sync.Mutex -} - -// CancelRead implements quic.Stream. -func (*memByteStream) CancelRead(quic.StreamErrorCode) { - panic("unimplemented") -} - -func (*memByteStream) CancelWrite(quic.StreamErrorCode) { - panic("unimplemented") -} - -func (*memByteStream) SetDeadline(t time.Time) error { - panic("unimplemented") -} - -func (*memByteStream) SetReadDeadline(t time.Time) error { - panic("unimplemented") -} - -func (*memByteStream) SetWriteDeadline(t time.Time) error { - panic("unimplemented") -} - -func (*memByteStream) StreamID() quic.StreamID { - panic("unimplemented") -} - -func newMemByteStream(readInitBytes []byte) *memByteStream { - ctx, cancel := context.WithCancel(context.Background()) - return &memByteStream{ - ctx: ctx, - cancel: cancel, - readBuf: bytes.NewBuffer(readInitBytes), - writeBuf: &bytes.Buffer{}, - } -} - -func (rw *memByteStream) Context() context.Context { return rw.ctx } - -func (rw *memByteStream) Read(p []byte) (n int, err error) { - select { - case <-rw.ctx.Done(): - return 0, io.EOF - default: - } - - rw.mutex.Lock() - defer rw.mutex.Unlock() - return rw.readBuf.Read(p) -} - -func (rw *memByteStream) Write(p []byte) (n int, err error) { - select { - case <-rw.ctx.Done(): - return 0, io.EOF - default: - } - - rw.mutex.Lock() - defer rw.mutex.Unlock() - return rw.writeBuf.Write(p) -} - -func (rw *memByteStream) Close() error { - rw.cancel() - select { - case <-rw.ctx.Done(): - return nil - default: - } - - rw.mutex.Lock() - defer rw.mutex.Unlock() - rw.writeBuf.Reset() - return nil -} - -func (rw *memByteStream) GetReadBytes() []byte { - rw.mutex.Lock() - defer rw.mutex.Unlock() - return rw.writeBuf.Bytes() -} diff --git a/core/connector_test.go b/core/connector_test.go index e8822099f..e720f1d1f 100644 --- a/core/connector_test.go +++ b/core/connector_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/yomorun/yomo/core/frame" - "golang.org/x/exp/slog" + "github.com/yomorun/yomo/core/ylog" ) func TestConnector(t *testing.T) { @@ -117,5 +117,5 @@ func TestConnector(t *testing.T) { // mockConn returns a connection that only includes an ID and a name. // This function is used for unit testing purposes. func mockConn(id, name string) *Connection { - return newConnection(name, id, ClientType(0), nil, []frame.Tag{0}, nil, nil, slog.Default()) + return newConnection(name, id, ClientType(0), nil, []frame.Tag{0}, nil, ylog.Default()) } diff --git a/core/context.go b/core/context.go index 986456ceb..37666f922 100644 --- a/core/context.go +++ b/core/context.go @@ -58,13 +58,15 @@ func (c *Context) Get(key string) (any, bool) { var _ context.Context = &Context{} // Done returns nil (chan which will wait forever) when c.Connection.Context() has no Context. -func (c *Context) Done() <-chan struct{} { return c.Connection.Context().Done() } +func (c *Context) Done() <-chan struct{} { return c.Connection.FrameConn().Context().Done() } // Deadline returns that there is no deadline (ok==false) when c.Connection has no Context. -func (c *Context) Deadline() (deadline time.Time, ok bool) { return c.Connection.Context().Deadline() } +func (c *Context) Deadline() (deadline time.Time, ok bool) { + return c.Connection.FrameConn().Context().Deadline() +} // Err returns nil when c.Request has no Context. -func (c *Context) Err() error { return c.Connection.Context().Err() } +func (c *Context) Err() error { return c.Connection.FrameConn().Context().Err() } // Value retrieves the value associated with the specified key within the context. // If no value is found, it returns nil. Subsequent invocations of "Value" with the same key yield identical outcomes. @@ -79,7 +81,7 @@ func (c *Context) Value(key any) any { c.mu.Unlock() // this will not take effect forever. - return c.Connection.Context().Value(key) + return c.Connection.FrameConn().Context().Value(key) } // newContext returns a new YoMo context that implements the standard library `context.Context` interface. @@ -121,7 +123,7 @@ func newContext(conn *Connection, route router.Route, df *frame.DataFrame) (c *C func (c *Context) CloseWithError(errString string) { c.Logger.Debug("connection closed", "err", errString) - err := c.Connection.CloseWithError(errString) + err := c.Connection.FrameConn().CloseWithError(errString) if err == nil { return } diff --git a/core/frame/frame.go b/core/frame/frame.go index 55799213a..24e3fe2d3 100644 --- a/core/frame/frame.go +++ b/core/frame/frame.go @@ -2,8 +2,10 @@ package frame import ( + "context" "fmt" "io" + "net" ) // Frame is the minimum unit required for Yomo to run. @@ -162,29 +164,37 @@ type Codec interface { // Tag tags data and can be used for data routing. type Tag = uint32 -// ReadWriteCloser is the interface that groups the ReadFrame, WriteFrame and Close methods. -type ReadWriteCloser interface { - Reader - Writer - Close() error -} - -// ReadWriter is the interface that groups the ReadFrame and WriteFrame methods. -type ReadWriter interface { - Reader - Writer -} - // Writer is the interface that wraps the WriteFrame method, it writes // frame to the underlying connection. type Writer interface { - // WriteFrame writes frame to underlying stream. + // WriteFrame writes frame to underlying connection. WriteFrame(Frame) error } -// Reader reads frame from underlying stream. -type Reader interface { - // ReadFrame reads a frame, if an error occurs, the returned error will not be empty, - // and the returned frame will be nil. +// Listener accepts Conns. +type Listener interface { + // Accept accepts Conns. + Accept(context.Context) (Conn, error) + // Close closes listener, + // If listener be closed, all Conn accepted will be unavailable. + Close() error +} + +// Conn is a connection that transmits data in frame format. +type Conn interface { + // Context returns Conn.Context. + // The Context can be used to manage the lifecycle of connection and + // retrieve error using `context.Cause(conn.Context())` after calling `CloseWithError()`. + Context() context.Context + // WriteFrame writes a frame to connection. + WriteFrame(Frame) error + // ReadFrame returns a channel from which frames can be received. ReadFrame() (Frame, error) + // RemoteAddr returns the remote address of connection. + RemoteAddr() net.Addr + // LocalAddr returns the local address of connection. + LocalAddr() net.Addr + // CloseWithError closes the connection with an error message. + // It will be unavailable if the connection is closed. the error message should be written to the conn.Context(). + CloseWithError(string) error } diff --git a/core/frame_stream.go b/core/frame_stream.go deleted file mode 100644 index eba04dec6..000000000 --- a/core/frame_stream.go +++ /dev/null @@ -1,89 +0,0 @@ -package core - -import ( - "context" - "io" - "sync" - - "github.com/quic-go/quic-go" - "github.com/yomorun/yomo/core/frame" -) - -// FrameStream is the frame.ReadWriter that goroutinue read write safely. -type FrameStream struct { - codec frame.Codec - packetReadWriter frame.PacketReadWriter - - // mu protected stream write and close - // because of stream write and close is not goroutinue-safely. - mu sync.Mutex - underlying quic.Stream -} - -// NewFrameStream creates a new FrameStream. -func NewFrameStream( - stream quic.Stream, codec frame.Codec, packetReadWriter frame.PacketReadWriter, -) *FrameStream { - return &FrameStream{ - underlying: stream, - codec: codec, - packetReadWriter: packetReadWriter, - } -} - -// Context returns the context of the FrameStream. -func (fs *FrameStream) Context() context.Context { - return fs.underlying.Context() -} - -// ReadFrame reads next frame from underlying stream. -func (fs *FrameStream) ReadFrame() (frame.Frame, error) { - select { - case <-fs.underlying.Context().Done(): - return nil, io.EOF - default: - } - - fType, b, err := fs.packetReadWriter.ReadPacket(fs.underlying) - if err != nil { - return nil, err - } - - f, err := frame.NewFrame(fType) - if err != nil { - return nil, err - } - - if err := fs.codec.Decode(b, f); err != nil { - return nil, err - } - - return f, nil -} - -// WriteFrame writes a frame into underlying stream. -func (fs *FrameStream) WriteFrame(f frame.Frame) error { - select { - case <-fs.underlying.Context().Done(): - return io.EOF - default: - } - - fs.mu.Lock() - defer fs.mu.Unlock() - - b, err := fs.codec.Encode(f) - if err != nil { - return err - } - - return fs.packetReadWriter.WritePacket(fs.underlying, f.Type(), b) -} - -// Close closes the FrameStream and returns an error if any. -func (fs *FrameStream) Close() error { - fs.mu.Lock() - defer fs.mu.Unlock() - - return fs.underlying.Close() -} diff --git a/core/server.go b/core/server.go index 7ec7be1b6..acfe423c4 100644 --- a/core/server.go +++ b/core/server.go @@ -10,7 +10,6 @@ import ( "sync" "sync/atomic" - "github.com/quic-go/quic-go" "github.com/yomorun/yomo/core/auth" "github.com/yomorun/yomo/core/frame" "github.com/yomorun/yomo/core/metadata" @@ -20,6 +19,7 @@ import ( // authentication implements, Currently, only token authentication is implemented _ "github.com/yomorun/yomo/pkg/auth" "github.com/yomorun/yomo/pkg/frame-codec/y3codec" + yquic "github.com/yomorun/yomo/pkg/listener/quic" pkgtls "github.com/yomorun/yomo/pkg/tls" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -56,7 +56,7 @@ type Server struct { opts *serverOptions frameHandler FrameHandler connHandler ConnHandler - listener *quic.Listener + listener frame.Listener logger *slog.Logger tracerProvider oteltrace.TracerProvider } @@ -111,28 +111,68 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) error { return s.Serve(ctx, conn) } -func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router.Route, *Connection) { +// Serve the server with a net.PacketConn. +func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { + if err := s.validateRouter(); err != nil { + return err + } + + s.connector = NewConnector(ctx) + + tlsConfig := s.opts.tlsConfig + if tlsConfig == nil { + tlsConfig = pkgtls.MustCreateServerTLSConfig(conn.LocalAddr().String()) + } + + // listen the address + listener, err := yquic.Listen(conn, y3codec.Codec(), y3codec.PacketReadWriter(), tlsConfig, s.opts.quicConfig) + if err != nil { + s.logger.Error("failed to listen on quic", "err", err) + return err + } + s.listener = listener + + s.logger.Info( + "zipper is up and running", + "zipper_addr", conn.LocalAddr().String(), "pid", os.Getpid(), "quic", s.opts.quicConfig.Versions, "auth_name", s.authNames()) + + defer closeServer(s.downstreams, s.connector, s.listener, s.router) + + for { + fconn, err := s.listener.Accept(s.ctx) + if err != nil { + if err == s.ctx.Err() { + return ErrServerClosed + } + s.logger.Error("accepted an error when accepting a connection", "err", err) + return err + } + + go s.handleFrameConn(fconn, s.logger) + } +} + +func (s *Server) handshake(fconn frame.Conn) (bool, router.Route, *Connection) { var gerr error defer func() { if gerr == nil { - _ = fs.WriteFrame(&frame.HandshakeAckFrame{}) + _ = fconn.WriteFrame(&frame.HandshakeAckFrame{}) } else { - _ = fs.WriteFrame(&frame.RejectedFrame{Message: gerr.Error()}) + _ = fconn.WriteFrame(&frame.RejectedFrame{Message: gerr.Error()}) } }() - first, err := fs.ReadFrame() + first, err := fconn.ReadFrame() if err != nil { gerr = err return false, nil, nil } - switch first.Type() { case frame.TypeHandshakeFrame: hf := first.(*frame.HandshakeFrame) - conn, err := s.handleHandshakeFrame(qconn, fs, hf) + conn, err := s.handleHandshakeFrame(fconn, hf) if err != nil { gerr = err return false, nil, conn @@ -150,14 +190,10 @@ func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router } func (s *Server) handleConnRoute(conn *Connection, route router.Route) { - defer func() { - if conn.ClientType() == ClientTypeStreamFunction { - _ = route.Remove(conn.ID()) - } - _ = s.connector.Remove(conn.ID()) - }() + conn.Logger.Info("new client connected", "client_type", conn.ClientType().String()) + for { - f, err := conn.ReadFrame() + f, err := conn.FrameConn().ReadFrame() if err != nil { conn.Logger.Info("failed to read frame", "err", err) return @@ -169,9 +205,10 @@ func (s *Server) handleConnRoute(conn *Connection, route router.Route) { conn.Logger.Info("failed to new context", "err", err) return } - defer c.Release() - s.frameHandler(c) + s.frameHandler(c) // s.handleFrame(c) with middlewares + + c.Release() default: conn.Logger.Info("unexpected frame", "type", f.Type().String()) return @@ -179,32 +216,22 @@ func (s *Server) handleConnRoute(conn *Connection, route router.Route) { } } -func (s *Server) handleQuicConnection(qconn quic.Connection, fs *FrameStream, logger *slog.Logger) { - ok, route, conn := s.handshake(qconn, fs) +func (s *Server) handleFrameConn(fconn frame.Conn, logger *slog.Logger) { + ok, route, conn := s.handshake(fconn) if !ok { logger.Error("handshake failed") return } - s.connHandler(conn, route) -} + s.connHandler(conn, route) // s.handleConn(conn) with middlewares -func (s *Server) addSfnToRoute(hf *frame.HandshakeFrame, md metadata.M) (router.Route, error) { - if hf.ClientType != byte(ClientTypeStreamFunction) { - return nil, nil - } - route := s.router.Route(md) - if route == nil { - return nil, errors.New("yomo: can't find route in handshake metadata") - } - err := route.Add(hf.ID, hf.ObserveDataTags) - if err != nil { - return nil, err + if conn.ClientType() == ClientTypeStreamFunction { + _ = route.Remove(conn.ID()) } - return route, nil + _ = s.connector.Remove(conn.ID()) } -func (s *Server) handleHandshakeFrame(qconn quic.Connection, fs *FrameStream, hf *frame.HandshakeFrame) (*Connection, error) { +func (s *Server) handleHandshakeFrame(fconn frame.Conn, hf *frame.HandshakeFrame) (*Connection, error) { md, ok := auth.Authenticate(s.opts.auths, hf) if !ok { @@ -216,89 +243,24 @@ func (s *Server) handleHandshakeFrame(qconn quic.Connection, fs *FrameStream, hf return nil, fmt.Errorf("authentication failed: client credential name is %s", hf.AuthName) } - conn := newConnection(hf.Name, hf.ID, ClientType(hf.ClientType), md, hf.ObserveDataTags, qconn, fs, s.logger) + conn := newConnection(hf.Name, hf.ID, ClientType(hf.ClientType), md, hf.ObserveDataTags, fconn, s.logger) return conn, s.connector.Store(hf.ID, conn) } -// Serve the server with a net.PacketConn. -func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { - if err := s.validateRouter(); err != nil { - return err +func (s *Server) addSfnToRoute(hf *frame.HandshakeFrame, md metadata.M) (router.Route, error) { + if hf.ClientType != byte(ClientTypeStreamFunction) { + return nil, nil } - - s.connector = NewConnector(ctx) - - tlsConfig := s.opts.tlsConfig - if tlsConfig == nil { - tc, err := pkgtls.CreateServerTLSConfig(conn.LocalAddr().String()) - if err != nil { - return err - } - tlsConfig = tc + route := s.router.Route(md) + if route == nil { + return nil, errors.New("yomo: can't find route in handshake metadata") } - - // listen the address - listener, err := quic.Listen(conn, tlsConfig, s.opts.quicConfig) + err := route.Add(hf.ID, hf.ObserveDataTags) if err != nil { - s.logger.Error("failed to listen on quic", "err", err) - return err - } - s.listener = listener - - s.logger.Info("zipper is up and running", "zipper_addr", s.listener.Addr().String(), "pid", os.Getpid(), "quic", s.opts.quicConfig.Versions, "auth_name", s.authNames()) - - defer closeServer(s.downstreams, s.connector, s.listener, s.router) - - for { - qconn, err := s.listener.Accept(s.ctx) - if err != nil { - if err == s.ctx.Err() { - return ErrServerClosed - } - s.logger.Error("accepted an error when accepting a connection", "err", err) - return err - } - - stream, err := qconn.AcceptStream(ctx) - if err != nil { - continue - } - - fs := NewFrameStream(stream, y3codec.Codec(), y3codec.PacketReadWriter()) - - go s.handleQuicConnection(qconn, fs, s.logger) - } -} - -// Logger returns the logger of server. -func (s *Server) Logger() *slog.Logger { - return s.logger -} - -// Close will shutdown the server. -func (s *Server) Close() error { - s.ctxCancel() - return nil -} - -func closeServer(downstreams map[string]Downstream, connector *Connector, listener *quic.Listener, router router.Router) error { - for _, ds := range downstreams { - ds.Close() - } - // connector - if connector != nil { - connector.Close() - } - // listener - if listener != nil { - listener.Close() - } - // router - if router != nil { - router.Clean() + return nil, err } - return nil + return route, nil } func (s *Server) handleFrame(c *Context) { @@ -368,7 +330,7 @@ func (s *Server) routingDataFrame(c *Context) error { c.Logger.Info("data routing", "tag", dataFrame.Tag, "data_length", data_length, "to_id", toID, "to_name", stream.Name()) // write data frame to stream - if err := stream.WriteFrame(dataFrame); err != nil { + if err := stream.FrameConn().WriteFrame(dataFrame); err != nil { c.Logger.Error("failed to write frame for routing data", "err", err) } } @@ -392,7 +354,7 @@ func (s *Server) handleBackflowFrame(c *Context) error { for _, s := range sources { if s != nil { c.Logger.Info("backflow to source", "source_conn_id", sourceID) - if err := s.WriteFrame(bf); err != nil { + if err := s.FrameConn().WriteFrame(bf); err != nil { c.Logger.Error("failed to write frame for backflow to the source", "err", err) return err } @@ -401,6 +363,53 @@ func (s *Server) handleBackflowFrame(c *Context) error { return nil } +// dispatch every DataFrames to all downstreams +func (s *Server) dispatchToDownstreams(c *Context) error { + dataFrame := c.Frame + if c.Connection.ClientType() == ClientTypeUpstreamZipper { + c.Logger.Debug("ignored client", "client_type", c.Connection.ClientType().String()) + // loop protection + return nil + } + + mdBytes, err := c.FrameMetadata.Encode() + if err != nil { + c.Logger.Error("failed to dispatch to downstream", "err", err) + return err + } + dataFrame.Metadata = mdBytes + + for _, ds := range s.downstreams { + c.Logger.Info( + "dispatching to downstream", + "tag", dataFrame.Tag, "data_length", len(dataFrame.Payload), + "downstream_id", ds.ID(), "downstream_name", ds.LocalName()) + + _ = ds.WriteFrame(dataFrame) + } + + return nil +} + +func closeServer(downstreams map[string]Downstream, connector *Connector, listener frame.Listener, router router.Router) error { + for _, ds := range downstreams { + ds.Close() + } + // connector + if connector != nil { + connector.Close() + } + // listener + if listener != nil { + listener.Close() + } + // router + if router != nil { + router.Clean() + } + return nil +} + // sourceIDTagFindConnectionFunc creates a FindStreamFunc that finds a source type stream matching the specified sourceID and tag. func sourceIDTagFindConnectionFunc(sourceID string, tag frame.Tag) FindConnectionFunc { return func(conn ConnectionInfo) bool { @@ -453,31 +462,14 @@ func (s *Server) AddDownstreamServer(c Downstream) { s.mu.Unlock() } -// dispatch every DataFrames to all downstreams -func (s *Server) dispatchToDownstreams(c *Context) error { - dataFrame := c.Frame - if c.Connection.ClientType() == ClientTypeUpstreamZipper { - c.Logger.Debug("ignored client", "client_type", c.Connection.ClientType().String()) - // loop protection - return nil - } - - mdBytes, err := c.FrameMetadata.Encode() - if err != nil { - c.Logger.Error("failed to dispatch to downstream", "err", err) - return err - } - dataFrame.Metadata = mdBytes - - for _, ds := range s.downstreams { - c.Logger.Info( - "dispatching to downstream", - "tag", dataFrame.Tag, "data_length", len(dataFrame.Payload), - "downstream_id", ds.ID(), "downstream_name", ds.LocalName()) - - _ = ds.WriteFrame(dataFrame) - } +// Logger returns the logger of server. +func (s *Server) Logger() *slog.Logger { + return s.logger +} +// Close will shutdown the server. +func (s *Server) Close() error { + s.ctxCancel() return nil } diff --git a/pkg/listener/quic/quic.go b/pkg/listener/quic/quic.go new file mode 100644 index 000000000..abaec95dc --- /dev/null +++ b/pkg/listener/quic/quic.go @@ -0,0 +1,229 @@ +package yquic + +import ( + "context" + "crypto/tls" + "errors" + "net" + + "github.com/quic-go/quic-go" + "github.com/yomorun/yomo/core/frame" +) + +// ErrConnClosed is returned when the connection is closed. +// If the connection is closed, both the stream and the connection will receive this error. +type ErrConnClosed struct { + Message string +} + +// Error implements the error interface and returns the reason why the connection was closed. +func (e *ErrConnClosed) Error() string { + return e.Message +} + +// FrameConn is an implements of FrameConn, +// It transmits frames upon the first stream from a QUIC connection. +type FrameConn struct { + ctx context.Context + ctxCancel context.CancelCauseFunc + frameCh chan frame.Frame + conn quic.Connection + stream quic.Stream + codec frame.Codec + prw frame.PacketReadWriter +} + +// DialAddr dials the given address and returns a new FrameConn. +func DialAddr( + ctx context.Context, + addr string, + codec frame.Codec, prw frame.PacketReadWriter, + tlsConfig *tls.Config, quicConfig *quic.Config, +) (*FrameConn, error) { + qconn, err := quic.DialAddr(ctx, addr, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + + stream, err := qconn.OpenStream() + if err != nil { + return nil, err + } + + return newFrameConn(qconn, stream, codec, prw), nil +} + +func newFrameConn( + qconn quic.Connection, stream quic.Stream, + codec frame.Codec, prw frame.PacketReadWriter, +) *FrameConn { + ctx, ctxCancel := context.WithCancelCause(context.Background()) + + conn := &FrameConn{ + ctx: ctx, + ctxCancel: ctxCancel, + frameCh: make(chan frame.Frame), + conn: qconn, + stream: stream, + codec: codec, + prw: prw, + } + + go conn.framing() + + return conn +} + +// YomoCloseErrorCode is the error code for close quic Connection for yomo. +// If the Connection implemented by quic is closed, the quic ApplicationErrorCode is always 0x13. +const YomoCloseErrorCode = quic.ApplicationErrorCode(0x13) + +// Context returns the context of the connection. +func (p *FrameConn) Context() context.Context { + return p.ctx +} + +// RemoteAddr returns the remote address of connection. +func (p *FrameConn) RemoteAddr() net.Addr { + return p.conn.RemoteAddr() +} + +// LocalAddr returns the local address of connection. +func (p *FrameConn) LocalAddr() net.Addr { + return p.conn.LocalAddr() +} + +// CloseWithError closes the connection. +func (p *FrameConn) CloseWithError(errString string) error { + select { + case <-p.ctx.Done(): + return nil + default: + } + p.ctxCancel(&ErrConnClosed{errString}) + + // _ = p.stream.Close() + return p.conn.CloseWithError(YomoCloseErrorCode, errString) +} + +func (p *FrameConn) framing() { + for { + fType, b, err := p.prw.ReadPacket(p.stream) + if err != nil { + p.ctxCancel(convertErrorToConnectionClosed(err)) + return + } + + f, err := frame.NewFrame(fType) + if err != nil { + p.ctxCancel(convertErrorToConnectionClosed(err)) + return + } + + if err := p.codec.Decode(b, f); err != nil { + p.ctxCancel(convertErrorToConnectionClosed(err)) + return + } + p.frameCh <- f + } +} + +func convertErrorToConnectionClosed(err error) error { + if se := new(quic.ApplicationError); errors.As(err, &se) { + if se.ErrorCode == 0 && se.ErrorMessage == "" { + return &ErrConnClosed{"yomo: listener closed"} + } + return &ErrConnClosed{se.ErrorMessage} + } + return err +} + +// ReadFrame reads a frame. it usually be called in a for-loop. +func (p *FrameConn) ReadFrame() (frame.Frame, error) { + select { + case <-p.ctx.Done(): + return nil, context.Cause(p.ctx) + case ff := <-p.frameCh: + return ff, nil + } +} + +// WriteFrame writes a frame to connection. +func (p *FrameConn) WriteFrame(f frame.Frame) error { + select { + case <-p.ctx.Done(): + return context.Cause(p.ctx) + default: + b, err := p.codec.Encode(f) + if err != nil { + return err + } + + return p.prw.WritePacket(p.stream, f.Type(), b) + } +} + +// Listener listens a net.PacketConn and accepts connections. +type Listener struct { + underlying *quic.Listener + codec frame.Codec + prw frame.PacketReadWriter +} + +// Listen returns a quic Listener that can accept connections. +func Listen( + conn net.PacketConn, + codec frame.Codec, prw frame.PacketReadWriter, + tlsConfig *tls.Config, quicConfig *quic.Config, +) (*Listener, error) { + ql, err := quic.Listen(conn, tlsConfig, quicConfig) + if err != nil { + return nil, err + } + + listener := &Listener{ + underlying: ql, + codec: codec, + prw: prw, + } + + return listener, err +} + +// ListenAddr listens an address and returns a new Listener. +func ListenAddr( + addr string, + codec frame.Codec, prw frame.PacketReadWriter, + tlsConfig *tls.Config, quicConfig *quic.Config, +) (*Listener, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return nil, err + } + + return Listen(conn, codec, prw, tlsConfig, quicConfig) +} + +// Accept accepts FrameConns. +func (listener *Listener) Accept(ctx context.Context) (frame.Conn, error) { + qconn, err := listener.underlying.Accept(ctx) + if err != nil { + return nil, err + } + stream, err := qconn.AcceptStream(ctx) + if err != nil { + return nil, err + } + + return newFrameConn(qconn, stream, listener.codec, listener.prw), nil +} + +// Close closes listener. +// If listener be closed, all connection receive quic application error that code=0, message="". +func (listener *Listener) Close() error { + return listener.underlying.Close() +} diff --git a/pkg/listener/quic/quic_test.go b/pkg/listener/quic/quic_test.go new file mode 100644 index 000000000..040489656 --- /dev/null +++ b/pkg/listener/quic/quic_test.go @@ -0,0 +1,90 @@ +package yquic + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/yomorun/yomo/core/frame" + "github.com/yomorun/yomo/pkg/frame-codec/y3codec" + pkgtls "github.com/yomorun/yomo/pkg/tls" +) + +const testHost = "localhost:9008" + +const ( + handshakeName = "hello yomo" + streamContent = "hello stream" + CloseMessage = "bye!" +) + +func TestFrameConnection(t *testing.T) { + go func() { + if err := runListener(t); err != nil { + panic(err) + } + }() + + fconn, err := DialAddr(context.TODO(), testHost, + y3codec.Codec(), y3codec.PacketReadWriter(), + pkgtls.MustCreateClientTLSConfig(), nil, + ) + assert.NoError(t, err) + + err = fconn.WriteFrame(&frame.HandshakeAckFrame{}) + assert.NoError(t, err) + + for { + f, err := fconn.ReadFrame() + if err != nil { + se := new(ErrConnClosed) + assert.True(t, errors.As(err, &se)) + assert.Equal(t, &ErrConnClosed{CloseMessage}, err) + return + } + hf := f.(*frame.HandshakeFrame) + assert.Equal(t, handshakeName, hf.Name) + } +} + +func runListener(t *testing.T) error { + listener, err := ListenAddr(testHost, y3codec.Codec(), y3codec.PacketReadWriter(), pkgtls.MustCreateServerTLSConfig(testHost), nil) + if err != nil { + return err + } + + time.AfterFunc(3*time.Second, func() { + listener.Close() + }) + + fconn, err := listener.Accept(context.TODO()) + if err != nil { + return err + } + + f, err := fconn.ReadFrame() + assert.NoError(t, err) + assert.Equal(t, f.Type(), frame.TypeHandshakeAckFrame) + + if err := fconn.WriteFrame(&frame.HandshakeFrame{Name: handshakeName}); err != nil { + return err + } + + time.AfterFunc(time.Second, func() { + err := fconn.CloseWithError(CloseMessage) + assert.NoError(t, err) + + // close twice has no effect. + err = fconn.CloseWithError(CloseMessage) + assert.NoError(t, err) + + err = fconn.WriteFrame(&frame.DataFrame{Payload: []byte("aaaa")}) + assert.Equal(t, &ErrConnClosed{CloseMessage}, err) + + t.Log("close connection done") + }) + + return nil +} diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index 61aca5766..7b9458f3c 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -52,6 +52,15 @@ func CreateServerTLSConfig(host string) (*tls.Config, error) { }, nil } +// MustCreateServerTLSConfig creates server tls config, It is panic If error here. +func MustCreateServerTLSConfig(host string) *tls.Config { + conf, err := CreateServerTLSConfig(host) + if err != nil { + panic(err) + } + return conf +} + // MustCreateClientTLSConfig creates client tls config, It is panic If error here. func MustCreateClientTLSConfig() *tls.Config { conf, err := CreateClientTLSConfig() diff --git a/zipper.go b/zipper.go index bc5f83264..d1bd11a51 100644 --- a/zipper.go +++ b/zipper.go @@ -66,13 +66,13 @@ func NewZipper(name string, router router.Router, meshConfig map[string]config.D for downstreamName, meshConf := range meshConfig { addr := fmt.Sprintf("%s:%d", meshConf.Host, meshConf.Port) - clientOptions := append( - opts.clientOption, + clientOptions := []core.ClientOption{ core.WithCredential(meshConf.Credential), core.WithNonBlockWrite(), core.WithConnectUntilSucceed(), core.WithLogger(server.Logger().With("downstream_name", downstreamName, "downstream_addr", addr)), - ) + } + clientOptions = append(clientOptions, opts.clientOption...) downstream := &downstream{ localName: downstreamName,