Skip to content

Commit

Permalink
scan for ports
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-elastic committed Jan 17, 2025
1 parent d9cf6ed commit 83d5009
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -31,6 +33,16 @@

public class DefaultEndPointsIT extends InferenceBaseRestTest {

@BeforeClass
public static void startModelServer() {
mlModelServer.start();
}

@AfterClass
public static void stopModelServer() {
mlModelServer.stop();
}

private TestThreadPool threadPool;

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,31 +44,20 @@

public class InferenceBaseRestTest extends ESRestTestCase {

static final int ML_MODEL_SERVER_PORT = 9999;
@ClassRule(order = 0)
public static MlModelServer mlModelServer = new MlModelServer();

@ClassRule
@ClassRule(order = 1)
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
.setting("xpack.ml.model_repository", "http://localhost:" + ML_MODEL_SERVER_PORT)
.setting("xpack.ml.model_repository", "http://localhost:" + mlModelServer.getPort())
.plugin("inference-service-test")
.user("x_pack_rest_user", "x-pack-test-password")
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

private static MlModelServer mlModelServer;

@BeforeClass
public static void startModelServer() throws Exception {
mlModelServer = new MlModelServer(ML_MODEL_SERVER_PORT);
}

@AfterClass
public static void stopModelServer() {
mlModelServer.close();
}

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,53 @@
* If the file is found, its content is returned, otherwise 404.
* Respects a range header to serve partial content.
*/
class MlModelServer {
public class MlModelServer {

private static final Logger logger = LogManager.getLogger(MlModelServer.class);

private final HttpServer mlModelServer;
private final ExecutorService mlModelServerExecutor;
private final int port;
private final HttpServer server;

MlModelServer(int port) throws IOException {
private ExecutorService executor;

public MlModelServer() {
try {
server = HttpServer.create();
} catch (IOException e) {
throw new RuntimeException("Could not create server", e);
}
server.createContext("/", this::handle);
port = findUnusedPort();
}

private int findUnusedPort() {
Exception exception = null;
for (int port = 10000; port < 11000; port++) {
try {
server.bind(new InetSocketAddress(port), 0);
return port;
} catch (IOException e) {
exception = e;
}
}
throw new RuntimeException("Could not find port", exception);
}

public int getPort() {
return port;
}

public void start() {
logger.info("Starting ML model server on port {}", port);
mlModelServer = HttpServer.create(new InetSocketAddress(port), 10);
mlModelServer.createContext("/", this::handle);
mlModelServerExecutor = Executors.newCachedThreadPool();
mlModelServer.setExecutor(mlModelServerExecutor);
mlModelServer.start();
executor = Executors.newCachedThreadPool();
server.setExecutor(executor);
server.start();
}

void close() {
logger.info("Stopping ML model server");
mlModelServer.stop(5);
mlModelServerExecutor.close();
public void stop() {
logger.info("Stopping ML model server in port {}", port);
server.stop(1);
executor.shutdown();
}

private void handle(HttpExchange exchange) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.Platforms;
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.List;
Expand All @@ -22,6 +24,16 @@
// See "https://github.com/elastic/elasticsearch/issues/105198".
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

@BeforeClass
public static void startModelServer() {
mlModelServer.start();
}

@AfterClass
public static void stopModelServer() {
mlModelServer.stop();
}

public void testPutE5Small_withNoModelVariant() {
{
String inferenceEntityId = "testPutE5Small_withNoModelVariant";
Expand Down

0 comments on commit 83d5009

Please sign in to comment.