From 4dcfc56a99fe850193376f93050e848fae16adb2 Mon Sep 17 00:00:00 2001 From: "C.C." Date: Fri, 25 Oct 2024 11:52:07 +0800 Subject: [PATCH] feat(llm-bridge): add x.ai provider (#927) add [x.ai](https://x.ai) support - docs: https://docs.x.ai/api#introduction > [!CAUTION] > Function calling is currently not working due to an API bug. `/chat/completions` endpoint works, but the function calling feature is still affected by the x.ai API error. [Official Docs](https://docs.x.ai/docs/guides#function-calling) state that the `finish_reason` should be `tool_calls`, but API returns `stop`, this breaks the system. ![image](https://github.com/user-attachments/assets/ec92123a-53ac-470a-b92a-db8b7ff31afe) We will fix it as soon as x.ai resolves the API bug. --- README.md | 15 ++++- cli/serve.go | 3 + example/10-ai/README.md | 43 ++++++++++++- example/10-ai/zipper.yaml | 4 ++ go.mod | 2 +- go.sum | 4 +- pkg/bridge/ai/provider/cerebras/provider.go | 4 +- pkg/bridge/ai/provider/xai/provider.go | 67 +++++++++++++++++++++ pkg/bridge/ai/provider/xai/provider_test.go | 54 +++++++++++++++++ 9 files changed, 187 insertions(+), 9 deletions(-) create mode 100644 pkg/bridge/ai/provider/xai/provider.go create mode 100644 pkg/bridge/ai/provider/xai/provider_test.go diff --git a/README.md b/README.md index b6a6e9fd5..90caea101 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,20 @@ time=2024-03-19T21:43:30.584+08:00 level=INFO msg="register ai function success" ### Done, let's have a try ```sh -$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"compare nike and puma website speed"}' http://127.0.0.1:8000/invoke +$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "messages": [ + { + "role": "system", + "content": "You are a test assistant." + }, + { + "role": "user", + "content": "Compare website speed between Nike and Puma" + } + ], + "stream": false +}' + HTTP/1.1 200 OK Content-Length: 944 Connection: keep-alive diff --git a/cli/serve.go b/cli/serve.go index 7cafff3a5..768e3ec2a 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -38,6 +38,7 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels" "github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama" "github.com/yomorun/yomo/pkg/bridge/ai/provider/openai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider/xai" ) // serveCmd represents the serve command @@ -159,6 +160,8 @@ func registerAIProvider(aiConfig *ai.Config) error { providerpkg.RegisterProvider(cerebras.NewProvider(provider["api_key"], provider["model"])) case "anthropic": providerpkg.RegisterProvider(anthropic.NewProvider(provider["api_key"], provider["model"])) + case "xai": + providerpkg.RegisterProvider(xai.NewProvider(provider["api_key"], provider["model"])) default: log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name) } diff --git a/example/10-ai/README.md b/example/10-ai/README.md index 2514daf90..e3400db08 100644 --- a/example/10-ai/README.md +++ b/example/10-ai/README.md @@ -24,7 +24,20 @@ cd llm-sfn-get-ip-and-latency && yomo run -m go.mod app.go ## Step 3: Invoke the LLM Function ```bash -$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"compare nike and puma website speed"}' http://127.0.0.1:8000/invoke +$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "messages": [ + { + "role": "system", + "content": "You are a test assistant." + }, + { + "role": "user", + "content": "Compare website speed between Nike and Puma" + } + ], + "stream": false +}' + HTTP/1.1 200 OK Content-Length: 944 Connection: keep-alive @@ -47,7 +60,19 @@ Proxy-Connection: keep-alive ``` ```bash -$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"what is the time in Singapore for Thursday, February 15th, 2024 7:00am and 8:00am (UTC-08:00) Pacific Time"}' http://127.0.0.1:8000/invoke +$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "messages": [ + { + "role": "system", + "content": "You are a test assistant." + }, + { + "role": "user", + "content": "what is the time in Singapore for Thursday, February 15th, 2024 7:00am and 8:00am (UTC-08:00) Pacific Time" + } + ], + "stream": false +}' HTTP/1.1 200 OK Content-Length: 618 Connection: keep-alive @@ -70,7 +95,19 @@ Proxy-Connection: keep-alive ``` ```bash -$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"How much is 100 usd in Korea and UK currency"}' http://127.0.0.1:8000/invoke +$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{ + "messages": [ + { + "role": "system", + "content": "You are a test assistant." + }, + { + "role": "user", + "content": "How much is 100 usd in Korea and UK currency?" + } + ], + "stream": false +}' HTTP/1.1 200 OK Content-Length: 333 Connection: keep-alive diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index ac4903af1..c8fc79733 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -44,3 +44,7 @@ bridge: anthropic: api_key: model: + + xai: + api_key: + model: diff --git a/go.mod b/go.mod index 3608a98b3..4578bae1b 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/caarlos0/env/v6 v6.10.1 github.com/cenkalti/backoff/v4 v4.3.0 github.com/fatih/color v1.17.0 - github.com/google/generative-ai-go v0.17.0 + github.com/google/generative-ai-go v0.18.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/invopop/jsonschema v0.12.0 github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index 3993d250b..34a26a07f 100644 --- a/go.sum +++ b/go.sum @@ -103,8 +103,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= -github.com/google/generative-ai-go v0.17.0 h1:kUmCXUIwJouD7I7ev3OmxzzQVICyhIWAxaXk2yblCMY= -github.com/google/generative-ai-go v0.17.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= +github.com/google/generative-ai-go v0.18.0 h1:6ybg9vOCLcI/UpBBYXOTVgvKmcUKFRNj+2Cj3GnebSo= +github.com/google/generative-ai-go v0.18.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= diff --git a/pkg/bridge/ai/provider/cerebras/provider.go b/pkg/bridge/ai/provider/cerebras/provider.go index 4905b3832..19df2e43a 100644 --- a/pkg/bridge/ai/provider/cerebras/provider.go +++ b/pkg/bridge/ai/provider/cerebras/provider.go @@ -20,9 +20,9 @@ var _ provider.LLMProvider = &Provider{} // Provider is the provider for Cerebras type Provider struct { - // APIKey is the API key for Cerberas + // APIKey is the API key for Cerebras APIKey string - // Model is the model for Cerberas + // Model is the model for Cerebras // eg. "llama3.1-8b", "llama-3.1-70b" Model string client *openai.Client diff --git a/pkg/bridge/ai/provider/xai/provider.go b/pkg/bridge/ai/provider/xai/provider.go new file mode 100644 index 000000000..ea2726c77 --- /dev/null +++ b/pkg/bridge/ai/provider/xai/provider.go @@ -0,0 +1,67 @@ +// Package xai is the x.ai provider +package xai + +import ( + "context" + + _ "github.com/joho/godotenv/autoload" + "github.com/sashabaranov/go-openai" + "github.com/yomorun/yomo/core/metadata" + + provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" +) + +const BaseURL = "https://api.x.ai/v1" +const DefaultModel = "grok-beta" + +// check if implements ai.Provider +var _ provider.LLMProvider = &Provider{} + +// Provider is the provider for x.ai +type Provider struct { + // APIKey is the API key for x.ai + APIKey string + // Model is the model for x.ai + // eg. "grok-beta" + Model string + client *openai.Client +} + +// NewProvider creates a new x.ai ai provider +func NewProvider(apiKey string, model string) *Provider { + if model == "" { + model = DefaultModel + } + + c := openai.DefaultConfig(apiKey) + c.BaseURL = BaseURL + + return &Provider{ + APIKey: apiKey, + Model: model, + client: openai.NewClientWithConfig(c), + } +} + +// Name returns the name of the provider +func (p *Provider) Name() string { + return "xai" +} + +// GetChatCompletions implements ai.LLMProvider. +func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { + if req.Model == "" { + req.Model = p.Model + } + + return p.client.CreateChatCompletion(ctx, req) +} + +// GetChatCompletionsStream implements ai.LLMProvider. +func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { + if req.Model == "" { + req.Model = p.Model + } + // it does not support streaming calls when tools are present + return p.client.CreateChatCompletionStream(ctx, req) +} diff --git a/pkg/bridge/ai/provider/xai/provider_test.go b/pkg/bridge/ai/provider/xai/provider_test.go new file mode 100644 index 000000000..58efebb35 --- /dev/null +++ b/pkg/bridge/ai/provider/xai/provider_test.go @@ -0,0 +1,54 @@ +package xai + +import ( + "context" + "testing" + + "github.com/sashabaranov/go-openai" + "github.com/stretchr/testify/assert" +) + +func TestXAIProvider_Name(t *testing.T) { + provider := &Provider{} + name := provider.Name() + + assert.Equal(t, "xai", name) +} + +func TestXAIProvider_GetChatCompletions(t *testing.T) { + provider := NewProvider("", "") + msgs := []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "hello", + }, + { + Role: "system", + Content: "I'm a bot", + }, + } + req := openai.ChatCompletionRequest{ + Messages: msgs, + Model: "groq-beta", + } + + _, err := provider.GetChatCompletions(context.TODO(), req, nil) + assert.Error(t, err) + t.Log(err) + + _, err = provider.GetChatCompletionsStream(context.TODO(), req, nil) + assert.Error(t, err) + t.Log(err) + + req = openai.ChatCompletionRequest{ + Messages: msgs, + } + + _, err = provider.GetChatCompletions(context.TODO(), req, nil) + assert.Error(t, err) + t.Log(err) + + _, err = provider.GetChatCompletionsStream(context.TODO(), req, nil) + assert.Error(t, err) + t.Log(err) +}