Skip to content

Commit

Permalink
feat: allow connection with pre-configured SSLSocket
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon authored and tlocke committed Jul 26, 2023
1 parent 71b806d commit 16c481a
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 63 deletions.
15 changes: 13 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ pg8000.native.DatabaseError

For errors that originate from the server.

pg8000.native.Connection(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None)
pg8000.native.Connection(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None, sock=None)
```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````

Creates a connection to a PostgreSQL database.
Expand Down Expand Up @@ -1374,6 +1374,11 @@ replication
character encoding is not ``ascii`` or ``utf8``, then you need to provide values as
bytes, eg. ``'database'.encode('EUC-JP')``.

sock
A socket-like object to use for the connection. For example, ``sock`` could be a plain
``socket.socket``, or it could represent an SSH tunnel or perhaps an
``ssl.SSLSocket`` to an SSL proxy. If an |ssl.SSLContext| is provided, then it will be
used to attempt to create an SSL socket from the provided socket.

pg8000.native.Connection.notifications
``````````````````````````````````````
Expand Down Expand Up @@ -1619,7 +1624,7 @@ ROWID type oid
Functions
`````````

pg8000.dbapi.connect(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None)
pg8000.dbapi.connect(user, host='localhost', database=None, port=5432, password=None, source_address=None, unix_sock=None, ssl_context=None, timeout=None, tcp_keepalive=True, application_name=None, replication=None, sock=None)
:::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::::

Creates a connection to a PostgreSQL database.
Expand Down Expand Up @@ -1696,6 +1701,12 @@ replication
character encoding is not ``ascii`` or ``utf8``, then you need to provide values as
bytes, eg. ``'database'.encode('EUC-JP')``.

sock
A socket-like object to use for the connection. For example, ``sock`` could be a plain
``socket.socket``, or it could represent an SSH tunnel or perhaps an
``ssl.SSLSocket`` to an SSL proxy. If an |ssl.SSLContext| is provided, then it will be
used to attempt to create an SSL socket from the provided socket.


pg8000.dbapi.Date(year, month, day)

Expand Down
147 changes: 86 additions & 61 deletions pg8000/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,81 @@ def _write(sock, d):
raise InterfaceError("network error") from e


def _make_socket(
unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context
):
if unix_sock is not None:
if sock is not None:
raise InterfaceError("If unix_sock is provided, sock must be None")

try:
if not hasattr(socket, "AF_UNIX"):
raise InterfaceError(
"attempt to connect to unix socket on unsupported platform"
)
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(timeout)
sock.connect(unix_sock)
except socket.error as e:
if sock is not None:
sock.close()
raise InterfaceError("communication error") from e

elif sock is not None:
pass

elif host is not None:
if unix_sock is not None:
raise InterfaceError("If the host is provided, unix_sock must be None")
if sock is not None:
raise InterfaceError("If the host is provided, sock must be None")
try:
sock = socket.create_connection((host, port), timeout, source_address)
except socket.error as e:
raise InterfaceError(
f"Can't create a connection to host {host} and port {port} "
f"(timeout is {timeout} and source_address is {source_address})."
) from e

else:
raise InterfaceError("one of host, sock or unix_sock must be provided")

if tcp_keepalive:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

channel_binding = None
if ssl_context is not None:
try:
import ssl

if ssl_context is True:
ssl_context = ssl.create_default_context()

request_ssl = getattr(ssl_context, "request_ssl", True)

if request_ssl:
# Int32(8) - Message length, including self.
# Int32(80877103) - The SSL request code.
sock.sendall(ii_pack(8, 80877103))
resp = sock.recv(1)
if resp != b"S":
raise InterfaceError("Server refuses SSL")

sock = ssl_context.wrap_socket(sock, server_hostname=host)

if request_ssl:
channel_binding = scramp.make_channel_binding(
"tls-server-end-point", sock
)

except ImportError:
raise InterfaceError(
"SSL required but ssl module not available in this python "
"installation."
)
return channel_binding, sock


