Skip to content

Commit 05dfe19

Browse files
committed
fix rebase conflict
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent fa709a8 commit 05dfe19

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.util.Locale;
1919
import java.util.Map;
2020
import java.util.concurrent.CompletableFuture;
21+
import java.util.concurrent.atomic.AtomicReference;
2122

2223
import org.apache.commons.text.StringEscapeUtils;
2324
import org.apache.logging.log4j.Logger;
@@ -40,6 +41,8 @@
4041
import org.opensearch.transport.StreamTransportService;
4142
import org.opensearch.transport.client.Client;
4243

44+
import com.google.common.annotations.VisibleForTesting;
45+
4346
import lombok.Getter;
4447
import lombok.Setter;
4548
import lombok.extern.log4j.Log4j2;
@@ -70,19 +73,18 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
7073
@Getter
7174
private MLGuard mlGuard;
7275

73-
private SdkAsyncHttpClient httpClient;
76+
private final AtomicReference<SdkAsyncHttpClient> httpClientRef = new AtomicReference<>();
7477

7578
@Setter
7679
@Getter
7780
private StreamTransportService streamTransportService;
7881

82+
@Setter
83+
private boolean connectorPrivateIpEnabled;
84+
7985
public AwsConnectorExecutor(Connector connector) {
8086
super.initialize(connector);
8187
this.connector = (AwsConnector) connector;
82-
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
83-
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
84-
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
85-
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
8688
}
8789

8890
@Override
@@ -129,7 +131,8 @@ public void invokeRemoteService(
129131
)
130132
)
131133
.build();
132-
AccessController.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> httpClient.execute(executeRequest));
134+
AccessController
135+
.doPrivileged((PrivilegedExceptionAction<CompletableFuture<Void>>) () -> getHttpClient().execute(executeRequest));
133136
} catch (RuntimeException exception) {
134137
log.error("Failed to execute {} in aws connector: {}", action, exception.getMessage(), exception);
135138
actionListener.onFailure(exception);
@@ -180,4 +183,19 @@ private void validateLLMInterface(String llmInterface) {
180183
throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface));
181184
}
182185
}
186+
187+
@VisibleForTesting
188+
protected SdkAsyncHttpClient getHttpClient() {
189+
if (httpClientRef.get() == null) {
190+
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
191+
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
192+
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
193+
this.httpClientRef
194+
.compareAndSet(
195+
null,
196+
MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection, connectorPrivateIpEnabled)
197+
);
198+
}
199+
return httpClientRef.get();
200+
}
183201
}

0 commit comments

Comments
 (0)