Skip to content

Commit

Permalink
🚧 [alan-turing-institute#168]: Work in progress to allow multithreadi…
Browse files Browse the repository at this point in the history
…ng in postgresql environments
  • Loading branch information
Tdarnell committed Aug 14, 2023
1 parent de76b99 commit 881b810
Show file tree
Hide file tree
Showing 24 changed files with 184 additions and 120 deletions.
2 changes: 1 addition & 1 deletion airsenal/framework/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
list_teams,
)

DBSESSION = scoped_session(session)
DBSESSION = session()


def remove_db_session(dbsession=DBSESSION):
Expand Down
2 changes: 1 addition & 1 deletion airsenal/framework/aws_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_suggestions_string():
except Exception as e:
return f"Problem importing stuff {e}"
try:
return build_suggestion_string(session, TransferSuggestion, Player)
return build_suggestion_string(session(), TransferSuggestion, Player)

except Exception as e:
return f"Problem with the query {e}"
Expand Down
2 changes: 1 addition & 1 deletion airsenal/framework/bpl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_fitted_team_model(


def fixture_probabilities(
gameweek, season=CURRENT_SEASON, team_model=None, dbsession=session
gameweek, season=CURRENT_SEASON, team_model=None, dbsession=session()
):
"""
Returns probabilities for all fixtures in a given gameweek and season, as a data
Expand Down
20 changes: 12 additions & 8 deletions airsenal/framework/optimization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
DEFAULT_SUB_WEIGHTS = {"GK": 0.03, "Outfield": (0.65, 0.3, 0.1)}


def check_tag_valid(pred_tag, gameweek_range, season=CURRENT_SEASON, dbsession=session):
def check_tag_valid(
pred_tag, gameweek_range, season=CURRENT_SEASON, dbsession=session()
):
"""Check a prediction tag contains predictions for all the specified gameweeks."""
# get unique gameweek and season values associated with pred_tag
fixtures = (
Expand Down Expand Up @@ -141,7 +143,8 @@ def get_squad_from_transactions(gameweek, season=CURRENT_SEASON, fpl_team_id=Non
if not fpl_team_id:
# use the most recent transaction in the table
most_recent = (
session.query(Transaction)
session()
.query(Transaction)
.order_by(Transaction.id.desc())
.filter_by(free_hit=0)
.filter_by(season=season)
Expand All @@ -155,7 +158,8 @@ def get_squad_from_transactions(gameweek, season=CURRENT_SEASON, fpl_team_id=Non
# Don't include free hit transfers as they only apply for the week the
# chip is activated
transactions = (
session.query(Transaction)
session()
.query(Transaction)
.order_by(Transaction.gameweek, Transaction.id)
.filter_by(fpl_team_id=fpl_team_id)
.filter_by(free_hit=0)
Expand Down Expand Up @@ -264,12 +268,12 @@ def fill_suggestion_table(baseline_score, best_strat, season, fpl_team_id):
ts.season = season
ts.fpl_team_id = fpl_team_id
ts.chip_played = best_strat["chips_played"][gameweek]
session.add(ts)
session.commit()
session().add(ts)
session().commit()


def fill_transaction_table(
starting_squad, best_strat, season, fpl_team_id, tag=None, dbsession=session
starting_squad, best_strat, season, fpl_team_id, tag=None, dbsession=session()
):
"""Add transactions from an optimised strategy to the transactions table in the
database. Used for simulating seasons only, for playing the current FPL season
Expand Down Expand Up @@ -323,7 +327,7 @@ def fill_initial_suggestion_table(
tag,
season=CURRENT_SEASON,
gameweek=NEXT_GAMEWEEK,
dbsession=session,
dbsession=session(),
):
"""
Fill an initial squad into the table
Expand All @@ -350,7 +354,7 @@ def fill_initial_transaction_table(
tag,
season=CURRENT_SEASON,
gameweek=NEXT_GAMEWEEK,
dbsession=session,
dbsession=session(),
):
"""Add transactions from an initial squad optimisation to the transactions table
in the database. Used for simulating seasons only, for playing the current FPL
Expand Down
28 changes: 14 additions & 14 deletions airsenal/framework/prediction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
MAX_GOALS = 10


def check_absence(player, gameweek, season, dbsession=session):
def check_absence(player, gameweek, season, dbsession=session()):
"""
Query the Absence table for a given player and season to see if the
gameweek is within the period of absence. If so, return the details of absence.
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_player_history_df(
fill_blank=True,
season=CURRENT_SEASON,
gameweek=NEXT_GAMEWEEK,
dbsession=session,
dbsession=session(),
):
"""
Query the player_score table to get goals/assists/minutes, and then
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_player_history_df(
print("Unknown opponent!")
team_goals = -1
absence_reason, absence_detail = check_absence(
player, row.fixture.gameweek, row.fixture.season, session
player, row.fixture.gameweek, row.fixture.season, session()
)
player_data.append(
[
Expand Down Expand Up @@ -321,7 +321,7 @@ def calc_predicted_points_for_player(
fixtures_behind=None,
min_fixtures_behind=3,
tag="",
dbsession=session,
dbsession=session(),
):
"""
Use the team-level model to get the probs of scoring or conceding
Expand Down Expand Up @@ -374,7 +374,7 @@ def calc_predicted_points_for_player(
# this should now be dealt with in get_recent_minutes_for_player, so
# throw error if not.
# recent_minutes = estimate_minutes_from_prev_season(
# player, season=season, dbsession=session
# player, season=season, dbsession=session()
# )
raise ValueError("Recent minutes is empty.")

Expand Down Expand Up @@ -458,7 +458,7 @@ def calc_predicted_points_for_pos(
gw_range,
tag,
model=ConjugatePlayerModel(),
dbsession=session,
dbsession=session(),
):
"""
Calculate points predictions for all players in a given position and
Expand Down Expand Up @@ -501,7 +501,7 @@ def make_prediction(player, fixture, points, tag):
# session.add(pp)


def fill_ep(csv_filename, dbsession=session):
def fill_ep(csv_filename, dbsession=session()):
"""
fill the database with FPLs ep_next prediction, and also
write output to a csv.
Expand Down Expand Up @@ -529,7 +529,7 @@ def fill_ep(csv_filename, dbsession=session):


def process_player_data(
prefix, season=CURRENT_SEASON, gameweek=NEXT_GAMEWEEK, dbsession=session
prefix, season=CURRENT_SEASON, gameweek=NEXT_GAMEWEEK, dbsession=session()
):
"""
transform the player dataframe, basically giving a list (for each player)
Expand Down Expand Up @@ -576,7 +576,7 @@ def process_player_data(


def fit_player_data(
position, season, gameweek, model=ConjugatePlayerModel(), dbsession=session
position, season, gameweek, model=ConjugatePlayerModel(), dbsession=session()
):
"""
fit the data for a particular position (FWD, MID, DEF)
Expand All @@ -597,7 +597,7 @@ def fit_player_data(


def get_all_fitted_player_data(
season, gameweek, model=ConjugatePlayerModel(), dbsession=session
season, gameweek, model=ConjugatePlayerModel(), dbsession=session()
):
df_positions = {"GK": None}
for pos in ["DEF", "MID", "FWD"]:
Expand All @@ -606,7 +606,7 @@ def get_all_fitted_player_data(


def get_player_scores(
season, gameweek, min_minutes=0, max_minutes=90, dbsession=session
season, gameweek, min_minutes=0, max_minutes=90, dbsession=session()
):
"""Utility function to get player scores rows up to (or the same as) season and
gameweek as a dataframe"""
Expand Down Expand Up @@ -637,7 +637,7 @@ def mean_group_min_count(df, group_col, mean_col, min_count=10):


def fit_bonus_points(
gameweek=NEXT_GAMEWEEK, season=CURRENT_SEASON, min_matches=10, dbsession=session
gameweek=NEXT_GAMEWEEK, season=CURRENT_SEASON, min_matches=10, dbsession=session()
):
"""Calculate the average bonus points scored by each player for matches they play
between 60 and 90 minutes, and matches they play between 30 and 59 minutes.
Expand Down Expand Up @@ -672,7 +672,7 @@ def fit_save_points(
season=CURRENT_SEASON,
min_matches=10,
min_minutes=90,
dbsession=session,
dbsession=session(),
):
"""Calculate the average save points scored by each goalkeeper for matches they
played at least min_minutes in.
Expand Down Expand Up @@ -702,7 +702,7 @@ def fit_card_points(
season=CURRENT_SEASON,
min_matches=10,
min_minutes=1,
dbsession=session,
dbsession=session(),
):
"""Calculate the average points per match lost to yellow or red cards
for each player.
Expand Down
84 changes: 66 additions & 18 deletions airsenal/framework/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
Use SQLAlchemy to convert between DB tables and python objects.
"""
from contextlib import contextmanager

from sqlalchemy import Column, Float, ForeignKey, Integer, String, create_engine
from sqlalchemy.orm import declarative_base, relationship, sessionmaker
from threading import current_thread

from sqlalchemy import Column, Engine, Float, ForeignKey, Integer, String, create_engine
from sqlalchemy.orm import (
Session,
declarative_base,
relationship,
scoped_session,
sessionmaker,
)
from sqlalchemy.pool import QueuePool

from airsenal.framework.env import AIRSENAL_HOME, get_env

Expand All @@ -14,7 +22,9 @@

class Player(Base):
__tablename__ = "player"
player_id = Column(Integer, primary_key=True, nullable=False, autoincrement=True)
player_id = Column(
Integer, primary_key=True, nullable=False, autoincrement=True, index=True
)
fpl_api_id = Column(Integer, nullable=True)
name = Column(String(100), nullable=False)
attributes = relationship("PlayerAttributes", uselist=True, back_populates="player")
Expand Down Expand Up @@ -400,19 +410,19 @@ class SessionBudget(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
session_id = Column(String(100), nullable=False)
budget = Column(Integer, nullable=False)


class TransferPriceTracker(Base):
__tablename__ = "transfer_price_tracker"
id = Column(Integer, primary_key=True, autoincrement=True)
player_id = Column(Integer, ForeignKey("player.player_id"))
player_id = Column(Integer, ForeignKey("player.player_id"), index=True)
gameweek = Column(Integer, nullable=False)
season = Column(String(4), nullable=False)
timestamp = Column(String(100), nullable=False)
price = Column(Integer, nullable=False)
transfers_in = Column(Integer, nullable=False)
transfers_out = Column(Integer, nullable=False)

def __str__(self):
return (
f"{self.player_id} ({self.season} GW{self.gameweek}): "
Expand Down Expand Up @@ -446,34 +456,72 @@ def get_connection_string():
return f"sqlite:///{get_env('AIRSENAL_DB_FILE', default=AIRSENAL_HOME / 'data.db')}"


def get_session():
engine = create_engine(get_connection_string())
def get_session() -> scoped_session[Session]:
db_session = scoped_session(get_sessionmaker())
return db_session


def get_sessionmaker() -> sessionmaker:
if "postgresql" in get_connection_string():
engine: Engine = create_engine(
get_connection_string(), use_native_hstore=False, poolclass=QueuePool
)
else:
engine: Engine = create_engine(get_connection_string())
Base.metadata.create_all(engine)
# Bind the engine to the metadata of the Base class so that the
# declaratives can be accessed through a DBSession instance
Base.metadata.bind = engine

DBSession = sessionmaker(bind=engine, autoflush=False)
return DBSession()
return DBSession


# global database session used by default throughout the package
session = get_session()
_sessions = {}
_global_session: scoped_session[Session] = get_session()
_global_thread_id: int | None = current_thread().ident
_sessions[current_thread().ident] = _global_session


# session: scoped_session[Session] = scoped_session(get_sessionmaker())
def session(ident: int = None) -> scoped_session[Session]:
"""
Create a scoped session for the current thread.
"""
# if postgres is not in the connection string, we don't need to worry about threads
if "postgresql" in get_connection_string():
# get the thread id that called this function
thread_id: int | None = current_thread().ident if ident is None else ident
if thread_id != _global_thread_id:
print("WARNING: using a new session for thread", thread_id)
if thread_id is None:
raise RuntimeError("Could not get thread id")
# if we are in the main thread, just return the global session
if thread_id not in _sessions:
_sessions[thread_id] = scoped_session(get_sessionmaker())
return _sessions[thread_id]
return _global_session


@contextmanager
def session_scope():
"""Provide a transactional scope around a series of operations."""
session = get_session()
_session: scoped_session[Session] = session()
try:
yield session
session.commit()
yield _session
_session.commit()
except Exception:
session.rollback()
_session.rollback()
raise
finally:
session.close()
_session.remove()
if "postgresql" in get_connection_string():
# get the thread id that called this function
thread_id: int | None = current_thread().ident
if thread_id is None:
raise RuntimeError("Could not get thread id")
if thread_id in _sessions:
del _sessions[thread_id]


def clean_database():
Expand Down
2 changes: 1 addition & 1 deletion airsenal/framework/season.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_teams_for_season(season, dbsession):


# global variable for the module
CURRENT_TEAMS = get_teams_for_season(CURRENT_SEASON, session)
CURRENT_TEAMS = get_teams_for_season(CURRENT_SEASON, session())


def season_str_to_year(season: str) -> int:
Expand Down
Loading

0 comments on commit 881b810

Please sign in to comment.