Skip to content

Commit

Permalink
refactor: Changed asyncpg to psycopg3
Browse files Browse the repository at this point in the history
Signed-off-by: carlos.vdr <[email protected]>
  • Loading branch information
carlosvdr committed Nov 22, 2023
1 parent ae4b055 commit 203272a
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 323 deletions.
24 changes: 13 additions & 11 deletions autoagora/logs_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

import asyncpg
import graphql
import psycopg_pool
from psycopg import sql


class LogsDB:
Expand All @@ -16,7 +17,7 @@ class QueryStats:
avg_time: float
stddev_time: float

def __init__(self, pgpool: asyncpg.Pool) -> None:
def __init__(self, pgpool: psycopg_pool.AsyncConnectionPool) -> None:
self.pgpool = pgpool

def return_query_body(self, query):
Expand All @@ -30,9 +31,11 @@ def return_query_body(self, query):
async def get_most_frequent_queries(
self, subgraph_ipfs_hash: str, min_count: int = 100
):
async with self.pgpool.acquire() as connection:
rows = await connection.fetch(
"""

async with self.pgpool.connection() as connection:
rows = await connection.execute(
sql.SQL(
"""
SELECT
query,
count_id,
Expand All @@ -54,22 +57,21 @@ async def get_most_frequent_queries(
FROM
query_logs
WHERE
subgraph = $1
subgraph = {hash}
AND query_time_ms IS NOT NULL
GROUP BY
qhash
HAVING
Count(id) >= $2
Count(id) >= {min_count}
) as query_logs
ON
qhash = hash
ORDER BY
count_id DESC
""",
subgraph_ipfs_hash,
min_count,
"""
).format(hash=subgraph_ipfs_hash, min_count=str(min_count)),
)

rows = await rows.fetchall()
return [
LogsDB.QueryStats(
query=self.return_query_body(row[0])
Expand Down
22 changes: 13 additions & 9 deletions autoagora/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass
from typing import Dict, Optional

import asyncpg
import psycopg_pool
from prometheus_async.aio.web import start_http_server

from autoagora.config import args, init_config
Expand Down Expand Up @@ -39,15 +39,19 @@ async def allocated_subgraph_watcher():

# Initialize connection pool to PG database
try:
pgpool = await asyncpg.create_pool(
host=args.postgres_host,
database=args.postgres_database,
user=args.postgres_username,
password=args.postgres_password,
port=args.postgres_port,
min_size=1,
max_size=args.postgres_max_connections,
conn_string = (
f"host={args.postgres_host} "
f"dbname={args.postgres_database} "
f"user={args.postgres_username} "
f'password="{args.postgres_password}" '
f"port={args.postgres_port}"
)

pgpool = psycopg_pool.AsyncConnectionPool(
conn_string, min_size=1, max_size=args.postgres_max_connections, open=False
)
await pgpool.open()
await pgpool.wait()
assert pgpool
except:
logging.exception(
Expand Down
4 changes: 2 additions & 2 deletions autoagora/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
from importlib.metadata import version

import asyncpg
import psycopg_pool
from jinja2 import Template

from autoagora.config import args
Expand All @@ -15,7 +15,7 @@
from autoagora.utils.constants import AGORA_ENTRY_TEMPLATE


async def model_builder(subgraph: str, pgpool: asyncpg.Pool) -> str:
async def model_builder(subgraph: str, pgpool: psycopg_pool.AsyncConnectionPool) -> str:
logs_db = LogsDB(pgpool)
most_frequent_queries = await logs_db.get_most_frequent_queries(subgraph)
model = build_template(subgraph, most_frequent_queries)
Expand Down
6 changes: 4 additions & 2 deletions autoagora/price_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime, timedelta, timezone
from typing import Tuple

import asyncpg
import psycopg_pool
from autoagora_agents.agent_factory import AgentFactory
from prometheus_client import Gauge

Expand Down Expand Up @@ -36,7 +36,9 @@


async def price_bandit_loop(
subgraph: str, pgpool: asyncpg.Pool, metrics_endpoints: MetricsEndpoints
subgraph: str,
pgpool: psycopg_pool.AsyncConnectionPool,
metrics_endpoints: MetricsEndpoints,
):
try:
# Instantiate environment.
Expand Down
53 changes: 29 additions & 24 deletions autoagora/price_save_state_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import datetime, timezone
from typing import Optional

import asyncpg
import psycopg_pool
from psycopg import sql


@dataclass
Expand All @@ -16,13 +17,13 @@ class SaveState:


class PriceSaveStateDB:
def __init__(self, pgpool: asyncpg.Pool) -> None:
def __init__(self, pgpool: psycopg_pool.AsyncConnectionPool) -> None:
self.pgpool = pgpool
self._table_created = False

async def _create_table_if_not_exists(self) -> None:
if not self._table_created:
async with self.pgpool.acquire() as connection:
async with self.pgpool.connection() as connection:
await connection.execute( # type: ignore
"""
CREATE TABLE IF NOT EXISTS price_save_state (
Expand All @@ -38,45 +39,49 @@ async def _create_table_if_not_exists(self) -> None:
async def save_state(self, subgraph: str, mean: float, stddev: float):
await self._create_table_if_not_exists()

async with self.pgpool.acquire() as connection:
async with self.pgpool.connection() as connection:
await connection.execute(
"""
sql.SQL(
"""
INSERT INTO price_save_state (subgraph, last_update, mean, stddev)
VALUES($1, $2, $3, $4)
VALUES({subgraph_hash}, {datetime}, {mean}, {stddev})
ON CONFLICT (subgraph)
DO
UPDATE SET
last_update = $2,
mean = $3,
stddev = $4
""",
subgraph,
datetime.now(timezone.utc),
mean,
stddev,
last_update = {datetime},
mean = {mean},
stddev = {stddev}
"""
).format(
subgraph_hash=subgraph,
datetime=str(datetime.now(timezone.utc)),
mean=mean,
stddev=stddev,
)
)

async def load_state(self, subgraph: str) -> Optional[SaveState]:
await self._create_table_if_not_exists()

async with self.pgpool.acquire() as connection:
row = await connection.fetchrow(
"""
async with self.pgpool.connection() as connection:
row = await connection.execute(
sql.SQL(
"""
SELECT
last_update,
mean,
stddev
FROM
price_save_state
WHERE
subgraph = $1
""",
subgraph,
subgraph = {subgraph_hash}
"""
).format(subgraph_hash=subgraph)
)

row = await row.fetchone()
if row:
return SaveState(
last_update=row["last_update"], # type: ignore
mean=row["mean"], # type: ignore
stddev=row["stddev"], # type: ignore
last_update=row[0], # type: ignore
mean=row[1], # type: ignore
stddev=row[2], # type: ignore
)
Loading

0 comments on commit 203272a

Please sign in to comment.