-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
# Description Implement [Cerebras Inference](https://cerebras.ai) cloud support. ## Known Issues - It does not support streaming calls when `tools` are present. - The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores ~and dashes~, with a maximum length of 64.
- Loading branch information
Showing
4 changed files
with
154 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Package cerebras is the Cerebras llm provider | ||
package cerebras | ||
|
||
import ( | ||
"context" | ||
"os" | ||
|
||
_ "github.com/joho/godotenv/autoload" | ||
"github.com/sashabaranov/go-openai" | ||
"github.com/yomorun/yomo/core/metadata" | ||
"github.com/yomorun/yomo/core/ylog" | ||
|
||
provider "github.com/yomorun/yomo/pkg/bridge/ai/provider" | ||
) | ||
|
||
const BaseURL = "https://api.cerebras.ai/v1" | ||
|
||
// check if implements ai.Provider | ||
var _ provider.LLMProvider = &Provider{} | ||
|
||
// Provider is the provider for Cerebras | ||
type Provider struct { | ||
// APIKey is the API key for Cerberas | ||
APIKey string | ||
// Model is the model for Cerberas | ||
// eg. "llama3.1-8b", "llama-3.1-70b" | ||
Model string | ||
client *openai.Client | ||
} | ||
|
||
// NewProvider creates a new cerebras ai provider | ||
func NewProvider(apiKey string, model string) *Provider { | ||
if apiKey == "" { | ||
apiKey = os.Getenv("CEREBRAS_API_KEY") | ||
if apiKey == "" { | ||
ylog.Error("CEREBRAS_API_KEY is empty, cerebras provider will not work properly") | ||
} | ||
} | ||
if model == "" { | ||
model = os.Getenv("CEREBRAS_MODEL") | ||
if model == "" { | ||
model = "llama3.1-8b" | ||
} | ||
} | ||
c := openai.DefaultConfig(apiKey) | ||
c.BaseURL = BaseURL | ||
|
||
return &Provider{ | ||
APIKey: apiKey, | ||
Model: model, | ||
client: openai.NewClientWithConfig(c), | ||
} | ||
} | ||
|
||
// Name returns the name of the provider | ||
func (p *Provider) Name() string { | ||
return "cerebras" | ||
} | ||
|
||
// GetChatCompletions implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletions(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (openai.ChatCompletionResponse, error) { | ||
if req.Model == "" { | ||
req.Model = p.Model | ||
} | ||
|
||
return p.client.CreateChatCompletion(ctx, req) | ||
} | ||
|
||
// GetChatCompletionsStream implements ai.LLMProvider. | ||
func (p *Provider) GetChatCompletionsStream(ctx context.Context, req openai.ChatCompletionRequest, _ metadata.M) (provider.ResponseRecver, error) { | ||
if req.Model == "" { | ||
req.Model = p.Model | ||
} | ||
// The following fields are currently not supported and will result in a 400 error if they are supplied: | ||
// frequency_penalty | ||
// logit_bias | ||
// logprobs | ||
// presence_penalty | ||
// parallel_tool_calls | ||
// service_tier | ||
|
||
// it does not support streaming calls when tools are present | ||
|
||
return p.client.CreateChatCompletionStream(ctx, req) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package cerebras | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"testing" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestNewProvider(t *testing.T) { | ||
// Set environment variables for testing | ||
os.Setenv("CEREBRAS_API_KEY", "test_api_key") | ||
os.Setenv("CEREBRAS_MODEL", "llama3.1-70b") | ||
|
||
provider := NewProvider("", "") | ||
assert.Equal(t, "test_api_key", provider.APIKey) | ||
assert.Equal(t, "llama3.1-70b", provider.Model) | ||
} | ||
|
||
func TestCerebrasProvider_Name(t *testing.T) { | ||
provider := &Provider{} | ||
name := provider.Name() | ||
|
||
assert.Equal(t, "cerebras", name) | ||
} | ||
|
||
func TestCerebrasProvider_GetChatCompletions(t *testing.T) { | ||
provider := NewProvider("", "") | ||
msgs := []openai.ChatCompletionMessage{ | ||
{ | ||
Role: "user", | ||
Content: "hello", | ||
}, | ||
{ | ||
Role: "system", | ||
Content: "I'm a bot", | ||
}, | ||
} | ||
req := openai.ChatCompletionRequest{ | ||
Messages: msgs, | ||
Model: "llama3.1-8b", | ||
} | ||
|
||
_, err := provider.GetChatCompletions(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
|
||
_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil) | ||
assert.Error(t, err) | ||
t.Log(err) | ||
} |