diff --git a/cli/serve.go b/cli/serve.go index 01eed0be5..23880ccfa 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -56,11 +56,16 @@ var serveCmd = &cobra.Command{ listenAddr := fmt.Sprintf("%s:%d", conf.Host, conf.Port) options := []yomo.ZipperOption{yomo.WithZipperTracerProvider(tp)} + // auth if _, ok := conf.Auth["type"]; ok { if tokenString, ok := conf.Auth["token"]; ok { options = append(options, yomo.WithAuth("token", tokenString)) } } + // stream + if conf.Stream != nil { + options = append(options, yomo.WithZipperStreamChunkSize(conf.Stream.ChunkSize)) + } zipper, err := yomo.NewZipper(conf.Name, conf.Downstreams, options...) if err != nil { diff --git a/core/client.go b/core/client.go index 4a673f78a..8ea2c8cba 100644 --- a/core/client.go +++ b/core/client.go @@ -8,6 +8,8 @@ import ( "io" "reflect" "runtime" + "sync" + "sync/atomic" "time" "github.com/quic-go/quic-go" @@ -18,6 +20,11 @@ import ( "golang.org/x/exp/slog" ) +const ( + // DefaultReconnectInterval is the default interval of reconnecting to zipper. + DefaultReconnectInterval = 1 * time.Second +) + // Client is the abstraction of a YoMo-Client. a YoMo-Client can be // Source, Upstream Zipper or StreamFunction. type Client struct { @@ -36,6 +43,13 @@ type Client struct { ctxCancel context.CancelCauseFunc writeFrameChan chan frame.Frame + + // conn quic.Connection + conn atomic.Pointer[quic.Connection] + // fs frame stream + fs atomic.Pointer[FrameStream] + // data streams + dataStreams sync.Map } // NewClient creates a new YoMo-Client. @@ -155,10 +169,11 @@ func (c *Client) runBackground(ctx context.Context, addr string, conn quic.Conne return } c.logger.Error("reconnect to zipper error", "err", cr.err) - time.Sleep(time.Second) + time.Sleep(DefaultReconnectInterval) goto reconnect } fs = cr.fs + c.setConnection(&cr.conn) go c.handleReadFrames(fs, reconnection) } } @@ -183,8 +198,12 @@ connect: c.logger.Error("can not connect to zipper", "err", result.err) return result.err } + c.logger = c.logger.With("local_addr", result.conn.LocalAddr().String()) c.logger.Info("connected to zipper") + c.setConnection(&result.conn) + c.setFrameStream(result.fs) + go c.runBackground(ctx, addr, result.conn, result.fs) return nil @@ -293,8 +312,11 @@ func (c *Client) handleFrame(f frame.Frame) { c.processor(ff) case *frame.BackflowFrame: c.receiver(ff) + case *frame.StreamFrame: + // TODO: handle stream frame + c.logger.Debug("receive stream frame", "stream_id", ff.StreamID, "conn_id", ff.ClientID, "tag", ff.Tag) default: - c.logger.Warn("received unexpected frame", "frame_type", f.Type().String()) + c.logger.Error("received unexpected frame", "frame_type", f.Type().String()) } } @@ -356,4 +378,149 @@ type ErrAuthenticateFailed struct { } // Error returns a string that represents the ErrAuthenticateFailed error for the implementation of the error interface. -func (e ErrAuthenticateFailed) Error() string { return e.ReasonFromServer } +func (e ErrAuthenticateFailed) Error() string { + return e.ReasonFromServer +} + +// readStream read stream from client. +func (c *Client) readStream() error { +STREAM: + select { + case <-c.ctx.Done(): + return c.ctx.Err() + default: + qconn := c.Connection() + if qconn == nil { + err := errors.New("quic connection is nil") + c.logger.Error(err.Error()) + return err + } + dataStream, err := qconn.AcceptStream(c.ctx) + if err != nil { + c.logger.Error("client accept stream error", "err", err) + return err + } + // close data stream + defer dataStream.Close() + c.logger.Debug("client accept stream success", "stream_id", dataStream.StreamID()) + // read stream frame + fs := NewFrameStream(dataStream, y3codec.Codec(), y3codec.PacketReadWriter()) + f, err := fs.ReadFrame() + if err != nil { + c.logger.Warn("failed to read data stream", "err", err) + return err + } + c.logger.Debug("client read stream frame success", "stream_id", dataStream.StreamID()) + switch f.Type() { + case frame.TypeStreamFrame: + streamFrame := f.(*frame.StreamFrame) + // lookfor data stream + reader, ok := c.dataStreams.Load(streamFrame.ID) + if !ok { + c.logger.Debug( + "data stream is not found", + "stream_id", dataStream.StreamID(), + "datastream_id", streamFrame.ID, + "stream_chunk_size", streamFrame.ChunkSize, + // "datastream_id", dataStreamID, + // "received_id", streamFrame.ID, + "client_id", streamFrame.ClientID, + "tag", streamFrame.Tag, + ) + goto STREAM + } + // clean data stream + defer c.dataStreams.Delete(streamFrame.ID) + // if found, pipe stream + c.logger.Debug( + "client pipe stream is ready", + "remote_addr", qconn.RemoteAddr().String(), + "datastream_id", streamFrame.ID, + "stream_id", dataStream.StreamID(), + "stream_chunk_size", streamFrame.ChunkSize, + "client_id", streamFrame.ClientID, + "tag", streamFrame.Tag, + ) + // pipe stream + stream := reader.(io.Reader) + // TEST: read source stream + // buf, err := io.ReadAll(stream) + // if err != nil { + // c.logger.Error("!!!pipe stream error!!!", "err", err) + // return + // } + // bufString := string(buf) + // l := len(bufString) + // if l > 1000 { + // bufString = bufString[l-1000:] + // } + // c.logger.Info("!!!pipe stream success!!!", + // "remote_addr", qconn.RemoteAddr().String(), + // "datastream_id", streamFrame.ID, + // "stream_id", dataStream.StreamID(), + // "client_id", streamFrame.ClientID, + // "tag", streamFrame.Tag, + // "buf", bufString, + // "len", len(buf), + // ) + buf := make([]byte, streamFrame.ChunkSize) + n, err := io.CopyBuffer(dataStream, stream, buf) + if err != nil { + c.logger.Error("client pipe stream error", "err", err) + return err + } + c.logger.Info("client pipe stream success", + "remote_addr", qconn.RemoteAddr().String(), + "id", streamFrame.ID, + "stream_id", dataStream.StreamID(), + "stream_chunk_size", streamFrame.ChunkSize, + "client_id", streamFrame.ClientID, + "tag", streamFrame.Tag, + "n", n, + ) + default: + c.logger.Error("!!!unexpected frame!!!", "unexpected_frame_type", f.Type().String()) + return errors.New("unexpected frame") + } + } + return nil +} + +// PipeStream pipe a stream to server. +func (c *Client) PipeStream(dataStreamID string, stream io.Reader) error { + c.logger.Debug(fmt.Sprintf("client pipe stream[%s] -- start", dataStreamID)) + c.dataStreams.Store(dataStreamID, stream) + // process all data streams + err := c.readStream() + c.logger.Debug(fmt.Sprintf("client pipe stream[%s] -- end", dataStreamID)) + return err +} + +// Connection returns the connection of client. +func (c *Client) Connection() quic.Connection { + conn := c.conn.Load() + if conn != nil { + return *conn + } + return nil +} + +// setConnection set the connection of client. +func (c *Client) setConnection(conn *quic.Connection) { + c.conn.Store(conn) +} + +// FrameStream returns the FrameStream of client. +func (c *Client) FrameStream() *FrameStream { + return c.fs.Load() +} + +// setFrameStream set the FrameStream of client. +func (c *Client) setFrameStream(fs *FrameStream) { + c.fs.Store(fs) +} + +// DataStreams returns the data streams of client. +func (c *Client) DataStreams() *sync.Map { + return &c.dataStreams +} diff --git a/core/connection.go b/core/connection.go index 87244e3b6..f6bbe32ac 100644 --- a/core/connection.go +++ b/core/connection.go @@ -30,6 +30,8 @@ type Connection interface { frame.ReadWriteCloser // CloseWithError closes the connection with an error string. CloseWithError(string) error + // QuicConnection returns raw quic connection. + QuicConnection() quic.Connection } type connection struct { @@ -44,7 +46,8 @@ type connection struct { func newConnection( name string, id string, clientType ClientType, md metadata.M, tags []uint32, - conn quic.Connection, fs *FrameStream) *connection { + conn quic.Connection, fs *FrameStream, +) *connection { return &connection{ name: name, id: id, @@ -96,6 +99,10 @@ func (c *connection) CloseWithError(errString string) error { return c.conn.CloseWithError(YomoCloseErrorCode, errString) } +func (c *connection) QuicConnection() quic.Connection { + return c.conn +} + // YomoCloseErrorCode is the error code for close quic Connection for yomo. // If the Connection implemented by quic is closed, the quic ApplicationErrorCode is always 0x13. const YomoCloseErrorCode = quic.ApplicationErrorCode(0x13) diff --git a/core/frame/frame.go b/core/frame/frame.go index 55799213a..0a1e59d62 100644 --- a/core/frame/frame.go +++ b/core/frame/frame.go @@ -16,6 +16,7 @@ import ( // 4. BackflowFrame // 5. RejectedFrame // 6. GoawayFrame +// 7. StreamFrame // // Read frame comments to understand the role of the frame. type Frame interface { @@ -98,22 +99,47 @@ type GoawayFrame struct { // Type returns the type of GoawayFrame. func (f *GoawayFrame) Type() Type { return TypeGoawayFrame } +// StreamFrame is used to transmit data across DataStream. +type StreamFrame struct { + ID string + ClientID string + StreamID int64 + ChunkSize int64 + Tag Tag +} + +// Type returns the type of StreamFrame. +func (f *StreamFrame) Type() Type { return TypeStreamFrame } + +// RequestStreamFrame is used to request a stream. +type RequestStreamFrame struct { + ClientID string + Tag Tag +} + +// Type returns the type of RequestStreamFrame. +func (f *RequestStreamFrame) Type() Type { return TypeRequestStreamFrame } + const ( - TypeDataFrame Type = 0x3F // TypeDataFrame is the type of DataFrame. - TypeHandshakeFrame Type = 0x31 // TypeHandshakeFrame is the type of HandshakeFrame. - TypeHandshakeAckFrame Type = 0x29 // TypeHandshakeAckFrame is the type of HandshakeAckFrame. - TypeRejectedFrame Type = 0x39 // TypeRejectedFrame is the type of RejectedFrame. - TypeBackflowFrame Type = 0x2D // TypeBackflowFrame is the type of BackflowFrame. - TypeGoawayFrame Type = 0x2E // TypeGoawayFrame is the type of GoawayFrame. + TypeDataFrame Type = 0x3F // TypeDataFrame is the type of DataFrame. + TypeHandshakeFrame Type = 0x31 // TypeHandshakeFrame is the type of HandshakeFrame. + TypeHandshakeAckFrame Type = 0x29 // TypeHandshakeAckFrame is the type of HandshakeAckFrame. + TypeRejectedFrame Type = 0x39 // TypeRejectedFrame is the type of RejectedFrame. + TypeBackflowFrame Type = 0x2D // TypeBackflowFrame is the type of BackflowFrame. + TypeGoawayFrame Type = 0x2E // TypeGoawayFrame is the type of GoawayFrame. + TypeStreamFrame Type = 0x2F // TypeStreamFrame is the type of StreamFrame. + TypeRequestStreamFrame Type = 0x30 // TypeRequestStreamFrame is the type of RequestStreamFrame. ) var frameTypeStringMap = map[Type]string{ - TypeDataFrame: "DataFrame", - TypeHandshakeFrame: "HandshakeFrame", - TypeHandshakeAckFrame: "HandshakeAckFrame", - TypeRejectedFrame: "RejectedFrame", - TypeBackflowFrame: "BackflowFrame", - TypeGoawayFrame: "GoawayFrame", + TypeDataFrame: "DataFrame", + TypeHandshakeFrame: "HandshakeFrame", + TypeHandshakeAckFrame: "HandshakeAckFrame", + TypeRejectedFrame: "RejectedFrame", + TypeBackflowFrame: "BackflowFrame", + TypeGoawayFrame: "GoawayFrame", + TypeStreamFrame: "StreamFrame", + TypeRequestStreamFrame: "RequestStreamFrame", } // String returns a human-readable string which represents the frame type. @@ -127,12 +153,14 @@ func (f Type) String() string { } var frameTypeNewFuncMap = map[Type]func() Frame{ - TypeDataFrame: func() Frame { return new(DataFrame) }, - TypeHandshakeFrame: func() Frame { return new(HandshakeFrame) }, - TypeHandshakeAckFrame: func() Frame { return new(HandshakeAckFrame) }, - TypeRejectedFrame: func() Frame { return new(RejectedFrame) }, - TypeBackflowFrame: func() Frame { return new(BackflowFrame) }, - TypeGoawayFrame: func() Frame { return new(GoawayFrame) }, + TypeDataFrame: func() Frame { return new(DataFrame) }, + TypeHandshakeFrame: func() Frame { return new(HandshakeFrame) }, + TypeHandshakeAckFrame: func() Frame { return new(HandshakeAckFrame) }, + TypeRejectedFrame: func() Frame { return new(RejectedFrame) }, + TypeBackflowFrame: func() Frame { return new(BackflowFrame) }, + TypeGoawayFrame: func() Frame { return new(GoawayFrame) }, + TypeStreamFrame: func() Frame { return new(StreamFrame) }, + TypeRequestStreamFrame: func() Frame { return new(RequestStreamFrame) }, } // NewFrame creates a new frame from Type. diff --git a/core/frame_stream.go b/core/frame_stream.go index eba04dec6..528a8dc3f 100644 --- a/core/frame_stream.go +++ b/core/frame_stream.go @@ -36,6 +36,11 @@ func (fs *FrameStream) Context() context.Context { return fs.underlying.Context() } +// ReadStream reads the underlying stream. +func (fs *FrameStream) ReadStream() (quic.Stream, error) { + return fs.underlying, nil +} + // ReadFrame reads next frame from underlying stream. func (fs *FrameStream) ReadFrame() (frame.Frame, error) { select { @@ -87,3 +92,10 @@ func (fs *FrameStream) Close() error { return fs.underlying.Close() } + +// Codec returns the codec of the FrameStream. +func (fs *FrameStream) Codec() frame.Codec { + fs.mu.Lock() + defer fs.mu.Unlock() + return fs.codec +} diff --git a/core/metadata.go b/core/metadata.go index a60b6a958..0adfd5ae4 100644 --- a/core/metadata.go +++ b/core/metadata.go @@ -10,10 +10,16 @@ const ( MetadataTIDKey = "yomo-tid" MetadataSIDKey = "yomo-sid" MetaTraced = "yomo-traced" + MetaStreamed = "yomo-streamed" ) // NewDefaultMetadata returns a default metadata. -func NewDefaultMetadata(sourceID string, tid string, sid string, traced bool) metadata.M { +func NewDefaultMetadata(sourceID string, tid string, sid string, traced bool, streamed bool) metadata.M { + // streamed + streamedString := "false" + if streamed { + streamedString = "true" + } tracedString := "false" if traced { tracedString = "true" @@ -23,6 +29,7 @@ func NewDefaultMetadata(sourceID string, tid string, sid string, traced bool) me MetadataTIDKey: tid, MetadataSIDKey: sid, MetaTraced: tracedString, + MetaStreamed: streamedString, } } @@ -50,6 +57,12 @@ func GetTracedFromMetadata(m metadata.M) bool { return traced == "true" } +// GetStreamedFromMetadata gets streamed from metadata. +func GetStreamedFromMetadata(m metadata.M) bool { + streamed, _ := m.Get(MetaStreamed) + return streamed == "true" +} + // SetTIDToMetadata sets tid to metadata. func SetTIDToMetadata(m metadata.M, tid string) { m.Set(MetadataTIDKey, tid) @@ -69,6 +82,15 @@ func SetTracedToMetadata(m metadata.M, traced bool) { m.Set(MetaTraced, tracedString) } +// SetStreamedToMetadata sets streamed to metadata. +func SetStreamedToMetadata(m metadata.M, streamed bool) { + streamedString := "false" + if streamed { + streamedString = "true" + } + m.Set(MetaStreamed, streamedString) +} + // MetadataSlogAttr returns slog.Attr from metadata. func MetadataSlogAttr(md metadata.M) slog.Attr { kvStrings := make([]any, len(md)*2) diff --git a/core/server.go b/core/server.go index 9b28b57ba..d130f586f 100644 --- a/core/server.go +++ b/core/server.go @@ -1,9 +1,11 @@ package core import ( + "bytes" "context" "errors" "fmt" + "io" "net" "os" "reflect" @@ -18,6 +20,7 @@ import ( "golang.org/x/exp/slog" // authentication implements, Currently, only token authentication is implemented + _ "github.com/yomorun/yomo/pkg/auth" "github.com/yomorun/yomo/pkg/frame-codec/y3codec" "github.com/yomorun/yomo/pkg/id" @@ -45,7 +48,7 @@ type Server struct { codec frame.Codec packetReadWriter frame.PacketReadWriter counterOfDataFrame int64 - downstreams map[string]FrameWriterConnection + downstreams map[string]*Client mu sync.Mutex opts *serverOptions startHandlers []FrameHandler @@ -73,7 +76,7 @@ func NewServer(name string, opts ...ServerOption) *Server { ctx: ctx, ctxCancel: ctxCancel, name: name, - downstreams: make(map[string]FrameWriterConnection), + downstreams: make(map[string]*Client), logger: logger, tracerProvider: options.tracerProvider, codec: y3codec.Codec(), @@ -103,7 +106,7 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) error { return s.Serve(ctx, conn) } -func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router.Route, Connection) { +func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (*frame.HandshakeFrame, error) { var gerr error defer func() { @@ -117,44 +120,48 @@ func (s *Server) handshake(qconn quic.Connection, fs *FrameStream) (bool, router first, err := fs.ReadFrame() if err != nil { gerr = err - return false, nil, nil + return nil, err } switch first.Type() { case frame.TypeHandshakeFrame: - hf := first.(*frame.HandshakeFrame) - - conn, err := s.handleHandshakeFrame(qconn, fs, hf) - if err != nil { - gerr = err - return false, nil, conn - } - - route, err := s.handleRoute(hf, conn.Metadata()) - if err != nil { - gerr = err - } - return true, route, conn + return first.(*frame.HandshakeFrame), nil + + // conn, err := s.handleHandshakeFrame(qconn, fs, hf) + // if err != nil { + // gerr = err + // return nil, conn, err + // } + // + // route, err := s.handleRoute(hf, conn.Metadata()) + // if err != nil { + // gerr = err + // return nil, conn, err + // } + // return true, route, conn + // return route, conn, nil default: gerr = fmt.Errorf("yomo: handshake read unexpected frame, read: %s", first.Type().String()) - return false, nil, nil + return nil, gerr } } -func (s *Server) handleConnection(qconn quic.Connection, fs *FrameStream, logger *slog.Logger) { - ok, route, conn := s.handshake(qconn, fs) - if !ok { - logger.Error("handshake failed") - return - } - - logger = logger.With("conn_id", conn.ID(), "conn_name", conn.Name()) - logger.Info("client connected", "remote_addr", qconn.RemoteAddr().String(), "client_type", conn.ClientType().String()) - - c := newContext(conn, route, logger) - - s.handleContext(c) -} +// func (s *Server) + +// func (s *Server) handleStream(qconn quic.Connection, fs *FrameStream, logger *slog.Logger) { +// route, conn, err := s.handshake(qconn, fs) +// if err != nil { +// logger.error("handshake failed", "err", err) +// return +// } +// +// logger = logger.with("conn_id", conn.id(), "conn_name", conn.name()) +// logger.info("client connected", "remote_addr", qconn.remoteaddr().string(), "client_type", conn.clienttype().string()) +// +// c := newContext(conn, route, logger) +// +// s.handleContext(c) +// } func (s *Server) handleContext(c *Context) { for _, h := range s.startHandlers { @@ -208,7 +215,6 @@ func (s *Server) handleFrames(c *Context) { } } } - } func (s *Server) handleRoute(hf *frame.HandshakeFrame, md metadata.M) (router.Route, error) { @@ -226,7 +232,7 @@ func (s *Server) handleRoute(hf *frame.HandshakeFrame, md metadata.M) (router.Ro return route, nil } -func (s *Server) handleHandshakeFrame(qconn quic.Connection, fs *FrameStream, hf *frame.HandshakeFrame) (Connection, error) { +func (s *Server) NewConnection(qconn quic.Connection, fs *FrameStream, hf *frame.HandshakeFrame) (Connection, error) { md, ok := auth.Authenticate(s.opts.auths, hf) if !ok { @@ -277,16 +283,44 @@ func (s *Server) Serve(ctx context.Context, conn net.PacketConn) error { s.logger.Error("accepted an error when accepting a connection", "err", err) return err } + // handle connection + go s.handleConnection(ctx, qconn) + } +} - stream, err := qconn.AcceptStream(ctx) - if err != nil { - continue - } - - fs := NewFrameStream(stream, y3codec.Codec(), y3codec.PacketReadWriter()) - - go s.handleConnection(qconn, fs, s.logger) +func (s *Server) handleConnection(ctx context.Context, qconn quic.Connection) { + // first stream + stream, err := qconn.AcceptStream(ctx) + if err != nil { + s.logger.Error("failed to accept stream", "err", err) + return + } + s.logger.Info("accept new stream", "remote_addr", qconn.RemoteAddr().String(), "stream_id", stream.StreamID()) + // handshake for connection + fs := NewFrameStream(stream, y3codec.Codec(), y3codec.PacketReadWriter()) + // handshake + hf, err := s.handshake(qconn, fs) + if err != nil { + s.logger.Error("handshake failed", "err", err) + return + } + // connection + conn, err := s.NewConnection(qconn, fs, hf) + if err != nil { + s.logger.Error("failed to create connection", "err", err) + return + } + // route + route, err := s.handleRoute(hf, conn.Metadata()) + if err != nil { + s.logger.Error("failed to handle route", "err", err) + return } + logger := s.logger.With("conn_id", conn.ID(), "conn_name", conn.Name()) + logger.Info("client connected", "remote_addr", qconn.RemoteAddr().String(), "client_type", conn.ClientType().String()) + // context + c := newContext(conn, route, logger) + s.handleContext(c) } // Logger returns the logger of server. @@ -300,7 +334,7 @@ func (s *Server) Close() error { return nil } -func closeServer(downstreams map[string]FrameWriterConnection, connector *Connector, listener *quic.Listener, router router.Router) error { +func closeServer(downstreams map[string]*Client, connector *Connector, listener *quic.Listener, router router.Router) error { for _, ds := range downstreams { ds.Close() } @@ -350,6 +384,7 @@ func (s *Server) handleDataFrame(c *Context) error { sid := GetSIDFromMetadata(c.FrameMetadata) parentTraced := GetTracedFromMetadata(c.FrameMetadata) traced := false + streamed := GetStreamedFromMetadata(c.FrameMetadata) // trace tp := s.TracerProvider() if tp != nil { @@ -383,13 +418,14 @@ func (s *Server) handleDataFrame(c *Context) error { SetTIDToMetadata(c.FrameMetadata, tid) SetSIDToMetadata(c.FrameMetadata, sid) SetTracedToMetadata(c.FrameMetadata, traced || parentTraced) + SetStreamedToMetadata(c.FrameMetadata, streamed) md, err := c.FrameMetadata.Encode() if err != nil { s.logger.Error("encode metadata error", "err", err) return err } dataFrame.Metadata = md - s.logger.Debug("zipper metadata", "tid", tid, "sid", sid, "parentTraced", parentTraced, "traced", traced, "frome_stream_name", from.Name()) + s.logger.Debug("zipper metadata", "tid", tid, "sid", sid, "parentTraced", parentTraced, "traced", traced, "from_name", from.Name()) // route route := s.router.Route(c.FrameMetadata) if route == nil { @@ -406,7 +442,7 @@ func (s *Server) handleDataFrame(c *Context) error { c.Logger.Debug("connector snapshot", "tag", dataFrame.Tag, "sfn_conn_ids", connIDs, "connector", s.connector.Snapshot()) for _, toID := range connIDs { - stream, ok, err := s.connector.Get(toID) + to, ok, err := s.connector.Get(toID) if err != nil { continue } @@ -415,17 +451,204 @@ func (s *Server) handleDataFrame(c *Context) error { continue } - c.Logger.Info("data routing", "tid", tid, "sid", sid, "tag", dataFrame.Tag, "data_length", data_length, "to_id", toID, "to_name", stream.Name()) + c.Logger.Info("data routing", "tid", tid, "sid", sid, "tag", dataFrame.Tag, "data_length", data_length, "to_id", toID, "to_name", to.Name()) // write data frame to stream - if err := stream.WriteFrame(dataFrame); err != nil { + if err := to.WriteFrame(dataFrame); err != nil { c.Logger.Error("failed to write frame for routing data", "err", err) } } + // data stream + if streamed { + if err := s.handleDataStream(c, connIDs); err != nil { + return err + } + } + return nil } +func (s *Server) handleDataStream(c *Context, connIDs []string) error { + dataFrame := c.Frame.(*frame.DataFrame) + dataStreamID := string(dataFrame.Payload) + // create stream for source + sourceStream, err := s.openDataStream(c.Connection, dataFrame) + if err != nil { + return err + } + defer sourceStream.Close() + // forward writer + forwardBuf := bytes.NewBuffer(nil) + writer := io.MultiWriter(forwardBuf) + // forward source stream to sfn + dispatchToSFN := func(c *Context, sourceStream quic.Stream, connIDs []string, writer io.Writer) { + from := c.Connection + conns := len(connIDs) + if conns > 0 { + for _, toID := range connIDs { + to, ok, err := s.connector.Get(toID) + if err != nil { + continue + } + if !ok { + c.Logger.Error( + "can't find forward stream", + "err", "route sfn error", + "from_id", from.ID(), + "from", from.Name(), + "to_id", toID, + "to_name", to.Name(), + ) + continue + } + c.Logger.Info( + "data stream routing", + // "tag", dataFrame.Tag, + "from_id", from.ID(), + "from", from.Name(), + "to_id", toID, + "to_name", to.Name(), + ) + // create stream for sfn + sfnStream, err := s.openDataStream(to, dataFrame) + if err != nil { + // fallback(sourceStream) + c.Logger.Error( + "failed to create data stream for sfn", + "err", err, + "from_id", from.ID(), + "from", from.Name(), + "to_id", toID, + "to_name", to.Name(), + ) + continue + } + defer sfnStream.Close() + writer = io.MultiWriter(writer, sfnStream) + } + } else { + c.Logger.Warn("no connections available, ignored") + // fallback(sourceStream) + } + // write datastream to writers(sfn/downstream buffer) + buf := make([]byte, s.opts.streamChunkSize) + _, err = io.CopyBuffer(writer, sourceStream, buf) + if err != nil { + c.Logger.Error( + "failed to forward source stream to sfn", + "err", err, + "from_id", from.ID(), + "from_name", from.Name(), + // "to_id", toID, + // "to_name", to.Name(), + ) + fallback(sourceStream) + // continue + } + c.Logger.Info( + "forward source stream to sfn", + "from_id", from.ID(), + "from_name", from.Name(), + // "to_id", toID, + // "to_name", to.Name(), + ) + } + // }(c, sourceStream, connIDs, writer) + dispatchToSFN(c, sourceStream, connIDs, writer) + // INFO: dispatch to downstreams + if len(s.downstreams) > 0 { + for _, ds := range s.downstreams { + c.Logger.Info( + "dispatching datastream to downstream", + "downstream_name", ds.Name(), + "datastream_id", dataStreamID, + ) + // PERF: need to optimize + duplicatedBuf := bytes.NewBuffer(nil) + buf := make([]byte, s.opts.streamChunkSize) + _, err := io.CopyBuffer(duplicatedBuf, forwardBuf, buf) + if err != nil { + c.Logger.Error("failed to copy buf", + "err", err, + "dispatching datastream to downstream", + "downstream_name", ds.Name(), + "datastream_id", dataStreamID, + ) + continue + } + c.Logger.Info( + "downstream info", + "downstream_name", ds.Name(), + "datastream_id", dataStreamID, + "is_nil", ds == nil, + "ds.ctx", ds.ctx, + ) + go func(c *Context, dsClient *Client, dataStreamID string, stream io.Reader) { + if err := dsClient.PipeStream(dataStreamID, duplicatedBuf); err != nil { + if c != nil && c.Logger != nil { + c.Logger.Error( + "failed to dispatch datastream to downstream", + "err", err, + "downstream_name", dsClient.Name(), + "datastream_id", dataStreamID, + ) + } + } + }(c, ds, dataStreamID, duplicatedBuf) + } + } else { + fallback(forwardBuf) + } + + return nil +} + +// fallback is used to discard the data stream. +func fallback(reader io.Reader) { + io.Copy(io.Discard, reader) +} + +// openDataStream creates a quic stream for data stream. +func (s *Server) openDataStream(conn Connection, dataFrame *frame.DataFrame) (quic.Stream, error) { + // open data stream + dataStream, err := conn.QuicConnection().OpenStream() + if err != nil { + s.logger.Error("failed to create data stream", "err", err) + return nil, err + } + dataStreamID := string(dataFrame.Payload) + s.logger.Debug("creating data stream", "datastream_id", dataStreamID, "client_id", conn.ID()) + streamFrame := &frame.StreamFrame{ + ID: dataStreamID, + ClientID: conn.ID(), + StreamID: int64(dataStream.StreamID()), + ChunkSize: s.opts.streamChunkSize, + Tag: dataFrame.Tag, + } + // write stream frame to from + err = conn.WriteFrame(streamFrame) + if err != nil { + s.logger.Error("failed to write stream frame to main stream", "err", err) + return nil, err + } + // write stream frame to dataStream + fs := NewFrameStream(dataStream, y3codec.Codec(), y3codec.PacketReadWriter()) + err = fs.WriteFrame(streamFrame) + if err != nil { + s.logger.Error("failed to write stream frame to data stream", "err", err) + return nil, err + } + s.logger.Info( + "created data stream", + "datastream_id", streamFrame.ID, + "stream_id", streamFrame.StreamID, + "stream_chunk_size", streamFrame.ChunkSize, + "client_id", streamFrame.ClientID, + ) + return dataStream, nil +} + func (s *Server) handleBackflowFrame(c *Context) error { dataFrame := c.Frame.(*frame.DataFrame) @@ -497,7 +720,7 @@ func (s *Server) ConfigRouter(router router.Router) { // AddDownstreamServer add a downstream server to this server. all the DataFrames will be // dispatch to all the downstreams. -func (s *Server) AddDownstreamServer(addr string, c FrameWriterConnection) { +func (s *Server) AddDownstreamServer(addr string, c *Client) { s.mu.Lock() s.downstreams[addr] = c s.mu.Unlock() @@ -515,6 +738,7 @@ func (s *Server) dispatchToDownstreams(c *Context) { var ( tid = GetTIDFromMetadata(c.FrameMetadata) sid = GetSIDFromMetadata(c.FrameMetadata) + // streamd = GetStreamedFromMetadata(c.FrameMetadata) ) mdBytes, err := c.FrameMetadata.Encode() if err != nil { @@ -524,7 +748,7 @@ func (s *Server) dispatchToDownstreams(c *Context) { dataFrame.Metadata = mdBytes for _, ds := range s.downstreams { - c.Logger.Info("dispatching to downstream", "tid", tid, "sid", sid, "tag", dataFrame.Tag, "data_length", len(dataFrame.Payload), "downstream_id", ds.ClientID()) + c.Logger.Info("dispatching dataframe to downstream", "tid", tid, "sid", sid, "tag", dataFrame.Tag, "data_length", len(dataFrame.Payload), "downstream_id", ds.ClientID()) _ = ds.WriteFrame(dataFrame) } } diff --git a/core/server_options.go b/core/server_options.go index 63e570455..691ceff21 100644 --- a/core/server_options.go +++ b/core/server_options.go @@ -11,6 +11,11 @@ import ( "golang.org/x/exp/slog" ) +const ( + // DefaultStreamChunkSize is the default stream chunk size. + DefaultStreamChunkSize int64 = 32 * 1024 +) + // DefalutQuicConfig be used when `quicConfig` is nil. var DefalutQuicConfig = &quic.Config{ Versions: []quic.VersionNumber{quic.Version1, quic.Version2}, @@ -30,21 +35,23 @@ type ServerOption func(*serverOptions) // ServerOptions are the options for YoMo server. // TODO: quic alpn function. type serverOptions struct { - quicConfig *quic.Config - tlsConfig *tls.Config - auths map[string]auth.Authentication - logger *slog.Logger - tracerProvider oteltrace.TracerProvider + quicConfig *quic.Config + tlsConfig *tls.Config + auths map[string]auth.Authentication + logger *slog.Logger + tracerProvider oteltrace.TracerProvider + streamChunkSize int64 } func defaultServerOptions() *serverOptions { logger := ylog.Default() opts := &serverOptions{ - quicConfig: DefalutQuicConfig, - tlsConfig: nil, - auths: map[string]auth.Authentication{}, - logger: logger, + quicConfig: DefalutQuicConfig, + tlsConfig: nil, + auths: map[string]auth.Authentication{}, + logger: logger, + streamChunkSize: DefaultStreamChunkSize, } return opts } @@ -89,3 +96,14 @@ func WithServerTracerProvider(tp oteltrace.TracerProvider) ServerOption { o.tracerProvider = tp } } + +// WithServerStreamChunkSize sets stream chunk size for the server. +func WithServerStreamChunkSize(size int64) ServerOption { + return func(o *serverOptions) { + if size > 0 { + o.streamChunkSize = size + } else { + o.streamChunkSize = DefaultStreamChunkSize + } + } +} diff --git a/core/serverless/context.go b/core/serverless/context.go index c03ff6c5e..458ec1358 100644 --- a/core/serverless/context.go +++ b/core/serverless/context.go @@ -2,21 +2,49 @@ package serverless import ( + "context" + "errors" + "fmt" + "io" + + "github.com/quic-go/quic-go" + "github.com/yomorun/yomo/core" "github.com/yomorun/yomo/core/frame" + "github.com/yomorun/yomo/core/metadata" + "github.com/yomorun/yomo/pkg/frame-codec/y3codec" ) // Context sfn handler context type Context struct { - writer frame.Writer + client *core.Client dataFrame *frame.DataFrame + streamed bool + stream io.ReadCloser } // NewContext creates a new serverless Context -func NewContext(writer frame.Writer, dataFrame *frame.DataFrame) *Context { - return &Context{ - writer: writer, +func NewContext(client *core.Client, dataFrame *frame.DataFrame) *Context { + c := &Context{ + client: client, dataFrame: dataFrame, } + // streamed + m, err := metadata.Decode(c.dataFrame.Metadata) + if err != nil { + c.streamed = false + } else { + c.streamed = core.GetStreamedFromMetadata(m) + } + // stream + if c.streamed { + stream, err := c.readStream(context.Background()) + if err == nil { + c.stream = stream + } else { + c.client.Logger().Error("context read stream error", "err", err) + } + } + return c } // Tag returns the tag of the data frame @@ -41,5 +69,84 @@ func (c *Context) Write(tag uint32, data []byte) error { Payload: data, } - return c.writer.WriteFrame(dataFrame) + return c.client.WriteFrame(dataFrame) +} + +// Streamed returns whether the data is streamed. +func (c *Context) Streamed() bool { + return c.streamed +} + +// Stream returns the stream. +func (c *Context) Stream() io.Reader { + defer c.stream.Close() + return c.stream +} + +func (c *Context) readStream(ctx context.Context) (quic.Stream, error) { + client := c.client + dataFrame := c.dataFrame + dataStreamID := string(dataFrame.Payload) + client.Logger().Debug(fmt.Sprintf("context receive stream[%s] -- start", dataStreamID)) + // process data stream +STREAM: + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + qconn := client.Connection() + if qconn == nil { + err := errors.New("quic connection is nil") + client.Logger().Error(err.Error()) + return nil, err + } + dataStream, err := qconn.AcceptStream(ctx) + if err != nil { + client.Logger().Error("context request stream error", "err", err, "datastream_id", dataStreamID) + return nil, err + } + client.DataStreams().Store(dataStreamID, dataStream) + client.Logger().Debug("context accept stream success", "datastream_id", dataStreamID, "stream_id", dataStream.StreamID()) + // read stream frame + fs := core.NewFrameStream(dataStream, y3codec.Codec(), y3codec.PacketReadWriter()) + f, err := fs.ReadFrame() + if err != nil { + client.Logger().Warn("failed to read data stream", "err", err, "datastream_id", dataStreamID) + return nil, err + } + switch f.Type() { + case frame.TypeStreamFrame: + streamFrame := f.(*frame.StreamFrame) + // lookup data stream + // if streamFrame.ID != dataStreamID { + reader, ok := client.DataStreams().Load(dataStreamID) + if !ok { + client.Logger().Debug( + "data strem is not found, continue", + "stream_id", dataStream.StreamID(), + "datastream_id", dataStreamID, + "received_id", streamFrame.ID, + "client_id", streamFrame.ClientID, + "tag", streamFrame.Tag, + ) + goto STREAM + } + defer client.DataStreams().Delete(dataStreamID) + client.Logger().Debug( + "data stream is ready", + "remote_addr", qconn.RemoteAddr().String(), + "datastream_id", streamFrame.ID, + "stream_id", dataStream.StreamID(), + "stream_chunk_szie", streamFrame.ChunkSize, + "client_id", streamFrame.ClientID, + "tag", streamFrame.Tag, + ) + return reader.(quic.Stream), nil + default: + client.Logger().Error("!!!unexpected frame!!!", "unexpected_frame_type", f.Type().String()) + } + client.Logger().Debug(fmt.Sprintf("context receive stream[%s] -- end", dataStreamID)) + + return dataStream, nil + } } diff --git a/example/10-stream/Taskfile.yml b/example/10-stream/Taskfile.yml new file mode 100644 index 000000000..5b12a1303 --- /dev/null +++ b/example/10-stream/Taskfile.yml @@ -0,0 +1,59 @@ +# https://taskfile.dev + +version: "3" + +output: "prefixed" + +env: + YOMO_LOG_LEVEL: debug + +tasks: + run: + desc: run + deps: [zipper, source, sfn] + cmds: + - echo 'stream example run' + + # example cleanup + clean: + desc: clean + cmds: + - rm -rf ./bin + + build: + desc: build source, sfn and zipper + deps: [source-build, sfn-build] + cmds: + - echo 'building done' + internal: true + + source-build: + desc: build source + cmds: + - "go build -o ./bin/source{{exeExt}} source/main.go" + internal: true + + sfn-build: + desc: build sfn + cmds: + - "go build -o ./bin/sfn{{exeExt}} sfn/main.go" + internal: true + + source: + desc: run source + deps: [source-build] + cmds: + - "./bin/source{{exeExt}}" + env: + YOMO_LOG_LEVEL: debug + + sfn: + desc: run sfn + deps: [sfn-build] + cmds: + - "./bin/sfn{{exeExt}}" + + zipper: + desc: run zipper + cmds: + - "yomo serve -c config.yaml" diff --git a/example/10-stream/config.yaml b/example/10-stream/config.yaml new file mode 100644 index 000000000..d8517595b --- /dev/null +++ b/example/10-stream/config.yaml @@ -0,0 +1,13 @@ +name: Service +host: localhost +port: 9000 +# data stream config +stream: + chunksize: 524288 + +downstreams: + zipper-2: + host: 127.0.0.1 + port: 9002 + credential: "token:z2" + diff --git a/example/10-stream/sfn/main.go b/example/10-stream/sfn/main.go new file mode 100644 index 000000000..a89d4780b --- /dev/null +++ b/example/10-stream/sfn/main.go @@ -0,0 +1,75 @@ +package main + +import ( + "io" + "os" + + "github.com/yomorun/yomo" + "github.com/yomorun/yomo/serverless" + "golang.org/x/exp/slog" +) + +type noiseData struct { + Noise float32 `json:"noise"` // Noise value + Time int64 `json:"time"` // Timestamp (ms) + From string `json:"from"` // Source IP +} + +func main() { + addr := "localhost:9000" + if v := os.Getenv("YOMO_ADDR"); v != "" { + addr = v + } + sfn := yomo.NewStreamFunction( + "sfn-stream", + addr, + ) + sfn.SetObserveDataTags(0x33) + defer sfn.Close() + + // set handler + sfn.SetHandler(handler) + // start + err := sfn.Connect() + if err != nil { + slog.Error("[sfn] connect", err) + os.Exit(1) + } + // set the error handler function when server error occurs + sfn.SetErrorHandler(func(err error) { + slog.Error("[sfn] receive server error", "err", err) + }) + + select {} +} + +func handler(ctx serverless.Context) { + if ctx.Streamed() { + handleStream(ctx) + return + } + handleData(ctx) +} + +func handleStream(ctx serverless.Context) { + dataStream := ctx.Stream() + if dataStream != nil { + buf, err := io.ReadAll(dataStream) + if err != nil { + slog.Error("[sfn] failed to read all", "err", err) + return + } + bufString := string(buf) + l := len(buf) + if l > 1000 { + bufString = string(buf[l-1000:]) + } + slog.Info("[sfn] read all", "len", l, "buf", bufString) + } else { + slog.Info("[sfn] dataStream is nil") + } +} + +func handleData(ctx serverless.Context) { + slog.Info("[sfn] got", "data", ctx.Data()) +} diff --git a/example/10-stream/source/main.go b/example/10-stream/source/main.go new file mode 100644 index 000000000..28b00dc9e --- /dev/null +++ b/example/10-stream/source/main.go @@ -0,0 +1,99 @@ +package main + +import ( + "fmt" + "os" + "strconv" + "time" + + "github.com/yomorun/yomo" + "golang.org/x/exp/slog" +) + +func main() { + // connect to YoMo-Zipper. + addr := "localhost:9000" + if v := os.Getenv("YOMO_ADDR"); v != "" { + addr = v + } + source := yomo.NewSource( + "yomo-source", + addr, + ) + err := source.Connect() + if err != nil { + slog.Info("[source] ❌ Emit the data to YoMo-Zipper failure with err", "err", err) + return + } + + defer source.Close() + + // set the error handler function when server error occurs + source.SetErrorHandler(func(err error) { + slog.Error("[source] receive server error", "err", err) + }) + + streamed := true + if v := os.Getenv("YOMO_STREAMED"); v != "" { + s, err := strconv.ParseBool(v) + if err == nil { + streamed = s + } + } + slog.Info(fmt.Sprintf("[source] use stream: %v", streamed)) + + if streamed { + err = pipeStream(source) + } else { + err = write(source) + } + slog.Info("[source] err: ", "err", err) + if err != nil { + slog.Error("[source] >>>> ERR", "err", err) + // os.Exit(0) + } + select {} +} + +func pipeStream(source yomo.Source) error { + for i := 0; ; i++ { + // read data from file. + d := i % 2 + file := fmt.Sprintf("%d.dat", d) + slog.Info(fmt.Sprintf("[source] #%d. pipe stream to YoMo-Zipper", i)) + pipeFile(source, file) + go pipeFile(source, "0.dat") + go pipeFile(source, "1.dat") + time.Sleep(time.Second * 1) + } +} + +func pipeFile(source yomo.Source, file string) error { + reader, err := os.Open(file) + if err != nil { + slog.Error("[source] ❌ Read file failure with err", "err", err) + return err + } + // defer reader.Close() + // send data to YoMo-Zipper. + slog.Info("[source] pipe stream to YoMo-Zipper", "stream", file) + err = source.Pipe(0x33, reader) + if err != nil { + slog.Error("[source] ❌ Emit to YoMo-Zipper failure with err", "err", err) + return err + } + return nil +} + +func write(source yomo.Source) error { + for { + time.Sleep(1000 * time.Millisecond) + n := time.Now().UnixMilli() + data := strconv.FormatInt(n, 10) + slog.Info("[source] write data to YoMo-Zipper", "data", data) + if err := source.Write(0x33, []byte(data)); err != nil { + slog.Error("[source] ❌ Emit to YoMo-Zipper failure with err", "err", err) + return err + } + } +} diff --git a/options.go b/options.go index f419ef693..303bacaee 100644 --- a/options.go +++ b/options.go @@ -110,4 +110,11 @@ var ( o.serverOption = append(o.serverOption, core.WithServerTracerProvider(tp)) } } + + // WithZipperStreamChunkSize sets the chunk size for the zipper. + WithZipperStreamChunkSize = func(size int64) ZipperOption { + return func(o *zipperOptions) { + o.serverOption = append(o.serverOption, core.WithServerStreamChunkSize(size)) + } + } ) diff --git a/pkg/config/config.go b/pkg/config/config.go index 673d14125..5002c0773 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -29,6 +29,13 @@ type Config struct { Auth map[string]string `yaml:"auth"` // Downstreams holds cascading zippers config. the map-key is downstream name. Downstreams map[string]Downstream `yaml:"downstreams"` + + Stream *DataStream `yaml:"stream"` +} + +// DataStream represents the data stream config. +type DataStream struct { + ChunkSize int64 `yaml:"chunksize"` } // Downstream describes a cascading zipper config. diff --git a/pkg/frame-codec/y3codec/codec.go b/pkg/frame-codec/y3codec/codec.go index 46d48e5d2..dc81a0158 100644 --- a/pkg/frame-codec/y3codec/codec.go +++ b/pkg/frame-codec/y3codec/codec.go @@ -51,6 +51,10 @@ func (c *y3codec) Encode(f frame.Frame) ([]byte, error) { return encodeBackflowFrame(ff) case *frame.GoawayFrame: return encodeGoawayFrame(ff) + case *frame.StreamFrame: + return encodeStreamFrame(ff) + case *frame.RequestStreamFrame: + return encodeRequestStreamFrame(ff) default: return nil, ErrUnknownFrame } @@ -70,6 +74,10 @@ func (c *y3codec) Decode(data []byte, f frame.Frame) error { return decodeBackflowFrame(data, ff) case *frame.GoawayFrame: return decodeGoawayFrame(data, ff) + case *frame.StreamFrame: + return decodeStreamFrame(data, ff) + case *frame.RequestStreamFrame: + return decodeRequestStreamFrame(data, ff) default: return ErrUnknownFrame } diff --git a/pkg/frame-codec/y3codec/data_frame.go b/pkg/frame-codec/y3codec/data_frame.go index c813dbef0..f725729bc 100644 --- a/pkg/frame-codec/y3codec/data_frame.go +++ b/pkg/frame-codec/y3codec/data_frame.go @@ -12,7 +12,7 @@ func encodeDataFrame(f *frame.DataFrame) ([]byte, error) { tagBlock.SetUInt32Value(f.Tag) // metadata - metadataBlock := y3.NewPrimitivePacketEncoder(tagDataFramesMetadata) + metadataBlock := y3.NewPrimitivePacketEncoder(tagDataFrameMetadata) metadataBlock.SetBytesValue(f.Metadata) // payload @@ -46,7 +46,7 @@ func decodeDataFrame(data []byte, f *frame.DataFrame) error { } // metadata - if metadataBlock, ok := packet.PrimitivePackets[byte(tagDataFramesMetadata)]; ok { + if metadataBlock, ok := packet.PrimitivePackets[byte(tagDataFrameMetadata)]; ok { metadata := metadataBlock.ToBytes() f.Metadata = metadata } @@ -61,7 +61,7 @@ func decodeDataFrame(data []byte, f *frame.DataFrame) error { } var ( - tagDataFrameTag byte = 0x01 - tagDataFramePayload byte = 0x02 - tagDataFramesMetadata byte = 0x03 + tagDataFrameTag byte = 0x01 + tagDataFramePayload byte = 0x02 + tagDataFrameMetadata byte = 0x03 ) diff --git a/pkg/frame-codec/y3codec/request_stream_frame.go b/pkg/frame-codec/y3codec/request_stream_frame.go new file mode 100644 index 000000000..f71f0aa20 --- /dev/null +++ b/pkg/frame-codec/y3codec/request_stream_frame.go @@ -0,0 +1,54 @@ +package y3codec + +import ( + "github.com/yomorun/y3" + frame "github.com/yomorun/yomo/core/frame" +) + +// encodeRequestStreamFrame RequestStreamframe to Y3 encoded bytes +func encodeRequestStreamFrame(f *frame.RequestStreamFrame) ([]byte, error) { + // client id + clientID := y3.NewPrimitivePacketEncoder(tagRequestStreamFrameClientID) + clientID.SetStringValue(f.ClientID) + // tag + tag := y3.NewPrimitivePacketEncoder(tagRequestStreamFrameTag) + tag.SetUInt32Value(f.Tag) + // encode + node := y3.NewNodePacketEncoder(byte(f.Type())) + node.AddPrimitivePacket(clientID) + node.AddPrimitivePacket(tag) + + return node.Encode(), nil +} + +// decodeRequestStreamFrame decodes Y3 encoded bytes to RequestStreamFrame. +func decodeRequestStreamFrame(data []byte, f *frame.RequestStreamFrame) error { + nodeBlock := y3.NodePacket{} + _, err := y3.DecodeToNodePacket(data, &nodeBlock) + if err != nil { + return err + } + // client id + if p, ok := nodeBlock.PrimitivePackets[tagRequestStreamFrameClientID]; ok { + clientID, err := p.ToUTF8String() + if err != nil { + return err + } + f.ClientID = clientID + } + // tag + if p, ok := nodeBlock.PrimitivePackets[byte(tagRequestStreamFrameTag)]; ok { + tag, err := p.ToUInt32() + if err != nil { + return err + } + f.Tag = tag + } + + return nil +} + +var ( + tagRequestStreamFrameClientID byte = 0x02 + tagRequestStreamFrameTag byte = 0x05 +) diff --git a/pkg/frame-codec/y3codec/stream_frame.go b/pkg/frame-codec/y3codec/stream_frame.go new file mode 100644 index 000000000..79f8381e8 --- /dev/null +++ b/pkg/frame-codec/y3codec/stream_frame.go @@ -0,0 +1,93 @@ +package y3codec + +import ( + "github.com/yomorun/y3" + frame "github.com/yomorun/yomo/core/frame" +) + +// encodeStreamFrame Streamframe to Y3 encoded bytes +func encodeStreamFrame(f *frame.StreamFrame) ([]byte, error) { + // id + id := y3.NewPrimitivePacketEncoder(tagStreamFrameID) + id.SetStringValue(f.ID) + // client id + clientID := y3.NewPrimitivePacketEncoder(tagStreamFrameClientID) + clientID.SetStringValue(f.ClientID) + // stream id + streamID := y3.NewPrimitivePacketEncoder(tagStreamFrameStreamID) + streamID.SetInt64Value(f.StreamID) + // chunk size + chunkSize := y3.NewPrimitivePacketEncoder(tagStreamFrameChunkSize) + chunkSize.SetInt64Value(f.ChunkSize) + // tag + tag := y3.NewPrimitivePacketEncoder(tagStreamFrameTag) + tag.SetUInt32Value(f.Tag) + // encode + node := y3.NewNodePacketEncoder(byte(f.Type())) + node.AddPrimitivePacket(id) + node.AddPrimitivePacket(clientID) + node.AddPrimitivePacket(streamID) + node.AddPrimitivePacket(chunkSize) + node.AddPrimitivePacket(tag) + + return node.Encode(), nil +} + +// decodeStreamFrame decodes Y3 encoded bytes to StreamFrame. +func decodeStreamFrame(data []byte, f *frame.StreamFrame) error { + nodeBlock := y3.NodePacket{} + _, err := y3.DecodeToNodePacket(data, &nodeBlock) + if err != nil { + return err + } + // id + if p, ok := nodeBlock.PrimitivePackets[tagStreamFrameID]; ok { + id, err := p.ToUTF8String() + if err != nil { + return err + } + f.ID = id + } + // client id + if p, ok := nodeBlock.PrimitivePackets[tagStreamFrameClientID]; ok { + clientID, err := p.ToUTF8String() + if err != nil { + return err + } + f.ClientID = clientID + } + // stream id + if p, ok := nodeBlock.PrimitivePackets[tagStreamFrameStreamID]; ok { + steamID, err := p.ToInt64() + if err != nil { + return err + } + f.StreamID = steamID + } + // chunk size + if p, ok := nodeBlock.PrimitivePackets[tagStreamFrameChunkSize]; ok { + chunkSize, err := p.ToInt64() + if err != nil { + return err + } + f.ChunkSize = chunkSize + } + // tag + if p, ok := nodeBlock.PrimitivePackets[byte(tagStreamFrameTag)]; ok { + tag, err := p.ToUInt32() + if err != nil { + return err + } + f.Tag = tag + } + + return nil +} + +var ( + tagStreamFrameID byte = 0x01 + tagStreamFrameClientID byte = 0x02 + tagStreamFrameStreamID byte = 0x03 + tagStreamFrameChunkSize byte = 0x04 + tagStreamFrameTag byte = 0x05 +) diff --git a/serverless/context.go b/serverless/context.go index 9fabe48b9..fc53544e5 100644 --- a/serverless/context.go +++ b/serverless/context.go @@ -1,6 +1,8 @@ // Package serverless defines serverless handler context package serverless +import "io" + // Context sfn handler context type Context interface { // Data incoming data @@ -11,6 +13,10 @@ type Context interface { Write(tag uint32, data []byte) error // HTTP http interface HTTP() HTTP + // Streamed returns whether the data is streamed + Streamed() bool + // Stream returns the stream + Stream() io.Reader } // HTTP http interface diff --git a/sfn.go b/sfn.go index 26b35e1b8..2a54ba363 100644 --- a/sfn.go +++ b/sfn.go @@ -178,7 +178,7 @@ func (s *streamFunction) onDataFrame(dataFrame *frame.DataFrame) { return } - newMd, deferFunc := ExtendTraceMetadata(md, s.client.ClientID(), s.client.Name(), s.client.TracerProvider(), s.client.Logger()) + newMd, deferFunc := ExtendTraceMetadata(md, s.client.ClientID(), s.client.Name(), tp, s.client.Logger()) defer deferFunc() newMetadata, err := newMd.Encode() @@ -187,7 +187,6 @@ func (s *streamFunction) onDataFrame(dataFrame *frame.DataFrame) { return } dataFrame.Metadata = newMetadata - serverlessCtx := serverless.NewContext(s.client, dataFrame) s.fn(serverlessCtx) }(tp, dataFrame) @@ -215,6 +214,7 @@ func ExtendTraceMetadata(md metadata.M, clientID, name string, tp oteltrace.Trac deferFunc := func() {} tid := core.GetTIDFromMetadata(md) sid := core.GetSIDFromMetadata(md) + streamed := core.GetStreamedFromMetadata(md) parentTraced := core.GetTracedFromMetadata(md) traced := false // trace @@ -249,8 +249,9 @@ func ExtendTraceMetadata(md metadata.M, clientID, name string, tp oteltrace.Trac core.SetTIDToMetadata(md, tid) core.SetSIDToMetadata(md, sid) core.SetTracedToMetadata(md, traced) + core.SetStreamedToMetadata(md, streamed) - logger.Debug("sfn metadata", "tid", tid, "sid", sid, "parentTraced", parentTraced, "traced", traced) + logger.Debug("sfn metadata", "tid", tid, "sid", sid, "parentTraced", parentTraced, "traced", traced, "streamed", streamed) return md, deferFunc } diff --git a/source.go b/source.go index 72e8d3b19..61a838b43 100644 --- a/source.go +++ b/source.go @@ -2,6 +2,9 @@ package yomo import ( "context" + "errors" + "io" + "time" "github.com/yomorun/yomo/core" "github.com/yomorun/yomo/core/frame" @@ -24,6 +27,8 @@ type Source interface { SetErrorHandler(fn func(err error)) // [Experimental] SetReceiveHandler set the observe handler function SetReceiveHandler(fn func(tag uint32, data []byte)) + // Pipe pipe the stream data to zipper. + Pipe(tag uint32, reader io.Reader) error } // YoMo-Source @@ -77,7 +82,7 @@ func (s *yomoSource) Connect() error { // Write writes data with specified tag. func (s *yomoSource) Write(tag uint32, data []byte) error { - md, deferFunc := TraceMetadata(s.client.ClientID(), s.name, s.client.TracerProvider(), s.client.Logger()) + md, deferFunc := TraceMetadata(s.client.ClientID(), s.name, false, s.client.TracerProvider(), s.client.Logger()) defer deferFunc() mdBytes, err := md.Encode() @@ -105,8 +110,49 @@ func (s *yomoSource) SetReceiveHandler(fn func(uint32, []byte)) { s.client.Logger().Info("receive hander set for the source") } +// Pipe pipe the stream data to zipper. +func (s *yomoSource) Pipe(tag uint32, reader io.Reader) error { + // NOTE: this is a simple implementation, we will improve it later. +PIPE: + md, deferFunc := TraceMetadata(s.client.ClientID(), s.name, true, s.client.TracerProvider(), s.client.Logger()) + defer deferFunc() + // metadata + mdBytes, err := md.Encode() + // metadata + if err != nil { + return err + } + // write dataframe with data stream id + dataStreamID := id.New() + f := &frame.DataFrame{ + Tag: tag, + Metadata: mdBytes, + Payload: []byte(dataStreamID), + } + // write dataframe to main stream + err = s.client.WriteFrame(f) + if err != nil { + s.client.Logger().Error("source write frame error", "err", err, "datastream_id", dataStreamID) + return err + } + s.client.Logger().Debug("source write stream frame", "tag", tag, "datastream_id", dataStreamID) + s.client.Logger().Debug("source pipe stream...", "tag", tag, "datastream_id", dataStreamID) + err = s.client.PipeStream(dataStreamID, reader) + if err != nil { + // process reconnect + if errors.As(err, new(core.ErrAuthenticateFailed)) { + return err + } + s.client.Logger().Error("source pipe stream error", "err", err, "datastream_id", dataStreamID) + time.Sleep(core.DefaultReconnectInterval) + goto PIPE + } + + return nil +} + // TraceMetadata generates source trace metadata. -func TraceMetadata(clientID, name string, tp oteltrace.TracerProvider, logger *slog.Logger) (metadata.M, func()) { +func TraceMetadata(clientID, name string, streamed bool, tp oteltrace.TracerProvider, logger *slog.Logger) (metadata.M, func()) { deferFunc := func() {} var tid, sid string // trace @@ -130,9 +176,9 @@ func TraceMetadata(clientID, name string, tp oteltrace.TracerProvider, logger *s logger.Debug("source create new sid") sid = id.SID() } - logger.Debug("source metadata", "tid", tid, "sid", sid, "traced", traced) + logger.Debug("source metadata", "tid", tid, "sid", sid, "traced", traced, "streamed", streamed) - md := core.NewDefaultMetadata(clientID, tid, sid, traced) + md := core.NewDefaultMetadata(clientID, tid, sid, traced, streamed) return md, deferFunc }