Skip to content

Commit

Permalink
Add "always denied" network access checks (elastic#119867)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldematte authored Jan 13, 2025
1 parent 80729f9 commit d3a1d9b
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
import java.net.ContentHandlerFactory;
import java.net.DatagramSocketImplFactory;
import java.net.FileNameMap;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.SocketImplFactory;
import java.net.URL;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.util.List;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;

@SuppressWarnings("unused") // Called from instrumentation code inserted by the Entitlements agent
Expand Down Expand Up @@ -167,4 +171,22 @@ public interface EntitlementChecker {

void check$java_net_URLConnection$$setContentHandlerFactory(Class<?> callerClass, ContentHandlerFactory fac);

////////////////////
//
// Network access
//
void check$java_net_ProxySelector$$setDefault(Class<?> callerClass, ProxySelector ps);

void check$java_net_ResponseCache$$setDefault(Class<?> callerClass, ResponseCache rc);

void check$java_net_spi_InetAddressResolverProvider$(Class<?> callerClass);

void check$java_net_spi_URLStreamHandlerProvider$(Class<?> callerClass);

void check$java_net_URL$(Class<?> callerClass, String protocol, String host, int port, String file, URLStreamHandler handler);

void check$java_net_URL$(Class<?> callerClass, URL context, String spec, URLStreamHandler handler);

// The only implementation of SSLSession#getSessionContext(); unfortunately it's an interface, so we need to check the implementation
void check$sun_security_ssl_SSLSessionImpl$getSessionContext(Class<?> callerClass, SSLSession sslSession);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,19 @@
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.DatagramSocket;
import java.net.DatagramSocketImpl;
import java.net.DatagramSocketImplFactory;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.URL;
import java.net.URLClassLoader;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.net.spi.InetAddressResolver;
import java.net.spi.InetAddressResolverProvider;
import java.net.spi.URLStreamHandlerProvider;
import java.security.NoSuchAlgorithmException;
import java.util.List;
import java.util.Map;
Expand All @@ -50,13 +55,17 @@

import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;

import static java.util.Map.entry;
import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckAction.CheckAction.alwaysDenied;
import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckAction.CheckAction.deniedToPlugins;
import static org.elasticsearch.entitlement.qa.common.RestEntitlementsCheckAction.CheckAction.forPlugins;
import static org.elasticsearch.rest.RestRequest.Method.GET;

@SuppressWarnings("unused")
public class RestEntitlementsCheckAction extends BaseRestHandler {
private static final Logger logger = LogManager.getLogger(RestEntitlementsCheckAction.class);
public static final Thread NO_OP_SHUTDOWN_HOOK = new Thread(() -> {}, "Shutdown hook for testing");
Expand Down Expand Up @@ -125,9 +134,87 @@ static CheckAction alwaysDenied(Runnable action) {
entry("socket_setSocketImplFactory", alwaysDenied(RestEntitlementsCheckAction::socket$$setSocketImplFactory)),
entry("url_setURLStreamHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::url$$setURLStreamHandlerFactory)),
entry("urlConnection_setFileNameMap", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setFileNameMap)),
entry("urlConnection_setContentHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setContentHandlerFactory))
entry("urlConnection_setContentHandlerFactory", alwaysDenied(RestEntitlementsCheckAction::urlConnection$$setContentHandlerFactory)),

entry("proxySelector_setDefault", alwaysDenied(RestEntitlementsCheckAction::setDefaultProxySelector)),
entry("responseCache_setDefault", alwaysDenied(RestEntitlementsCheckAction::setDefaultResponseCache)),
entry("createInetAddressResolverProvider", alwaysDenied(RestEntitlementsCheckAction::createInetAddressResolverProvider)),
entry("createURLStreamHandlerProvider", alwaysDenied(RestEntitlementsCheckAction::createURLStreamHandlerProvider)),
entry("createURLWithURLStreamHandler", alwaysDenied(RestEntitlementsCheckAction::createURLWithURLStreamHandler)),
entry("createURLWithURLStreamHandler2", alwaysDenied(RestEntitlementsCheckAction::createURLWithURLStreamHandler2)),
entry("sslSessionImpl_getSessionContext", alwaysDenied(RestEntitlementsCheckAction::sslSessionImplGetSessionContext))
);

private static void createURLStreamHandlerProvider() {
var x = new URLStreamHandlerProvider() {
@Override
public URLStreamHandler createURLStreamHandler(String protocol) {
return null;
}
};
}

private static void sslSessionImplGetSessionContext() {
SSLSocketFactory factory = HttpsURLConnection.getDefaultSSLSocketFactory();
try (SSLSocket socket = (SSLSocket) factory.createSocket()) {
SSLSession session = socket.getSession();

session.getSessionContext();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

@SuppressWarnings("deprecation")
private static void createURLWithURLStreamHandler() {
try {
var x = new URL("http", "host", 1234, "file", new URLStreamHandler() {
@Override
protected URLConnection openConnection(URL u) {
return null;
}
});
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}

@SuppressWarnings("deprecation")
private static void createURLWithURLStreamHandler2() {
try {
var x = new URL(null, "spec", new URLStreamHandler() {
@Override
protected URLConnection openConnection(URL u) {
return null;
}
});
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}

private static void createInetAddressResolverProvider() {
var x = new InetAddressResolverProvider() {
@Override
public InetAddressResolver get(Configuration configuration) {
return null;
}

@Override
public String name() {
return "TEST";
}
};
}

private static void setDefaultResponseCache() {
ResponseCache.setDefault(null);
}

private static void setDefaultProxySelector() {
ProxySelector.setDefault(null);
}

private static void setDefaultSSLContext() {
try {
SSLContext.setDefault(SSLContext.getDefault());
Expand Down Expand Up @@ -270,12 +357,7 @@ private static void setHttpsConnectionProperties() {
@SuppressForbidden(reason = "We're required to prevent calls to this forbidden API")
private static void datagramSocket$$setDatagramSocketImplFactory() {
try {
DatagramSocket.setDatagramSocketImplFactory(new DatagramSocketImplFactory() {
@Override
public DatagramSocketImpl createDatagramSocketImpl() {
throw new IllegalStateException();
}
});
DatagramSocket.setDatagramSocketImplFactory(() -> { throw new IllegalStateException(); });
} catch (IOException e) {
throw new IllegalStateException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
import java.net.ContentHandlerFactory;
import java.net.DatagramSocketImplFactory;
import java.net.FileNameMap;
import java.net.ProxySelector;
import java.net.ResponseCache;
import java.net.SocketImplFactory;
import java.net.URL;
import java.net.URLStreamHandler;
import java.net.URLStreamHandlerFactory;
import java.util.List;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocketFactory;

/**
Expand Down Expand Up @@ -310,4 +314,39 @@ public ElasticsearchEntitlementChecker(PolicyManager policyManager) {
public void check$javax_net_ssl_SSLContext$$setDefault(Class<?> callerClass, SSLContext context) {
policyManager.checkChangeJVMGlobalState(callerClass);
}

@Override
public void check$java_net_ProxySelector$$setDefault(Class<?> callerClass, ProxySelector ps) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$java_net_ResponseCache$$setDefault(Class<?> callerClass, ResponseCache rc) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$java_net_spi_InetAddressResolverProvider$(Class<?> callerClass) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$java_net_spi_URLStreamHandlerProvider$(Class<?> callerClass) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$java_net_URL$(Class<?> callerClass, String protocol, String host, int port, String file, URLStreamHandler handler) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$java_net_URL$(Class<?> callerClass, URL context, String spec, URLStreamHandler handler) {
policyManager.checkChangeNetworkHandling(callerClass);
}

@Override
public void check$sun_security_ssl_SSLSessionImpl$getSessionContext(Class<?> callerClass, SSLSession sslSession) {
policyManager.checkReadSensitiveNetworkInformation(callerClass);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ public void checkChangeJVMGlobalState(Class<?> callerClass) {
});
}

/**
* Check for operations that can modify the way network operations are handled
*/
public void checkChangeNetworkHandling(Class<?> callerClass) {
checkChangeJVMGlobalState(callerClass);
}

/**
* Check for operations that can access sensitive network information, e.g. secrets, tokens or SSL sessions
*/
public void checkReadSensitiveNetworkInformation(Class<?> callerClass) {
neverEntitled(callerClass, "access sensitive network information");
}

private String operationDescription(String methodName) {
// TODO: Use a more human-readable description. Perhaps share code with InstrumentationServiceImpl.parseCheckerMethodName
return methodName.substring(methodName.indexOf('$'));
Expand Down

0 comments on commit d3a1d9b

Please sign in to comment.