Skip to content

Commit

Permalink
Gemini support (#98)
Browse files Browse the repository at this point in the history
* updates in demo

* Gemini
  • Loading branch information
jamesrochabrun authored Nov 12, 2024
1 parent 4895981 commit c3a04bb
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ struct ChatDemoView: View {
messages: [.init(
role: .user,
content: content)],
model: .gpt41106Preview,
logProbs: true,
topLogprobs: 1)
model: .gpt4o)
switch selectedSegment {
case .chatCompletion:
try await chatProvider.startChat(parameters: parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ import SwiftOpenAI
// This information is essential for maintaining context in the conversation and for updating
// the chat UI with proper role attributions for each message.
var newDelta = ChatDisplayMessage.Delta(role: "", content: "")
if let firstDelta = firstChatMessageResponseDelta[result.id] {
if let firstDelta = firstChatMessageResponseDelta[result.id ?? ""] {
// If we have already stored the first delta for this result ID, reuse its role.
newDelta.role = firstDelta.role!
} else {
// Otherwise, store the first delta received for future reference.
firstChatMessageResponseDelta[result.id] = choice.delta
firstChatMessageResponseDelta[result.id ?? ""] = choice.delta
}
// Assign the content received in the current message to the newDelta.
newDelta.content = temporalReceivedMessageContent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ extension View {
}
}

extension DeletionStatus: Equatable {
extension DeletionStatus: @retroactive Equatable {
public static func == (lhs: DeletionStatus, rhs: DeletionStatus) -> Bool {
lhs.id == rhs.id
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import SwiftUI
import SwiftOpenAI

extension FileObject: Equatable {
extension FileObject: @retroactive Equatable {
public static func == (lhs: FileObject, rhs: FileObject) -> Bool {
lhs.id == rhs.id
}
}

extension FileParameters: Equatable, Identifiable {
extension FileParameters: @retroactive Equatable, @retroactive Identifiable {
public static func == (lhs: FileParameters, rhs: FileParameters) -> Bool {
lhs.file == rhs.file &&
lhs.fileName == rhs.fileName &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ import SwiftOpenAI
// This information is essential for maintaining context in the conversation and for updating
// the chat UI with proper role attributions for each message.
var newDelta = ChatDisplayMessage.Delta(role: "", content: "")
if let firstDelta = firstChatMessageResponseDelta[result.id] {
if let firstDelta = firstChatMessageResponseDelta[result.id ?? ""] {
// If we have already stored the first delta for this result ID, reuse its role.
newDelta.role = firstDelta.role!
} else {
// Otherwise, store the first delta received for future reference.
firstChatMessageResponseDelta[result.id] = choice.delta
firstChatMessageResponseDelta[result.id ?? ""] = choice.delta
}
// Assign the content received in the current message to the newDelta.
newDelta.content = temporalReceivedMessageContent
Expand Down
61 changes: 53 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ An open-source Swift package designed for effortless interaction with OpenAI's p
- [Description](#description)
- [Getting an API Key](#getting-an-api-key)
- [Installation](#installation)
- [Compatibility](#compatibility)
- [Usage](#usage)
- [Azure OpenAI](#azure-openai)
- [AIProxy](#aiproxy)
Expand Down Expand Up @@ -100,6 +101,16 @@ limit, so you should not accept the defaults that Xcode proposes. Instead, enter
tab out of the input box for Xcode to adjust the upper bound. Alternatively, you may select `branch` -> `main`
to stay on the bleeding edge.

## Compatibility

SwiftOpenAI supports various providers that are OpenAI-compatible, including but not limited to:

- [Ollama](#ollama)
- [Groq](#groq)
- [Gemini](#gemini)

Check OpenAIServiceFactory for convenience initializers that you can use to provide custom URLs.

## Usage

To use SwiftOpenAI in your project, first import the package:
Expand Down Expand Up @@ -3217,6 +3228,21 @@ let parameters = ChatCompletionParameters(messages: [.init(role: .user, content:
let chatCompletionObject = service.startStreamedChat(parameters: parameters)
```

⚠️ Note: You can probably use the `OpenAIServiceFactory.service(apiKey:overrideBaseURL:proxyPath)` for any OpenAI compatible service.

### Resources:

[Ollama OpenAI compatibility docs.](https://github.com/ollama/ollama/blob/main/docs/openai.md)
[Ollama OpenAI compatibility blog post.](https://ollama.com/blog/openai-compatibility)

### Notes

You can also use this service constructor to provide any URL or apiKey if you need.

```swift
let service = OpenAIServiceFactory.service(apiKey: "YOUR_API_KEY", baseURL: "http://localhost:11434")
```

## Groq

<img width="792" alt="Screenshot 2024-10-11 at 11 49 04 PM" src="https://github.com/user-attachments/assets/7afb36a2-b2d8-4f89-9592-f4cece20d469">
Expand All @@ -3231,23 +3257,42 @@ let service = OpenAIServiceFactory.service(apiKey: apiKey, overrideBaseURL: "htt

For Supported API's using Groq visit its [documentation](https://console.groq.com/docs/openai).

⚠️ Note: You can probably use the `OpenAIServiceFactory.service(apiKey:overrideBaseURL:proxyPath)` for any OpenAI compatible service.
## Gemini

### Resources:
<img width="982" alt="Screenshot 2024-11-12 at 10 53 43 AM" src="https://github.com/user-attachments/assets/cebc18fe-b96d-4ffe-912e-77d625249cf2">

[Ollama OpenAI compatibility docs.](https://github.com/ollama/ollama/blob/main/docs/openai.md)
[Ollama OpenAI compatibility blog post.](https://ollama.com/blog/openai-compatibility)
Gemini is now accessible from the OpenAI Library. Announcement .
`SwiftOpenAI` support all OpenAI endpoints, however Please refer to Gemini documentation to understand which API's are currently compatible'

### Notes
Gemini is now accessible through the OpenAI Library. See the announcement [here](https://developers.googleblog.com/en/gemini-is-now-accessible-from-the-openai-library/).
SwiftOpenAI supports all OpenAI endpoints. However, please refer to the [Gemini documentation](https://ai.google.dev/gemini-api/docs/openai) to understand which APIs are currently compatible."

You can also use this service constructor to provide any URL or apiKey if you need.

You can instantiate a `OpenAIService` using your Gemini token like this...

```swift
let service = OpenAIServiceFactory.service(apiKey: "YOUR_API_KEY", baseURL: "http://localhost:11434")
let geminiAPIKey = "your_api_key"
let baseURL = "https://generativelanguage.googleapis.com"
let version = "v1beta"

let service = OpenAIServiceFactory.service(
apiKey: apiKey,
overrideBaseURL: baseURL,
overrideVersion: version)
```

You can now create a chat request using the .custom model parameter and pass the model name as a string.

```swift
let parameters = ChatCompletionParameters(
messages: [.init(
role: .user,
content: content)],
model: .custom("gemini-1.5-flash"))

let stream = try await service.startStreamedChat(parameters: parameters)
```

## Collaboration
Open a PR for any proposed change pointing it to `main` branch. Unit tests are highly appreciated ❤️


81 changes: 43 additions & 38 deletions Sources/OpenAI/Private/Networking/OpenAIAPI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ enum OpenAIAPI {

static var overrideBaseURL: String? = nil
static var proxyPath: String? = nil
static var overrideVersion: String? = nil

case assistant(AssistantCategory) // https://platform.openai.com/docs/api-reference/assistants
case audio(AudioCategory) // https://platform.openai.com/docs/api-reference/audio
Expand Down Expand Up @@ -152,81 +153,85 @@ extension OpenAIAPI: Endpoint {
return "/\(proxyPath)\(openAIPath)"
}

var version: String {
Self.overrideVersion ?? "v1"
}

var openAIPath: String {
switch self {
case .assistant(let category):
switch category {
case .create, .list: return "/v1/assistants"
case .retrieve(let assistantID), .modify(let assistantID), .delete(let assistantID): return "/v1/assistants/\(assistantID)"
case .create, .list: return "/\(version)/assistants"
case .retrieve(let assistantID), .modify(let assistantID), .delete(let assistantID): return "/\(version)/assistants/\(assistantID)"
}
case .audio(let category): return "/v1/audio/\(category.rawValue)"
case .audio(let category): return "/\(version)/audio/\(category.rawValue)"
case .batch(let category):
switch category {
case .create, .list: return "v1/batches"
case .retrieve(let batchID): return "v1/batches/\(batchID)"
case .cancel(let batchID): return "v1/batches/\(batchID)/cancel"
case .create, .list: return "\(version)/batches"
case .retrieve(let batchID): return "\(version)/batches/\(batchID)"
case .cancel(let batchID): return "\(version)/batches/\(batchID)/cancel"
}
case .chat: return "/v1/chat/completions"
case .embeddings: return "/v1/embeddings"
case .chat: return "/\(version)/chat/completions"
case .embeddings: return "/\(version)/embeddings"
case .file(let category):
switch category {
case .list, .upload: return "/v1/files"
case .delete(let fileID), .retrieve(let fileID): return "/v1/files/\(fileID)"
case .retrieveFileContent(let fileID): return "/v1/files/\(fileID)/content"
case .list, .upload: return "/\(version)/files"
case .delete(let fileID), .retrieve(let fileID): return "/\(version)/files/\(fileID)"
case .retrieveFileContent(let fileID): return "/\(version)/files/\(fileID)/content"
}
case .fineTuning(let category):
switch category {
case .create, .list: return "/v1/fine_tuning/jobs"
case .retrieve(let jobID): return "/v1/fine_tuning/jobs/\(jobID)"
case .cancel(let jobID): return "/v1/fine_tuning/jobs/\(jobID)/cancel"
case .events(let jobID): return "/v1/fine_tuning/jobs/\(jobID)/events"
case .create, .list: return "/\(version)/fine_tuning/jobs"
case .retrieve(let jobID): return "/\(version)/fine_tuning/jobs/\(jobID)"
case .cancel(let jobID): return "/\(version)/fine_tuning/jobs/\(jobID)/cancel"
case .events(let jobID): return "/\(version)/fine_tuning/jobs/\(jobID)/events"
}
case .images(let category): return "/v1/images/\(category.rawValue)"
case .images(let category): return "/\(version)/images/\(category.rawValue)"
case .message(let category):
switch category {
case .create(let threadID), .list(let threadID): return "/v1/threads/\(threadID)/messages"
case .retrieve(let threadID, let messageID), .modify(let threadID, let messageID), .delete(let threadID, let messageID): return "/v1/threads/\(threadID)/messages/\(messageID)"
case .create(let threadID), .list(let threadID): return "/\(version)/threads/\(threadID)/messages"
case .retrieve(let threadID, let messageID), .modify(let threadID, let messageID), .delete(let threadID, let messageID): return "/\(version)/threads/\(threadID)/messages/\(messageID)"
}
case .model(let category):
switch category {
case .list: return "/v1/models"
case .retrieve(let modelID), .deleteFineTuneModel(let modelID): return "/v1/models/\(modelID)"
case .list: return "/\(version)/models"
case .retrieve(let modelID), .deleteFineTuneModel(let modelID): return "/\(version)/models/\(modelID)"
}
case .moderations: return "/v1/moderations"
case .moderations: return "/\(version)/moderations"
case .run(let category):
switch category {
case .create(let threadID), .list(let threadID): return "/v1/threads/\(threadID)/runs"
case .retrieve(let threadID, let runID), .modify(let threadID, let runID): return "/v1/threads/\(threadID)/runs/\(runID)"
case .cancel(let threadID, let runID): return "/v1/threads/\(threadID)/runs/\(runID)/cancel"
case .submitToolOutput(let threadID, let runID): return "/v1/threads/\(threadID)/runs/\(runID)/submit_tool_outputs"
case .createThreadAndRun: return "/v1/threads/runs"
case .create(let threadID), .list(let threadID): return "/\(version)/threads/\(threadID)/runs"
case .retrieve(let threadID, let runID), .modify(let threadID, let runID): return "/\(version)/threads/\(threadID)/runs/\(runID)"
case .cancel(let threadID, let runID): return "/\(version)/threads/\(threadID)/runs/\(runID)/cancel"
case .submitToolOutput(let threadID, let runID): return "/\(version)/threads/\(threadID)/runs/\(runID)/submit_tool_outputs"
case .createThreadAndRun: return "/\(version)/threads/runs"
}
case .runStep(let category):
switch category {
case .retrieve(let threadID, let runID, let stepID): return "/v1/threads/\(threadID)/runs/\(runID)/steps/\(stepID)"
case .list(let threadID, let runID): return "/v1/threads/\(threadID)/runs/\(runID)/steps"
case .retrieve(let threadID, let runID, let stepID): return "/\(version)/threads/\(threadID)/runs/\(runID)/steps/\(stepID)"
case .list(let threadID, let runID): return "/\(version)/threads/\(threadID)/runs/\(runID)/steps"
}
case .thread(let category):
switch category {
case .create: return "/v1/threads"
case .retrieve(let threadID), .modify(let threadID), .delete(let threadID): return "/v1/threads/\(threadID)"
case .create: return "/\(version)/threads"
case .retrieve(let threadID), .modify(let threadID), .delete(let threadID): return "/\(version)/threads/\(threadID)"
}
case .vectorStore(let category):
switch category {
case .create, .list: return "/v1/vector_stores"
case .retrieve(let vectorStoreID), .modify(let vectorStoreID), .delete(let vectorStoreID): return "/v1/vector_stores/\(vectorStoreID)"
case .create, .list: return "/\(version)/vector_stores"
case .retrieve(let vectorStoreID), .modify(let vectorStoreID), .delete(let vectorStoreID): return "/\(version)/vector_stores/\(vectorStoreID)"
}
case .vectorStoreFile(let category):
switch category {
case .create(let vectorStoreID), .list(let vectorStoreID): return "/v1/vector_stores/\(vectorStoreID)/files"
case .retrieve(let vectorStoreID, let fileID), .delete(let vectorStoreID, let fileID): return "/v1/vector_stores/\(vectorStoreID)/files/\(fileID)"
case .create(let vectorStoreID), .list(let vectorStoreID): return "/\(version)/vector_stores/\(vectorStoreID)/files"
case .retrieve(let vectorStoreID, let fileID), .delete(let vectorStoreID, let fileID): return "/\(version)/vector_stores/\(vectorStoreID)/files/\(fileID)"
}
case .vectorStoreFileBatch(let category):
switch category {
case .create(let vectorStoreID): return"/v1/vector_stores/\(vectorStoreID)/file_batches"
case .retrieve(let vectorStoreID, let batchID): return "v1/vector_stores/\(vectorStoreID)/file_batches/\(batchID)"
case .cancel(let vectorStoreID, let batchID): return "/v1/vector_stores/\(vectorStoreID)/file_batches/\(batchID)/cancel"
case .list(let vectorStoreID, let batchID): return "/v1/vector_stores/\(vectorStoreID)/file_batches/\(batchID)/files"
case .create(let vectorStoreID): return"/\(version)/vector_stores/\(vectorStoreID)/file_batches"
case .retrieve(let vectorStoreID, let batchID): return "\(version)/vector_stores/\(vectorStoreID)/file_batches/\(batchID)"
case .cancel(let vectorStoreID, let batchID): return "/\(version)/vector_stores/\(vectorStoreID)/file_batches/\(batchID)/cancel"
case .list(let vectorStoreID, let batchID): return "/\(version)/vector_stores/\(vectorStoreID)/file_batches/\(batchID)/files"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ public struct ChatCompletionParameters: Encodable {
stop: [String]? = nil,
temperature: Double? = nil,
topProbability: Double? = nil,
user: String? = nil)
user: String? = nil,
streamOptions: StreamOptions? = nil)
{
self.messages = messages
self.model = model.value
Expand All @@ -455,5 +456,6 @@ public struct ChatCompletionParameters: Encodable {
self.temperature = temperature
self.topP = topProbability
self.user = user
self.streamOptions = streamOptions
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
public struct ChatCompletionChunkObject: Decodable {

/// A unique identifier for the chat completion chunk.
public let id: String
public let id: String?
/// A list of chat completion choices. Can be more than one if n is greater than 1.
public let choices: [ChatChoice]
/// The Unix timestamp (in seconds) of when the chat completion chunk was created.
Expand Down
3 changes: 2 additions & 1 deletion Sources/OpenAI/Public/Service/DefaultOpenAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct DefaultOpenAIService: OpenAIService {
organizationID: String? = nil,
baseURL: String? = nil,
proxyPath: String? = nil,
overrideVersion: String? = nil,
configuration: URLSessionConfiguration = .default,
decoder: JSONDecoder = .init(),
debugEnabled: Bool)
Expand All @@ -37,6 +38,7 @@ struct DefaultOpenAIService: OpenAIService {
self.organizationID = organizationID
OpenAIAPI.overrideBaseURL = baseURL
OpenAIAPI.proxyPath = proxyPath
OpenAIAPI.overrideVersion = overrideVersion
self.debugEnabled = debugEnabled
}

Expand Down Expand Up @@ -85,7 +87,6 @@ struct DefaultOpenAIService: OpenAIService {
{
var chatParameters = parameters
chatParameters.stream = true
chatParameters.streamOptions = .init(includeUsage: true)
let request = try OpenAIAPI.chat.request(apiKey: apiKey, organizationID: organizationID, method: .post, params: chatParameters)
return try await fetchStream(debugEnabled: debugEnabled, type: ChatCompletionChunkObject.self, with: request)
}
Expand Down
5 changes: 4 additions & 1 deletion Sources/OpenAI/Public/Service/OpenAIServiceFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,25 @@ public class OpenAIServiceFactory {
///
/// - Parameters:
/// - apiKey: The optional API key required for authentication.
/// - baseURL: The local host URL. defaults to "https://api.groq.com"
/// - baseURL: The local host URL. e.g "https://api.groq.com" or "https://generativelanguage.googleapis.com"
/// - proxyPath: The proxy path e.g `openai`
/// - overrideVersion: The API version. defaults to `V1`
/// - debugEnabled: If `true` service prints event on DEBUG builds, default to `false`.
///
/// - Returns: A fully configured object conforming to `OpenAIService`.
public static func service(
apiKey: String,
overrideBaseURL: String,
proxyPath: String? = nil,
overrideVersion: String? = nil,
debugEnabled: Bool = false)
-> OpenAIService
{
DefaultOpenAIService(
apiKey: apiKey,
baseURL: overrideBaseURL,
proxyPath: proxyPath,
overrideVersion: overrideVersion,
debugEnabled: debugEnabled)
}
}

0 comments on commit c3a04bb

Please sign in to comment.