|
1 | 1 | import gzip |
2 | 2 | import json |
| 3 | +import os |
| 4 | +import tempfile |
3 | 5 | import time |
4 | 6 | from concurrent.futures.thread import ThreadPoolExecutor |
5 | 7 | from decimal import Decimal |
6 | 8 | from io import BytesIO |
7 | | -from typing import Callable, Tuple, Optional, Any, Dict |
| 9 | +from typing import Callable, Tuple, Optional, Any, Dict, List |
8 | 10 |
|
9 | 11 | import ijson |
10 | 12 | import requests |
|
19 | 21 | from .sdk_configs import _SDK_Configs |
20 | 22 | from .statsig_context import InitContext |
21 | 23 | from .statsig_error_boundary import _StatsigErrorBoundary |
22 | | -from .statsig_options import StatsigOptions, STATSIG_API, STATSIG_CDN |
| 24 | +from .statsig_options import StatsigOptions, STATSIG_API, STATSIG_CDN, AuthenticationMode |
| 25 | +from .grpc_websocket_worker import load_credential_from_file |
23 | 26 |
|
24 | 27 | REQUEST_TIMEOUT = 20 |
25 | 28 |
|
@@ -47,7 +50,51 @@ def __init__( |
47 | 50 | self.__statsig_metadata = statsig_metadata |
48 | 51 | self.__diagnostics = diagnostics |
49 | 52 | self.__request_count = 0 |
50 | | - self.__request_session = requests.session() |
| 53 | + self.__temp_cert_files: List[str] = [] |
| 54 | + self.__request_session = self.__init_session(options) |
| 55 | + |
| 56 | + def __init_session(self, options: StatsigOptions) -> requests.Session: |
| 57 | + session = requests.Session() |
| 58 | + http_proxy_config = None |
| 59 | + for _, config in options.proxy_configs.items(): |
| 60 | + if config.protocol == NetworkProtocol.HTTP: |
| 61 | + if config.authentication_mode in [AuthenticationMode.TLS, AuthenticationMode.MTLS]: |
| 62 | + http_proxy_config = config |
| 63 | + break |
| 64 | + if http_proxy_config is None: |
| 65 | + return session |
| 66 | + try: |
| 67 | + if http_proxy_config.authentication_mode == AuthenticationMode.TLS: |
| 68 | + ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate") |
| 69 | + if ca_cert: |
| 70 | + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file: |
| 71 | + ca_file.write(ca_cert) |
| 72 | + session.verify = ca_file.name |
| 73 | + self.__temp_cert_files.append(ca_file.name) |
| 74 | + globals.logger.log_process("HTTP Worker", "Connecting using an TLS secure channel for HTTP") |
| 75 | + elif http_proxy_config.authentication_mode == AuthenticationMode.MTLS: |
| 76 | + client_cert = load_credential_from_file(http_proxy_config.tls_client_cert_path, "TLS client certificate") |
| 77 | + client_key = load_credential_from_file(http_proxy_config.tls_client_key_path, "TLS client key") |
| 78 | + ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate") |
| 79 | + if client_cert and client_key and ca_cert: |
| 80 | + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as cert_file: |
| 81 | + cert_file.write(client_cert) |
| 82 | + cert_path = cert_file.name |
| 83 | + self.__temp_cert_files.append(cert_path) |
| 84 | + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.key') as key_file: |
| 85 | + key_file.write(client_key) |
| 86 | + key_path = key_file.name |
| 87 | + self.__temp_cert_files.append(key_path) |
| 88 | + with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file: |
| 89 | + ca_file.write(ca_cert) |
| 90 | + ca_path = ca_file.name |
| 91 | + self.__temp_cert_files.append(ca_path) |
| 92 | + session.cert = (cert_path, key_path) |
| 93 | + session.verify = ca_path |
| 94 | + globals.logger.log_process("HTTP Worker", "Connecting using an mTLS secure channel for HTTP") |
| 95 | + except Exception as e: |
| 96 | + self.__error_boundary.log_exception("http_worker:init_session", e) |
| 97 | + return session |
51 | 98 |
|
52 | 99 | def is_pull_worker(self) -> bool: |
53 | 100 | return True |
@@ -166,6 +213,12 @@ def log_events( |
166 | 213 |
|
167 | 214 | def shutdown(self) -> None: |
168 | 215 | self._executor.shutdown(wait=False) |
| 216 | + for temp_file in self.__temp_cert_files: |
| 217 | + try: |
| 218 | + os.unlink(temp_file) |
| 219 | + except Exception: |
| 220 | + pass |
| 221 | + self.__temp_cert_files.clear() |
169 | 222 |
|
170 | 223 | def _run_task_for_initialize( |
171 | 224 | self, task, timeout |
|
0 commit comments