diff --git a/README.rst b/README.rst index 6772bcc..28295b7 100644 --- a/README.rst +++ b/README.rst @@ -1295,7 +1295,7 @@ pg8000.native.DatabaseError For errors that originate from the server. -pg8000.native.Connection(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None) +pg8000.native.Connection(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None, sock=None) ``````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` Creates a connection to a PostgreSQL database. @@ -1374,6 +1374,11 @@ replication character encoding is not ``ascii`` or ``utf8``, then you need to provide values as bytes, eg. ``'database'.encode('EUC-JP')``. +sock + A socket-like object to use for the connection. For example, ``sock`` could be a plain + ``socket.socket``, or it could represent an SSH tunnel or perhaps an + ``ssl.SSLSocket`` to an SSL proxy. If an |ssl.SSLContext| is provided, then it will be + used to attempt to create an SSL socket from the provided socket. pg8000.native.Connection.notifications `````````````````````````````````````` @@ -1619,7 +1624,7 @@ ROWID type oid Functions ````````` -pg8000.dbapi.connect(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None) +pg8000.dbapi.connect(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None, sock=None) ::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::: Creates a connection to a PostgreSQL database. @@ -1696,6 +1701,12 @@ replication character encoding is not ``ascii`` or ``utf8``, then you need to provide values as bytes, eg. ``'database'.encode('EUC-JP')``. +sock + A socket-like object to use for the connection. For example, ``sock`` could be a plain + ``socket.socket``, or it could represent an SSH tunnel or perhaps an + ``ssl.SSLSocket`` to an SSL proxy. If an |ssl.SSLContext| is provided, then it will be + used to attempt to create an SSL socket from the provided socket. + pg8000.dbapi.Date(year, month, day) diff --git a/pg8000/core.py b/pg8000/core.py index 1bfcbef..5dfcec6 100644 --- a/pg8000/core.py +++ b/pg8000/core.py @@ -172,6 +172,81 @@ def _write(sock, d): raise InterfaceError("network error") from e +def _make_socket( + unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context +): + if unix_sock is not None: + if sock is not None: + raise InterfaceError("If unix_sock is provided, sock must be None") + + try: + if not hasattr(socket, "AF_UNIX"): + raise InterfaceError( + "attempt to connect to unix socket on unsupported platform" + ) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect(unix_sock) + except socket.error as e: + if sock is not None: + sock.close() + raise InterfaceError("communication error") from e + + elif sock is not None: + pass + + elif host is not None: + if unix_sock is not None: + raise InterfaceError("If the host is provided, unix_sock must be None") + if sock is not None: + raise InterfaceError("If the host is provided, sock must be None") + try: + sock = socket.create_connection((host, port), timeout, source_address) + except socket.error as e: + raise InterfaceError( + f"Can't create a connection to host {host} and port {port} " + f"(timeout is {timeout} and source_address is {source_address})." + ) from e + + else: + raise InterfaceError("one of host, sock or unix_sock must be provided") + + if tcp_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + + channel_binding = None + if ssl_context is not None: + try: + import ssl + + if ssl_context is True: + ssl_context = ssl.create_default_context() + + request_ssl = getattr(ssl_context, "request_ssl", True) + + if request_ssl: + # Int32(8) - Message length, including self. + # Int32(80877103) - The SSL request code. + sock.sendall(ii_pack(8, 80877103)) + resp = sock.recv(1) + if resp != b"S": + raise InterfaceError("Server refuses SSL") + + sock = ssl_context.wrap_socket(sock, server_hostname=host) + + if request_ssl: + channel_binding = scramp.make_channel_binding( + "tls-server-end-point", sock + ) + + except ImportError: + raise InterfaceError( + "SSL required but ssl module not available in this python " + "installation." + ) + return channel_binding, sock + + class CoreConnection: def __enter__(self): return self @@ -193,6 +268,7 @@ def __init__( tcp_keepalive=True, application_name=None, replication=None, + sock=None, ): self._client_encoding = "utf8" self._commands_with_count = ( @@ -238,67 +314,16 @@ def __init__( self._caches = {} - if unix_sock is None and host is not None: - try: - self._usock = socket.create_connection( - (host, port), timeout, source_address - ) - except socket.error as e: - raise InterfaceError( - f"Can't create a connection to host {host} and port {port} " - f"(timeout is {timeout} and source_address is {source_address})." - ) from e - - elif unix_sock is not None: - try: - if not hasattr(socket, "AF_UNIX"): - raise InterfaceError( - "attempt to connect to unix socket on unsupported platform" - ) - self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self._usock.settimeout(timeout) - self._usock.connect(unix_sock) - except socket.error as e: - if self._usock is not None: - self._usock.close() - raise InterfaceError("communication error") from e - - else: - raise InterfaceError("one of host or unix_sock must be provided") - - if tcp_keepalive: - self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - - self.channel_binding = None - if ssl_context is not None: - try: - import ssl - - if ssl_context is True: - ssl_context = ssl.create_default_context() - - request_ssl = getattr(ssl_context, "request_ssl", True) - - if request_ssl: - # Int32(8) - Message length, including self. - # Int32(80877103) - The SSL request code. - self._usock.sendall(ii_pack(8, 80877103)) - resp = self._usock.recv(1) - if resp != b"S": - raise InterfaceError("Server refuses SSL") - - self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host) - - if request_ssl: - self.channel_binding = scramp.make_channel_binding( - "tls-server-end-point", self._usock - ) - - except ImportError: - raise InterfaceError( - "SSL required but ssl module not available in this python " - "installation." - ) + self.channel_binding, self._usock = _make_socket( + unix_sock, + sock, + host, + port, + timeout, + source_address, + tcp_keepalive, + ssl_context, + ) self._sock = self._usock.makefile(mode="rwb") diff --git a/pg8000/dbapi.py b/pg8000/dbapi.py index d639120..50b140c 100644 --- a/pg8000/dbapi.py +++ b/pg8000/dbapi.py @@ -206,6 +206,7 @@ def connect( tcp_keepalive=True, application_name=None, replication=None, + sock=None, ): return Connection( user, @@ -220,6 +221,7 @@ def connect( tcp_keepalive=tcp_keepalive, application_name=application_name, replication=replication, + sock=sock, ) diff --git a/test/dbapi/test_connection.py b/test/dbapi/test_connection.py index cc23753..a9e110d 100644 --- a/test/dbapi/test_connection.py +++ b/test/dbapi/test_connection.py @@ -1,4 +1,5 @@ import datetime +import socket import warnings import pytest @@ -25,6 +26,22 @@ def test_internet_socket_connection_refused(): connect(**conn_params) +def test_Connection_plain_socket(db_kwargs): + host = db_kwargs.get("host", "localhost") + port = db_kwargs.get("port", 5432) + with socket.create_connection((host, port)) as sock: + user = db_kwargs["user"] + password = db_kwargs["password"] + conn_params = {"sock": sock, "user": user, "password": password} + + con = connect(**conn_params) + cur = con.cursor() + + cur.execute("SELECT 1") + res = cur.fetchall() + assert res[0][0] == 1 + + def test_database_missing(db_kwargs): db_kwargs["database"] = "missing-db" with pytest.raises(DatabaseError): diff --git a/test/native/test_connection.py b/test/native/test_connection.py index 845adc1..9dfc2b7 100644 --- a/test/native/test_connection.py +++ b/test/native/test_connection.py @@ -1,3 +1,5 @@ +import socket + from datetime import time as Time import pytest @@ -23,6 +25,20 @@ def test_internet_socket_connection_refused(): Connection(**conn_params) +def test_Connection_plain_socket(db_kwargs): + host = db_kwargs.get("host", "localhost") + port = db_kwargs.get("port", 5432) + with socket.create_connection((host, port)) as sock: + user = db_kwargs["user"] + password = db_kwargs["password"] + conn_params = {"sock": sock, "user": user, "password": password} + + con = Connection(**conn_params) + + res = con.run("SELECT 1") + assert res[0][0] == 1 + + def test_database_missing(db_kwargs): db_kwargs["database"] = "missing-db" with pytest.raises(DatabaseError): diff --git a/test/native/test_core.py b/test/native/test_core.py index 0590756..7648484 100644 --- a/test/native/test_core.py +++ b/test/native/test_core.py @@ -8,11 +8,26 @@ NULL_BYTE, PASSWORD, _create_message, + _make_socket, _read, ) from pg8000.native import InterfaceError +def test_make_socket(mocker): + unix_sock = None + sock = mocker.Mock() + host = "localhost" + port = 5432 + timeout = None + source_address = None + tcp_keepalive = True + ssl_context = None + _make_socket( + unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context + ) + + def test_handle_AUTHENTICATION_3(mocker): """Shouldn't send a FLUSH message, as FLUSH only used in extended-query"""