From e8271e5bc13d1862ebaa54c0e98808da3cac5273 Mon Sep 17 00:00:00 2001 From: Anders Swanson Date: Mon, 2 Dec 2024 12:36:41 -0800 Subject: [PATCH] feat: oci genai chat models Signed-off-by: Anders Swanson --- pkg/ai/ocigenai.go | 135 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 107 insertions(+), 28 deletions(-) diff --git a/pkg/ai/ocigenai.go b/pkg/ai/ocigenai.go index 53c7076289..9e7a39adfa 100644 --- a/pkg/ai/ocigenai.go +++ b/pkg/ai/ocigenai.go @@ -16,21 +16,32 @@ package ai import ( "context" "errors" + "fmt" "github.com/oracle/oci-go-sdk/v65/common" + "github.com/oracle/oci-go-sdk/v65/generativeai" "github.com/oracle/oci-go-sdk/v65/generativeaiinference" - "strings" + "reflect" ) const ociClientName = "oci" +type ociModelVendor string + +const ( + vendorCohere = "cohere" + vendorMeta = "meta" +) + type OCIGenAIClient struct { nopCloser client *generativeaiinference.GenerativeAiInferenceClient - model string + model *generativeai.Model + modelID string compartmentId string temperature float32 topP float32 + topK int32 maxTokens int } @@ -40,9 +51,10 @@ func (c *OCIGenAIClient) GetName() string { func (c *OCIGenAIClient) Configure(config IAIConfig) error { config.GetEndpointName() - c.model = config.GetModel() + c.modelID = config.GetModel() c.temperature = config.GetTemperature() c.topP = config.GetTopP() + c.topK = config.GetTopK() c.maxTokens = config.GetMaxTokens() c.compartmentId = config.GetCompartmentId() provider := common.DefaultConfigProvider() @@ -55,43 +67,110 @@ func (c *OCIGenAIClient) Configure(config IAIConfig) error { } func (c *OCIGenAIClient) GetCompletion(ctx context.Context, prompt string) (string, error) { - generateTextRequest := c.newGenerateTextRequest(prompt) - generateTextResponse, err := c.client.GenerateText(ctx, generateTextRequest) + request := c.newChatRequest(prompt) + response, err := c.client.Chat(ctx, request) if err != nil { return "", err } - return extractGeneratedText(generateTextResponse.InferenceResponse) + if err != nil { + return "", err + } + return extractGeneratedText(response.ChatResponse) } -func (c *OCIGenAIClient) newGenerateTextRequest(prompt string) generativeaiinference.GenerateTextRequest { +func (c *OCIGenAIClient) newChatRequest(prompt string) generativeaiinference.ChatRequest { + return generativeaiinference.ChatRequest{ + ChatDetails: generativeaiinference.ChatDetails{ + CompartmentId: &c.compartmentId, + ServingMode: c.getServingMode(), + ChatRequest: c.getChatModelRequest(prompt), + }, + } +} + +func (c *OCIGenAIClient) getChatModelRequest(prompt string) generativeaiinference.BaseChatRequest { temperatureF64 := float64(c.temperature) topPF64 := float64(c.topP) - return generativeaiinference.GenerateTextRequest{ - GenerateTextDetails: generativeaiinference.GenerateTextDetails{ - CompartmentId: &c.compartmentId, - ServingMode: generativeaiinference.OnDemandServingMode{ - ModelId: &c.model, - }, - InferenceRequest: generativeaiinference.CohereLlmInferenceRequest{ - Prompt: &prompt, - MaxTokens: &c.maxTokens, - Temperature: &temperatureF64, - TopP: &topPF64, + topK := int(c.topP) + + switch c.getVendor() { + case vendorMeta: + messages := []generativeaiinference.Message{ + generativeaiinference.UserMessage{ + Content: []generativeaiinference.ChatContent{ + generativeaiinference.TextContent{ + Text: &prompt, + }, + }, }, - }, + } + return generativeaiinference.GenericChatRequest{ + Messages: messages, + TopK: &topK, + TopP: &topPF64, + Temperature: &temperatureF64, + MaxTokens: &c.maxTokens, + } + default: // Default to cohere + return generativeaiinference.CohereChatRequest{ + Message: &prompt, + MaxTokens: &c.maxTokens, + Temperature: &temperatureF64, + TopK: &topK, + TopP: &topPF64, + } + } } -func extractGeneratedText(llmInferenceResponse generativeaiinference.LlmInferenceResponse) (string, error) { - response, ok := llmInferenceResponse.(generativeaiinference.CohereLlmInferenceResponse) - if !ok { - return "", errors.New("failed to extract generated text from backed response") +func extractGeneratedText(llmInferenceResponse generativeaiinference.BaseChatResponse) (string, error) { + switch response := llmInferenceResponse.(type) { + case generativeaiinference.GenericChatResponse: + if len(response.Choices) > 0 && len(response.Choices[0].Message.GetContent()) > 0 { + if content, ok := response.Choices[0].Message.GetContent()[0].(generativeaiinference.TextContent); ok { + return *content.Text, nil + } + } + return "", errors.New("no text found in oci response") + case generativeaiinference.CohereChatResponse: + return *response.Text, nil + default: + return "", fmt.Errorf("unknown oci response type: %s", reflect.TypeOf(llmInferenceResponse).Name()) } - sb := strings.Builder{} - for _, text := range response.GeneratedTexts { - if text.Text != nil { - sb.WriteString(*text.Text) +} + +func (c *OCIGenAIClient) getServingMode() generativeaiinference.ServingMode { + if c.isBaseModel() { + return generativeaiinference.OnDemandServingMode{ + ModelId: &c.modelID, } } - return sb.String(), nil + return generativeaiinference.DedicatedServingMode{ + EndpointId: &c.modelID, + } +} + +func (c *OCIGenAIClient) getModel(provider common.ConfigurationProvider) (*generativeai.Model, error) { + client, err := generativeai.NewGenerativeAiClientWithConfigurationProvider(provider) + if err != nil { + return nil, err + } + response, err := client.GetModel(context.Background(), generativeai.GetModelRequest{ + ModelId: &c.modelID, + }) + if err != nil { + return nil, err + } + return &response.Model, nil +} + +func (c *OCIGenAIClient) isBaseModel() bool { + return c.model != nil && c.model.Type == generativeai.ModelTypeBase +} + +func (c *OCIGenAIClient) getVendor() ociModelVendor { + if c.model == nil || c.model.Vendor == nil { + return "" + } + return ociModelVendor(*c.model.Vendor) }