diff --git a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md index 635de051..771922e9 100644 --- a/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md +++ b/docs/guides/ORCHESTRATION_CHAT_COMPLETION.md @@ -206,31 +206,31 @@ In this example, the input will be masked before the call to the LLM and will re Use the grounding module to provide additional context to the AI model. ```java - var message = - Message.user( - "{{?groundingInput}} Use the following information as additional context: {{?groundingOutput}}"); - var prompt = - new OrchestrationPrompt(Map.of("groundingInput", "What does Joule do?"), message); - - var filterInner = - DocumentGroundingFilter.create().id("someID").dataRepositoryType(DataRepositoryType.VECTOR); - var groundingConfigConfig = - GroundingModuleConfigConfig.create() - .inputParams(List.of("groundingInput")) - .outputParam("groundingOutput") - .addFiltersItem(filterInner); - - var groundingConfig = - GroundingModuleConfig.create() - .type(GroundingModuleConfig.TypeEnum.DOCUMENT_GROUNDING_SERVICE) - .config(groundingConfigConfig); - var configWithGrounding = config.withGroundingConfig(groundingConfig); - - var result = - new OrchestrationClient().chatCompletion(prompt, configWithGrounding); +// optional filter for collections +var documentMetadata = + SearchDocumentKeyValueListPair.create() + .key("my-collection") + .value("value") + .addSelectModeItem(SearchSelectOptionEnum.IGNORE_IF_KEY_ABSENT); +// optional filter for document chunks +var databaseFilter = + DocumentGroundingFilter.create() + .id("") + .dataRepositoryType(DataRepositoryType.VECTOR) + .addDocumentMetadataItem(documentMetadata); + +var groundingConfig = Grounding.create().filter(databaseFilter); +var prompt = groundingConfig.createGroundingPrompt("What does Joule do?"); +var configWithGrounding = config.withGrounding(groundingConfig); + +var result = client.chatCompletion(prompt, configWithGrounding); ``` -In this example, the AI model is provided with additional context in the form of grounding information. Note, that it is necessary to provide the grounding input via one or more input variables. +In this example, the AI model is provided with additional context in the form of grounding information. + +`Grounding.create()` is by default a document grounding service with a vector data repository. + +Please find [an example in our Spring Boot application](../../sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java). ## Stream chat completion diff --git a/docs/release-notes/release_notes.md b/docs/release-notes/release_notes.md index ef831384..1275365d 100644 --- a/docs/release-notes/release_notes.md +++ b/docs/release-notes/release_notes.md @@ -16,6 +16,7 @@ - New Orchestration features: - [Spring AI integration](../guides/ORCHESTRATION_CHAT_COMPLETION.md#spring-ai-integration) + - [Add Grounding configuration convenience](../guides/ORCHESTRATION_CHAT_COMPLETION.md#grounding) - Images are now supported as input in newly introduced `MultiChatMessage`. - `MultiChatMessage` also allows for multiple content items (text or image) in one object. - Grounding input can be masked with `DPIConfig`. diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Grounding.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Grounding.java new file mode 100644 index 00000000..d6087436 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/Grounding.java @@ -0,0 +1,93 @@ +package com.sap.ai.sdk.orchestration; + +import com.google.common.annotations.Beta; +import com.sap.ai.sdk.orchestration.model.DataRepositoryType; +import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig.TypeEnum; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfig; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfigFiltersInner; +import java.util.List; +import java.util.Map; +import javax.annotation.Nonnull; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.val; + +/** + * Grounding integrates external, contextually relevant, domain-specific, or real-time data into AI + * processes. This data supplements the natural language processing capabilities of pre-trained + * models, which are trained on general material. + * + * @link SAP AI + * Core: Orchestration - Grounding + */ +@Beta +@Accessors(fluent = true) +public class Grounding implements GroundingProvider { + + @Nonnull + private List filters = + List.of( + DocumentGroundingFilter.create().id("").dataRepositoryType(DataRepositoryType.VECTOR)); + + @Setter(onMethod_ = {@Nonnull}) + private TypeEnum documentGroundingService = TypeEnum.DOCUMENT_GROUNDING_SERVICE; + + /** + * Create a new default grounding provider. + * + *

