Skip to content
Merged
8 changes: 8 additions & 0 deletions .changes/4d3a104a-3225-4dcb-be05-f5155d320952.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "4d3a104a-3225-4dcb-be05-f5155d320952",
"type": "feature",
"description": "Allow configuring a custom OkHttpClient in OkHttpEngine",
"issues": [
"https://github.com/aws/aws-sdk-kotlin/issues/1707"
]
}
2 changes: 0 additions & 2 deletions runtime/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,6 @@ subprojects {
}
}

// configureIosSimulatorTasks()

val excludeFromDocumentation = listOf(
":runtime:testing",
":runtime:smithy-test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public final class aws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine : a
public static final field Companion Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngine$Companion;
public fun <init> ()V
public fun <init> (Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConfig;)V
public synthetic fun <init> (Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConfig;Lokhttp3/OkHttpClient;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Lokhttp3/OkHttpClient;)V
public synthetic fun getConfig ()Laws/smithy/kotlin/runtime/http/engine/HttpClientEngineConfig;
public fun getConfig ()Laws/smithy/kotlin/runtime/http/engine/okhttp/OkHttpEngineConfig;
public fun roundTrip (Laws/smithy/kotlin/runtime/operation/ExecutionContext;Laws/smithy/kotlin/runtime/http/request/HttpRequest;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import aws.smithy.kotlin.runtime.InternalApi
import aws.smithy.kotlin.runtime.http.HttpCall
import aws.smithy.kotlin.runtime.http.config.EngineFactory
import aws.smithy.kotlin.runtime.http.engine.AlpnId
import aws.smithy.kotlin.runtime.http.engine.HttpClientEngine
import aws.smithy.kotlin.runtime.http.engine.HttpClientEngineBase
import aws.smithy.kotlin.runtime.http.engine.TlsContext
import aws.smithy.kotlin.runtime.http.engine.callContext
Expand All @@ -34,18 +35,23 @@ import okhttp3.TlsVersion as OkHttpTlsVersion
/**
* [aws.smithy.kotlin.runtime.http.engine.HttpClientEngine] based on OkHttp.
*/
public class OkHttpEngine(
public class OkHttpEngine private constructor(
override val config: OkHttpEngineConfig,
private val userProvidedClient: OkHttpClient?,
) : HttpClientEngineBase("OkHttp") {
public constructor() : this(OkHttpEngineConfig.Default)
public constructor() : this(OkHttpEngineConfig.Default, null)

public constructor(config: OkHttpEngineConfig) : this(config, null)

public constructor(client: OkHttpClient) : this(OkHttpEngineConfig.Default, client)

public companion object : EngineFactory<OkHttpEngineConfig.Builder, OkHttpEngine> {
/**
* Initializes a new [OkHttpEngine] via a DSL builder block
* @param block A receiver lambda which sets the properties of the config to be built
*/
public operator fun invoke(block: OkHttpEngineConfig.Builder.() -> Unit): OkHttpEngine =
OkHttpEngine(OkHttpEngineConfig(block))
OkHttpEngine(OkHttpEngineConfig(block), null)

override val engineConstructor: (OkHttpEngineConfig.Builder.() -> Unit) -> OkHttpEngine = ::invoke
}
Expand All @@ -57,7 +63,10 @@ public class OkHttpEngine(
}

private val metrics = HttpClientMetrics(TELEMETRY_SCOPE, config.telemetryProvider)
private val client = config.buildClient(metrics, connectionMonitoringListener)
private val client: OkHttpClient by lazy {
userProvidedClient?.withMetrics(metrics, config)
?: config.buildClient(metrics, connectionMonitoringListener)
}

@OptIn(ExperimentalCoroutinesApi::class)
override suspend fun roundTrip(context: ExecutionContext, request: HttpRequest): HttpCall {
Expand Down Expand Up @@ -85,9 +94,11 @@ public class OkHttpEngine(

override fun shutdown() {
connectionMonitoringListener?.closeIfCloseable()
client.connectionPool.evictAll()
client.dispatcher.executorService.shutdown()
metrics.close()
if (userProvidedClient == null) {
client.connectionPool.evictAll()
client.dispatcher.executorService.shutdown()
}
}
}

Expand Down Expand Up @@ -170,6 +181,18 @@ public fun OkHttpEngineConfig.buildClient(
}.build()
}

// Configure a user-provided client to collect SDK metrics
private fun OkHttpClient.withMetrics(metrics: HttpClientMetrics, config: OkHttpEngineConfig) = newBuilder().apply {
eventListenerFactory { call ->
EventListenerChain(
listOf(
HttpEngineEventListener(connectionPool, config.hostResolver, dispatcher, metrics, call),
),
)
}
addInterceptor(MetricsInterceptor)
}.build()

private fun tlsConnectionSpec(tlsContext: TlsContext, cipherSuites: List<String>?): ConnectionSpec {
val minVersion = tlsContext.minVersion ?: TlsVersion.TLS_1_2
val okHttpTlsVersions = SdkTlsVersion
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package aws.smithy.kotlin.runtime.http.engine.okhttp

import aws.smithy.kotlin.runtime.content.ByteStream
import aws.smithy.kotlin.runtime.http.Headers
import aws.smithy.kotlin.runtime.http.HttpException
import aws.smithy.kotlin.runtime.http.HttpMethod
import aws.smithy.kotlin.runtime.http.SdkHttpClient
import aws.smithy.kotlin.runtime.http.request.HttpRequest
import aws.smithy.kotlin.runtime.http.toHttpBody
import aws.smithy.kotlin.runtime.net.url.Url
import kotlinx.coroutines.test.runTest
import okhttp3.OkHttpClient
import java.io.IOException
import kotlin.test.Test
import kotlin.test.assertFailsWith
import kotlin.test.assertTrue

class OkHttpEngineConfigTest {
@Test
fun testUserClient() = runTest {
val userClient = OkHttpClient.Builder().apply {
addInterceptor { throw DummyOkHttpClientException() }
}.build()

val engine = OkHttpEngine(userClient)
val sdkClient = SdkHttpClient(engine)

val data = "a".repeat(100)
val url = Url.parse("https://aws.amazon.com")
val request = HttpRequest(HttpMethod.POST, url, Headers.Empty, ByteStream.fromString(data).toHttpBody())

val ex = assertFailsWith<HttpException> {
sdkClient.call(request)
}
assertTrue(ex.cause is DummyOkHttpClientException)
}

private class DummyOkHttpClientException : IOException("Custom OkHttpClient interceptor was called")
}
Loading