Skip to content

Commit

Permalink
Add support for Conscrypt SSL provider, direct buf option for index (#…
Browse files Browse the repository at this point in the history
…1269)

- Conscrypt (https://github.com/google/conscrypt) is a security provider
  library that leverages BoringSSL to implement a more performant
  SSLEngine. This change adds support for using it as a provider in
  JdkSslFactory.
- Add an option to load IndexSegments into a direct (off-heap) ByteBuffer.
- Rename the IndexMemState values to better match their meanings.
  • Loading branch information
cgtz authored and zzmao committed Sep 30, 2019
1 parent 0656d72 commit ac98eec
Show file tree
Hide file tree
Showing 13 changed files with 189 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ public class NetworkConfig {
@Default("-1")
public final int selectorMaxKeyToProcess;

/**
* True to allocate direct buffers within the selector (for things like SSL work).
*/
@Config("selector.use.direct.buffers")
@Default("false")
public final boolean selectorUseDirectBuffers;

public NetworkConfig(VerifiableProperties verifiableProperties) {

numIoThreads = verifiableProperties.getIntInRange("num.io.threads", 8, 1, Integer.MAX_VALUE);
Expand All @@ -106,5 +113,6 @@ public NetworkConfig(VerifiableProperties verifiableProperties) {
verifiableProperties.getIntInRange("selector.executor.pool.size", 4, 0, Integer.MAX_VALUE);
selectorMaxKeyToProcess =
verifiableProperties.getIntInRange("selector.max.key.to.process", -1, -1, Integer.MAX_VALUE);
selectorUseDirectBuffers = verifiableProperties.getBoolean("selector.use.direct.buffers", false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ public class StoreConfig {
* Provides a hint for how indexes should be treated w.r.t memory
*/
@Config(storeIndexMemStateName)
@Default("NOT_IN_MEM")
@Default("MMAP_WITHOUT_FORCE_LOAD")
public final IndexMemState storeIndexMemState;
public static final String storeIndexMemStateName = "store.index.mem.state";

Expand Down Expand Up @@ -343,8 +343,8 @@ public StoreConfig(VerifiableProperties verifiableProperties) {
storeValidateAuthorization = verifiableProperties.getBoolean("store.validate.authorization", false);
storeTtlUpdateBufferTimeSeconds =
verifiableProperties.getIntInRange(storeTtlUpdateBufferTimeSecondsName, 60 * 60 * 24, 0, Integer.MAX_VALUE);
storeIndexMemState =
IndexMemState.valueOf(verifiableProperties.getString(storeIndexMemStateName, IndexMemState.NOT_IN_MEM.name()));
storeIndexMemState = IndexMemState.valueOf(
verifiableProperties.getString(storeIndexMemStateName, IndexMemState.MMAP_WITHOUT_FORCE_LOAD.name()));
storeIoErrorCountToTriggerShutdown =
verifiableProperties.getIntInRange("store.io.error.count.to.trigger.shutdown", Integer.MAX_VALUE, 1,
Integer.MAX_VALUE);
Expand Down
16 changes: 11 additions & 5 deletions ambry-api/src/main/java/com.github.ambry/store/IndexMemState.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@
*/
public enum IndexMemState {
/**
* Index should not be in memory
* Index should be read from an mmap-ed file, but not forced to reside in memory.
*/
NOT_IN_MEM,
MMAP_WITHOUT_FORCE_LOAD,

/**
* Index should be in heap memory
* Index should be mmap-ed and force loaded into memory. The index should make a best effort to keep the segments in
* memory, but it is not guaranteed.
*/
MMAP_WITH_FORCE_LOAD,

/**
* Index should be in heap memory.
*/
IN_HEAP_MEM,

/**
* If mmaped, the index should be force loaded into memory
* Index should be in direct (off-heap) memory.
*/
FORCE_LOAD_MMAP
IN_DIRECT_MEM
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.Security;
import java.util.ArrayList;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.TrustManagerFactory;
import org.conscrypt.Conscrypt;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -35,8 +37,15 @@
* Factory to create SSLContext and SSLEngine
*/
public class JdkSslFactory implements SSLFactory {
private static final Logger LOGGER = LoggerFactory.getLogger(JdkSslFactory.class);

protected static final Logger logger = LoggerFactory.getLogger(JdkSslFactory.class);
static {
if (Conscrypt.isAvailable()) {
Security.addProvider(Conscrypt.newProvider());
} else {
LOGGER.warn("Conscrypt not available for this platform; will not be able to use OpenSSL-based engine");
}
}

private final SSLContext sslContext;
private final String[] cipherSuites;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

public class TestSSLUtils {
private final static String SSL_CONTEXT_PROTOCOL = "TLS";
private final static String SSL_CONTEXT_PROVIDER = "SunJSSE";
private final static String SSL_CONTEXT_PROVIDER = "Conscrypt";
private final static String TLS_V1_2_PROTOCOL = "TLSv1.2";
private static final String SSL_V2_HELLO_PROTOCOL = "SSLv2Hello";
private final static String ENDPOINT_IDENTIFICATION_ALGORITHM = "HTTPS";
Expand Down Expand Up @@ -163,8 +163,35 @@ public static <T extends Certificate> void createTrustStore(String filename, Str
saveKeyStore(ks, filename, password);
}

/**
* Generate a cert and add SSL related properties to {@code props}. This will use {@link #SSL_CONTEXT_PROVIDER} as the
* provider for SSL routines/libraries.
* @param props the {@link Properties} instance.
* @param sslEnabledDatacenters a comma separated list of datacenters where SSL should be enabled.
* @param mode whether to generate the "client" or "server" certificate.
* @param trustStoreFile the file path at which to create the trust store.
* @param certAlias the cert alias to use.
* @throws IOException
* @throws GeneralSecurityException
*/
public static void addSSLProperties(Properties props, String sslEnabledDatacenters, SSLFactory.Mode mode,
File trustStoreFile, String certAlias) throws IOException, GeneralSecurityException {
addSSLProperties(props, sslEnabledDatacenters, mode, trustStoreFile, certAlias, SSL_CONTEXT_PROVIDER);
}

/**
* Generate a cert and add SSL related properties to {@code props}
* @param props the {@link Properties} instance.
* @param sslEnabledDatacenters a comma separated list of datacenters where SSL should be enabled.
* @param mode whether to generate the "client" or "server" certificate.
* @param trustStoreFile the file path at which to create the trust store.
* @param certAlias the cert alias to use.
* @param sslContextProvider the name of a registered security provider to use for instantiating {@link SSLContext}.
* @throws IOException
* @throws GeneralSecurityException
*/
public static void addSSLProperties(Properties props, String sslEnabledDatacenters, SSLFactory.Mode mode,
File trustStoreFile, String certAlias, String sslContextProvider) throws IOException, GeneralSecurityException {
Map<String, X509Certificate> certs = new HashMap<>();
File keyStoreFile;
String password;
Expand All @@ -188,7 +215,7 @@ public static void addSSLProperties(Properties props, String sslEnabledDatacente
createTrustStore(trustStoreFile.getPath(), TRUSTSTORE_PASSWORD, certs);

props.put("ssl.context.protocol", SSL_CONTEXT_PROTOCOL);
props.put("ssl.context.provider", SSL_CONTEXT_PROVIDER);
props.put("ssl.context.provider", sslContextProvider == null ? SSL_CONTEXT_PROVIDER : sslContextProvider);
props.put("ssl.enabled.protocols", TLS_V1_2_PROTOCOL);
props.put("ssl.endpoint.identification.algorithm", ENDPOINT_IDENTIFICATION_ALGORITHM);
props.put("ssl.client.authentication", CLIENT_AUTHENTICATION);
Expand Down Expand Up @@ -220,14 +247,14 @@ public static void addSSLProperties(Properties props, String sslEnabledDatacente
public static VerifiableProperties createSslProps(String sslEnabledDatacenters, SSLFactory.Mode mode,
File trustStoreFile, String certAlias) throws IOException, GeneralSecurityException {
Properties props = new Properties();
addSSLProperties(props, sslEnabledDatacenters, mode, trustStoreFile, certAlias);
addSSLProperties(props, sslEnabledDatacenters, mode, trustStoreFile, certAlias, SSL_CONTEXT_PROVIDER);
props.setProperty("clustermap.cluster.name", "test");
props.setProperty("clustermap.datacenter.name", "dc1");
props.setProperty("clustermap.host.name", "localhost");
return new VerifiableProperties(props);
}

public static void verifySSLConfig(SSLContext sslContext, SSLEngine sslEngine, boolean isClient) {
private static void verifySSLConfig(SSLContext sslContext, SSLEngine sslEngine, boolean isClient) {
// SSLContext verify
Assert.assertEquals(sslContext.getProtocol(), SSL_CONTEXT_PROTOCOL);
Assert.assertEquals(sslContext.getProvider().getName(), SSL_CONTEXT_PROVIDER);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,19 @@ public class SSLTransmission extends Transmission implements ReadableByteChannel
private long handshakeStartTime;

public SSLTransmission(SSLFactory sslFactory, String connectionId, SocketChannel socketChannel, SelectionKey key,
String remoteHost, int remotePort, Time time, NetworkMetrics metrics, SSLFactory.Mode mode) throws IOException {
String remoteHost, int remotePort, Time time, NetworkMetrics metrics, SSLFactory.Mode mode,
boolean useDirectBuffers) throws IOException {
super(connectionId, socketChannel, key, time, metrics);
this.sslEngine = sslFactory.createSSLEngine(remoteHost, remotePort, mode);
this.netReadBuffer = ByteBuffer.allocate(netReadBufferSize());
this.netWriteBuffer = ByteBuffer.allocate(netWriteBufferSize());
this.appReadBuffer = ByteBuffer.allocate(appReadBufferSize());
if (useDirectBuffers) {
this.netReadBuffer = ByteBuffer.allocateDirect(netReadBufferSize());
this.netWriteBuffer = ByteBuffer.allocateDirect(netWriteBufferSize());
this.appReadBuffer = ByteBuffer.allocateDirect(appReadBufferSize());
} else {
this.netReadBuffer = ByteBuffer.allocate(netReadBufferSize());
this.netWriteBuffer = ByteBuffer.allocate(netWriteBufferSize());
this.appReadBuffer = ByteBuffer.allocate(appReadBufferSize());
}
startHandshake();
}

Expand Down Expand Up @@ -121,29 +128,39 @@ public void close() {
closing = true;
sslEngine.closeOutbound();
try {
if (!flush(netWriteBuffer)) {
throw new IOException("Remaining data in the network buffer, can't send SSL close message.");
}
//prep the buffer for the close message
netWriteBuffer.clear();
//perform the close, since we called sslEngine.closeOutbound
SSLEngineResult handshake = sslEngine.wrap(emptyBuf, netWriteBuffer);
//we should be in a close state
if (handshake.getStatus() != Status.CLOSED) {
throw new IOException("Invalid close state, will not send network data.");
if (socketChannel.isConnected()) {
if (!flush(netWriteBuffer)) {
throw new IOException("Remaining data in the network buffer, can't send SSL close message.");
}
//prep the buffer for the close message
netWriteBuffer.clear();
//perform the close, since we called sslEngine.closeOutbound
SSLEngineResult handshake = sslEngine.wrap(emptyBuf, netWriteBuffer);
//we should be in a close state
if (handshake.getStatus() != Status.CLOSED) {
throw new IOException("Invalid close state, will not send network data.");
}
netWriteBuffer.flip();
flush(netWriteBuffer);
}
netWriteBuffer.flip();
flush(netWriteBuffer);
clearReceive();
clearSend();
socketChannel.socket().close();
socketChannel.close();
} catch (IOException ie) {
metrics.selectorCloseSocketErrorCount.inc();
logger.debug("Failed to send SSL close message ", ie);
} finally {
try {
clearReceive();
clearSend();
clearBuffers();
socketChannel.socket().close();
socketChannel.close();
} catch (IOException ie) {
metrics.selectorCloseSocketErrorCount.inc();
logger.debug("Failed to close socket", ie);
} finally {
key.attach(null);
key.cancel();
}
}
key.attach(null);
key.cancel();
}

/**
Expand Down Expand Up @@ -330,7 +347,7 @@ private void handshakeFinished() throws IOException {
* @return SSLEngineResult
* @throws IOException
*/
private SSLEngineResult handshakeWrap(Boolean doWrite) throws IOException {
private SSLEngineResult handshakeWrap(boolean doWrite) throws IOException {
logger.trace("SSLHandshake handshakeWrap {}", getConnectionId());
if (netWriteBuffer.hasRemaining()) {
throw new IllegalStateException("handshakeWrap called with netWriteBuffer not empty");
Expand All @@ -357,7 +374,7 @@ private SSLEngineResult handshakeWrap(Boolean doWrite) throws IOException {
* @return SSLEngineResult
* @throws IOException
*/
private SSLEngineResult handshakeUnwrap(Boolean doRead) throws IOException {
private SSLEngineResult handshakeUnwrap(boolean doRead) throws IOException {
logger.trace("SSLHandshake handshakeUnwrap {}", getConnectionId());
int read;
if (doRead) {
Expand Down Expand Up @@ -637,9 +654,9 @@ int appReadBufferSize() {
private void handleUnwrapOverflow() {
int currentAppReadBufferSize = appReadBufferSize();
appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppReadBufferSize);
if (appReadBuffer.position() >= currentAppReadBufferSize) {
if (appReadBuffer.position() > currentAppReadBufferSize) {
throw new IllegalStateException(
"Buffer overflow when available data size (" + appReadBuffer.position() + ") >= application buffer size ("
"Buffer overflow when available data size (" + appReadBuffer.position() + ") > application buffer size ("
+ currentAppReadBufferSize + ")");
}
}
Expand Down Expand Up @@ -683,4 +700,13 @@ private void handshakeFailure() {
logger.debug("SSLEngine.closeInBound() raised an exception.", e);
}
}

/**
* Nullify buffers used for I/O with {@link SSLEngine}.
*/
private void clearBuffers() {
appReadBuffer = null;
netReadBuffer = null;
netWriteBuffer = null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,8 @@ protected Transmission createTransmission(String connectionId, SelectionKey key,
} else if (portType == PortType.SSL) {
try {
transmission =
new SSLTransmission(sslFactory, connectionId, channel(key), key, hostname, port, time, metrics, mode);
new SSLTransmission(sslFactory, connectionId, channel(key), key, hostname, port, time, metrics, mode,
networkConfig.selectorUseDirectBuffers);
metrics.sslTransmissionInitializationCount.inc();
} catch (IOException e) {
metrics.sslTransmissionInitializationErrorCount.inc();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
Expand All @@ -53,20 +52,37 @@ public class SSLSelectorTest {
private final EchoServer server;
private Selector selector;
private final File trustStoreFile;
private final int poolSize;
private final NetworkConfig networkConfig;

@Parameterized.Parameters
public static List<Object[]> data() {
return Arrays.asList(new Object[][]{{0}, {2}});
List<Object[]> params = new ArrayList<>();
for (String provider : new String[]{"Conscrypt", "SunJSSE"}) {
for (int poolSize : new int[]{0, 2}) {
for (boolean useDirectBuffers : TestUtils.BOOLEAN_VALUES) {
params.add(new Object[]{provider, poolSize, useDirectBuffers});
}
}
}
return params;
}

public SSLSelectorTest(int poolSize) throws Exception {
this.poolSize = poolSize;
/**
* @param sslContextProvider the name of the SSL library provider to use.
* @param poolSize the size of the worker pool for reading/writing/connecting to sockets.
* @param useDirectBuffers true to allocate direct buffers in {@link SSLTransmission}.
* @throws Exception
*/
public SSLSelectorTest(String sslContextProvider, int poolSize, boolean useDirectBuffers) throws Exception {
trustStoreFile = File.createTempFile("truststore", ".jks");
SSLConfig sslConfig =
new SSLConfig(TestSSLUtils.createSslProps("DC1,DC2,DC3", SSLFactory.Mode.SERVER, trustStoreFile, "server"));
SSLConfig clientSSLConfig =
new SSLConfig(TestSSLUtils.createSslProps("DC1,DC2,DC3", SSLFactory.Mode.CLIENT, trustStoreFile, "client"));
Properties serverProps = new Properties();
TestSSLUtils.addSSLProperties(serverProps, "DC1,DC2,DC3", SSLFactory.Mode.SERVER, trustStoreFile, "server",
sslContextProvider);
SSLConfig sslConfig = new SSLConfig(new VerifiableProperties(serverProps));
Properties clientProps = new Properties();
TestSSLUtils.addSSLProperties(clientProps, "DC1,DC2,DC3", SSLFactory.Mode.CLIENT, trustStoreFile, "client",
sslContextProvider);
SSLConfig clientSSLConfig = new SSLConfig(new VerifiableProperties(clientProps));
SSLFactory serverSSLFactory = SSLFactory.getNewInstance(sslConfig);
clientSSLFactory = SSLFactory.getNewInstance(clientSSLConfig);
server = new EchoServer(serverSSLFactory, 18383);
Expand All @@ -76,8 +92,9 @@ public SSLSelectorTest(int poolSize) throws Exception {
.getApplicationBufferSize();
Properties props = new Properties();
props.setProperty("selector.executor.pool.size", Integer.toString(poolSize));
props.setProperty("selector.use.direct.buffers", Boolean.toString(useDirectBuffers));
VerifiableProperties vprops = new VerifiableProperties(props);
NetworkConfig networkConfig = new NetworkConfig(vprops);
networkConfig = new NetworkConfig(vprops);
selector = new Selector(new NetworkMetrics(new MetricRegistry()), SystemTime.getInstance(), clientSSLFactory,
networkConfig);
}
Expand Down Expand Up @@ -347,10 +364,6 @@ private void useCustomBufferSizeSelector(final Integer netReadBufSizeStart, fina
selector.close();
NetworkMetrics metrics = new NetworkMetrics(new MetricRegistry());
Time time = SystemTime.getInstance();
Properties props = new Properties();
props.setProperty("selector.executor.pool.size", Integer.toString(poolSize));
VerifiableProperties vprops = new VerifiableProperties(props);
NetworkConfig networkConfig = new NetworkConfig(vprops);
selector = new Selector(metrics, time, clientSSLFactory, networkConfig) {
@Override
protected Transmission createTransmission(String connectionId, SelectionKey key, String hostname, int port,
Expand All @@ -359,7 +372,7 @@ protected Transmission createTransmission(String connectionId, SelectionKey key,
AtomicReference<Integer> netWriteBufSizeOverride = new AtomicReference<>(netWriteBufSizeStart);
AtomicReference<Integer> appReadBufSizeOverride = new AtomicReference<>(appReadBufSizeStart);
return new SSLTransmission(clientSSLFactory, connectionId, (SocketChannel) key.channel(), key, hostname, port,
time, metrics, mode) {
time, metrics, mode, networkConfig.selectorUseDirectBuffers) {
@Override
protected int netReadBufferSize() {
// netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read
Expand Down
Loading

0 comments on commit ac98eec

Please sign in to comment.