diff --git a/curl_cffi/__init__.py b/curl_cffi/__init__.py index 263fe113..8245301e 100644 --- a/curl_cffi/__init__.py +++ b/curl_cffi/__init__.py @@ -16,7 +16,7 @@ from ._wrapper import ffi, lib # type: ignore from .const import CurlInfo, CurlMOpt, CurlOpt, CurlECode, CurlHttpVersion -from .curl import Curl, CurlError +from .curl import Curl, CurlError, CurlWsFrame from .aio import AsyncCurl from .__version__ import __title__, __version__, __description__, __curl_version__ diff --git a/curl_cffi/curl.py b/curl_cffi/curl.py index 672dcd0b..edb86144 100644 --- a/curl_cffi/curl.py +++ b/curl_cffi/curl.py @@ -1,16 +1,37 @@ +from __future__ import annotations + +import ctypes import re import warnings from http.cookies import SimpleCookie -from typing import Any, List, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Tuple, Union import certifi from ._wrapper import ffi, lib # type: ignore from .const import CurlHttpVersion, CurlInfo, CurlOpt, CurlWsFlag + DEFAULT_CACERT = certifi.where() +class CurlWsFrame(ctypes.Structure): + _fields_ = [ + ("age", ctypes.c_int), + ("flags", ctypes.c_int), + ("offset", ctypes.c_uint64), + ("bytesleft", ctypes.c_uint64), + ("len", ctypes.c_size_t), + ] + + if TYPE_CHECKING: + age: int + flags: int + offset: int + bytesleft: int + len: int + + class CurlError(Exception): """Base exception for curl_cffi package""" @@ -50,11 +71,13 @@ def buffer_callback(ptr, size, nmemb, userdata): buffer.write(ffi.buffer(ptr, nmemb)[:]) return nmemb * size + def ensure_int(s): if not s: return 0 return int(s) + @ffi.def_extern() def write_callback(ptr, size, nmemb, userdata): # although similar enough to the function above, kept here for performance reasons @@ -85,7 +108,7 @@ class Curl: Wrapper for `curl_easy_*` functions of libcurl. """ - def __init__(self, cacert: str = DEFAULT_CACERT, debug: bool = False, handle = None): + def __init__(self, cacert: str = DEFAULT_CACERT, debug: bool = False, handle=None): """ Parameters: cacert: CA cert path to use, by default, curl_cffi uses its own bundled cert. @@ -159,15 +182,11 @@ def setopt(self, option: CurlOpt, value: Any): elif option == CurlOpt.WRITEDATA: c_value = ffi.new_handle(value) self._write_handle = c_value - lib._curl_easy_setopt( - self._curl, CurlOpt.WRITEFUNCTION, lib.buffer_callback - ) + lib._curl_easy_setopt(self._curl, CurlOpt.WRITEFUNCTION, lib.buffer_callback) elif option == CurlOpt.HEADERDATA: c_value = ffi.new_handle(value) self._header_handle = c_value - lib._curl_easy_setopt( - self._curl, CurlOpt.HEADERFUNCTION, lib.buffer_callback - ) + lib._curl_easy_setopt(self._curl, CurlOpt.HEADERFUNCTION, lib.buffer_callback) elif option == CurlOpt.WRITEFUNCTION: c_value = ffi.new_handle(value) self._write_handle = c_value @@ -246,9 +265,7 @@ def impersonate(self, target: str, default_headers: bool = True) -> int: target: browser to impersonate. default_headers: whether to add default headers, like User-Agent. """ - return lib.curl_easy_impersonate( - self._curl, target.encode(), int(default_headers) - ) + return lib.curl_easy_impersonate(self._curl, target.encode(), int(default_headers)) def _ensure_cacert(self): if not self._is_cert_set: @@ -346,7 +363,7 @@ def close(self): ffi.release(self._error_buffer) self._resolve = ffi.NULL - def ws_recv(self, n: int = 1024): + def ws_recv(self, n: int = 1024) -> Tuple[bytes, CurlWsFrame]: buffer = ffi.new("char[]", n) n_recv = ffi.new("int *") p_frame = ffi.new("struct curl_ws_frame **") @@ -365,6 +382,3 @@ def ws_send(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY) -> int: ret = lib.curl_ws_send(self._curl, buffer, len(buffer), n_sent, 0, flags) self._check_error(ret, "WS_SEND") return n_sent[0] - - def ws_close(self): - self.ws_send(b"", CurlWsFlag.CLOSE) diff --git a/curl_cffi/requests/__init__.py b/curl_cffi/requests/__init__.py index aa95cd0b..9ecf22af 100644 --- a/curl_cffi/requests/__init__.py +++ b/curl_cffi/requests/__init__.py @@ -16,6 +16,7 @@ "Headers", "Request", "Response", + "AsyncWebSocket", "WebSocket", "WebSocketError", "WsCloseCode", @@ -27,11 +28,11 @@ from ..const import CurlHttpVersion, CurlWsFlag from .cookies import Cookies, CookieTypes -from .models import Request, Response +from .models import BrowserType, Request, Response from .errors import RequestsError from .headers import Headers, HeaderTypes -from .session import AsyncSession, BrowserType, Session, ProxySpec -from .websockets import WebSocket, WebSocketError, WsCloseCode +from .session import AsyncSession, Session, ProxySpec +from .websockets import AsyncWebSocket, WebSocket, WebSocketError, WsCloseCode # ThreadType = Literal["eventlet", "gevent", None] @@ -52,7 +53,7 @@ def request( proxies: Optional[ProxySpec] = None, proxy: Optional[str] = None, proxy_auth: Optional[Tuple[str, str]] = None, - verify: Optional[bool] = None, + verify: Optional[Union[bool, str]] = None, referer: Optional[str] = None, accept_encoding: Optional[str] = "gzip, deflate, br", content_callback: Optional[Callable] = None, diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index f2d50f2e..da56bff7 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -1,7 +1,8 @@ -import warnings +from enum import Enum from json import loads from typing import Optional import queue +import warnings from .. import Curl from .headers import Headers @@ -16,6 +17,33 @@ def clear_queue(q: queue.Queue): q.unfinished_tasks = 0 +class BrowserType(str, Enum): + edge99 = "edge99" + edge101 = "edge101" + chrome99 = "chrome99" + chrome100 = "chrome100" + chrome101 = "chrome101" + chrome104 = "chrome104" + chrome107 = "chrome107" + chrome110 = "chrome110" + chrome116 = "chrome116" + chrome119 = "chrome119" + chrome120 = "chrome120" + chrome99_android = "chrome99_android" + safari15_3 = "safari15_3" + safari15_5 = "safari15_5" + safari17_0 = "safari17_0" + safari17_2_ios = "safari17_2_ios" + + chrome = "chrome120" + safari = "safari17_0" + safari_ios = "safari17_2_ios" + + @classmethod + def has(cls, item): + return item in cls.__members__ + + class Request: def __init__(self, url: str, headers: Headers, method: str): self.url = url @@ -86,9 +114,7 @@ def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None): """ pending = None - for chunk in self.iter_content( - chunk_size=chunk_size, decode_unicode=decode_unicode - ): + for chunk in self.iter_content(chunk_size=chunk_size, decode_unicode=decode_unicode): if pending is not None: chunk = pending + chunk if delimiter: @@ -139,9 +165,7 @@ async def aiter_lines(self, chunk_size=None, decode_unicode=False, delimiter=Non """ pending = None - async for chunk in self.aiter_content( - chunk_size=chunk_size, decode_unicode=decode_unicode - ): + async for chunk in self.aiter_content(chunk_size=chunk_size, decode_unicode=decode_unicode): if pending is not None: chunk = pending + chunk if delimiter: diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index dc54ce43..f54c4af6 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -4,22 +4,18 @@ import threading import warnings import queue -from enum import Enum from functools import partialmethod from io import BytesIO -from json import dumps -from typing import Callable, Dict, List, Any, Optional, Tuple, Union, cast, TYPE_CHECKING -from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse +from typing import Callable, Dict, Optional, Tuple, Union, cast, TYPE_CHECKING from concurrent.futures import ThreadPoolExecutor - from .. import AsyncCurl, Curl, CurlError, CurlInfo, CurlOpt, CurlHttpVersion -from ..curl import CURL_WRITEFUNC_ERROR from .cookies import Cookies, CookieTypes, CurlMorsel from .errors import RequestsError from .headers import Headers, HeaderTypes -from .models import Request, Response -from .websockets import WebSocket +from .models import BrowserType, Response +from .utils import _set_curl_options, _update_url_params, not_set +from .websockets import AsyncWebSocket try: import gevent @@ -45,97 +41,12 @@ class ProxySpec(TypedDict, total=False): ProxySpec = Dict[str, str] -class BrowserType(str, Enum): - edge99 = "edge99" - edge101 = "edge101" - chrome99 = "chrome99" - chrome100 = "chrome100" - chrome101 = "chrome101" - chrome104 = "chrome104" - chrome107 = "chrome107" - chrome110 = "chrome110" - chrome116 = "chrome116" - chrome119 = "chrome119" - chrome120 = "chrome120" - chrome99_android = "chrome99_android" - safari15_3 = "safari15_3" - safari15_5 = "safari15_5" - safari17_0 = "safari17_0" - safari17_2_ios = "safari17_2_ios" - - chrome = "chrome120" - safari = "safari17_0" - safari_ios = "safari17_2_ios" - - @classmethod - def has(cls, item): - return item in cls.__members__ - - class BrowserSpec: """A more structured way of selecting browsers""" # TODO -def _update_url_params(url: str, params: Dict) -> str: - """Add GET params to provided URL being aware of existing. - - Parameters: - url: string of target URL - params: dict containing requested params to be added - - Returns: - string with updated URL - - >> url = 'http://stackoverflow.com/test?answers=true' - >> new_params = {'answers': False, 'data': ['some','values']} - >> _update_url_params(url, new_params) - 'http://stackoverflow.com/test?data=some&data=values&answers=false' - """ - # Unquoting URL first so we don't loose existing args - url = unquote(url) - # Extracting url info - parsed_url = urlparse(url) - # Extracting URL arguments from parsed URL - get_args = parsed_url.query - # Converting URL arguments to dict - parsed_get_args = dict(parse_qsl(get_args)) - # Merging URL arguments dict with new params - parsed_get_args.update(params) - - # Bool and Dict values should be converted to json-friendly values - # you may throw this part away if you don't like it :) - parsed_get_args.update( - {k: dumps(v) for k, v in parsed_get_args.items() if isinstance(v, (bool, dict))} - ) - - # Converting URL argument to proper query string - encoded_get_args = urlencode(parsed_get_args, doseq=True) - # Creating new parsed result object based on provided with new - # URL arguments. Same thing happens inside of urlparse. - new_url = ParseResult( - parsed_url.scheme, - parsed_url.netloc, - parsed_url.path, - parsed_url.params, - encoded_get_args, - parsed_url.fragment, - ).geturl() - - return new_url - - -def _update_header_line(header_lines: List[str], key: str, value: str): - """Update header line list by key value pair.""" - for idx, line in enumerate(header_lines): - if line.lower().startswith(key.lower() + ":"): - header_lines[idx] = f"{key}: {value}" - break - else: # if not break - header_lines.append(f"{key}: {value}") - - def _peek_queue(q: queue.Queue, default=None): try: return q.queue[0] @@ -150,9 +61,6 @@ def _peek_aio_queue(q: asyncio.Queue, default=None): return default -not_set = object() - - class BaseSession: """Provide common methods for setting curl options and reading info in sessions.""" @@ -166,7 +74,7 @@ def __init__( proxy: Optional[str] = None, proxy_auth: Optional[Tuple[str, str]] = None, params: Optional[dict] = None, - verify: bool = True, + verify: Union[bool, str] = True, timeout: Union[float, Tuple[float, float]] = 30, trust_env: bool = True, allow_redirects: bool = True, @@ -203,17 +111,14 @@ def __init__( self.proxies: ProxySpec = proxies or {} self.proxy_auth = proxy_auth - def _set_curl_options( + def _merge_curl_options( self, curl, method: str, url: str, - params: Optional[dict] = None, - data: Optional[Union[Dict[str, str], str, BytesIO, bytes]] = None, - json: Optional[dict] = None, + *, headers: Optional[HeaderTypes] = None, cookies: Optional[CookieTypes] = None, - files: Optional[Dict] = None, auth: Optional[Tuple[str, str]] = None, timeout: Optional[Union[float, Tuple[float, float], object]] = not_set, allow_redirects: Optional[bool] = None, @@ -222,143 +127,19 @@ def _set_curl_options( proxy: Optional[str] = None, proxy_auth: Optional[Tuple[str, str]] = None, verify: Optional[Union[bool, str]] = None, - referer: Optional[str] = None, - accept_encoding: Optional[str] = "gzip, deflate, br", - content_callback: Optional[Callable] = None, impersonate: Optional[Union[str, BrowserType]] = None, default_headers: Optional[bool] = None, http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, - stream: bool = False, - max_recv_speed: int = 0, - queue_class: Any = None, - event_class: Any = None, + **kwargs, ): - c = curl - - # method - if method == "POST": - c.setopt(CurlOpt.POST, 1) - elif method != "GET": - c.setopt(CurlOpt.CUSTOMREQUEST, method.encode()) - - # url + """Merge curl options from session and request.""" if self.params: url = _update_url_params(url, self.params) - if params: - url = _update_url_params(url, params) - c.setopt(CurlOpt.URL, url.encode()) - - # data/body/json - if isinstance(data, dict): - body = urlencode(data).encode() - elif isinstance(data, str): - body = data.encode() - elif isinstance(data, BytesIO): - body = data.read() - elif isinstance(data, bytes): - body = data - elif data is None: - body = b"" - else: - raise TypeError("data must be dict, str, BytesIO or bytes") - if json is not None: - body = dumps(json, separators=(",", ":")).encode() - - # Tell libcurl to be aware of bodies and related headers when, - # 1. POST/PUT/PATCH, even if the body is empty, it's up to curl to decide what to do; - # 2. GET/DELETE with body, although it's against the RFC, some applications. e.g. Elasticsearch, use this. - if body or method in ("POST", "PUT", "PATCH"): - c.setopt(CurlOpt.POSTFIELDS, body) - # necessary if body contains '\0' - c.setopt(CurlOpt.POSTFIELDSIZE, len(body)) - - # headers - h = Headers(self.headers) - h.update(headers) - - # remove Host header if it's unnecessary, otherwise curl maybe confused. - # Host header will be automatically added by curl if it's not present. - # https://github.com/yifeikong/curl_cffi/issues/119 - host_header = h.get("Host") - if host_header is not None: - u = urlparse(url) - if host_header == u.netloc or host_header == u.hostname: - try: - del h["Host"] - except KeyError: - pass - - header_lines = [] - for k, v in h.multi_items(): - header_lines.append(f"{k}: {v}") - if json is not None: - _update_header_line(header_lines, "Content-Type", "application/json") - if isinstance(data, dict) and method != "POST": - _update_header_line( - header_lines, "Content-Type", "application/x-www-form-urlencoded" - ) - # print("header lines", header_lines) - c.setopt(CurlOpt.HTTPHEADER, [h.encode() for h in header_lines]) - - req = Request(url, h, method) - - # cookies - c.setopt(CurlOpt.COOKIEFILE, b"") # always enable the curl cookie engine first - c.setopt(CurlOpt.COOKIELIST, "ALL") # remove all the old cookies first. - - for morsel in self.cookies.get_cookies_for_curl(req): - # print("Setting", morsel.to_curl_format()) - curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format()) - if cookies: - temp_cookies = Cookies(cookies) - for morsel in temp_cookies.get_cookies_for_curl(req): - curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format()) - - # files - if files: - raise NotImplementedError("Files has not been implemented.") - - # auth - if self.auth or auth: - if self.auth: - username, password = self.auth - if auth: - username, password = auth - c.setopt(CurlOpt.USERNAME, username.encode()) # type: ignore - c.setopt(CurlOpt.PASSWORD, password.encode()) # type: ignore - - # timeout - if timeout is not_set: - timeout = self.timeout - if timeout is None: - timeout = 0 # indefinitely - - if isinstance(timeout, tuple): - connect_timeout, read_timeout = timeout - all_timeout = connect_timeout + read_timeout - c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(connect_timeout * 1000)) - if not stream: - c.setopt(CurlOpt.TIMEOUT_MS, int(all_timeout * 1000)) - else: - if not stream: - c.setopt(CurlOpt.TIMEOUT_MS, int(timeout * 1000)) # type: ignore - else: - c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(timeout * 1000)) # type: ignore - - # allow_redirects - c.setopt( - CurlOpt.FOLLOWLOCATION, - int(self.allow_redirects if allow_redirects is None else allow_redirects), - ) - # max_redirects - c.setopt( - CurlOpt.MAXREDIRS, - self.max_redirects if max_redirects is None else max_redirects, - ) + _headers = Headers(self.headers) + _headers.update(headers) - # proxies if proxy and proxies: raise TypeError("Cannot specify both 'proxy' and 'proxies'") if proxy: @@ -366,115 +147,28 @@ def _set_curl_options( if proxies is None: proxies = self.proxies - if proxies: - parts = urlparse(url) - proxy = proxies.get(parts.scheme, proxies.get("all")) - if parts.hostname: - proxy = proxies.get( - f"{parts.scheme}://{parts.hostname}", - proxies.get(f"all://{parts.hostname}"), - ) or proxy - - if proxy is not None: - if parts.scheme == "https" and proxy.startswith("https://"): - warnings.warn( - "You may be using http proxy WRONG, the prefix should be 'http://' not 'https://'," - "see: https://github.com/yifeikong/curl_cffi/issues/6", - RuntimeWarning, - stacklevel=2, - ) - - c.setopt(CurlOpt.PROXY, proxy) - # for http proxy, need to tell curl to enable tunneling - if not proxy.startswith("socks"): - c.setopt(CurlOpt.HTTPPROXYTUNNEL, 1) - - # proxy_auth - proxy_auth = proxy_auth or self.proxy_auth - if proxy_auth: - username, password = proxy_auth - c.setopt(CurlOpt.PROXYUSERNAME, username.encode()) - c.setopt(CurlOpt.PROXYPASSWORD, password.encode()) - - # verify - if verify is False or not self.verify and verify is None: - c.setopt(CurlOpt.SSL_VERIFYPEER, 0) - c.setopt(CurlOpt.SSL_VERIFYHOST, 0) - - # cert for this single request - if isinstance(verify, str): - c.setopt(CurlOpt.CAINFO, verify) - - # cert for the session - if verify in (None, True) and isinstance(self.verify, str): - c.setopt(CurlOpt.CAINFO, self.verify) - - # referer - if referer: - c.setopt(CurlOpt.REFERER, referer.encode()) - - # accept_encoding - if accept_encoding is not None: - c.setopt(CurlOpt.ACCEPT_ENCODING, accept_encoding.encode()) - - # impersonate - impersonate = impersonate or self.impersonate - default_headers = ( - self.default_headers if default_headers is None else default_headers + return _set_curl_options( + curl, + method, + url, + headers=_headers, + session_cookies=self.cookies, + cookies=cookies, + auth=auth or self.auth, + timeout=self.timeout if timeout is not_set else timeout, + allow_redirects=self.allow_redirects if allow_redirects is None else allow_redirects, + max_redirects=self.max_redirects if max_redirects is None else max_redirects, + proxies=proxies, + proxy_auth=proxy_auth or self.proxy_auth, + session_verify=self.verify, # Not possible to merge verify parameter + verify=verify, + impersonate=impersonate or self.impersonate, + default_headers=self.default_headers if default_headers is None else default_headers, + http_version=http_version or self.http_version, + interface=interface or self.interface, + curl_options=self.curl_options, + **kwargs, ) - if impersonate: - if not BrowserType.has(impersonate): - raise RequestsError(f"impersonate {impersonate} is not supported") - c.impersonate(impersonate, default_headers=default_headers) - - # http_version, after impersonate, which will change this to http2 - http_version = http_version or self.http_version - if http_version: - c.setopt(CurlOpt.HTTP_VERSION, http_version) - - # set extra curl options, must come after impersonate, because it will alter some options - for k, v in self.curl_options.items(): - c.setopt(k, v) - - buffer = None - q = None - header_recved = None - quit_now = None - if stream: - q = queue_class() # type: ignore - header_recved = event_class() - quit_now = event_class() - - def qput(chunk): - if not header_recved.is_set(): - header_recved.set() - if quit_now.is_set(): - return CURL_WRITEFUNC_ERROR - q.put_nowait(chunk) - return len(chunk) - - c.setopt(CurlOpt.WRITEFUNCTION, qput) # type: ignore - elif content_callback is not None: - c.setopt(CurlOpt.WRITEFUNCTION, content_callback) - else: - buffer = BytesIO() - c.setopt(CurlOpt.WRITEDATA, buffer) - header_buffer = BytesIO() - c.setopt(CurlOpt.HEADERDATA, header_buffer) - - if method == "HEAD": - c.setopt(CurlOpt.NOBODY, 1) - - # interface - interface = interface or self.interface - if interface: - c.setopt(CurlOpt.INTERFACE, interface.encode()) - - # max_recv_speed - # do not check, since 0 is a valid value to disable it - c.setopt(CurlOpt.MAX_RECV_SPEED_LARGE, max_recv_speed) - - return req, buffer, header_buffer, q, header_recved, quit_now def _parse_response(self, curl, buffer, header_buffer): c = curl @@ -504,9 +198,7 @@ def _parse_response(self, curl, buffer, header_buffer): header_list.append(header_line) rsp.headers = Headers(header_list) # print("Set-cookie", rsp.headers["set-cookie"]) - morsels = [ - CurlMorsel.from_curl_format(l) for l in c.getinfo(CurlInfo.COOKIELIST) - ] + morsels = [CurlMorsel.from_curl_format(l) for l in c.getinfo(CurlInfo.COOKIELIST)] # for l in c.getinfo(CurlInfo.COOKIELIST): # print("Curl Cookies", l.decode()) @@ -629,31 +321,6 @@ def stream(self, *args, **kwargs): finally: rsp.close() - def ws_connect( - self, - url, - *args, - on_message: Optional[Callable[[WebSocket, str], None]] = None, - on_error: Optional[Callable[[WebSocket, str], None]] = None, - on_open: Optional[Callable] = None, - on_close: Optional[Callable] = None, - **kwargs, - ): - self._set_curl_options(self.curl, "GET", url, *args, **kwargs) - - # https://curl.se/docs/websocket.html - self.curl.setopt(CurlOpt.CONNECT_ONLY, 2) - self.curl.perform() - - return WebSocket( - self, - self.curl, - on_message=on_message, - on_error=on_error, - on_open=on_open, - on_close=on_close, - ) - def request( self, method: str, @@ -671,7 +338,7 @@ def request( proxies: Optional[ProxySpec] = None, proxy: Optional[str] = None, proxy_auth: Optional[Tuple[str, str]] = None, - verify: Optional[bool] = None, + verify: Optional[Union[bool, str]] = None, referer: Optional[str] = None, accept_encoding: Optional[str] = "gzip, deflate, br", content_callback: Optional[Callable] = None, @@ -691,7 +358,7 @@ def request( else: c = self.curl - req, buffer, header_buffer, q, header_recved, quit_now = self._set_curl_options( + req, buffer, header_buffer, q, header_recved, quit_now = self._merge_curl_options( c, method=method, url=url, @@ -907,13 +574,82 @@ async def stream(self, *args, **kwargs): finally: await rsp.aclose() - async def ws_connect(self, url, *args, **kwargs): + async def ws_connect( + self, + url, + *, + autoclose: bool = True, + headers: Optional[HeaderTypes] = None, + cookies: Optional[CookieTypes] = None, + auth: Optional[Tuple[str, str]] = None, + timeout: Optional[Union[float, Tuple[float, float]]] = None, + allow_redirects: Optional[bool] = None, + max_redirects: Optional[int] = None, + proxies: Optional[ProxySpec] = None, + proxy: Optional[str] = None, + proxy_auth: Optional[Tuple[str, str]] = None, + verify: Optional[Union[bool, str]] = True, + referer: Optional[str] = None, + accept_encoding: Optional[str] = "gzip, deflate, br", + impersonate: Optional[Union[str, BrowserType]] = None, + default_headers: Optional[bool] = None, + http_version: Optional[CurlHttpVersion] = None, + interface: Optional[str] = None, + max_recv_speed: int = 0, + ): + """Connect to the WebSocket. + + libcurl automatically handles pings and pongs. + ref: https://curl.se/libcurl/c/libcurl-ws.html + + Parameters: + url: url for the requests. + autoclose: whether to close the WebSocket after receiving a close frame. + headers: headers to send. + cookies: cookies to use. + auth: HTTP basic auth, a tuple of (username, password), only basic auth is supported. + timeout: how many seconds to wait before giving up. + allow_redirects: whether to allow redirection. + max_redirects: max redirect counts, default unlimited(-1). + proxies: dict of proxies to use, format: {"http": proxy_url, "https": proxy_url}. + proxy: proxy to use, format: "http://proxy_url". Cannot be used with the above parameter. + proxy_auth: HTTP basic auth for proxy, a tuple of (username, password). + verify: whether to verify https certs. + referer: shortcut for setting referer header. + accept_encoding: shortcut for setting accept-encoding header. + impersonate: which browser version to impersonate. + default_headers: whether to set default browser headers. + curl_options: extra curl options to use. + http_version: limiting http version, http2 will be tries by default. + interface: which interface use in request to server. + max_recv_speed: max receiving speed in bytes per second. + """ curl = await self.pop_curl() - # curl.debug() - self._set_curl_options(curl, "GET", url, *args, **kwargs) + self._merge_curl_options( + curl, + "GET", + url, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + max_redirects=max_redirects, + proxies=proxies, + proxy=proxy, + proxy_auth=proxy_auth, + verify=verify, + referer=referer, + accept_encoding=accept_encoding, + impersonate=impersonate, + default_headers=default_headers, + http_version=http_version, + interface=interface, + max_recv_speed=max_recv_speed, + ) curl.setopt(CurlOpt.CONNECT_ONLY, 2) # https://curl.se/docs/websocket.html await self.loop.run_in_executor(None, curl.perform) - return WebSocket(self, curl) + return AsyncWebSocket(curl, autoclose=autoclose) async def request( self, @@ -945,7 +681,7 @@ async def request( ): """Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters.""" curl = await self.pop_curl() - req, buffer, header_buffer, q, header_recved, quit_now = self._set_curl_options( + req, buffer, header_buffer, q, header_recved, quit_now = self._merge_curl_options( curl=curl, method=method, url=url, diff --git a/curl_cffi/requests/utils.py b/curl_cffi/requests/utils.py new file mode 100644 index 00000000..de13132b --- /dev/null +++ b/curl_cffi/requests/utils.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from io import BytesIO +from json import dumps +from typing import Callable, Dict, List, Any, Optional, Tuple, Union, TYPE_CHECKING +from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse +import warnings + +from .. import CurlOpt +from ..curl import CURL_WRITEFUNC_ERROR +from .cookies import Cookies, CookieTypes +from .errors import RequestsError +from .headers import Headers, HeaderTypes +from .models import BrowserType, Request + +if TYPE_CHECKING: + from ..const import CurlHttpVersion + from .cookies import CookieTypes + from .headers import HeaderTypes + from .session import ProxySpec + +not_set: Any = object() + + +def _update_url_params(url: str, params: dict) -> str: + """Add GET params to provided URL being aware of existing. + + Parameters: + url: string of target URL + params: dict containing requested params to be added + + Returns: + string with updated URL + + >> url = 'http://stackoverflow.com/test?answers=true' + >> new_params = {'answers': False, 'data': ['some','values']} + >> _update_url_params(url, new_params) + 'http://stackoverflow.com/test?data=some&data=values&answers=false' + """ + # Unquoting URL first so we don't loose existing args + url = unquote(url) + # Extracting url info + parsed_url = urlparse(url) + # Extracting URL arguments from parsed URL + get_args = parsed_url.query + # Converting URL arguments to dict + parsed_get_args = dict(parse_qsl(get_args)) + # Merging URL arguments dict with new params + parsed_get_args.update(params) + + # Bool and Dict values should be converted to json-friendly values + # you may throw this part away if you don't like it :) + parsed_get_args.update({k: dumps(v) for k, v in parsed_get_args.items() if isinstance(v, (bool, dict))}) + + # Converting URL argument to proper query string + encoded_get_args = urlencode(parsed_get_args, doseq=True) + # Creating new parsed result object based on provided with new + # URL arguments. Same thing happens inside of urlparse. + new_url = ParseResult( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + parsed_url.params, + encoded_get_args, + parsed_url.fragment, + ).geturl() + + return new_url + + +def _update_header_line(header_lines: List[str], key: str, value: str): + """Update header line list by key value pair.""" + for idx, line in enumerate(header_lines): + if line.lower().startswith(key.lower() + ":"): + header_lines[idx] = f"{key}: {value}" + break + else: # if not break + header_lines.append(f"{key}: {value}") + + +def _set_curl_options( + curl, + method: str, + url: str, + *, + params: Optional[dict] = None, + data: Optional[Union[Dict[str, str], str, BytesIO, bytes]] = None, + json: Optional[dict] = None, + headers: Optional[HeaderTypes] = None, + session_cookies: Optional[Cookies] = None, + cookies: Optional[CookieTypes] = None, + files: Optional[Dict] = None, + auth: Optional[Tuple[str, str]] = None, + timeout: Optional[Union[float, Tuple[float, float], object]] = not_set, + allow_redirects: bool = True, + max_redirects: int = -1, + proxies: Optional[ProxySpec] = None, + proxy_auth: Optional[Tuple[str, str]] = None, + session_verify: Union[bool, str] = True, + verify: Optional[Union[bool, str]] = None, + referer: Optional[str] = None, + accept_encoding: Optional[str] = "gzip, deflate, br", + content_callback: Optional[Callable] = None, + impersonate: Optional[Union[str, BrowserType]] = None, + default_headers: bool = True, + http_version: Optional[CurlHttpVersion] = None, + interface: Optional[str] = None, + stream: bool = False, + max_recv_speed: int = 0, + queue_class: Any = None, + event_class: Any = None, + curl_options: Optional[dict] = None, +): + c = curl + + # method + if method == "POST": + c.setopt(CurlOpt.POST, 1) + elif method != "GET": + c.setopt(CurlOpt.CUSTOMREQUEST, method.encode()) + + # url + if params: + url = _update_url_params(url, params) + c.setopt(CurlOpt.URL, url.encode()) + + # data/body/json + if isinstance(data, dict): + body = urlencode(data).encode() + elif isinstance(data, str): + body = data.encode() + elif isinstance(data, BytesIO): + body = data.read() + elif isinstance(data, bytes): + body = data + elif data is None: + body = b"" + else: + raise TypeError("data must be dict, str, BytesIO or bytes") + if json is not None: + body = dumps(json, separators=(",", ":")).encode() + + # Tell libcurl to be aware of bodies and related headers when, + # 1. POST/PUT/PATCH, even if the body is empty, it's up to curl to decide what to do; + # 2. GET/DELETE with body, although it's against the RFC, some applications. e.g. Elasticsearch, use this. + if body or method in ("POST", "PUT", "PATCH"): + c.setopt(CurlOpt.POSTFIELDS, body) + # necessary if body contains '\0' + c.setopt(CurlOpt.POSTFIELDSIZE, len(body)) + + # headers + h = Headers(headers) + + # remove Host header if it's unnecessary, otherwise curl maybe confused. + # Host header will be automatically added by curl if it's not present. + # https://github.com/yifeikong/curl_cffi/issues/119 + host_header = h.get("Host") + if host_header is not None: + u = urlparse(url) + if host_header == u.netloc or host_header == u.hostname: + try: + del h["Host"] + except KeyError: + pass + + header_lines = [] + for k, v in h.multi_items(): + header_lines.append(f"{k}: {v}") + if json is not None: + _update_header_line(header_lines, "Content-Type", "application/json") + if isinstance(data, dict) and method != "POST": + _update_header_line(header_lines, "Content-Type", "application/x-www-form-urlencoded") + # print("header lines", header_lines) + c.setopt(CurlOpt.HTTPHEADER, [h.encode() for h in header_lines]) + + req = Request(url, h, method) + + # cookies + c.setopt(CurlOpt.COOKIEFILE, b"") # always enable the curl cookie engine first + c.setopt(CurlOpt.COOKIELIST, "ALL") # remove all the old cookies first. + + if session_cookies: + for morsel in session_cookies.get_cookies_for_curl(req): + # print("Setting", morsel.to_curl_format()) + curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format()) + if cookies: + temp_cookies = Cookies(cookies) + for morsel in temp_cookies.get_cookies_for_curl(req): + curl.setopt(CurlOpt.COOKIELIST, morsel.to_curl_format()) + + # files + if files: + raise NotImplementedError("Files has not been implemented.") + + # auth + if auth: + username, password = auth + c.setopt(CurlOpt.USERNAME, username.encode()) # type: ignore + c.setopt(CurlOpt.PASSWORD, password.encode()) # type: ignore + + # timeout + if timeout is None: + timeout = 0 # indefinitely + + if isinstance(timeout, tuple): + connect_timeout, read_timeout = timeout + all_timeout = connect_timeout + read_timeout + c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(connect_timeout * 1000)) + if not stream: + c.setopt(CurlOpt.TIMEOUT_MS, int(all_timeout * 1000)) + else: + if not stream: + c.setopt(CurlOpt.TIMEOUT_MS, int(timeout * 1000)) # type: ignore + else: + c.setopt(CurlOpt.CONNECTTIMEOUT_MS, int(timeout * 1000)) # type: ignore + + # allow_redirects + c.setopt(CurlOpt.FOLLOWLOCATION, int(allow_redirects)) + + # max_redirects + c.setopt(CurlOpt.MAXREDIRS, max_redirects) + + # proxies + if proxies: + parts = urlparse(url) + proxy = proxies.get(parts.scheme, proxies.get("all")) + if parts.hostname: + proxy = ( + proxies.get( + f"{parts.scheme}://{parts.hostname}", + proxies.get(f"all://{parts.hostname}"), + ) + or proxy + ) + + if proxy is not None: + if parts.scheme == "https" and proxy.startswith("https://"): + warnings.warn( + "You may be using http proxy WRONG, the prefix should be 'http://' not 'https://'," + "see: https://github.com/yifeikong/curl_cffi/issues/6", + RuntimeWarning, + stacklevel=2, + ) + + c.setopt(CurlOpt.PROXY, proxy) + # for http proxy, need to tell curl to enable tunneling + if not proxy.startswith("socks"): + c.setopt(CurlOpt.HTTPPROXYTUNNEL, 1) + + # proxy_auth + if proxy_auth: + username, password = proxy_auth + c.setopt(CurlOpt.PROXYUSERNAME, username.encode()) + c.setopt(CurlOpt.PROXYPASSWORD, password.encode()) + + # verify + if verify is False or not session_verify and verify is None: + c.setopt(CurlOpt.SSL_VERIFYPEER, 0) + c.setopt(CurlOpt.SSL_VERIFYHOST, 0) + + # cert for this single request + if isinstance(verify, str): + c.setopt(CurlOpt.CAINFO, verify) + + # cert for the session + if verify in (None, True) and isinstance(session_verify, str): + c.setopt(CurlOpt.CAINFO, session_verify) + + # referer + if referer: + c.setopt(CurlOpt.REFERER, referer.encode()) + + # accept_encoding + if accept_encoding is not None: + c.setopt(CurlOpt.ACCEPT_ENCODING, accept_encoding.encode()) + + # impersonate + if impersonate: + if not BrowserType.has(impersonate): + raise RequestsError(f"impersonate {impersonate} is not supported") + c.impersonate(impersonate, default_headers=default_headers) + + # http_version, after impersonate, which will change this to http2 + if http_version: + c.setopt(CurlOpt.HTTP_VERSION, http_version) + + # set extra curl options, must come after impersonate, because it will alter some options + if curl_options: + for k, v in curl_options.items(): + c.setopt(k, v) + + buffer = None + q = None + header_recved = None + quit_now = None + if stream: + q = queue_class() # type: ignore + header_recved = event_class() + quit_now = event_class() + + def qput(chunk): + if not header_recved.is_set(): + header_recved.set() + if quit_now.is_set(): + return CURL_WRITEFUNC_ERROR + q.put_nowait(chunk) + return len(chunk) + + c.setopt(CurlOpt.WRITEFUNCTION, qput) # type: ignore + elif content_callback is not None: + c.setopt(CurlOpt.WRITEFUNCTION, content_callback) + else: + buffer = BytesIO() + c.setopt(CurlOpt.WRITEDATA, buffer) + header_buffer = BytesIO() + c.setopt(CurlOpt.HEADERDATA, header_buffer) + + if method == "HEAD": + c.setopt(CurlOpt.NOBODY, 1) + + # interface + if interface: + c.setopt(CurlOpt.INTERFACE, interface.encode()) + + # max_recv_speed + # do not check, since 0 is a valid value to disable it + c.setopt(CurlOpt.MAX_RECV_SPEED_LARGE, max_recv_speed) + + return req, buffer, header_buffer, q, header_recved, quit_now diff --git a/curl_cffi/requests/websockets.py b/curl_cffi/requests/websockets.py index 247a1c0d..46173de1 100644 --- a/curl_cffi/requests/websockets.py +++ b/curl_cffi/requests/websockets.py @@ -1,16 +1,33 @@ +from __future__ import annotations + import asyncio +from select import select import struct from enum import IntEnum -from typing import Callable, Optional, Tuple +from json import loads, dumps +from typing import Any, Callable, Optional, Tuple, TYPE_CHECKING, TypeVar, Union + +from .utils import _set_curl_options, not_set +from ..const import CurlECode, CurlOpt, CurlWsFlag +from ..curl import Curl, CurlError -from curl_cffi.const import CurlECode, CurlWsFlag -from curl_cffi.curl import CurlError +if TYPE_CHECKING: + from typing_extensions import Self + from .cookies import CookieTypes + from .headers import HeaderTypes + from .models import BrowserType + from .session import ProxySpec + from ..const import CurlHttpVersion + from ..curl import CurlWsFrame -ON_MESSAGE_T = Callable[["WebSocket", bytes], None] -ON_ERROR_T = Callable[["WebSocket", CurlError], None] -ON_OPEN_T = Callable[["WebSocket"], None] -ON_CLOSE_T = Callable[["WebSocket", int, str], None] + T = TypeVar("T") + + ON_DATA_T = Callable[["WebSocket", bytes, CurlWsFrame], None] + ON_MESSAGE_T = Callable[["WebSocket", Union[bytes, str]], None] + ON_ERROR_T = Callable[["WebSocket", CurlError], None] + ON_OPEN_T = Callable[["WebSocket"], None] + ON_CLOSE_T = Callable[["WebSocket", int, str], None] class WsCloseCode(IntEnum): @@ -34,27 +51,230 @@ class WebSocketError(CurlError): pass -class WebSocket: +class BaseWebSocket: + def __init__(self, curl: Curl, *, autoclose: bool = True): + self.curl: Curl = curl + self.autoclose: bool = autoclose + self._close_code: Optional[int] = None + self._close_reason: Optional[str] = None + + @property + def closed(self) -> bool: + """Whether the WebSocket is closed.""" + return self.curl is not_set + + @property + def close_code(self) -> Optional[int]: + """The WebSocket close code, if the connection is closed.""" + return self._close_code + + @property + def close_reason(self) -> Optional[str]: + """The WebSocket close reason, if the connection is closed.""" + return self._close_reason + + @staticmethod + def _pack_close_frame(code: int, reason: bytes) -> bytes: + return struct.pack("!H", code) + reason + + @staticmethod + def _unpack_close_frame(frame: bytes) -> Tuple[int, str]: + if len(frame) < 2: + code = WsCloseCode.UNKNOWN + reason = "" + else: + try: + code = struct.unpack_from("!H", frame)[0] + reason = frame[2:].decode() + except UnicodeDecodeError: + raise WebSocketError("Invalid close message", WsCloseCode.INVALID_DATA) + except Exception: + raise WebSocketError("Invalid close frame", WsCloseCode.PROTOCOL_ERROR) + else: + if code < 3000 and (code not in WsCloseCode or code == 1005): + raise WebSocketError("Invalid close code", WsCloseCode.PROTOCOL_ERROR) + return code, reason + + def terminate(self): + """Terminate the underlying connection.""" + if self.curl is not_set: + return + self.curl.close() + self.curl = not_set + + +class WebSocket(BaseWebSocket): + """A WebSocket implementation using libcurl.""" + def __init__( self, - session, - curl, - on_message: Optional[ON_MESSAGE_T] = None, - on_error: Optional[ON_ERROR_T] = None, + *, + autoclose: bool = True, + skip_utf8_validation: bool = False, + debug: bool = False, on_open: Optional[ON_OPEN_T] = None, on_close: Optional[ON_CLOSE_T] = None, + on_data: Optional[ON_DATA_T] = None, + on_message: Optional[ON_MESSAGE_T] = None, + on_error: Optional[ON_ERROR_T] = None, + ): + """ + Parameters: + autoclose: whether to close the WebSocket after receiving a close frame. + skip_utf8_validation: whether to skip UTF-8 validation for text frames in run_forever(). + debug: print extra curl debug info. + on_open: callback to receive open events. + The callback should accept one argument: WebSocket. + on_close: callback to receive close events. + The callback should accept three arguments: WebSocket, int, and str. + on_data: callback to receive raw data frames. + The callback should accept three arguments: WebSocket, bytes, and CurlWsFrame. + on_message: callback to receive text frames. + The callback should accept two arguments: WebSocket and Union[bytes, str]. + on_error: callback to receive errors. + The callback should accept two arguments: WebSocket and CurlError. + """ + super().__init__(not_set, autoclose=autoclose) + self.skip_utf8_validation = skip_utf8_validation + self.debug = debug + + self._emitters: dict[str, Callable] = {} + if on_open: + self._emitters["open"] = on_open + if on_close: + self._emitters["close"] = on_close + if on_data: + self._emitters["data"] = on_data + if on_message: + self._emitters["message"] = on_message + if on_error: + self._emitters["error"] = on_error + + def __iter__(self) -> WebSocket: + if self.closed: + raise TypeError("WebSocket is closed") + return self + + def __next__(self) -> bytes: + msg, flags = self.recv() + if flags & CurlWsFlag.CLOSE: + raise StopIteration + return msg + + def _emit(self, event_type: str, *args) -> None: + callback = self._emitters.get(event_type) + if callback: + try: + callback(self, *args) + except Exception as e: + error_callback = self._emitters.get("error") + if error_callback: + error_callback(self, e) + + def connect( + self, + url: str, + *, + headers: Optional[HeaderTypes] = None, + cookies: Optional[CookieTypes] = None, + auth: Optional[Tuple[str, str]] = None, + timeout: Optional[Union[float, Tuple[float, float]]] = None, + allow_redirects: bool = True, + max_redirects: int = -1, + proxies: Optional[ProxySpec] = None, + proxy: Optional[str] = None, + proxy_auth: Optional[Tuple[str, str]] = None, + verify: Optional[Union[bool, str]] = True, + referer: Optional[str] = None, + accept_encoding: Optional[str] = "gzip, deflate, br", + impersonate: Optional[Union[str, BrowserType]] = None, + default_headers: bool = True, + curl_options: Optional[dict] = None, + http_version: Optional[CurlHttpVersion] = None, + interface: Optional[str] = None, + max_recv_speed: int = 0, ): - self.session = session - self.curl = curl - self.on_message = on_message - self.on_error = on_error - self.on_open = on_open - self.on_close = on_close - self.keep_running = True - self._loop = None - - def recv_fragment(self): - return self.curl.ws_recv() + """Connect to the WebSocket. + + libcurl automatically handles pings and pongs. + ref: https://curl.se/libcurl/c/libcurl-ws.html + + Parameters: + url: url for the requests. + headers: headers to send. + cookies: cookies to use. + auth: HTTP basic auth, a tuple of (username, password), only basic auth is supported. + timeout: how many seconds to wait before giving up. + allow_redirects: whether to allow redirection. + max_redirects: max redirect counts, default unlimited(-1). + proxies: dict of proxies to use, format: {"http": proxy_url, "https": proxy_url}. + proxy: proxy to use, format: "http://proxy_url". Cannot be used with the above parameter. + proxy_auth: HTTP basic auth for proxy, a tuple of (username, password). + verify: whether to verify https certs. + referer: shortcut for setting referer header. + accept_encoding: shortcut for setting accept-encoding header. + impersonate: which browser version to impersonate. + default_headers: whether to set default browser headers. + curl_options: extra curl options to use. + http_version: limiting http version, http2 will be tries by default. + interface: which interface use in request to server. + max_recv_speed: max receiving speed in bytes per second. + """ + if not self.closed: + raise TypeError("WebSocket is already connected") + + if proxy and proxies: + raise TypeError("Cannot specify both 'proxy' and 'proxies'") + if proxy: + proxies = {"all": proxy} + + self.curl = curl = Curl(debug=self.debug) + _set_curl_options( + curl, + "GET", + url, + headers=headers, + cookies=cookies, + auth=auth, + timeout=timeout, + allow_redirects=allow_redirects, + max_redirects=max_redirects, + proxies=proxies, + proxy_auth=proxy_auth, + verify=verify, + referer=referer, + accept_encoding=accept_encoding, + impersonate=impersonate, + default_headers=default_headers, + curl_options=curl_options, + http_version=http_version, + interface=interface, + max_recv_speed=max_recv_speed, + ) + + # https://curl.se/docs/websocket.html + curl.setopt(CurlOpt.CONNECT_ONLY, 2) + curl.perform() + + def recv_fragment(self) -> Tuple[bytes, CurlWsFrame]: + """Receive a single frame as bytes.""" + if self.closed: + raise TypeError("WebSocket is closed") + + chunk, frame = self.curl.ws_recv() + if frame.flags & CurlWsFlag.CLOSE: + try: + self._close_code, self._close_reason = self._unpack_close_frame(chunk) + except WebSocketError as e: + # Follow the spec to close the connection + # Errors do not respect autoclose + self._close_code = e.code + self.close(e.code) + raise + if self.autoclose: + self.close() + + return chunk, frame def recv(self) -> Tuple[bytes, int]: """ @@ -68,7 +288,7 @@ def recv(self) -> Tuple[bytes, int]: # TODO use select here while True: try: - chunk, frame = self.curl.ws_recv() + chunk, frame = self.recv_fragment() flags = frame.flags chunks.append(chunk) if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0: @@ -81,60 +301,121 @@ def recv(self) -> Tuple[bytes, int]: return b"".join(chunks), flags - def send(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY): - """Send a data frame""" + def recv_str(self) -> str: + """Receive a text frame.""" + data, flags = self.recv() + if not flags & CurlWsFlag.TEXT: + raise TypeError("Received non-text frame") + return data.decode() + + def recv_json(self, *, loads: Callable[[str], T] = loads) -> T: + """Receive a JSON frame.""" + data = self.recv_str() + return loads(data) + + def send(self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.BINARY): + """Send a data frame.""" + if self.closed: + raise TypeError("WebSocket is closed") + + # curl expects bytes + if isinstance(payload, str): + payload = payload.encode() return self.curl.ws_send(payload, flags) - def run_forever(self): + def send_binary(self, payload: bytes): + """Send a binary frame.""" + return self.send(payload, CurlWsFlag.BINARY) + + def send_bytes(self, payload: bytes): + """Send a binary frame.""" + return self.send(payload, CurlWsFlag.BINARY) + + def send_str(self, payload: str): + """Send a text frame.""" + return self.send(payload, CurlWsFlag.TEXT) + + def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps): + """Send a JSON frame.""" + return self.send_str(dumps(payload)) + + def ping(self, payload: Union[str, bytes]): + """Send a ping frame.""" + return self.send(payload, CurlWsFlag.PING) + + def run_forever(self, url: str, **kwargs): """ libcurl automatically handles pings and pongs. - ref: https://curl.se/libcurl/c/libcurl-ws.html """ - if self.on_open: - self.on_open(self) + if not self.closed: + raise TypeError("WebSocket is already connected") + + self.connect(url, **kwargs) + self._emit("open") # Keep reading the messages and invoke callbacks - while self.keep_running: + # TODO: Reconnect logic + chunks = [] + keep_running = True + while keep_running: try: - msg, flags = self.recv() - if self.on_message: - self.on_message(self, msg) - if flags & CurlWsFlag.CLOSE: - self.keep_running = False - # Unpack close code and reason - if len(msg) < 2: - code = WsCloseCode.UNKNOWN - reason = "" - else: + msg, frame = self.recv_fragment() + flags = frame.flags + self._emit("data", msg, frame) + + if not (frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0): + chunks.append(msg) + continue + + # Avoid unnecessary computation + if "message" in self._emitters: + if (flags & CurlWsFlag.TEXT) and not self.skip_utf8_validation: try: - code = struct.unpack_from("!H", msg)[0] - reason = msg[2:].decode() + msg = msg.decode() except UnicodeDecodeError: - raise WebSocketError("Invalid close message", WsCloseCode.INVALID_DATA) - except Exception: - raise WebSocketError("Invalid close frame", WsCloseCode.PROTOCOL_ERROR) - else: - if code < 3000 and (code not in WsCloseCode or code == 1005): - raise WebSocketError("Invalid close code", WsCloseCode.PROTOCOL_ERROR) - if self.on_close: - self.on_close(self, code, reason) - except WebSocketError as e: - # Follow the spec to close the connection - # TODO: Consider adding setting to autoclose connection on error-free close - self.close(e.code) - if self.on_error: - self.on_error(self, e) + self._close_code = WsCloseCode.INVALID_DATA + self.close(WsCloseCode.INVALID_DATA) + raise WebSocketError("Invalid UTF-8", WsCloseCode.INVALID_DATA) + if (flags & CurlWsFlag.BINARY) or (flags & CurlWsFlag.TEXT): + self._emit("message", msg) + if flags & CurlWsFlag.CLOSE: + keep_running = False + self._emit("close", self._close_code or 0, self._close_reason or "") except CurlError as e: - if self.on_error: - self.on_error(self, e) + self._emit("error", e) + if e.code == CurlECode.AGAIN: + pass + else: + if not self.closed: + code = 1000 + if isinstance(e, WebSocketError): + code = e.code + self.close(code) + raise def close(self, code: int = WsCloseCode.OK, message: bytes = b""): - msg = struct.pack("!H", code) + message - # FIXME how to reset, or can a curl handle connect to two websockets? + """Close the connection.""" + if self.curl is not_set: + return + + # TODO: As per spec, we should wait for the server to close the connection + # But this is not a requirement + msg = self._pack_close_frame(code, message) self.send(msg, CurlWsFlag.CLOSE) - self.keep_running = False - self.curl.close() + # The only way to close the connection appears to be curl_easy_cleanup + # But this renders the curl handle unusable, so we do not push it back to the pool + self.terminate() + + +class AsyncWebSocket(BaseWebSocket): + """A pseudo-async WebSocket implementation using libcurl.""" + + def __init__(self, curl: Curl, *, autoclose: bool = True): + super().__init__(curl, autoclose=autoclose) + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._recv_lock = asyncio.Lock() + self._send_lock = asyncio.Lock() @property def loop(self): @@ -142,13 +423,117 @@ def loop(self): self._loop = asyncio.get_running_loop() return self._loop - async def arecv(self) -> Tuple[bytes, int]: - return await self.loop.run_in_executor(None, self.recv) + async def __aiter__(self) -> Self: + if self.closed: + raise TypeError("WebSocket is closed") + return self + + async def __anext__(self) -> bytes: + msg, flags = await self.recv() + if flags & CurlWsFlag.CLOSE: + raise StopAsyncIteration + return msg + + async def recv_fragment(self, *, timeout: Optional[float] = None) -> Tuple[bytes, CurlWsFrame]: + """Receive a single frame as bytes.""" + if self.closed: + raise TypeError("WebSocket is closed") + if self._recv_lock.locked(): + raise TypeError("Concurrent call to recv_fragment() is not allowed") + + async with self._recv_lock: + chunk, frame = await asyncio.wait_for(self.loop.run_in_executor(None, self.curl.ws_recv), timeout) + if frame.flags & CurlWsFlag.CLOSE: + try: + self._close_code, self._close_reason = self._unpack_close_frame(chunk) + except WebSocketError as e: + # Follow the spec to close the connection + # Errors do not respect autoclose + self._close_code = e.code + await self.close(e.code) + raise + if self.autoclose: + await self.close() + + return chunk, frame + + async def recv(self, *, timeout: Optional[float] = None) -> Tuple[bytes, int]: + """ + Receive a frame as bytes. + + libcurl split frames into fragments, so we have to collect all the chunks for + a frame. + """ + chunks = [] + flags = 0 + # TODO use select here + while True: + try: + chunk, frame = await self.recv_fragment(timeout=timeout) + flags = frame.flags + chunks.append(chunk) + if frame.bytesleft == 0 and flags & CurlWsFlag.CONT == 0: + break + except CurlError as e: + if e.code == CurlECode.AGAIN: + pass + else: + raise + + return b"".join(chunks), flags + + async def recv_str(self, *, timeout: Optional[float] = None) -> str: + """Receive a text frame.""" + data, flags = await self.recv(timeout=timeout) + if not flags & CurlWsFlag.TEXT: + raise TypeError("Received non-text frame") + return data.decode() + + async def recv_json(self, *, loads: Callable[[str], T] = loads, timeout: Optional[float] = None) -> T: + """Receive a JSON frame.""" + data = await self.recv_str(timeout=timeout) + return loads(data) + + async def send(self, payload: Union[str, bytes], flags: CurlWsFlag = CurlWsFlag.BINARY): + """Send a data frame.""" + if self.closed: + raise TypeError("WebSocket is closed") + + # curl expects bytes + if isinstance(payload, str): + payload = payload.encode() + async with self._send_lock: # TODO: Why does concurrently sending fail + return await self.loop.run_in_executor(None, self.curl.ws_send, payload, flags) + + async def send_binary(self, payload: bytes): + """Send a binary frame.""" + return await self.send(payload, CurlWsFlag.BINARY) + + async def send_bytes(self, payload: bytes): + """Send a binary frame.""" + return await self.send(payload, CurlWsFlag.BINARY) + + async def send_str(self, payload: str): + """Send a text frame.""" + return await self.send(payload, CurlWsFlag.TEXT) + + async def send_json(self, payload: Any, *, dumps: Callable[[Any], str] = dumps): + """Send a JSON frame.""" + return await self.send_str(dumps(payload)) + + async def ping(self, payload: Union[str, bytes]): + """Send a ping frame.""" + return await self.send(payload, CurlWsFlag.PING) - async def asend(self, payload: bytes, flags: CurlWsFlag = CurlWsFlag.BINARY): - return await self.loop.run_in_executor(None, self.send, payload, flags) + async def close(self, code: int = WsCloseCode.OK, message: bytes = b""): + """Close the connection.""" + if self.curl is not_set: + return - async def aclose(self, code: int = WsCloseCode.OK, message: bytes = b""): - await self.loop.run_in_executor(None, self.close, code, message) - self.curl.reset() - self.session.push_curl(self.curl) + # TODO: As per spec, we should wait for the server to close the connection + # But this is not a requirement + msg = self._pack_close_frame(code, message) + await self.send(msg, CurlWsFlag.CLOSE) + # The only way to close the connection appears to be curl_easy_cleanup + # But this renders the curl handle unusable, so we do not push it back to the pool + self.terminate() diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index c69a7da0..c6e98641 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -562,7 +562,9 @@ async def watch_restarts(self): # pragma: nocover async def echo(websocket): while True: - name = (await websocket.recv()).decode() + name = await websocket.recv() + if isinstance(name, bytes): + name = name.decode() # print(f"<<< {name}") await websocket.send(name) diff --git a/tests/unittest/test_websockets.py b/tests/unittest/test_websockets.py index 81342c13..469c7919 100644 --- a/tests/unittest/test_websockets.py +++ b/tests/unittest/test_websockets.py @@ -1,27 +1,40 @@ -from curl_cffi.requests import Session +from curl_cffi.requests import AsyncSession, WebSocket def test_websocket(ws_server): - with Session() as s: - s.ws_connect(ws_server.url) + ws = WebSocket() + ws.connect(ws_server.url) def test_hello(ws_server): - with Session() as s: - ws = s.ws_connect(ws_server.url) - ws.send(b"Foo me once") - content, _ = ws.recv() - assert content == b"Foo me once" + ws = WebSocket() + ws.connect(ws_server.url) + ws.send(b"Foo me once") + content, _ = ws.recv() + assert content == b"Foo me once" def test_hello_twice(ws_server): - with Session() as s: - w = s.ws_connect(ws_server.url) + ws = WebSocket() + ws.connect(ws_server.url) - w.send(b"Bar") - reply, _ = w.recv() + ws.send(b"Bar") + reply, _ = ws.recv() + + for _ in range(10): + ws.send_str("Bar") + reply = ws.recv_str() + assert reply == "Bar" + + +async def test_hello_twice_async(ws_server): + async with AsyncSession() as s: + ws = await s.ws_connect(ws_server.url) + + await ws.send(b"Bar") + reply, _ = await ws.recv() for _ in range(10): - w.send(b"Bar") - reply, _ = w.recv() - assert reply == b"Bar" + await ws.send_str("Bar") + reply = await ws.recv_str() + assert reply == "Bar"