Skip to content

Commit

Permalink
Fix broken implementation AssistantModify implementation (sashabarano…
Browse files Browse the repository at this point in the history
…v#685)

* add custom marshaller, documentation and isolate tests

* fix linter
  • Loading branch information
qhenkart authored and grulex committed Mar 15, 2024
1 parent 14db069 commit 99f1fb8
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 30 deletions.
30 changes: 28 additions & 2 deletions assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openai

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
Expand All @@ -21,7 +22,7 @@ type Assistant struct {
Description *string `json:"description,omitempty"`
Model string `json:"model"`
Instructions *string `json:"instructions,omitempty"`
Tools []AssistantTool `json:"tools,omitempty"`
Tools []AssistantTool `json:"tools"`
FileIDs []string `json:"file_ids,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`

Expand All @@ -41,16 +42,41 @@ type AssistantTool struct {
Function *FunctionDefinition `json:"function,omitempty"`
}

// AssistantRequest provides the assistant request parameters.
// When modifying the tools the API functions as the following:
// If Tools is undefined, no changes are made to the Assistant's tools.
// If Tools is empty slice it will effectively delete all of the Assistant's tools.
// If Tools is populated, it will replace all of the existing Assistant's tools with the provided tools.
type AssistantRequest struct {
Model string `json:"model"`
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Instructions *string `json:"instructions,omitempty"`
Tools []AssistantTool `json:"tools"`
Tools []AssistantTool `json:"-"`
FileIDs []string `json:"file_ids,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}

// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases
// If Tools is nil, the field is omitted from the JSON.
// If Tools is an empty slice, it's included in the JSON as an empty array ([]).
// If Tools is populated, it's included in the JSON with the elements.
func (a AssistantRequest) MarshalJSON() ([]byte, error) {
type Alias AssistantRequest
assistantAlias := &struct {
Tools *[]AssistantTool `json:"tools,omitempty"`
*Alias
}{
Alias: (*Alias)(&a),
}

if a.Tools != nil {
assistantAlias.Tools = &a.Tools
}

return json.Marshal(assistantAlias)
}

// AssistantsList is a list of assistants.
type AssistantsList struct {
Assistants []Assistant `json:"data"`
Expand Down
109 changes: 81 additions & 28 deletions assistant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ When asked a question, write and run Python code to answer the question.`
})
fmt.Fprintln(w, string(resBytes))
case http.MethodPost:
var request openai.AssistantRequest
var request openai.Assistant
err := json.NewDecoder(r.Body).Decode(&request)
checks.NoError(t, err, "Decode error")

Expand Down Expand Up @@ -163,44 +163,97 @@ When asked a question, write and run Python code to answer the question.`

ctx := context.Background()

_, err := client.CreateAssistant(ctx, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
t.Run("create_assistant", func(t *testing.T) {
_, err := client.CreateAssistant(ctx, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
})
checks.NoError(t, err, "CreateAssistant error")
})
checks.NoError(t, err, "CreateAssistant error")

_, err = client.RetrieveAssistant(ctx, assistantID)
checks.NoError(t, err, "RetrieveAssistant error")
t.Run("retrieve_assistant", func(t *testing.T) {
_, err := client.RetrieveAssistant(ctx, assistantID)
checks.NoError(t, err, "RetrieveAssistant error")
})

_, err = client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
t.Run("delete_assistant", func(t *testing.T) {
_, err := client.DeleteAssistant(ctx, assistantID)
checks.NoError(t, err, "DeleteAssistant error")
})
checks.NoError(t, err, "ModifyAssistant error")

_, err = client.DeleteAssistant(ctx, assistantID)
checks.NoError(t, err, "DeleteAssistant error")
t.Run("list_assistant", func(t *testing.T) {
_, err := client.ListAssistants(ctx, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistants error")
})

_, err = client.ListAssistants(ctx, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistants error")
t.Run("create_assistant_file", func(t *testing.T) {
_, err := client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
FileID: assistantFileID,
})
checks.NoError(t, err, "CreateAssistantFile error")
})

_, err = client.CreateAssistantFile(ctx, assistantID, openai.AssistantFileRequest{
FileID: assistantFileID,
t.Run("list_assistant_files", func(t *testing.T) {
_, err := client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistantFiles error")
})
checks.NoError(t, err, "CreateAssistantFile error")

_, err = client.ListAssistantFiles(ctx, assistantID, &limit, &order, &after, &before)
checks.NoError(t, err, "ListAssistantFiles error")
t.Run("retrieve_assistant_file", func(t *testing.T) {
_, err := client.RetrieveAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "RetrieveAssistantFile error")
})

_, err = client.RetrieveAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "RetrieveAssistantFile error")
t.Run("delete_assistant_file", func(t *testing.T) {
err := client.DeleteAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "DeleteAssistantFile error")
})

err = client.DeleteAssistantFile(ctx, assistantID, assistantFileID)
checks.NoError(t, err, "DeleteAssistantFile error")
t.Run("modify_assistant_no_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
})
checks.NoError(t, err, "ModifyAssistant error")

if assistant.Tools != nil {
t.Errorf("expected nil got %v", assistant.Tools)
}
})

t.Run("modify_assistant_with_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
Tools: []openai.AssistantTool{{Type: openai.AssistantToolTypeFunction}},
})
checks.NoError(t, err, "ModifyAssistant error")

if assistant.Tools == nil || len(assistant.Tools) != 1 {
t.Errorf("expected a slice got %v", assistant.Tools)
}
})

t.Run("modify_assistant_empty_tools", func(t *testing.T) {
assistant, err := client.ModifyAssistant(ctx, assistantID, openai.AssistantRequest{
Name: &assistantName,
Description: &assistantDescription,
Model: openai.GPT4TurboPreview,
Instructions: &assistantInstructions,
Tools: make([]openai.AssistantTool, 0),
})

checks.NoError(t, err, "ModifyAssistant error")

if assistant.Tools == nil {
t.Errorf("expected a slice got %v", assistant.Tools)
}
})
}

func TestAzureAssistant(t *testing.T) {
Expand Down

0 comments on commit 99f1fb8

Please sign in to comment.