Skip to content

Commit

Permalink
Add LTR License Check on PUT for Enterprise Licensing (elastic#111248) (
Browse files Browse the repository at this point in the history
elastic#111460)

* add isLicenseAllowedForAction trained model config

* fixup tests - trial is allowed

* fix license tests

* update tests for validate model static method

* add validateModel test; update license check
  • Loading branch information
markjhoy authored Jul 30, 2024
1 parent 5d0eb2a commit ade5b13
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xpack.core.ml.MlConfigVersion;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
Expand All @@ -22,6 +24,7 @@
import java.util.Arrays;

import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.xpack.core.ml.MachineLearningField.ML_API_FEATURE;

public interface InferenceConfig extends NamedXContentObject, VersionedNamedWriteable {

Expand Down Expand Up @@ -114,4 +117,12 @@ default ElasticsearchStatusException incompatibleUpdateException(String updateNa
updateName
);
}

default License.OperationMode getMinLicenseSupported() {
return ML_API_FEATURE.getMinimumOperationMode();
}

default License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
return getMinLicenseSupported();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.Rewriteable;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -226,6 +228,14 @@ public TransportVersion getMinimalSupportedTransportVersion() {
return MIN_SUPPORTED_TRANSPORT_VERSION;
}

@Override
public License.OperationMode getMinLicenseSupportedForAction(RestRequest.Method method) {
if (method == RestRequest.Method.PUT) {
return License.OperationMode.ENTERPRISE;
}
return super.getMinLicenseSupportedForAction(method);
}

@Override
public LearningToRankConfig rewrite(QueryRewriteContext ctx) throws IOException {
if (this.featureExtractorBuilders.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.license.License;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.NamedXContentRegistry;
Expand All @@ -36,6 +38,7 @@
import java.util.stream.Stream;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.hamcrest.Matchers.is;

public class LearningToRankConfigTests extends InferenceConfigItemTestCase<LearningToRankConfig> {
private boolean lenient;
Expand Down Expand Up @@ -140,6 +143,16 @@ public void testDuplicateFeatureNames() {
expectThrows(IllegalArgumentException.class, () -> builder.build());
}

public void testLicenseSupport_ForPutAction_RequiresEnterprise() {
var config = randomLearningToRankConfig();
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.PUT), is(License.OperationMode.ENTERPRISE));
}

public void testLicenseSupport_ForGetAction_RequiresPlatinum() {
var config = randomLearningToRankConfig();
assertThat(config.getMinLicenseSupportedForAction(RestRequest.Method.GET), is(License.OperationMode.PLATINUM));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.TransportVersion;
Expand Down Expand Up @@ -37,6 +38,7 @@
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.Task;
Expand Down Expand Up @@ -143,61 +145,7 @@ protected void masterOperation(
// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
boolean hasModelDefinition = config.getModelDefinition() != null;
if (hasModelDefinition) {
try {
config.getModelDefinition().getTrainedModel().validate();
} catch (ElasticsearchException ex) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
);
return;
}

TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
if (trainedModelType == null) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Unknown trained model definition class [{}]",
config.getModelDefinition().getTrainedModel().getName()
)
);
return;
}

if (config.getModelType() == null) {
// Set the model type from the definition
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
} else if (trainedModelType != config.getModelType()) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"{} [{}] does not match the model definition type [{}]",
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
config.getModelType(),
trainedModelType
)
);
return;
}

if (config.getInferenceConfig().isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
config.getModelId(),
config.getInferenceConfig().getName(),
config.getModelDefinition().getTrainedModel().targetType()
)
);
return;
}

TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Cannot create model [{}] while cluster upgrade is in progress.",
config.getModelId()
)
);
if (validateModelDefinition(config, state, licenseState, finalResponseListener) == false) {
return;
}
}
Expand Down Expand Up @@ -507,6 +455,85 @@ private void checkTagsAgainstModelIds(List<String> tags, ActionListener<Void> li
);
}

public static boolean validateModelDefinition(
TrainedModelConfig config,
ClusterState state,
XPackLicenseState licenseState,
ActionListener<Response> finalResponseListener
) {
try {
config.getModelDefinition().getTrainedModel().validate();
} catch (ElasticsearchException ex) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Definition for [{}] has validation failures.", ex, config.getModelId())
);
return false;
}

TrainedModelType trainedModelType = TrainedModelType.typeFromTrainedModel(config.getModelDefinition().getTrainedModel());
if (trainedModelType == null) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Unknown trained model definition class [{}]",
config.getModelDefinition().getTrainedModel().getName()
)
);
return false;
}

