|
18 | 18 | import java.util.Locale; |
19 | 19 | import java.util.Map; |
20 | 20 | import java.util.concurrent.CompletableFuture; |
| 21 | +import java.util.concurrent.atomic.AtomicReference; |
21 | 22 |
|
22 | 23 | import org.apache.commons.text.StringEscapeUtils; |
23 | 24 | import org.apache.logging.log4j.Logger; |
|
40 | 41 | import org.opensearch.transport.StreamTransportService; |
41 | 42 | import org.opensearch.transport.client.Client; |
42 | 43 |
|
| 44 | +import com.google.common.annotations.VisibleForTesting; |
| 45 | + |
43 | 46 | import lombok.Getter; |
44 | 47 | import lombok.Setter; |
45 | 48 | import lombok.extern.log4j.Log4j2; |
@@ -70,19 +73,18 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { |
70 | 73 | @Getter |
71 | 74 | private MLGuard mlGuard; |
72 | 75 |
|
73 | | - private SdkAsyncHttpClient httpClient; |
| 76 | + private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>(); |
74 | 77 |
|
75 | 78 | @Setter |
76 | 79 | @Getter |
77 | 80 | private StreamTransportService streamTransportService; |
78 | 81 |
|
| 82 | + @Setter |
| 83 | + private boolean connectorPrivateIpEnabled; |
| 84 | + |
79 | 85 | public AwsConnectorExecutor(Connector connector) { |
80 | 86 | super.initialize(connector); |
81 | 87 | this.connector = (AwsConnector) connector; |
82 | | - Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); |
83 | | - Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); |
84 | | - Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); |
85 | | - this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); |
86 | 88 | } |
87 | 89 |
|
88 | 90 | @Override |
@@ -129,7 +131,8 @@ public void invokeRemoteService( |
129 | 131 | ) |
130 | 132 | ) |
131 | 133 | .build(); |
132 | | - AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest)); |
| 134 | + AccessController |
| 135 | + .doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> getHttpClient().execute(executeRequest)); |
133 | 136 | } catch (RuntimeException exception) { |
134 | 137 | log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception); |
135 | 138 | actionListener.onFailure(exception); |
@@ -180,4 +183,19 @@ private void validateLLMInterface(String llmInterface) { |
180 | 183 | throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); |
181 | 184 | } |
182 | 185 | } |
| 186 | + |
| 187 | + @VisibleForTesting |
| 188 | + protected SdkAsyncHttpClient getHttpClient() { |
| 189 | + if (httpClientRef.get() == null) { |
| 190 | + Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout()); |
| 191 | + Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); |
| 192 | + Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); |
| 193 | + this.httpClientRef |
| 194 | + .compareAndSet( |
| 195 | + null, |
| 196 | + MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled) |
| 197 | + ); |
| 198 | + } |
| 199 | + return httpClientRef.get(); |
| 200 | + } |
183 | 201 | } |
0 commit comments