It is by default a document grounding service with a vector data repository. + * + * @return The grounding provider. + */ + @Nonnull + public static Grounding create() { + return new Grounding(); + } + + /** + * Set filters for grounding. + * + * @param filters List of filters to set. + * @return The modified grounding configuration. + */ + @Nonnull + public Grounding filters(@Nonnull final GroundingModuleConfigConfigFiltersInner... filters) { + if (filters.length != 0) { + this.filters = List.of(filters); + } + return this; + } + + /** + * Create a prompt with grounding parameters included in the message. + * + *

It uses the inputParams {@code userMessage} for the user message and {@code + * groundingContext} for the grounding context. + * + * @param message The user message. + * @return The prompt with grounding. + */ + @Nonnull + public OrchestrationPrompt createGroundingPrompt(@Nonnull final String message) { + return new OrchestrationPrompt( + Map.of("userMessage", message), + Message.user( + "{{?userMessage}} Use the following information as additional context: {{?groundingContext}}")); + } + + @Nonnull + @Override + public GroundingModuleConfig createConfig() { + val groundingConfigConfig = + GroundingModuleConfigConfig.create() + .inputParams(List.of("userMessage")) + .outputParam("groundingContext") + .filters(filters); + + return GroundingModuleConfig.create() + .type(documentGroundingService) + .config(groundingConfigConfig); + } +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/GroundingProvider.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/GroundingProvider.java new file mode 100644 index 00000000..b19b4c01 --- /dev/null +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/GroundingProvider.java @@ -0,0 +1,22 @@ +package com.sap.ai.sdk.orchestration; + +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; +import javax.annotation.Nonnull; + +/** + * Interface for grounding configurations. + * + * @link SAP AI + * Core: Orchestration - Grounding + */ +@FunctionalInterface +public interface GroundingProvider { + + /** + * Create a grounding configuration. + * + * @return the grounding configuration + */ + @Nonnull + GroundingModuleConfig createConfig(); +} diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingProvider.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingProvider.java index 001db2e2..96f932ce 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingProvider.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/MaskingProvider.java @@ -12,9 +12,9 @@ public interface MaskingProvider { /** - * Create a masking provider for the configuration. + * Create a masking configuration. * - * @return the masking provider + * @return the masking configuration */ @Nonnull MaskingProviderConfig createConfig(); diff --git a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java index 2d603678..ff6a52f3 100644 --- a/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java +++ b/orchestration/src/main/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfig.java @@ -197,4 +197,18 @@ public OrchestrationModuleConfig withOutputFiltering( return this.withFilteringConfig(newFilteringConfig); } + + /** + * Creates a new configuration with the given grounding configuration. + * + * @link SAP + * AI Core: Orchestration - Grounding + * @param groundingProvider The grounding configuration to use. + * @return A new configuration with the given grounding configuration. + */ + @Nonnull + public OrchestrationModuleConfig withGrounding( + @Nonnull final GroundingProvider groundingProvider) { + return this.withGroundingConfig(groundingProvider.createConfig()); + } } diff --git a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java index 8e2f2529..bdfa1f02 100644 --- a/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java +++ b/orchestration/src/test/java/com/sap/ai/sdk/orchestration/OrchestrationModuleConfigTest.java @@ -3,11 +3,17 @@ import static com.sap.ai.sdk.orchestration.AzureFilterThreshold.ALLOW_SAFE_LOW_MEDIUM; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.GPT_4O; import static com.sap.ai.sdk.orchestration.OrchestrationAiModel.Parameter.MAX_TOKENS; +import static com.sap.ai.sdk.orchestration.model.DataRepositoryType.VECTOR; +import static com.sap.ai.sdk.orchestration.model.GroundingModuleConfig.TypeEnum.DOCUMENT_GROUNDING_SERVICE; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import com.sap.ai.sdk.orchestration.model.DPIConfig; import com.sap.ai.sdk.orchestration.model.DPIEntities; +import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfig; +import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfigFiltersInner; +import java.util.List; import java.util.Map; import org.junit.jupiter.api.Test; @@ -121,4 +127,50 @@ void testLLMConfig() { .withFailMessage("Static models should be unchanged") .isEqualTo("latest"); } + + @Test + void testGroundingConfig() { + var groundingConfig = Grounding.create(); + var config = + new OrchestrationModuleConfig().withLlmConfig(GPT_4O).withGrounding(groundingConfig); + + assertThat(config.getGroundingConfig()).isNotNull(); + assertThat(config.getGroundingConfig().getType()).isEqualTo(DOCUMENT_GROUNDING_SERVICE); + + GroundingModuleConfigConfig configConfig = config.getGroundingConfig().getConfig(); + assertThat(configConfig).isNotNull(); + assertThat(configConfig.getInputParams()).containsExactly("userMessage"); + assertThat(configConfig.getOutputParam()).isEqualTo("groundingContext"); + + List filters = configConfig.getFilters(); + assertThat(filters).hasSize(1); + DocumentGroundingFilter filter = (DocumentGroundingFilter) filters.get(0); + assertThat(filter.getId()).isEqualTo(""); + assertThat(filter.getDataRepositoryType()).isEqualTo(VECTOR); + } + + @Test + void testGroundingConfigWithFilters() { + var filter1 = DocumentGroundingFilter.create().id("123").dataRepositoryType(VECTOR); + var filter2 = DocumentGroundingFilter.create().id("234").dataRepositoryType(VECTOR); + var groundingConfig = Grounding.create().filters(filter1, filter2); + var config = + new OrchestrationModuleConfig().withLlmConfig(GPT_4O).withGrounding(groundingConfig); + + assertThat(config.getGroundingConfig()).isNotNull(); + var configConfig = config.getGroundingConfig().getConfig(); + assertThat(configConfig).isNotNull(); + + assertThat(config.getGroundingConfig().getConfig().getFilters()).hasSize(2); + } + + @Test + void testGroundingPrompt() { + var prompt = Grounding.create().createGroundingPrompt("Hello, World!"); + assertThat(prompt.getMessages()).hasSize(1); + var message = prompt.getMessages().get(0); + assertThat(message.content()) + .isEqualTo( + "{{?userMessage}} Use the following information as additional context: {{?groundingContext}}"); + } } diff --git a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java index d8746f9c..88271418 100644 --- a/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java +++ b/sample-code/spring-app/src/main/java/com/sap/ai/sdk/app/services/OrchestrationService.java @@ -7,6 +7,7 @@ import com.sap.ai.sdk.orchestration.AzureContentFilter; import com.sap.ai.sdk.orchestration.AzureFilterThreshold; import com.sap.ai.sdk.orchestration.DpiMasking; +import com.sap.ai.sdk.orchestration.Grounding; import com.sap.ai.sdk.orchestration.Message; import com.sap.ai.sdk.orchestration.OrchestrationChatResponse; import com.sap.ai.sdk.orchestration.OrchestrationClient; @@ -16,8 +17,9 @@ import com.sap.ai.sdk.orchestration.model.DPIEntities; import com.sap.ai.sdk.orchestration.model.DataRepositoryType; import com.sap.ai.sdk.orchestration.model.DocumentGroundingFilter; -import com.sap.ai.sdk.orchestration.model.GroundingModuleConfig; -import com.sap.ai.sdk.orchestration.model.GroundingModuleConfigConfig; +import com.sap.ai.sdk.orchestration.model.GroundingFilterSearchConfiguration; +import com.sap.ai.sdk.orchestration.model.SearchDocumentKeyValueListPair; +import com.sap.ai.sdk.orchestration.model.SearchSelectOptionEnum; import com.sap.ai.sdk.orchestration.model.Template; import java.util.List; import java.util.Map; @@ -25,6 +27,7 @@ import javax.annotation.Nonnull; import lombok.Getter; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.springframework.stereotype.Service; /** Service class for the Orchestration service */ @@ -44,7 +47,7 @@ public class OrchestrationService { */ @Nonnull public OrchestrationChatResponse completion(@Nonnull final String famousPhrase) { - final var prompt = new OrchestrationPrompt(famousPhrase + " Why is this phrase so famous?"); + val prompt = new OrchestrationPrompt(famousPhrase + " Why is this phrase so famous?"); return client.chatCompletion(prompt, config); } @@ -55,7 +58,7 @@ public OrchestrationChatResponse completion(@Nonnull final String famousPhrase) */ @Nonnull public Stream streamChatCompletion(@Nonnull final String topic) { - final var prompt = + val prompt = new OrchestrationPrompt( "Please create a small story about " + topic + " with around 700 words."); return client.streamChatCompletion(prompt, config); @@ -70,13 +73,12 @@ public Stream streamChatCompletion(@Nonnull final String topic) { */ @Nonnull public OrchestrationChatResponse template(@Nonnull final String language) { - final var template = - Message.user("Reply with 'Orchestration Service is working!' in {{?language}}"); - final var templatingConfig = Template.create().template(List.of(template.createChatMessage())); - final var configWithTemplate = config.withTemplateConfig(templatingConfig); + val template = Message.user("Reply with 'Orchestration Service is working!' in {{?language}}"); + val templatingConfig = Template.create().template(List.of(template.createChatMessage())); + val configWithTemplate = config.withTemplateConfig(templatingConfig); - final var inputParams = Map.of("language", language); - final var prompt = new OrchestrationPrompt(inputParams); + val inputParams = Map.of("language", language); + val prompt = new OrchestrationPrompt(inputParams); return client.chatCompletion(prompt, configWithTemplate); } @@ -88,12 +90,12 @@ public OrchestrationChatResponse template(@Nonnull final String language) { */ @Nonnull public OrchestrationChatResponse messagesHistory(@Nonnull final String prevMessage) { - final var prompt = new OrchestrationPrompt(Message.user(prevMessage)); + val prompt = new OrchestrationPrompt(Message.user(prevMessage)); - final var result = client.chatCompletion(prompt, config); + val result = client.chatCompletion(prompt, config); // Let's presume a user asks the following follow-up question - final var nextPrompt = + val nextPrompt = new OrchestrationPrompt(Message.user("What is the typical food there?")) .messageHistory(result.getAllMessages()); @@ -113,12 +115,12 @@ public OrchestrationChatResponse messagesHistory(@Nonnull final String prevMessa @Nonnull public OrchestrationChatResponse inputFiltering(@Nonnull final AzureFilterThreshold policy) throws OrchestrationClientException { - final var prompt = + val prompt = new OrchestrationPrompt("'We shall spill blood tonight', said the operation in-charge."); - final var filterConfig = + val filterConfig = new AzureContentFilter().hate(policy).selfHarm(policy).sexual(policy).violence(policy); - final var configWithFilter = config.withInputFiltering(filterConfig); + val configWithFilter = config.withInputFiltering(filterConfig); return client.chatCompletion(prompt, configWithFilter); } @@ -135,16 +137,16 @@ public OrchestrationChatResponse inputFiltering(@Nonnull final AzureFilterThresh @Nonnull public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThreshold policy) { - final var systemMessage = Message.system("Give three paraphrases for the following sentence"); + val systemMessage = Message.system("Give three paraphrases for the following sentence"); // Reliably triggering the content filter of models fine-tuned for ethical compliance // is difficult. The prompt below may be rendered ineffective in the future. - final var prompt = + val prompt = new OrchestrationPrompt("'We shall spill blood tonight', said the operation in-charge.") .messageHistory(List.of(systemMessage)); - final var filterConfig = + val filterConfig = new AzureContentFilter().hate(policy).selfHarm(policy).sexual(policy).violence(policy); - final var configWithFilter = config.withOutputFiltering(filterConfig); + val configWithFilter = config.withOutputFiltering(filterConfig); return client.chatCompletion(prompt, configWithFilter); } @@ -160,19 +162,19 @@ public OrchestrationChatResponse outputFiltering(@Nonnull final AzureFilterThres */ @Nonnull public OrchestrationChatResponse maskingAnonymization(@Nonnull final DPIEntities entity) { - final var systemMessage = + val systemMessage = Message.system( "Please evaluate the following user feedback and judge if the sentiment is positive or negative."); - final var userMessage = + val userMessage = Message.user( """ I think the SDK is good, but could use some further enhancements. My architect Alice and manager Bob pointed out that we need the grounding capabilities, which aren't supported yet. """); - final var prompt = new OrchestrationPrompt(systemMessage, userMessage); - final var maskingConfig = DpiMasking.anonymization().withEntities(entity); - final var configWithMasking = config.withMaskingConfig(maskingConfig); + val prompt = new OrchestrationPrompt(systemMessage, userMessage); + val maskingConfig = DpiMasking.anonymization().withEntities(entity); + val configWithMasking = config.withMaskingConfig(maskingConfig); return client.chatCompletion(prompt, configWithMasking); } @@ -185,11 +187,11 @@ public OrchestrationChatResponse maskingAnonymization(@Nonnull final DPIEntities @Nonnull public OrchestrationChatResponse completionWithResourceGroup( @Nonnull final String resourceGroup, @Nonnull final String famousPhrase) { - final var destination = + val destination = new AiCoreService().getInferenceDestination(resourceGroup).forScenario("orchestration"); - final var clientWithResourceGroup = new OrchestrationClient(destination); + val clientWithResourceGroup = new OrchestrationClient(destination); - final var prompt = new OrchestrationPrompt(famousPhrase + " Why is this phrase so famous?"); + val prompt = new OrchestrationPrompt(famousPhrase + " Why is this phrase so famous?"); return clientWithResourceGroup.chatCompletion(prompt, config); } @@ -205,13 +207,13 @@ public OrchestrationChatResponse completionWithResourceGroup( */ @Nonnull public OrchestrationChatResponse maskingPseudonymization(@Nonnull final DPIEntities entity) { - final var systemMessage = + val systemMessage = Message.system( """ Please write an initial response to the below user feedback, stating that we are working on the feedback and will get back to them soon. Please make sure to address the user in person and end with "Best regards, the AI SDK team". """); - final var userMessage = + val userMessage = Message.user( """ Username: Mallory @@ -222,9 +224,9 @@ public OrchestrationChatResponse maskingPseudonymization(@Nonnull final DPIEntit My architect Alice and manager Bob pointed out that we need the grounding capabilities, which aren't supported yet. """); - final var prompt = new OrchestrationPrompt(systemMessage, userMessage); - final var maskingConfig = DpiMasking.pseudonymization().withEntities(entity, DPIEntities.EMAIL); - final var configWithMasking = config.withMaskingConfig(maskingConfig); + val prompt = new OrchestrationPrompt(systemMessage, userMessage); + val maskingConfig = DpiMasking.pseudonymization().withEntities(entity, DPIEntities.EMAIL); + val configWithMasking = config.withMaskingConfig(maskingConfig); return client.chatCompletion(prompt, configWithMasking); } @@ -234,27 +236,28 @@ public OrchestrationChatResponse maskingPseudonymization(@Nonnull final DPIEntit * * @link SAP * AI Core: Orchestration - Grounding + * @param userMessage the user message to provide grounding for * @return the assistant response object */ @Nonnull - public OrchestrationChatResponse grounding(@Nonnull final String groundingInput) { - final var message = - Message.user( - "{{?groundingInput}} Use the following information as additional context: {{?groundingOutput}}"); - final var prompt = new OrchestrationPrompt(Map.of("groundingInput", groundingInput), message); - - final var filterInner = - DocumentGroundingFilter.create().id("someID").dataRepositoryType(DataRepositoryType.VECTOR); - final var groundingConfigConfig = - GroundingModuleConfigConfig.create() - .inputParams(List.of("groundingInput")) - .outputParam("groundingOutput") - .addFiltersItem(filterInner); - final var groundingConfig = - GroundingModuleConfig.create() - .type(GroundingModuleConfig.TypeEnum.DOCUMENT_GROUNDING_SERVICE) - .config(groundingConfigConfig); - final var configWithGrounding = config.withGroundingConfig(groundingConfig); + public OrchestrationChatResponse grounding(@Nonnull final String userMessage) { + // optional filter for collections + val documentMetadata = + SearchDocumentKeyValueListPair.create() + .key("document metadata") + .value("2") + .addSelectModeItem(SearchSelectOptionEnum.IGNORE_IF_KEY_ABSENT); + // optional filter for document chunks + val databaseFilter = + DocumentGroundingFilter.create() + .id("") + .dataRepositoryType(DataRepositoryType.VECTOR) + .searchConfig(GroundingFilterSearchConfiguration.create().maxChunkCount(1)) + .addDocumentMetadataItem(documentMetadata); + + val groundingConfig = Grounding.create().filters(databaseFilter); + val prompt = groundingConfig.createGroundingPrompt(userMessage); + val configWithGrounding = config.withGrounding(groundingConfig); return client.chatCompletion(prompt, configWithGrounding); }