Skip to content

Commit

Permalink
test: improve comments and unit test (#737)
Browse files Browse the repository at this point in the history
related to #738
  • Loading branch information
fanweixiao authored Mar 2, 2024
1 parent cad974b commit 7e9636a
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 26 deletions.
7 changes: 2 additions & 5 deletions pkg/bridge/ai/provider/azopenai/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"sync"

// automatically load .env file
_ "github.com/joho/godotenv/autoload"

"github.com/yomorun/yomo/ai"
Expand Down Expand Up @@ -78,10 +79,6 @@ type connectedFn struct {
tc ai.ToolCall
}

func init() {
fns = sync.Map{}
}

// NewProvider creates a new AzureOpenAIProvider
func NewProvider(apiKey string, apiEndpoint string, deploymentID string, apiVersion string) *AzureOpenAIProvider {
if apiKey == "" {
Expand Down Expand Up @@ -201,7 +198,7 @@ func (p *AzureOpenAIProvider) GetChatCompletions(userInstruction string) (*ai.In
// functions may be more than one
// slog.Info("tool calls", "calls", calls, "mapTools", mapTools)
for _, call := range calls {
fns.Range(func(key, value interface{}) bool {
fns.Range(func(_, value interface{}) bool {
fn := value.(*connectedFn)
if fn.tc.Equal(&call) {
// Use toolCalls because tool_id is required in the following llm request
Expand Down
104 changes: 104 additions & 0 deletions pkg/bridge/ai/provider/azopenai/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package azopenai

import (
"os"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/ai"
)

func TestNewProvider(t *testing.T) {
// Set environment variables for testing
os.Setenv("AZURE_OPENAI_API_KEY", "test_api_key")
os.Setenv("AZURE_OPENAI_API_ENDPOINT", "test_api_endpoint")
os.Setenv("AZURE_OPENAI_DEPLOYMENT_ID", "test_deployment_id")
os.Setenv("AZURE_OPENAI_API_VERSION", "test_api_version")

provider := NewProvider("", "", "", "")

assert.Equal(t, "test_api_key", provider.APIKey)
assert.Equal(t, "test_api_endpoint", provider.APIEndpoint)
assert.Equal(t, "test_deployment_id", provider.DeploymentID)
assert.Equal(t, "test_api_version", provider.APIVersion)
}

func TestAzureOpenAIProvider_Name(t *testing.T) {
provider := &AzureOpenAIProvider{}

name := provider.Name()

assert.Equal(t, "azopenai", name)
}

func TestAzureOpenAIProvider_RegisterFunction(t *testing.T) {
fns = sync.Map{}
provider := &AzureOpenAIProvider{}
tag := uint32(66)
functionDefinition := &ai.FunctionDefinition{
Name: "TestFunction",
}
connID := uint64(88)

err := provider.RegisterFunction(tag, functionDefinition, connID)
assert.NoError(t, err)

fn, ok := fns.Load(connID)
assert.True(t, ok)
assert.Equal(t, connID, fn.(*connectedFn).connID)
assert.Equal(t, tag, fn.(*connectedFn).tag)
assert.Equal(t, "function", fn.(*connectedFn).tc.Type)
assert.Equal(t, functionDefinition.Name, fn.(*connectedFn).tc.Function.Name)

}

func TestAzureOpenAIProvider_UnregisterFunction(t *testing.T) {
provider := &AzureOpenAIProvider{}
err := provider.UnregisterFunction("", 1)
assert.NoError(t, err)
_, ok := fns.Load(1)
assert.False(t, ok)
}

func TestAzureOpenAIProvider_ListToolCalls(t *testing.T) {
fns = sync.Map{}
provider := &AzureOpenAIProvider{}

// Add a connectedFn to fns for testing
fns.Store(1, &connectedFn{
tag: 0x16,
tc: ai.ToolCall{
Type: "function",
Function: &ai.FunctionDefinition{
Name: "TestFunction",
},
},
})

toolCalls, err := provider.ListToolCalls()

assert.NoError(t, err)
assert.NotNil(t, toolCalls[0x16])
assert.Equal(t, toolCalls[0x16].Function.Name, "TestFunction")
}

func TestAzureOpenAIProvider_GetOverview(t *testing.T) {
fns = sync.Map{}
provider := &AzureOpenAIProvider{}

// Add a connectedFn to fns for testing
fns.Store(1, &connectedFn{
tag: 0x16,
tc: ai.ToolCall{Function: &ai.FunctionDefinition{
Name: "TestFunction",
}},
})

overview, err := provider.GetOverview()

assert.NoError(t, err)
assert.NotNil(t, overview)
assert.NotNil(t, overview.Functions[0x16])
assert.Equal(t, overview.Functions[0x16].Name, "TestFunction")
}
56 changes: 35 additions & 21 deletions pkg/bridge/ai/provider/openai/provider.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package openai is the OpenAI llm provider
package openai

import (
Expand All @@ -9,45 +10,58 @@ import (
"os"
"sync"

// automatically load .env file
_ "github.com/joho/godotenv/autoload"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/ylog"
)

// APIEndpoint is the endpoint for OpenAI
const APIEndpoint = "https://api.openai.com/v1/chat/completions"

var fns sync.Map

// Message
// ChatCompletionMessage describes `messages` for /chat/completions
type ChatCompletionMessage struct {
Role string `json:"role"`
// Role is the messages author
Role string `json:"role"`
// Content of the message
Content string `json:"content"`
// - https://github.com/openai/openai-python/blob/main/chatml.md
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// Name describes participant, provides the model information to differentiate
// between participants of the same role.
Name string `json:"name,omitempty"`
// MultiContent []ChatMessagePart
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
// ToolCalls describes the tool calls generated by the model.
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"`
// ToolCallID is the ID of the tool call
ToolCallID string `json:"tool_call_id,omitempty"`
}

// RequestBody is the request body
// ReqBody is the request body
type ReqBody struct {
Model string `json:"model"`
// Model describes the ID of the model to use for the completion.
Model string `json:"model"`
// Messages describes the messages in the conversation.
Messages []ChatCompletionMessage `json:"messages"`
Tools []ai.ToolCall `json:"tools"` // chatCompletionTool
// ToolChoice string `json:"tool_choice"` // chatCompletionFunction
// Tools describes the tool calls generated by the model.
Tools []ai.ToolCall `json:"tools"` // chatCompletionTool
}

// Resp is the response body
// RespBody is the response body
type RespBody struct {
ID string `json:"id"`
Object string `json:"object"`
Created int `json:"created"`
Model string `json:"model"`
Choices []RespChoice `json:"choices"`
Usage RespUsage `json:"usage"`
SystemFingerprint string `json:"system_fingerprint"`
// ID is the unique identifier for the chat completion.
ID string `json:"id"`
// Object describes the object type, it is always "chat.completion".
Object string `json:"object"`
// Created describes the timestamp when the chat completion was created.
Created int `json:"created"`
// Model describes the model used for the chat completion.
Model string `json:"model"`
// Choices describes the choices made by the model, can more than one if `n`>1
Choices []RespChoice `json:"choices"`
// Usage describes the token usage statistics for the chat completion request.
Usage RespUsage `json:"usage"`
// SystemFingerprint describes the system fingerprint of the chat completion.
SystemFingerprint string `json:"system_fingerprint"`
}

// RespMessage is the message in Response
Expand Down Expand Up @@ -222,7 +236,7 @@ func (p *OpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.Fun

// UnregisterFunction unregister function
// Be careful: a function can have multiple instances, remove the offline instance only.
func (p *OpenAIProvider) UnregisterFunction(name string, connID uint64) error {
func (p *OpenAIProvider) UnregisterFunction(_ string, connID uint64) error {
fns.Delete(connID)
return nil
}
Expand Down
92 changes: 92 additions & 0 deletions pkg/bridge/ai/provider/openai/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package openai

import (
"os"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/ai"
)

func TestOpenAIProvider_RegisterFunction(t *testing.T) {
fns = sync.Map{}
provider := &OpenAIProvider{}
tag := uint32(66)
functionDefinition := &ai.FunctionDefinition{
Name: "TestFunction",
}
connID := uint64(88)

err := provider.RegisterFunction(tag, functionDefinition, connID)
assert.NoError(t, err)

fn, ok := fns.Load(connID)
assert.True(t, ok)
assert.Equal(t, connID, fn.(*connectedFn).connID)
assert.Equal(t, tag, fn.(*connectedFn).tag)
assert.Equal(t, "function", fn.(*connectedFn).tc.Type)
assert.Equal(t, functionDefinition.Name, fn.(*connectedFn).tc.Function.Name)
}

func TestOpenAIProvider_UnregisterFunction(t *testing.T) {
provider := &OpenAIProvider{}
connID := uint64(1)

// Assuming a function is already registered with connID
err := provider.UnregisterFunction("", connID)
assert.NoError(t, err)

_, ok := fns.Load(connID)
assert.False(t, ok)
}

func TestOpenAIProvider_ListToolCalls(t *testing.T) {
provider := &OpenAIProvider{}

// Assuming some functions are already registered
toolCalls, err := provider.ListToolCalls()
assert.NoError(t, err)

// Replace with your own checks
assert.NotEmpty(t, toolCalls)
}

func TestOpenAIProvider_GetOverview(t *testing.T) {
provider := &OpenAIProvider{}

// Assuming some functions are already registered
overview, err := provider.GetOverview()
assert.NoError(t, err)

// Replace with your own checks
assert.NotEmpty(t, overview.Functions)
}

func TestHasToolCalls(t *testing.T) {
// Assuming some functions are already registered
toolCalls, hasCalls := hasToolCalls()

// Replace with your own checks
assert.True(t, hasCalls)
assert.NotEmpty(t, toolCalls)
}

func TestNewProvider(t *testing.T) {
// Set environment variables for testing
os.Setenv("OPENAI_API_KEY", "test_api_key")
os.Setenv("OPENAI_MODEL", "test_model")

provider := NewProvider("", "")

assert.Equal(t, "test_api_key", provider.APIKey)
assert.Equal(t, "test_model", provider.Model)
}

func TestOpenAIProvider_Name(t *testing.T) {
provider := &OpenAIProvider{}

name := provider.Name()

assert.Equal(t, "openai", name)
}

0 comments on commit 7e9636a

Please sign in to comment.