forked from BiFangKNT/mtga
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtemp_old_transport.py
More file actions
162 lines (136 loc) · 11.4 KB
/
temp_old_transport.py
File metadata and controls
162 lines (136 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations
import contextlib
import json
import os
import ssl
import time
import uuid
from collections.abc import Generator
import requests
from requests.adapters import HTTPAdapter
from modules.runtime.resource_manager import ResourceManager, is_packaged
class SSLContextAdapter(HTTPAdapter):
"""鏀寔鑷畾涔?SSLContext 鐨勯€傞厤鍣紝鐢ㄤ簬璋冩暣楠岃瘉绛栫暐銆?""
def __init__(self, ssl_context, *args, **kwargs):
self.ssl_context = ssl_context
super().__init__(*args, **kwargs)
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
pool_kwargs.setdefault("ssl_context", self.ssl_context)
return super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs)
def proxy_manager_for(self, proxy, **proxy_kwargs):
proxy_kwargs.setdefault("ssl_context", self.ssl_context)
return super().proxy_manager_for(proxy, **proxy_kwargs)
class ProxyTransport:
"""浠g悊浼犺緭灞傦細HTTP 浼氳瘽銆丼SE 瑙f瀽銆佷笂娓镐簨浠跺綊涓€鍖栥€?""
def __init__(
self,
*,
resource_manager: ResourceManager,
disable_ssl_strict_mode: bool,
log_func=print,
) -> None:
self._resource_manager = resource_manager
self._log = log_func
self._session = self._create_http_client(disable_ssl_strict_mode)
@property
def session(self) -> requests.Session:
return self._session
def close(self) -> None:
if self._session:
with contextlib.suppress(Exception):
self._session.close()
def _create_http_client(self, disable_ssl_strict_mode: bool) -> requests.Session:
session = requests.Session()
if disable_ssl_strict_mode:
try:
ctx = ssl.create_default_context()
ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
adapter = SSLContextAdapter(ctx)
session.mount("https://", adapter)
self._log("鍏抽棴 SSL 涓ユ牸妯″紡: 浣跨敤鑷畾涔?HTTPS 涓婁笅鏂?)
except Exception as exc: # noqa: BLE001
self._log(f"閰嶇疆闈炰弗鏍?SSL 涓婁笅鏂囧け璐ワ紝缁х画浣跨敤榛樿璁剧疆: {exc}")
return session
def prepare_sse_log_path(self) -> str:
base_dir = (
self._resource_manager.user_data_dir
if is_packaged()
else self._resource_manager.program_resource_dir
)
log_dir = os.path.join(base_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"sse_{timestamp}_{int(time.time() * 1000)}.log"
return os.path.join(log_dir, filename)
def extract_sse_events(
self, response, *, log_file=None, log
) -> Generator[tuple[int, bytes]]:
buffer = b""
chunk_index = 0
for chunk in response.iter_content(chunk_size=None):
chunk_index += 1
if log_file:
try:
log_file.write(chunk)
log_file.flush()
except Exception as write_exc: # noqa: BLE001
log(f"SSE 鏃ュ織鍐欏叆澶辫触锛屽仠姝㈣褰? {write_exc}")
with contextlib.suppress(Exception):
log_file.close()
log_file = None
buffer += chunk
while True:
sep = buffer.find(b"\n\n")
if sep == -1:
break
event = buffer[:sep]
buffer = buffer[sep + 2 :]
yield chunk_index, event
if buffer.strip():
log("璀﹀憡: 涓婃父 SSE 缁撴潫鏃跺瓨鍦ㄦ湭瀹屾暣鍒嗛殧鐨勬畫鐣欐暟鎹?)
yield chunk_index, buffer
@staticmethod
def _new_request_id() -> str:
return uuid.uuid4().hex[:6]
def normalize_openai_event(
self, data_str: str, event_index: int, *, model_name: str, log
) -> tuple[bytes, str | None]:
try:
payload = json.loads(data_str)
except Exception as exc: # noqa: BLE001
log(f"chunk#{event_index} JSON 瑙f瀽澶辫触锛屽師鏍烽€忎紶: {exc}")
return f"data: {data_str}\n\n".encode(), None
choices = payload.get("choices") or []
choice0 = choices[0] if choices else {}
raw_delta = choice0.get("delta") or {}
message = choice0.get("message") or {}
delta: dict[str, object] = {}
role = raw_delta.get("role") or message.get("role")
if role or event_index == 1:
delta["role"] = role or "assistant"
content = raw_delta.get("content") or message.get("content")
if content:
delta["content"] = content
for key in ("tool_calls", "function_calls", "reasoning_content"):
value = raw_delta.get(key)
if value not in (None, []):
delta[key] = value
finish_reason = choice0.get("finish_reason")
normalized_finish = finish_reason if finish_reason not in (None, "") else None
chunk_obj = {
"id": payload.get("id") or self._new_request_id(),
"object": "chat.completion.chunk",
"created": int(payload.get("created") or time.time()),
"model": payload.get("model") or model_name,
"choices": [
{
"index": choice0.get("index", 0),
"delta": delta,
"logprobs": None,
"finish_reason": normalized_finish,
}
],
}
chunk_json = json.dumps(chunk_obj, ensure_ascii=False)
return f"data: {chunk_json}\n\n".encode(), normalized_finish
__all__ = ["ProxyTransport", "SSLContextAdapter"]