Skip to content

Commit

Permalink
feat: cerebras llm provider (#895)
Browse files Browse the repository at this point in the history
# Description

Implement [Cerebras Inference](https://cerebras.ai) cloud support.

## Known Issues 

- It does not support streaming calls when `tools` are present.
- The name of the function to be called. Must be a-z, A-Z, 0-9, or
contain underscores ~and dashes~, with a maximum length of 64.
  • Loading branch information
venjiang authored Sep 5, 2024
1 parent 9ef70f6 commit cd6b972
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 10 deletions.
3 changes: 3 additions & 0 deletions cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/yomorun/yomo/pkg/bridge/ai"
providerpkg "github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cerebras"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfazure"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/cfopenai"
"github.com/yomorun/yomo/pkg/bridge/ai/provider/gemini"
Expand Down Expand Up @@ -153,6 +154,8 @@ func registerAIProvider(aiConfig *ai.Config) error {
providerpkg.RegisterProvider(gemini.NewProvider(provider["api_key"]))
case "githubmodels":
providerpkg.RegisterProvider(githubmodels.NewProvider(provider["api_key"], provider["model"]))
case "cerebras":
providerpkg.RegisterProvider(cerebras.NewProvider(provider["api_key"], provider["model"]))
default:
log.WarningStatusEvent(os.Stdout, "unknown provider: %s", name)
}
Expand Down
23 changes: 13 additions & 10 deletions example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ bridge:
ai:
server:
addr: localhost:8000
provider: azopenai
provider: cerebras

providers:
azopenai:
api_endpoint:
deployment_id:
api_key:
api_version:
api_endpoint:
deployment_id:
api_key:
api_version:

openai:
api_key:
Expand All @@ -27,9 +27,12 @@ bridge:
api_key:

cloudflare_azure:
endpoint:
api_key:
resource:
deployment_id:
api_version:
endpoint:
api_key:
resource:
deployment_id:
api_version:

cerebras:
api_key:
model:
85 changes: 85 additions & 0 deletions pkg/bridge/ai/provider/cerebras/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Package cerebras is the Cerebras llm provider
package cerebras

import (
"context"
"os"

_ "github.com/joho/godotenv/autoload"
"github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/core/ylog"

provider "github.com/yomorun/yomo/pkg/bridge/ai/provider"
)

const BaseURL = "https://api.cerebras.ai/v1"

// check if implements ai.Provider
var _ provider.LLMProvider = &Provider{}

// Provider is the provider for Cerebras
type Provider struct {
// APIKey is the API key for Cerberas
APIKey string
// Model is the model for Cerberas
// eg. "llama3.1-8b", "llama-3.1-70b"
Model string
client *openai.Client
}

// NewProvider creates a new cerebras ai provider
func NewProvider(apiKey string, model string) *Provider {
if apiKey == "" {
apiKey = os.Getenv("CEREBRAS_API_KEY")
if apiKey == "" {
ylog.Error("CEREBRAS_API_KEY is empty, cerebras provider will not work properly")
}
}
if model == "" {
model = os.Getenv("CEREBRAS_MODEL")
if model == "" {
model = "llama3.1-8b"
}
}
c := openai.DefaultConfig(apiKey)
c.BaseURL = BaseURL

return &Provider{
APIKey: apiKey,
Model: model,
client: openai.NewClientWithConfig(c),
}
}

// Name returns the name of the provider
func (p *Provider) Name() string {
return "cerebras"
}

// GetChatCompletions implements ai.LLMProvider.
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) {
if req.Model == "" {
req.Model = p.Model
}

return p.client.CreateChatCompletion(ctx, req)
}

// GetChatCompletionsStream implements ai.LLMProvider.
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) {
if req.Model == "" {
req.Model = p.Model
}
// The following fields are currently not supported and will result in a 400 error if they are supplied:
// frequency_penalty
// logit_bias
// logprobs
// presence_penalty
// parallel_tool_calls
// service_tier

// it does not support streaming calls when tools are present

return p.client.CreateChatCompletionStream(ctx, req)
}
53 changes: 53 additions & 0 deletions pkg/bridge/ai/provider/cerebras/provider_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package cerebras

import (
"context"
"os"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestNewProvider(t *testing.T) {
// Set environment variables for testing
os.Setenv("CEREBRAS_API_KEY", "test_api_key")
os.Setenv("CEREBRAS_MODEL", "llama3.1-70b")

provider := NewProvider("", "")
assert.Equal(t, "test_api_key", provider.APIKey)
assert.Equal(t, "llama3.1-70b", provider.Model)
}

func TestCerebrasProvider_Name(t *testing.T) {
provider := &Provider{}
name := provider.Name()

assert.Equal(t, "cerebras", name)
}

func TestCerebrasProvider_GetChatCompletions(t *testing.T) {
provider := NewProvider("", "")
msgs := []openai.ChatCompletionMessage{
{
Role: "user",
Content: "hello",
},
{
Role: "system",
Content: "I'm a bot",
},
}
req := openai.ChatCompletionRequest{
Messages: msgs,
Model: "llama3.1-8b",
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
assert.Error(t, err)
t.Log(err)
}

0 comments on commit cd6b972

Please sign in to comment.