diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/AbstractRollingUpgradeTestCase.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/AbstractRollingUpgradeTestCase.java index a1ad7178c..ecdf4cff8 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/AbstractRollingUpgradeTestCase.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/AbstractRollingUpgradeTestCase.java @@ -8,6 +8,7 @@ import java.nio.file.Path; import java.util.Locale; import java.util.Optional; + import org.junit.Before; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.BaseNeuralSearchIT; diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/SemanticSearchIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/SemanticSearchIT.java index 989d53897..97ba00e4d 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/SemanticSearchIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/rolling/SemanticSearchIT.java @@ -8,8 +8,6 @@ import java.nio.file.Path; import java.util.Map; import static org.opensearch.neuralsearch.util.TestUtils.NODES_BWC_CLUSTER; -import static org.opensearch.neuralsearch.util.TestUtils.TEXT_EMBEDDING_PROCESSOR; -import static org.opensearch.neuralsearch.util.TestUtils.getModelId; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; public class SemanticSearchIT extends AbstractRollingUpgradeTestCase { @@ -19,7 +17,6 @@ public class SemanticSearchIT extends AbstractRollingUpgradeTestCase { private static final String TEXT_MIXED = "Hello world mixed"; private static final String TEXT_UPGRADED = "Hello world upgraded"; private static final int NUM_DOCS_PER_ROUND = 1; - private static String modelId = ""; // Test rolling-upgrade Semantic Search // Create Text Embedding Processor, Ingestion Pipeline and add document @@ -28,9 +25,8 @@ public void testSemanticSearch_E2EFlow() throws Exception { waitForClusterHealthGreen(NODES_BWC_CLUSTER, 90); switch (getClusterType()) { case OLD: - modelId = uploadTextEmbeddingModel(); - loadModel(modelId); - createPipelineProcessor(modelId, PIPELINE_NAME); + loadModel(textEmbeddingModelId); + createPipelineProcessor(textEmbeddingModelId, PIPELINE_NAME); createIndexWithConfiguration( getIndexNameForTest(), Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), @@ -39,26 +35,24 @@ public void testSemanticSearch_E2EFlow() throws Exception { addDocument(getIndexNameForTest(), "0", TEST_FIELD, TEXT, null, null); break; case MIXED: - modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); int totalDocsCountMixed; if (isFirstMixedRound()) { totalDocsCountMixed = NUM_DOCS_PER_ROUND; - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, TEXT); + validateTestIndexOnUpgrade(totalDocsCountMixed, textEmbeddingModelId, TEXT); addDocument(getIndexNameForTest(), "1", TEST_FIELD, TEXT_MIXED, null, null); } else { totalDocsCountMixed = 2 * NUM_DOCS_PER_ROUND; - validateTestIndexOnUpgrade(totalDocsCountMixed, modelId, TEXT_MIXED); + validateTestIndexOnUpgrade(totalDocsCountMixed, textEmbeddingModelId, TEXT_MIXED); } break; case UPGRADED: try { - modelId = getModelId(getIngestionPipeline(PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); int totalDocsCountUpgraded = 3 * NUM_DOCS_PER_ROUND; - loadModel(modelId); + loadModel(textEmbeddingModelId); addDocument(getIndexNameForTest(), "2", TEST_FIELD, TEXT_UPGRADED, null, null); - validateTestIndexOnUpgrade(totalDocsCountUpgraded, modelId, TEXT_UPGRADED); + validateTestIndexOnUpgrade(totalDocsCountUpgraded, textEmbeddingModelId, TEXT_UPGRADED); } finally { - wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, modelId, null); + wipeOfTestResources(getIndexNameForTest(), PIPELINE_NAME, null, null); } break; default: diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 2182d8d79..4b880181c 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch; +import joptsimple.internal.Strings; import org.opensearch.ml.common.model.MLModelState; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; @@ -69,6 +70,7 @@ import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.util.TestUtils.MAX_RETRY; import static org.opensearch.neuralsearch.util.TestUtils.MAX_TIME_OUT_INTERVAL; +import static org.opensearch.neuralsearch.util.TestUtils.generateModelId; import lombok.AllArgsConstructor; import lombok.Getter; @@ -78,6 +80,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected static final Locale LOCALE = Locale.ROOT; + protected static String textEmbeddingModelId = Strings.EMPTY; protected static final Map PIPELINE_CONFIGS_BY_TYPE = Map.of( ProcessorType.TEXT_EMBEDDING, @@ -108,6 +111,12 @@ public void setupSettings() { updateClusterSettings(); } NeuralSearchClusterUtil.instance().initialize(clusterService); + setUpModels(); + } + + @SneakyThrows + protected void setUpModels() { + textEmbeddingModelId = uploadTextEmbedding(); } protected ThreadPool setUpThreadPool() { @@ -122,6 +131,19 @@ public static ClusterService createClusterService(ThreadPool threadPool) { return ClusterServiceUtils.createClusterService(threadPool); } + protected String uploadTextEmbedding() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return registerModel(requestBody); + } + + protected String registerModel(String requestBody) throws Exception { + String modelGroupRegisterRequestBody = Files.readString( + Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI()) + ); + String modelGroupId = registerModelGroup(String.format(LOCALE, modelGroupRegisterRequestBody, generateModelId())); + return uploadModel(String.format(LOCALE, requestBody, modelGroupId)); + } + protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine