Skip to content

Commit

Permalink
feat(llm-bridge): add x.ai provider (#927)
Browse files Browse the repository at this point in the history
add [x.ai](https://x.ai) support

- docs: https://docs.x.ai/api#introduction

> [!CAUTION]
> Function calling is currently not working due to an API bug.

`/chat/completions` endpoint works, but the function calling feature is
still affected by the x.ai API error.

[Official Docs](https://docs.x.ai/docs/guides#function-calling) state
that the `finish_reason` should be `tool_calls`, but API returns `stop`,
this breaks the system.


![image](https://github.com/user-attachments/assets/ec92123a-53ac-470a-b92a-db8b7ff31afe)

We will fix it as soon as x.ai resolves the API bug.
  • Loading branch information
fanweixiao authored Oct 25, 2024
1 parent 3c2ad67 commit 4dcfc56
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 9 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,20 @@ time=2024-03-19T21:43:30.584+08:00 level=INFO msg="register ai function success"
### Done, let's have a try

```sh
$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"compare nike and puma website speed"}' http://127.0.0.1:8000/invoke
$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{
"messages": [
{
"role": "system",
"content": "You are a test assistant."
},
{
"role": "user",
"content": "Compare website speed between Nike and Puma"
}
],
"stream": false
}'

HTTP/1.1 200 OK
Content-Length: 944
Connection: keep-alive
Expand Down
3 changes: 3 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai/provider/githubmodels"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/ollama"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/openai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/xai"
)

// serveCmd represents the serve command
Expand Down Expand Up @@ -159,6 +160,8 @@ func registerAIProvider(aiConfig *ai.Config) error {
providerpkg.RegisterProvider(cerebras.NewProvider(provider["api_key"], provider["model"]))
case "anthropic":
providerpkg.RegisterProvider(anthropic.NewProvider(provider["api_key"], provider["model"]))
case "xai":
providerpkg.RegisterProvider(xai.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
Expand Down
43 changes: 40 additions & 3 deletions example/10-ai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,20 @@ cd llm-sfn-get-ip-and-latency && yomo run -m go.mod app.go
## Step 3: Invoke the LLM Function

```bash
$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"compare nike and puma website speed"}' http://127.0.0.1:8000/invoke
$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{
"messages": [
{
"role": "system",
"content": "You are a test assistant."
},
{
"role": "user",
"content": "Compare website speed between Nike and Puma"
}
],
"stream": false
}'

HTTP/1.1 200 OK
Content-Length: 944
Connection: keep-alive
Expand All @@ -47,7 +60,19 @@ Proxy-Connection: keep-alive
```

```bash
$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"what is the time in Singapore for Thursday, February 15th, 2024 7:00am and 8:00am (UTC-08:00) Pacific Time"}' http://127.0.0.1:8000/invoke
$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{
"messages": [
{
"role": "system",
"content": "You are a test assistant."
},
{
"role": "user",
"content": "what is the time in Singapore for Thursday, February 15th, 2024 7:00am and 8:00am (UTC-08:00) Pacific Time"
}
],
"stream": false
}'
HTTP/1.1 200 OK
Content-Length: 618
Connection: keep-alive
Expand All @@ -70,7 +95,19 @@ Proxy-Connection: keep-alive
```

```bash
$ curl -i -X POST -H "Content-Type: application/json" -d '{"prompt":"How much is 100 usd in Korea and UK currency"}' http://127.0.0.1:8000/invoke
$ curl -i http://127.0.0.1:9000/v1/chat/completions -H "Content-Type: application/json" -d '{
"messages": [
{
"role": "system",
"content": "You are a test assistant."
},
{
"role": "user",
"content": "How much is 100 usd in Korea and UK currency?"
}
],
"stream": false
}'
HTTP/1.1 200 OK
Content-Length: 333
Connection: keep-alive
Expand Down
4 changes: 4 additions & 0 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ bridge:
anthropic:
api_key:
model:

xai:
api_key:
model:
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/caarlos0/env/v6 v6.10.1
github.com/cenkalti/backoff/v4 v4.3.0
github.com/fatih/color v1.17.0
github.com/google/generative-ai-go v0.17.0
github.com/google/generative-ai-go v0.18.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/invopop/jsonschema v0.12.0
github.com/joho/godotenv v1.5.1
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/generative-ai-go v0.17.0 h1:kUmCXUIwJouD7I7ev3OmxzzQVICyhIWAxaXk2yblCMY=
github.com/google/generative-ai-go v0.17.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/generative-ai-go v0.18.0 h1:6ybg9vOCLcI/UpBBYXOTVgvKmcUKFRNj+2Cj3GnebSo=
github.com/google/generative-ai-go v0.18.0/go.mod h1:JYolL13VG7j79kM5BtHz4qwONHkeJQzOCkKXnpqtS/E=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU=
Expand Down
4 changes: 2 additions & 2 deletions pkg/bridge/ai/provider/cerebras/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ var _ provider.LLMProvider = &Provider{}

// Provider is the provider for Cerebras
type Provider struct {
// APIKey is the API key for Cerberas
// APIKey is the API key for Cerebras
APIKey string
// Model is the model for Cerberas
// Model is the model for Cerebras
// eg. "llama3.1-8b", "llama-3.1-70b"
Model string
client *openai.Client
Expand Down
67 changes: 67 additions & 0 deletions pkg/bridge/ai/provider/xai/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Package xai is the x.ai provider
package xai

import (
"context"

_ "github.com/joho/godotenv/autoload"
"github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/metadata"

provider "github.com/yomorun/yomo/pkg/bridge/ai/provider"
)

const BaseURL = "https://api.x.ai/v1"
const DefaultModel = "grok-beta"

// check if implements ai.Provider
var _ provider.LLMProvider = &Provider{}

// Provider is the provider for x.ai
type Provider struct {
// APIKey is the API key for x.ai
APIKey string
// Model is the model for x.ai
// eg. "grok-beta"
Model string
client *openai.Client
}

// NewProvider creates a new x.ai ai provider
func NewProvider(apiKey string, model string) *Provider {
if model == "" {
model = DefaultModel
}

c := openai.DefaultConfig(apiKey)
c.BaseURL = BaseURL

return &Provider{
APIKey: apiKey,
Model: model,
client: openai.NewClientWithConfig(c),
}
}

// Name returns the name of the provider
func (p *Provider) Name() string {
return "xai"
}

// GetChatCompletions implements ai.LLMProvider.
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) {
if req.Model == "" {
req.Model = p.Model
}

return p.client.CreateChatCompletion(ctx, req)
}

// GetChatCompletionsStream implements ai.LLMProvider.
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) {
if req.Model == "" {
req.Model = p.Model
}
// it does not support streaming calls when tools are present
return p.client.CreateChatCompletionStream(ctx, req)
}
54 changes: 54 additions & 0 deletions pkg/bridge/ai/provider/xai/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package xai

import (
"context"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestXAIProvider_Name(t *testing.T) {
provider := &Provider{}
name := provider.Name()

assert.Equal(t, "xai", name)
}

func TestXAIProvider_GetChatCompletions(t *testing.T) {
provider := NewProvider("", "")
msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
Model: "groq-beta",
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

req = openai.ChatCompletionRequest{
Messages: msgs,
}

_, err = provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)
}

0 comments on commit 4dcfc56

Please sign in to comment.