Skip to content

Commit

Permalink
@ClassRule
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-elastic committed Jan 17, 2025
1 parent 4742325 commit af8f5d2
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -33,16 +31,6 @@

public class DefaultEndPointsIT extends InferenceBaseRestTest {

@BeforeClass
public static void startModelServer() {
mlModelServer.start();
}

@AfterClass
public static void stopModelServer() {
mlModelServer.stop();
}

private TestThreadPool threadPool;

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.junit.Before;
import org.junit.ClassRule;

import java.io.IOException;
Expand All @@ -42,20 +43,31 @@

public class InferenceBaseRestTest extends ESRestTestCase {

@ClassRule(order = 0)
public static MlModelServer mlModelServer = new MlModelServer();

@ClassRule(order = 1)
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
.setting("xpack.ml.model_repository", "http://localhost:" + mlModelServer.getPort())
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

@ClassRule
public static MlModelServer mlModelServer = new MlModelServer();

@Before
public void setMlModelRepository() throws IOException {
var request = new Request("PUT", "/_cluster/settings");
request.setJsonEntity(Strings.format("""
{
"persistent": {
"xpack.ml.model_repository": "http://localhost:%d"
}
}""", mlModelServer.getPort()));
client().performRequest(request);
}

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
import org.apache.http.HttpStatus;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

Expand All @@ -27,56 +31,17 @@
* If the file is found, its content is returned, otherwise 404.
* Respects a range header to serve partial content.
*/
public class MlModelServer {
public class MlModelServer implements TestRule {

private static final Logger logger = LogManager.getLogger(MlModelServer.class);

private final int port;
private final HttpServer server;
private int port;

private ExecutorService executor;

public MlModelServer() {
try {
server = HttpServer.create();
} catch (IOException e) {
throw new RuntimeException("Could not create server", e);
}
server.createContext("/", this::handle);
port = findUnusedPort();
}

private int findUnusedPort() {
Exception exception = null;
for (int port = 10000; port < 11000; port++) {
try {
server.bind(new InetSocketAddress(port), 0);
return port;
} catch (IOException e) {
exception = e;
}
}
throw new RuntimeException("Could not find port", exception);
}

public int getPort() {
int getPort() {
return port;
}

public void start() {
logger.info("Starting ML model server on port {}", port);
executor = Executors.newCachedThreadPool();
server.setExecutor(executor);
server.start();
}

public void stop() {
logger.info("Stopping ML model server in port {}", port);
server.stop(1);
executor.shutdown();
}

private void handle(HttpExchange exchange) throws IOException {
private static void handle(HttpExchange exchange) throws IOException {
String fileName = exchange.getRequestURI().getPath().substring(1);
// If this architecture is requested, serve the default model instead.
fileName = fileName.replace("_linux-x86_64", "");
Expand Down Expand Up @@ -118,4 +83,38 @@ private void handle(HttpExchange exchange) throws IOException {
}
}
}

@Override
public Statement apply(Statement statement, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
logger.info("Starting ML model server");
HttpServer server = HttpServer.create();
server.createContext("/", MlModelServer::handle);
while (true) {
port = new Random().nextInt(10000, 65536);
try {
server.bind(new InetSocketAddress("localhost", port), 1);
} catch (Exception e) {
continue;
}
break;
}
logger.info("Bound ML model server to port {}", port);

ExecutorService executor = Executors.newCachedThreadPool();
server.setExecutor(executor);
server.start();

try {
statement.evaluate();
} finally {
logger.info("Stopping ML model server in port {}", port);
server.stop(1);
executor.shutdown();
}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,15 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Platforms;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.List;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;

// This test was previously disabled in CI due to the models being too large
// See "https://github.com/elastic/elasticsearch/issues/105198".
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

@BeforeClass
public static void startModelServer() {
mlModelServer.start();
}

@AfterClass
public static void stopModelServer() {
mlModelServer.stop();
}

public void testPutE5Small_withNoModelVariant() {
{
String inferenceEntityId = "testPutE5Small_withNoModelVariant";
Expand Down

0 comments on commit af8f5d2

Please sign in to comment.