Skip to content

Commit

Permalink
file access tree
Browse files Browse the repository at this point in the history
  • Loading branch information
rjernst committed Jan 14, 2025
1 parent 2ca136e commit 4d43977
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.entitlement.runtime.policy;

import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

class FileAccessTree {
static final FileAccessTree EMPTY = new FileAccessTree(List.of());

private final String[] readPaths;
private final String[] writePaths;

FileAccessTree(List<FileEntitlement> fileEntitlements) {
List<String> readPaths = new ArrayList<>();
List<String> writePaths = new ArrayList<>();
for (FileEntitlement fileEntitlement : fileEntitlements) {
var mode = fileEntitlement.mode();
if (mode == FileEntitlement.Mode.READ_WRITE) {
writePaths.add(fileEntitlement.path());
}
readPaths.add(fileEntitlement.path());
}

readPaths.sort(String::compareTo);
writePaths.sort(String::compareTo);

this.readPaths = readPaths.toArray(new String[0]);
this.writePaths = writePaths.toArray(new String[0]);
}

boolean canRead(Path path) {
return checkPath(normalize(path), readPaths);
}

boolean canRead(File file) {
return checkPath(normalize(file.toPath()), readPaths);
}

boolean canWrite(Path path) {
return checkPath(normalize(path), writePaths);
}

boolean canWrite(File file) {
return checkPath(normalize(file.toPath()), writePaths);
}

private static String normalize(Path path) {
return path.toAbsolutePath().normalize().toString();
}

private static boolean checkPath(String path, String[] paths) {
if (paths.length == 0) {
return false;
}
int ndx = Arrays.binarySearch(paths, path);
if (ndx < -1) {
String maybeParent = paths[-ndx - 2];
return path.startsWith(maybeParent);
}
return ndx >= 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

package org.elasticsearch.entitlement.runtime.policy;

import java.util.Locale;
import java.nio.file.Paths;

/**
* Describes a file entitlement with a path and mode.
Expand All @@ -21,8 +21,28 @@ public enum Mode {
READ_WRITE
}

public FileEntitlement {
path = normalizePath(path);
}

private static String normalizePath(String path) {
return Paths.get(path).toAbsolutePath().normalize().toString();
}

// TODO: think about whether read/write or "all"
// TODO: think about mode parsing?
private static Mode parseMode(String mode) {
if (mode.equals("read")) {
return Mode.READ;
} else if (mode.equals("read_write")) {
return Mode.READ_WRITE;
} else {
throw new IllegalArgumentException("invalid mode: " + mode + ", valid values: [read, read_write]");
}
}

@ExternalEntitlement(parameterNames = { "path", "mode" }, esModulesOnly = false)
public FileEntitlement(String path, String mode) {
this(path, Mode.valueOf(mode.toUpperCase(Locale.ROOT)));
this(path, parseMode(mode));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,24 @@
import static java.lang.StackWalker.Option.RETAIN_CLASS_REFERENCE;
import static java.util.Objects.requireNonNull;
import static java.util.function.Predicate.not;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toUnmodifiableMap;

public class PolicyManager {
private static final Logger logger = LogManager.getLogger(PolicyManager.class);

record ModuleEntitlements(Map<Class<? extends Entitlement>, List<Entitlement>> entitlementsByType) {
public static final ModuleEntitlements NONE = new ModuleEntitlements(Map.of());
record ModuleEntitlements(Set<Class<? extends Entitlement>> flagEntitlements, FileAccessTree fileAccess) {
public static final ModuleEntitlements NONE = new ModuleEntitlements(Set.of(), FileAccessTree.EMPTY);

ModuleEntitlements {
entitlementsByType = Map.copyOf(entitlementsByType);
flagEntitlements = Set.copyOf(flagEntitlements);
}

public static ModuleEntitlements from(List<Entitlement> entitlements) {
return new ModuleEntitlements(entitlements.stream().collect(groupingBy(Entitlement::getClass)));
}

public boolean hasEntitlement(Class<? extends Entitlement> entitlementClass) {
return entitlementsByType.containsKey(entitlementClass);
}

public <E extends Entitlement> Stream<E> getEntitlements(Class<E> entitlementClass) {
return entitlementsByType.get(entitlementClass).stream().map(entitlementClass::cast);
Set<Class<? extends Entitlement>> flagEntitlements = entitlements.stream().map(Entitlement::getClass).collect(Collectors.toSet());
var fileEntitlements = entitlements.stream()
.filter(e -> e.getClass().equals(FileEntitlement.class))
.map(e -> (FileEntitlement) e).toList();
return new ModuleEntitlements(flagEntitlements, new FileAccessTree(fileEntitlements));
}
}

Expand Down Expand Up @@ -197,7 +192,7 @@ private void checkEntitlementPresent(Class<?> callerClass, Class<? extends Entit
}

ModuleEntitlements entitlements = getEntitlements(requestingClass);
if (entitlements.hasEntitlement(entitlementClass)) {
if (entitlements.flagEntitlements.contains(entitlementClass)) {
logger.debug(
() -> Strings.format(
"Entitled: class [%s], module [%s], entitlement [%s]",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.entitlement.runtime.policy;

import org.elasticsearch.test.ESTestCase;

import java.util.List;

import static org.hamcrest.Matchers.is;

public class FileAccessTreeTests extends ESTestCase {

public void testEmpty() {
var tree = new FileAccessTree(List.of());
assertThat(tree.canRead("/path"), is(false));
assertThat(tree.canWrite("/path"), is(false));
}

public void testRead() {
var tree = new FileAccessTree(List.of(new FileEntitlement("/foo", "read")));
assertThat(tree.canRead("/foo"), is(true));
assertThat(tree.canRead("/foo/subdir"), is(true));
assertThat(tree.canWrite("/foo"), is(false));

assertThat(tree.canRead("/before"), is(false));
assertThat(tree.canRead("/later"), is(false));
}

public void testWrite() {
var tree = new FileAccessTree(List.of(new FileEntitlement("/foo", "read_write")));
assertThat(tree.canWrite("/foo"), is(true));
assertThat(tree.canWrite("/foo/subdir"), is(true));
assertThat(tree.canRead("/foo"), is(true));

assertThat(tree.canWrite("/before"), is(false));
assertThat(tree.canWrite("/later"), is(false));
}

public void testTwoPaths() {
var tree = new FileAccessTree(List.of(new FileEntitlement("/foo", "read"), new FileEntitlement("/bar", "read")));
assertThat(tree.canRead("/a"), is(false));
assertThat(tree.canRead("/bar"), is(true));
assertThat(tree.canRead("/bar/subdir"), is(true));
assertThat(tree.canRead("/c"), is(false));
assertThat(tree.canRead("/foo"), is(true));
assertThat(tree.canRead("/foo/subdir"), is(true));
assertThat(tree.canRead("/z"), is(false));
}

public void testReadWriteUnderRead() {
var tree = new FileAccessTree(List.of(new FileEntitlement("/foo", "read"), new FileEntitlement("/foo/bar", "read_write")));
assertThat(tree.canRead("/foo"), is(true));
assertThat(tree.canWrite("/foo"), is(false));
assertThat(tree.canRead("/foo/bar"), is(true));
assertThat(tree.canWrite("/foo/bar"), is(true));
}

public void testNormalizePath() {
var tree = new FileAccessTree(List.of(new FileEntitlement("/foo/../bar", "read")));
assertThat(tree.canRead("/foo/../bar"), is(true));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void testGetEntitlementsReturnsEntitlementsForPluginUnnamedModule() {
var callerClass = this.getClass();

var entitlements = policyManager.getEntitlements(callerClass);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.flagEntitlements().contains(CreateClassLoaderEntitlement.class), is(true));
}

public void testGetEntitlementsThrowsOnMissingPolicyForServer() throws ClassNotFoundException {
Expand Down Expand Up @@ -148,8 +148,8 @@ public void testGetEntitlementsReturnsEntitlementsForServerModule() throws Class
var requestingModule = mockServerClass.getModule();

var entitlements = policyManager.getEntitlements(mockServerClass);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.hasEntitlement(ExitVMEntitlement.class), is(true));
assertThat(entitlements.flagEntitlements().contains(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.flagEntitlements().contains(ExitVMEntitlement.class), is(true));
}

public void testGetEntitlementsReturnsEntitlementsForPluginModule() throws IOException, ClassNotFoundException {
Expand All @@ -167,14 +167,10 @@ public void testGetEntitlementsReturnsEntitlementsForPluginModule() throws IOExc

var layer = createLayerForJar(jar, "org.example.plugin");
var mockPluginClass = layer.findLoader("org.example.plugin").loadClass("q.B");
var requestingModule = mockPluginClass.getModule();

var entitlements = policyManager.getEntitlements(mockPluginClass);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(
entitlements.getEntitlements(FileEntitlement.class).toList(),
contains(transformedMatch(FileEntitlement::toString, containsString("/test/path")))
);
assertThat(entitlements.flagEntitlements().contains(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.fileAccess().canRead("/test/path"), is(true));
}

public void testGetEntitlementsResultIsCached() {
Expand All @@ -190,7 +186,7 @@ public void testGetEntitlementsResultIsCached() {
var callerClass = this.getClass();

var entitlements = policyManager.getEntitlements(callerClass);
assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true));
assertThat(entitlements.flagEntitlements().contains(CreateClassLoaderEntitlement.class), is(true));
assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1));
var cachedResult = policyManager.moduleEntitlementsMap.values().stream().findFirst().get();
var entitlementsAgain = policyManager.getEntitlements(callerClass);
Expand Down Expand Up @@ -257,7 +253,7 @@ private static Policy createPluginPolicy(String... pluginModules) {
.map(
name -> new Scope(
name,
List.of(new FileEntitlement("/test/path", List.of(FileEntitlement.READ)), new CreateClassLoaderEntitlement())
List.of(new FileEntitlement("/test/path", FileEntitlement.Mode.READ), new CreateClassLoaderEntitlement())
)
)
.toList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void testPolicyBuilder() throws IOException {
.parsePolicy();
Policy expected = new Policy(
"test-policy.yaml",
List.of(new Scope("entitlement-module-name", List.of(new FileEntitlement("test/path/to/file", List.of("read", "write")))))
List.of(new Scope("entitlement-module-name", List.of(new FileEntitlement("test/path/to/file", "read_write"))))
);
assertEquals(expected, parsedPolicy);
}
Expand All @@ -47,7 +47,7 @@ public void testPolicyBuilderOnExternalPlugin() throws IOException {
.parsePolicy();
Policy expected = new Policy(
"test-policy.yaml",
List.of(new Scope("entitlement-module-name", List.of(new FileEntitlement("test/path/to/file", List.of("read", "write")))))
List.of(new Scope("entitlement-module-name", List.of(new FileEntitlement("test/path/to/file", "read_write"))))
);
assertEquals(expected, parsedPolicy);
}
Expand Down

0 comments on commit 4d43977

Please sign in to comment.