Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
prasunanand committed Jan 31, 2025
1 parent 95263d9 commit 606d488
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 70 deletions.
71 changes: 61 additions & 10 deletions websocket/channels.go → kernel/channels.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package websocket
package kernel

import (
"context"
Expand All @@ -7,7 +7,6 @@ import (

"github.com/gorilla/websocket"
"github.com/rs/zerolog/log"
"github.com/zasper-io/zasper/kernel"

"github.com/go-zeromq/zmq4"
)
Expand All @@ -19,17 +18,17 @@ type KernelWebSocketConnection struct {
Conn *websocket.Conn
Send chan []byte
KernelId string
KernelManager kernel.KernelManager
KernelManager KernelManager
LimitRate bool
IOpubMsgRateLimit int
IOpubDataRateLimit int
RateLimitWindow int
Context context.Context
PollingCancel context.CancelFunc
OpenSessions map[string]kernel.KernelSession
OpenSessions map[string]KernelSession
OpenSockets []string
Channels map[string]zmq4.Socket
Session kernel.KernelSession
Session KernelSession
IOPubWindowMsgCount int
IOPubWindowByteCount int
IOPubMsgsExceeded int
Expand Down Expand Up @@ -92,7 +91,7 @@ func (kwsConn *KernelWebSocketConnection) startPolling() { //msg interface{}, bi
kwsConn.pollChannel(iopub_channel, "iopub")
}

func (kwsConn *KernelWebSocketConnection) prepare(sessionId string) {
func (kwsConn *KernelWebSocketConnection) Prepare(sessionId string) {
km := kwsConn.KernelManager
if km.Ready {
log.Info().Msgf("%s", km.Session.Key)
Expand All @@ -102,9 +101,9 @@ func (kwsConn *KernelWebSocketConnection) prepare(sessionId string) {
kwsConn.Session = km.Session
}

func (kwsConn *KernelWebSocketConnection) connect() {
func (kwsConn *KernelWebSocketConnection) Connect() {
log.Info().Msg("notifying connection")
kernel.NotifyConnect()
NotifyConnect()

log.Info().Msg("creating stream")
kwsConn.createStream()
Expand Down Expand Up @@ -171,9 +170,9 @@ func (kwsConn *KernelWebSocketConnection) handleIncomingMessage(messageType int,
return
}

var msg kernel.Message
var msg Message
if kwsConn.Subprotocol == "v1.kernel.websocket.jupyter.org" {
msg = kernel.Message{}
msg = Message{}
} else {
if err := json.Unmarshal([]byte(wsMsg), &msg); err != nil {
log.Info().Msgf("Error unmarshalling message: %s", err)
Expand All @@ -184,3 +183,55 @@ func (kwsConn *KernelWebSocketConnection) handleIncomingMessage(messageType int,
kwsConn.Session.SendStreamMsg(kwsConn.Channels["shell"], msg)
}
}

func (kwsConn *KernelWebSocketConnection) ReadMessagesFromClient(waiter *sync.WaitGroup) {
defer func() {
log.Info().Msg("Closing readMessagesFromClient")
kwsConn.Conn.Close()
waiter.Done()
}()

for {
select {
case <-kwsConn.Context.Done(): // Check if context is canceled
log.Debug().Msgf("Socket closed, Incoming message handler stopped")
return
default:
messageType, data, err := kwsConn.Conn.ReadMessage()
if err != nil {
log.Debug().Msgf("%s", err)
return
}
log.Debug().Msgf("message type => %d", messageType)
kwsConn.handleIncomingMessage(messageType, data)
}

}
}

func (kwsConn *KernelWebSocketConnection) WriteMessages(waiter *sync.WaitGroup) {
defer func() {
kwsConn.Conn.Close()
waiter.Done()
}()
for {
select {
case <-kwsConn.Context.Done(): // Check if context is canceled
log.Debug().Msgf("Socket closed, Incoming message handler stopped")
return
default:
message, ok := <-kwsConn.Send
if !ok {
log.Info().Msg("Send channel closed, closing WebSocket connection")
kwsConn.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
kwsConn.mu.Lock()
if err := kwsConn.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
log.Info().Msgf("Error writing message: %s", err)
return
}
kwsConn.mu.Unlock()
}
}
}
68 changes: 8 additions & 60 deletions websocket/kernel_websocket_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ type APIResponse struct {

var clientsMu sync.Mutex // To handle concurrent access to clients

var ZasperActiveKernelConnections map[string]*KernelWebSocketConnection
var ZasperActiveKernelConnections map[string]*kernel.KernelWebSocketConnection

func SetUpStateKernels() map[string]*KernelWebSocketConnection {
return make(map[string]*KernelWebSocketConnection)
func SetUpStateKernels() map[string]*kernel.KernelWebSocketConnection {
return make(map[string]*kernel.KernelWebSocketConnection)
}

// DELETE handler for /api/kernels/{kernel_id}
Expand Down Expand Up @@ -98,7 +98,7 @@ func HandleWebSocket(w http.ResponseWriter, req *http.Request) {
// Create a new context for the polling operation
ctx, cancel := context.WithCancel(context.Background())

kernelConnection := KernelWebSocketConnection{
kernelConnection := kernel.KernelWebSocketConnection{
KernelId: kernelId,
KernelManager: kernelManager,
Channels: make(map[string]zmq4.Socket),
Expand All @@ -109,10 +109,10 @@ func HandleWebSocket(w http.ResponseWriter, req *http.Request) {
}

log.Info().Msg("preparing kernel connection")
kernelConnection.prepare(sessionId)
kernelConnection.Prepare(sessionId)

log.Info().Msg("connecting kernel")
kernelConnection.connect()
kernelConnection.Connect()

clientsMu.Lock()
ZasperActiveKernelConnections[kernelId] = &kernelConnection
Expand All @@ -121,58 +121,6 @@ func HandleWebSocket(w http.ResponseWriter, req *http.Request) {
var waiter sync.WaitGroup
waiter.Add(2)

go kernelConnection.readMessagesFromClient(&waiter)
go kernelConnection.writeMessages(&waiter)
}

func (kwsConn *KernelWebSocketConnection) readMessagesFromClient(waiter *sync.WaitGroup) {
defer func() {
log.Info().Msg("Closing readMessagesFromClient")
kwsConn.Conn.Close()
waiter.Done()
}()

for {
select {
case <-kwsConn.Context.Done(): // Check if context is canceled
log.Debug().Msgf("Socket closed, Incoming message handler stopped")
return
default:
messageType, data, err := kwsConn.Conn.ReadMessage()
if err != nil {
log.Debug().Msgf("%s", err)
return
}
log.Debug().Msgf("message type => %d", messageType)
kwsConn.handleIncomingMessage(messageType, data)
}

}
}

func (kwsConn *KernelWebSocketConnection) writeMessages(waiter *sync.WaitGroup) {
defer func() {
kwsConn.Conn.Close()
waiter.Done()
}()
for {
select {
case <-kwsConn.Context.Done(): // Check if context is canceled
log.Debug().Msgf("Socket closed, Incoming message handler stopped")
return
default:
message, ok := <-kwsConn.Send
if !ok {
log.Info().Msg("Send channel closed, closing WebSocket connection")
kwsConn.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
kwsConn.mu.Lock()
if err := kwsConn.Conn.WriteMessage(websocket.TextMessage, message); err != nil {
log.Info().Msgf("Error writing message: %s", err)
return
}
kwsConn.mu.Unlock()
}
}
go kernelConnection.ReadMessagesFromClient(&waiter)
go kernelConnection.WriteMessages(&waiter)
}

0 comments on commit 606d488

Please sign in to comment.