class CoreConnection:
def __enter__(self):
return self
Expand All @@ -193,6 +268,7 @@ def __init__(
tcp_keepalive=True,
application_name=None,
replication=None,
sock=None,
):
self._client_encoding = "utf8"
self._commands_with_count = (
Expand Down Expand Up @@ -238,67 +314,16 @@ def __init__(

self._caches = {}

if unix_sock is None and host is not None:
try:
self._usock = socket.create_connection(
(host, port), timeout, source_address
)
except socket.error as e:
raise InterfaceError(
f"Can't create a connection to host {host} and port {port} "
f"(timeout is {timeout} and source_address is {source_address})."
) from e

elif unix_sock is not None:
try:
if not hasattr(socket, "AF_UNIX"):
raise InterfaceError(
"attempt to connect to unix socket on unsupported platform"
)
self._usock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._usock.settimeout(timeout)
self._usock.connect(unix_sock)
except socket.error as e:
if self._usock is not None:
self._usock.close()
raise InterfaceError("communication error") from e

else:
raise InterfaceError("one of host or unix_sock must be provided")

if tcp_keepalive:
self._usock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)

self.channel_binding = None
if ssl_context is not None:
try:
import ssl

if ssl_context is True:
ssl_context = ssl.create_default_context()

request_ssl = getattr(ssl_context, "request_ssl", True)

if request_ssl:
# Int32(8) - Message length, including self.
# Int32(80877103) - The SSL request code.
self._usock.sendall(ii_pack(8, 80877103))
resp = self._usock.recv(1)
if resp != b"S":
raise InterfaceError("Server refuses SSL")

self._usock = ssl_context.wrap_socket(self._usock, server_hostname=host)

if request_ssl:
self.channel_binding = scramp.make_channel_binding(
"tls-server-end-point", self._usock
)

except ImportError:
raise InterfaceError(
"SSL required but ssl module not available in this python "
"installation."
)
self.channel_binding, self._usock = _make_socket(
unix_sock,
sock,
host,
port,
timeout,
source_address,
tcp_keepalive,
ssl_context,
)

self._sock = self._usock.makefile(mode="rwb")

Expand Down
2 changes: 2 additions & 0 deletions pg8000/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def connect(
tcp_keepalive=True,
application_name=None,
replication=None,
sock=None,
):
return Connection(
user,
Expand All @@ -220,6 +221,7 @@ def connect(
tcp_keepalive=tcp_keepalive,
application_name=application_name,
replication=replication,
sock=sock,
)


Expand Down
17 changes: 17 additions & 0 deletions test/dbapi/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import socket
import warnings

import pytest
Expand All @@ -25,6 +26,22 @@ def test_internet_socket_connection_refused():
connect(**conn_params)


def test_Connection_plain_socket(db_kwargs):
host = db_kwargs.get("host", "localhost")
port = db_kwargs.get("port", 5432)
with socket.create_connection((host, port)) as sock:
user = db_kwargs["user"]
password = db_kwargs["password"]
conn_params = {"sock": sock, "user": user, "password": password}

con = connect(**conn_params)
cur = con.cursor()

cur.execute("SELECT 1")
res = cur.fetchall()
assert res[0][0] == 1


def test_database_missing(db_kwargs):
db_kwargs["database"] = "missing-db"
with pytest.raises(DatabaseError):
Expand Down
16 changes: 16 additions & 0 deletions test/native/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import socket

from datetime import time as Time

import pytest
Expand All @@ -23,6 +25,20 @@ def test_internet_socket_connection_refused():
Connection(**conn_params)


def test_Connection_plain_socket(db_kwargs):
host = db_kwargs.get("host", "localhost")
port = db_kwargs.get("port", 5432)
with socket.create_connection((host, port)) as sock:
user = db_kwargs["user"]
password = db_kwargs["password"]
conn_params = {"sock": sock, "user": user, "password": password}

con = Connection(**conn_params)

res = con.run("SELECT 1")
assert res[0][0] == 1


def test_database_missing(db_kwargs):
db_kwargs["database"] = "missing-db"
with pytest.raises(DatabaseError):
Expand Down
15 changes: 15 additions & 0 deletions test/native/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,26 @@
NULL_BYTE,
PASSWORD,
_create_message,
_make_socket,
_read,
)
from pg8000.native import InterfaceError


def test_make_socket(mocker):
unix_sock = None
sock = mocker.Mock()
host = "localhost"
port = 5432
timeout = None
source_address = None
tcp_keepalive = True
ssl_context = None
_make_socket(
unix_sock, sock, host, port, timeout, source_address, tcp_keepalive, ssl_context
)


def test_handle_AUTHENTICATION_3(mocker):
"""Shouldn't send a FLUSH message, as FLUSH only used in extended-query"""

Expand Down

0 comments on commit 16c481a

Please sign in to comment.