var configModelType = config.getModelType();
if (configModelType == null) {
// Set the model type from the definition
config = new TrainedModelConfig.Builder(config).setModelType(trainedModelType).build();
} else if (trainedModelType != configModelType) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"{} [{}] does not match the model definition type [{}]",
TrainedModelConfig.MODEL_TYPE.getPreferredName(),
configModelType,
trainedModelType
)
);
return false;
}

var inferenceConfig = config.getInferenceConfig();
if (inferenceConfig.isTargetTypeSupported(config.getModelDefinition().getTrainedModel().targetType()) == false) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
config.getModelId(),
config.getInferenceConfig().getName(),
config.getModelDefinition().getTrainedModel().targetType()
)
);
return false;
}

var minLicenseSupported = inferenceConfig.getMinLicenseSupportedForAction(RestRequest.Method.PUT);
if (licenseState.isAllowedByLicense(minLicenseSupported) == false) {
finalResponseListener.onFailure(
new ElasticsearchSecurityException(
"Model of type [{}] requires [{}] license level",
RestStatus.FORBIDDEN,
config.getInferenceConfig().getName(),
minLicenseSupported
)
);
return false;
}

TransportVersion minCompatibilityVersion = config.getModelDefinition().getTrainedModel().getMinimalCompatibilityVersion();
if (state.getMinTransportVersion().before(minCompatibilityVersion)) {
finalResponseListener.onFailure(
ExceptionsHelper.badRequestException("Cannot create model [{}] while cluster upgrade is in progress.", config.getModelId())
);
return false;
}

return true;
}

@Override
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
package org.elasticsearch.xpack.ml.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.license.License;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.license.internal.XPackLicenseStatus;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
Expand All @@ -35,11 +40,13 @@
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigTests;
Expand All @@ -50,6 +57,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigTests;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand All @@ -60,10 +68,12 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.getTaskInfoListOfOne;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockClientWithTasksResponse;
import static org.elasticsearch.xpack.ml.utils.TaskRetrieverTests.mockListTasksClient;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.same;
Expand All @@ -73,6 +83,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class TransportPutTrainedModelActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
Expand Down Expand Up @@ -273,6 +284,56 @@ public void testVerifyMlNodesAndModelArchitectures_GivenArchitecturesMatch_ThenT
ensureNoWarnings();
}

public void testValidateModelDefinition_FailsWhenLicenseIsNotSupported() throws IOException {
ModelPackageConfig packageConfig = ModelPackageConfigTests.randomModulePackageConfig();

TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder().setModelId(
"." + packageConfig.getPackagedModelId()
).setInput(TrainedModelInputTests.createRandomInput());

TransportPutTrainedModelAction.setTrainedModelConfigFieldsFromPackagedModel(
trainedModelConfigBuilder,
packageConfig,
xContentRegistry()
);

var mockTrainedModelDefinition = mock(TrainedModelDefinition.class);
when(mockTrainedModelDefinition.getTrainedModel()).thenReturn(mock(LangIdentNeuralNetwork.class));
var trainedModelConfig = trainedModelConfigBuilder.setLicenseLevel("basic").build();

var mockModelInferenceConfig = spy(new LearningToRankConfig(1, List.of(), Map.of()));
when(mockModelInferenceConfig.isTargetTypeSupported(any())).thenReturn(true);

var mockTrainedModelConfig = spy(trainedModelConfig);
when(mockTrainedModelConfig.getModelType()).thenReturn(TrainedModelType.LANG_IDENT);
when(mockTrainedModelConfig.getModelDefinition()).thenReturn(mockTrainedModelDefinition);
when(mockTrainedModelConfig.getInferenceConfig()).thenReturn(mockModelInferenceConfig);

ActionListener<PutTrainedModelAction.Response> responseListener = ActionListener.wrap(
response -> fail("Expected exception, but got response: " + response),
exception -> {
assertThat(exception, instanceOf(ElasticsearchSecurityException.class));
assertThat(exception.getMessage(), is("Model of type [learning_to_rank] requires [ENTERPRISE] license level"));
}
);

var mockClusterState = mock(ClusterState.class);

AtomicInteger currentTime = new AtomicInteger(100);
var mockXPackLicenseStatus = new XPackLicenseStatus(License.OperationMode.BASIC, true, "");
var mockLicenseState = new XPackLicenseState(currentTime::get, mockXPackLicenseStatus);

assertThat(
TransportPutTrainedModelAction.validateModelDefinition(
mockTrainedModelConfig,
mockClusterState,
mockLicenseState,
responseListener
),
is(false)
);
}

private static void prepareGetTrainedModelResponse(Client client, List<TrainedModelConfig> trainedModels) {
doAnswer(invocationOnMock -> {
@SuppressWarnings("unchecked")
Expand Down

0 comments on commit ade5b13

Please sign in to comment.