Skip to content

Commit

Permalink
OpenAI Streaming (#25)
Browse files Browse the repository at this point in the history
* OpenAI streaming

* Added homepage and error handling todo

* Renamed vars

* Added todos

* Made stream generic, try-with resources, TEXT_EVENT_STREAM, exception refactored

* Formatting

* close stream correctly

* Formatting

* Created OpenAiStreamOutput

* Formatting

* Renamed stream to streamChatCompletion, Added comments

* Added total output

* Total output is printed

* Formatting

* addDelta is propagated everywhere

* addDelta is propagated everywhere

* forgotten addDeltas

* Added jackson dependencies

* Added Javadoc

* Removed 1 TODO

* PMD

* PMD again

* Added OpenAiClientTest.streamChatCompletion()

* Change return type of stream, added e2e test

* Added documentation

* Added documentation framework-agnostic + throw if finish reason is invalid

* Added error handling test

* Updates from pair review / discussion

* Cleanup + streamChatCompletion doesn't throw

* PMD

* Added errorHandling test

* Apply suggestions from code review

Co-authored-by: Matthias Kuhr <[email protected]>

* Dependency analyze

* Review comments

* Make client static

* Formatting

* PMD

* Fix tests

* Removed exception constructors no args

* Refactor exception message

* Readme sentences

* Remove superfluous call super

* reset httpclient-cache and -factory after each test case

* Very minor code-style improvements in test

* Minor code-style in OpenAIController

* Reduce README sample code

* Update OpenAiStreamingHandler.java (#43)

* Fix import

* Added stream_options to model

* Change Executor#submit() to #execute()

* Added usage testing

* Added beautiful Javadoc to enableStreaming

* typo

* Fix mistake

* streaming readme (#48)

* Reduce sample code

* Format

---------

Co-authored-by: SAP Cloud SDK Bot <[email protected]>
Co-authored-by: Matthias Kuhr <[email protected]>
Co-authored-by: Matthias Kuhr <[email protected]>
Co-authored-by: Alexander Dümont <[email protected]>
Co-authored-by: Alexander Dümont <[email protected]>
  • Loading branch information
6 people authored Sep 4, 2024
1 parent a2832a2 commit 89bb17c
Show file tree
Hide file tree
Showing 30 changed files with 1,138 additions and 157 deletions.
67 changes: 66 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,78 @@ final String resultMessage = result.getChoices().get(0).getMessage().getContent(

See [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java)

### Chat completion with a model not in defined `OpenAiModel`
### Chat completion with a model not defined in `OpenAiModel`

```java
final OpenAiChatCompletionOutput result =
OpenAiClient.forModel(new OpenAiModel("model")).chatCompletion(request);
```

### Stream chat completion

It's possible to pass a stream of chat completion delta elements, e.g. from the application backend to the frontend in real-time.

#### Stream the chat completion asynchronously
This is a blocking example for streaming and printing directly to the console:
```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// try-with-resources on stream ensures the connection will be closed
try( Stream<String> stream = client.streamChatCompletion(request)) {
stream.forEach(deltaString -> {
System.out.print(deltaString);
System.out.flush();
});
}
```

<details>
<summary>It's also possible to aggregate the total output.</summary>

The following example is non-blocking.
Any asynchronous library can be used, e.g. classic Thread API.

```java
String msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";

OpenAiChatCompletionParameters request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

OpenAiChatCompletionOutput totalOutput = new OpenAiChatCompletionOutput();
OpenAiClient client = OpenAiClient.forModel(GPT_35_TURBO);

// Do the request before the thread starts to handle exceptions during request initialization
Stream<OpenAiChatCompletionDelta> stream = client.streamChatCompletionDeltas(request);

Thread thread = new Thread(() -> {
// try-with-resources ensures the stream is closed
try (stream) {
stream.peek(totalOutput::addDelta).forEach(delta -> System.out.println(delta));
}
});
thread.start(); // non-blocking

thread.join(); // blocking

// access aggregated information from total output, e.g.
Integer tokens = totalOutput.getUsage().getCompletionTokens();
System.out.println("Tokens: " + tokens);
```

</details>

#### Spring Boot example

Please find [an example in our Spring Boot application](e2e-test-app/src/main/java/com/sap/ai/sdk/app/controllers/OpenAiController.java).
It shows the usage of Spring Boot's `ResponseBodyEmitter` to stream the chat completion delta messages to the frontend in real-time.

## Orchestration chat completion

### Prerequisites
Expand Down
13 changes: 13 additions & 0 deletions e2e-test-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>${springframework.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand All @@ -95,6 +100,14 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
</dependency>
<!-- scope "runtime" -->
<dependency>
<groupId>ch.qos.logback</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
Expand All @@ -14,13 +16,21 @@
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage.ImageDetailLevel;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

/** Endpoints for OpenAI operations */
@Slf4j
@RestController
class OpenAiController {
/**
Expand All @@ -38,6 +48,98 @@ public static OpenAiChatCompletionOutput chatCompletion() {
return OpenAiClient.forModel(GPT_35_TURBO).chatCompletion(request);
}

/**
* Asynchronous stream of an OpenAI chat request
*
* @return the emitter that streams the assistant message response
*/
@SuppressWarnings("unused") // The end-to-end test doesn't use this method
@GetMapping("/streamChatCompletionDeltas")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
final var msg = "Can you give me the first 100 numbers of the Fibonacci sequence?";
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText(msg)));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
final var totalOutput = new OpenAiChatCompletionOutput();
// try-with-resources ensures the stream is closed
try (stream) {
stream
.peek(totalOutput::addDelta)
.forEach(delta -> send(emitter, delta.getDeltaContent()));
} finally {
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

private static String objectToJson(@Nonnull final Object obj) {
try {
return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj);
} catch (final JsonProcessingException ignored) {
return "Could not parse object to JSON";
}
}

/**
* Asynchronous stream of an OpenAI chat request
*
* @return the emitter that streams the assistant message response
*/
@SuppressWarnings("unused") // The end-to-end test doesn't use this method
@GetMapping("/streamChatCompletion")
@Nonnull
public static ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(
List.of(
new OpenAiChatUserMessage()
.addText(
"Can you give me the first 100 numbers of the Fibonacci sequence?")));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletion(request);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
try (stream) {
stream.forEach(deltaMessage -> send(emitter, deltaMessage));
} finally {
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

private static void send(
@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
try {
emitter.send(chunk);
} catch (final IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
}

/**
* Chat request to OpenAI with an image
*
Expand Down
2 changes: 2 additions & 0 deletions e2e-test-app/src/main/resources/static/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ <h2>Endpoints</h2>
<li><h4>OpenAI</h4></li>
<ul>
<li><a href="/chatCompletion">/chatCompletion</a></li>
<li><a href="/streamChatCompletion">/streamChatCompletion</a></li>
<li><a href="/streamChatCompletionDeltas">/streamChatCompletionDeltas</a></li>
<li><a href="/chatCompletionTool">/chatCompletionTool</a></li>
<li><a href="/chatCompletionImage">/chatCompletionImage</a></li>
<li><a href="/embedding">/embedding</a></li>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package com.sap.ai.sdk.app.controllers;

import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.GPT_35_TURBO;
import static org.assertj.core.api.Assertions.assertThat;

import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage.OpenAiChatUserMessage;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;

@Slf4j
class OpenAiTest {
@Test
void chatCompletion() {
Expand All @@ -23,12 +32,44 @@ void chatCompletionImage() {
assertThat(message.getContent()).isNotEmpty();
}

@Test
void streamChatCompletion() {
final var request =
new OpenAiChatCompletionParameters()
.setMessages(List.of(new OpenAiChatUserMessage().addText("Who is the prettiest?")));

final var totalOutput = new OpenAiChatCompletionOutput();
final var emptyDeltaCount = new AtomicInteger(0);
OpenAiClient.forModel(GPT_35_TURBO)
.streamChatCompletionDeltas(request)
.peek(totalOutput::addDelta)
// foreach consumes all elements, closing the stream at the end
.forEach(
delta -> {
final String deltaContent = delta.getDeltaContent();
log.info("deltaContent: {}", deltaContent);
if (deltaContent.isEmpty()) {
emptyDeltaCount.incrementAndGet();
}
});

// the first two and the last delta don't have any content
// see OpenAiChatCompletionDelta#getDeltaContent
assertThat(emptyDeltaCount.get()).isLessThanOrEqualTo(3);

assertThat(totalOutput.getChoices()).isNotEmpty();
assertThat(totalOutput.getChoices().get(0).getMessage().getContent()).isNotEmpty();
assertThat(totalOutput.getPromptFilterResults()).isNotNull();
assertThat(totalOutput.getChoices().get(0).getContentFilterResults()).isNotNull();
}

@Test
void chatCompletionTools() {
final var completion = OpenAiController.chatCompletionTools();

final var message = completion.getChoices().get(0).getMessage();
assertThat(message.getRole()).isEqualTo("assistant");
assertThat(message.getTool_calls()).isNotNull();
assertThat(message.getTool_calls().get(0).getFunction().getName()).isEqualTo("fibonacci");
}

Expand Down
10 changes: 10 additions & 0 deletions foundation-models/openai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@
<artifactId>junit-jupiter-api</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.wiremock</groupId>
<artifactId>wiremock</artifactId>
Expand All @@ -107,5 +112,10 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Loading

0 comments on commit 89bb17c

Please sign in to comment.