Skip to content

Commit

Permalink
OPIK-611 refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
idoberko2 committed Jan 6, 2025
1 parent bf709d1 commit 26842dd
Showing 1 changed file with 28 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,7 @@ public Gemini(LlmProviderClientConfig llmProviderClientConfig, String apiKey) {

@Override
public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @NonNull String workspaceId) {
GoogleAiGeminiChatModel client = GoogleAiGeminiChatModel.builder()
.apiKey(apiKey)
.modelName(request.model())
.maxOutputTokens(request.maxCompletionTokens())
.maxRetries(1)
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.timeout(llmProviderClientConfig.getCallTimeout().toJavaDuration())
.build();
var response = client.generate(request.messages().stream().map(this::toChatMessage).toList());
var response = createClient(request).generate(request.messages().stream().map(this::toChatMessage).toList());

return ChatCompletionResponse.builder()
.model(request.model())
Expand All @@ -62,17 +52,7 @@ public ChatCompletionResponse generate(@NonNull ChatCompletionRequest request, @
public void generateStream(@NonNull ChatCompletionRequest request, @NonNull String workspaceId,
@NonNull Consumer<ChatCompletionResponse> handleMessage, @NonNull Runnable handleClose,
@NonNull Consumer<Throwable> handleError) {
var client = GoogleAiGeminiStreamingChatModel.builder()
.apiKey(apiKey)
.modelName(request.model())
.maxOutputTokens(request.maxCompletionTokens())
.maxRetries(1)
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.timeout(llmProviderClientConfig.getCallTimeout().toJavaDuration())
.build();
client.generate(request.messages().stream().map(this::toChatMessage).toList(),
createStreamingClient(request).generate(request.messages().stream().map(this::toChatMessage).toList(),
new ChunkedResponseHandler(handleMessage, handleClose, handleError, request.model()));
}

Expand Down Expand Up @@ -109,4 +89,30 @@ private String toStringMessageContent(Object rawContent) {

throw new BadRequestException("only text content is supported");
}

private GoogleAiGeminiChatModel createClient(ChatCompletionRequest request) {
return GoogleAiGeminiChatModel.builder()
.apiKey(apiKey)
.modelName(request.model())
.maxOutputTokens(request.maxCompletionTokens())
.maxRetries(1)
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.timeout(llmProviderClientConfig.getCallTimeout().toJavaDuration())
.build();
}

private GoogleAiGeminiStreamingChatModel createStreamingClient(ChatCompletionRequest request) {
return GoogleAiGeminiStreamingChatModel.builder()
.apiKey(apiKey)
.modelName(request.model())
.maxOutputTokens(request.maxCompletionTokens())
.maxRetries(1)
.stopSequences(request.stop())
.temperature(request.temperature())
.topP(request.topP())
.timeout(llmProviderClientConfig.getCallTimeout().toJavaDuration())
.build();
}
}

0 comments on commit 26842dd

Please sign in to comment.