From be34e02f170eeae2617ee13cbb4c9a770f0b12c6 Mon Sep 17 00:00:00 2001 From: "C.C" Date: Mon, 26 Feb 2024 23:59:44 +0800 Subject: [PATCH] fix: conn.ID() from string to uint64 --- pkg/bridge/ai/ai.go | 4 ++-- pkg/bridge/ai/provider.go | 4 ++-- pkg/bridge/ai/provider/azopenai/provider.go | 6 +++--- pkg/bridge/ai/service.go | 4 ++-- pkg/bridge/ai/test/ai_test.go | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pkg/bridge/ai/ai.go b/pkg/bridge/ai/ai.go index fbb6e0ec2..a718507ca 100644 --- a/pkg/bridge/ai/ai.go +++ b/pkg/bridge/ai/ai.go @@ -28,7 +28,7 @@ var ( // ======================= Package Functions ======================= // RegisterFunction registers the tool function -func RegisterFunction(tag uint32, functionDefinition []byte, connID string) error { +func RegisterFunction(tag uint32, functionDefinition []byte, connID uint64) error { provider, err := GetDefaultProvider() if err != nil { return err @@ -43,7 +43,7 @@ func RegisterFunction(tag uint32, functionDefinition []byte, connID string) erro } // UnregisterFunction unregister the tool function -func UnregisterFunction(name string, connID string) error { +func UnregisterFunction(name string, connID uint64) error { provider, err := GetDefaultProvider() if err != nil { return err diff --git a/pkg/bridge/ai/provider.go b/pkg/bridge/ai/provider.go index b2ae3253e..d63077750 100644 --- a/pkg/bridge/ai/provider.go +++ b/pkg/bridge/ai/provider.go @@ -15,9 +15,9 @@ type LLMProvider interface { // GetChatCompletions returns the chat completions GetChatCompletions(prompt string) (*ai.InvokeResponse, error) // RegisterFunction registers the llm function - RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID string) error + RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error // UnregisterFunction unregister the llm function - UnregisterFunction(name, connID string) error + UnregisterFunction(name string, connID uint64) error // ListToolCalls lists the llm tool calls ListToolCalls() (map[uint32]ai.ToolCall, error) } diff --git a/pkg/bridge/ai/provider/azopenai/provider.go b/pkg/bridge/ai/provider/azopenai/provider.go index da31bd714..c14005d5f 100644 --- a/pkg/bridge/ai/provider/azopenai/provider.go +++ b/pkg/bridge/ai/provider/azopenai/provider.go @@ -73,7 +73,7 @@ type AzureOpenAIProvider struct { } type connectedFn struct { - connID string + connID uint64 tag uint32 tc ai.ToolCall } @@ -221,7 +221,7 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In } // RegisterFunction register function -func (p *AzureOpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID string) error { +func (p *AzureOpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error { fns.Store(connID, &connectedFn{ connID: connID, tag: tag, @@ -236,7 +236,7 @@ func (p *AzureOpenAIProvider) RegisterFunction(tag uint32, functionDefinition *a // UnregisterFunction unregister function // Be careful: a function can have multiple instances, remove the offline instance only. -func (p *AzureOpenAIProvider) UnregisterFunction(_ string, connID string) error { +func (p *AzureOpenAIProvider) UnregisterFunction(_ string, connID uint64) error { fns.Delete(connID) return nil } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index 600e23012..40a038293 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -142,8 +142,8 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) { fmt.Fprintf(v.ResponseWriter, "event: result\n") fmt.Fprintf(v.ResponseWriter, "data: %s\n\n", invoke.JSONString()) - fmt.Fprintf(v.ResponseWriter, "event: retrieval_result\n") - fmt.Fprintf(v.ResponseWriter, "data: %s\n\n", invoke.RetrievalResult) + // fmt.Fprintf(v.ResponseWriter, "event: retrieval_result\n") + // fmt.Fprintf(v.ResponseWriter, "data: %s\n\n", invoke.RetrievalResult) // // one json per line, like groq.com did // fmt.Fprintf(v.ResponseWriter, invoke.JSONString()+"\n") diff --git a/pkg/bridge/ai/test/ai_test.go b/pkg/bridge/ai/test/ai_test.go index caaa470e2..9d3eb8c43 100644 --- a/pkg/bridge/ai/test/ai_test.go +++ b/pkg/bridge/ai/test/ai_test.go @@ -23,10 +23,10 @@ func TestAIServer(t *testing.T) { func TestAIToolCalls(t *testing.T) { go startAIServer() functionDefinition := `{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}},"required":["location"]}}` - err := ai.RegisterFunction(1, []byte(functionDefinition), "conn-id-1") + err := ai.RegisterFunction(1, []byte(functionDefinition), 123) assert.NoError(t, err) functionDefinition2 := `{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}},"required":["location"]}}` - err = ai.RegisterFunction(1, []byte(functionDefinition2), "conn-id-1") + err = ai.RegisterFunction(1, []byte(functionDefinition2), 123) assert.NoError(t, err) tools, err := ai.ListToolCalls() assert.NoError(t, err)