From af8f5d2a4f48126003c7eddf0c08c7048429a33d Mon Sep 17 00:00:00 2001 From: Jan Kuipers Date: Fri, 17 Jan 2025 22:36:47 +0100 Subject: [PATCH] @ClassRule --- .../xpack/inference/DefaultEndPointsIT.java | 12 --- .../inference/InferenceBaseRestTest.java | 22 +++-- .../xpack/inference/MlModelServer.java | 85 +++++++++---------- .../xpack/inference/TextEmbeddingCrudIT.java | 14 --- 4 files changed, 59 insertions(+), 74 deletions(-) diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index 3a995db3004d2..068b3e1f4ce04 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -16,9 +16,7 @@ import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.hamcrest.Matchers; import org.junit.After; -import org.junit.AfterClass; import org.junit.Before; -import org.junit.BeforeClass; import java.io.IOException; import java.util.ArrayList; @@ -33,16 +31,6 @@ public class DefaultEndPointsIT extends InferenceBaseRestTest { - @BeforeClass - public static void startModelServer() { - mlModelServer.start(); - } - - @AfterClass - public static void stopModelServer() { - mlModelServer.stop(); - } - private TestThreadPool threadPool; @Before diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index dd710696a2acd..c00c3a2b37c1f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -26,6 +26,7 @@ import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.junit.Before; import org.junit.ClassRule; import java.io.IOException; @@ -42,20 +43,31 @@ public class InferenceBaseRestTest extends ESRestTestCase { - @ClassRule(order = 0) - public static MlModelServer mlModelServer = new MlModelServer(); - - @ClassRule(order = 1) + @ClassRule public static ElasticsearchCluster cluster = ElasticsearchCluster.local() .distribution(DistributionType.DEFAULT) .setting("xpack.license.self_generated.type", "trial") .setting("xpack.security.enabled", "true") - .setting("xpack.ml.model_repository", "http://localhost:" + mlModelServer.getPort()) .plugin("inference-service-test") .user("x_pack_rest_user", "x-pack-test-password") .feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED) .build(); + @ClassRule + public static MlModelServer mlModelServer = new MlModelServer(); + + @Before + public void setMlModelRepository() throws IOException { + var request = new Request("PUT", "/_cluster/settings"); + request.setJsonEntity(Strings.format(""" + { + "persistent": { + "xpack.ml.model_repository": "http://localhost:%d" + } + }""", mlModelServer.getPort())); + client().performRequest(request); + } + @Override protected String getTestRestCluster() { return cluster.getHttpAddresses(); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java index 399725e689dfa..1536c973dc3d0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MlModelServer.java @@ -13,11 +13,15 @@ import org.apache.http.HttpStatus; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; +import java.util.Random; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -27,56 +31,17 @@ * If the file is found, its content is returned, otherwise 404. * Respects a range header to serve partial content. */ -public class MlModelServer { +public class MlModelServer implements TestRule { private static final Logger logger = LogManager.getLogger(MlModelServer.class); - private final int port; - private final HttpServer server; + private int port; - private ExecutorService executor; - - public MlModelServer() { - try { - server = HttpServer.create(); - } catch (IOException e) { - throw new RuntimeException("Could not create server", e); - } - server.createContext("/", this::handle); - port = findUnusedPort(); - } - - private int findUnusedPort() { - Exception exception = null; - for (int port = 10000; port < 11000; port++) { - try { - server.bind(new InetSocketAddress(port), 0); - return port; - } catch (IOException e) { - exception = e; - } - } - throw new RuntimeException("Could not find port", exception); - } - - public int getPort() { + int getPort() { return port; } - public void start() { - logger.info("Starting ML model server on port {}", port); - executor = Executors.newCachedThreadPool(); - server.setExecutor(executor); - server.start(); - } - - public void stop() { - logger.info("Stopping ML model server in port {}", port); - server.stop(1); - executor.shutdown(); - } - - private void handle(HttpExchange exchange) throws IOException { + private static void handle(HttpExchange exchange) throws IOException { String fileName = exchange.getRequestURI().getPath().substring(1); // If this architecture is requested, serve the default model instead. fileName = fileName.replace("_linux-x86_64", ""); @@ -118,4 +83,38 @@ private void handle(HttpExchange exchange) throws IOException { } } } + + @Override + public Statement apply(Statement statement, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + logger.info("Starting ML model server"); + HttpServer server = HttpServer.create(); + server.createContext("/", MlModelServer::handle); + while (true) { + port = new Random().nextInt(10000, 65536); + try { + server.bind(new InetSocketAddress("localhost", port), 1); + } catch (Exception e) { + continue; + } + break; + } + logger.info("Bound ML model server to port {}", port); + + ExecutorService executor = Executors.newCachedThreadPool(); + server.setExecutor(executor); + server.start(); + + try { + statement.evaluate(); + } finally { + logger.info("Stopping ML model server in port {}", port); + server.stop(1); + executor.shutdown(); + } + } + }; + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java index 9369de2f4c3ef..d8c2d678d0ef9 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java @@ -11,8 +11,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.Platforms; -import org.junit.AfterClass; -import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -20,20 +18,8 @@ import static org.hamcrest.Matchers.containsString; -// This test was previously disabled in CI due to the models being too large -// See "https://github.com/elastic/elasticsearch/issues/105198". public class TextEmbeddingCrudIT extends InferenceBaseRestTest { - @BeforeClass - public static void startModelServer() { - mlModelServer.start(); - } - - @AfterClass - public static void stopModelServer() { - mlModelServer.stop(); - } - public void testPutE5Small_withNoModelVariant() { { String inferenceEntityId = "testPutE5Small_withNoModelVariant";