-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
625 additions
and
375 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
""" | ||
from alembic import op | ||
import sqlalchemy as sa | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""agents | ||
Revision ID: ecc49f9f55bc | ||
Revises: 516ddfaaa55f | ||
Create Date: 2023-05-02 16:36:47.627765 | ||
""" | ||
from alembic import op | ||
import sqlalchemy as sa | ||
|
||
|
||
# revision identifiers, used by Alembic. | ||
revision = "ecc49f9f55bc" | ||
down_revision = "516ddfaaa55f" | ||
branch_labels = None | ||
depends_on = None | ||
|
||
|
||
def upgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.create_table( | ||
"agent_deployment", | ||
sa.Column("agent_id", sa.BigInteger(), nullable=False), | ||
sa.Column("url", sa.String(), nullable=False), | ||
sa.Column("healthy", sa.Boolean(), nullable=False), | ||
sa.Column("active", sa.Boolean(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["agent_id"], | ||
["agents.id"], | ||
), | ||
sa.PrimaryKeyConstraint("agent_id"), | ||
) | ||
op.create_table( | ||
"agent_history", | ||
sa.Column("agent_id", sa.BigInteger(), nullable=False), | ||
sa.Column("wins", sa.Integer(), nullable=False), | ||
sa.Column("losses", sa.Integer(), nullable=False), | ||
sa.Column("draws", sa.Integer(), nullable=False), | ||
sa.Column("errors", sa.Integer(), nullable=False), | ||
sa.ForeignKeyConstraint( | ||
["agent_id"], | ||
["agents.id"], | ||
), | ||
sa.PrimaryKeyConstraint("agent_id"), | ||
) | ||
op.add_column( | ||
"agents", | ||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), | ||
) | ||
op.execute("UPDATE agents SET created_at = NOW()") | ||
op.alter_column("agents", "created_at", nullable=False) | ||
# ### end Alembic commands ### | ||
|
||
|
||
def downgrade() -> None: | ||
# ### commands auto generated by Alembic - please adjust! ### | ||
op.drop_column("agents", "created_at") | ||
op.drop_table("agent_history") | ||
op.drop_table("agent_deployment") | ||
# ### end Alembic commands ### |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,17 @@ | ||
from .schemas import Agent, AgentDeployment | ||
from .repo import ( | ||
from .schemas import AgentDeployment, AgentHistory | ||
from .service import ( | ||
create_agent, | ||
get_agent_by_id, | ||
get_agent_by_username_and_agentname, | ||
get_agent_id_for_username_and_agentname, | ||
get_agent_deployment_by_id | ||
get_agent_action, | ||
) | ||
|
||
__all__ = [ | ||
"Agent", | ||
"AgentDeployment", | ||
"AgentHistory", | ||
"get_agent_by_id", | ||
"get_agent_by_username_and_agentname", | ||
"get_agent_id_for_username_and_agentname", | ||
"get_agent_deployment_by_id", | ||
"get_agent_action", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
from typing import Literal | ||
from pydantic import BaseModel, HttpUrl | ||
|
||
from gameplay_computer.common import BasePlayer, Game | ||
|
||
|
||
class Agent(BasePlayer): | ||
kind: Literal["agent"] = "agent" | ||
game: Game | ||
username: str | ||
agentname: str | ||
|
||
class AgentDeployment(BaseModel): | ||
url: HttpUrl | ||
active: bool | ||
healthy: bool | ||
|
||
|
||
class AgentHistory(BaseModel): | ||
wins: int | ||
losses: int | ||
draws: int | ||
errors: int |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from databases import Database | ||
from fastapi import HTTPException, status | ||
import httpx | ||
import asyncio | ||
|
||
from gameplay_computer.gameplay import Match, Connect4Action, Action, Game, Agent | ||
|
||
from gameplay_computer import users | ||
from . import repo | ||
|
||
|
||
async def create_agent( | ||
database: Database, | ||
created_by_user_id: str, | ||
game: Game, | ||
agentname: str, | ||
url: str, | ||
) -> int: | ||
created_by_user = await users.get_user_by_id(created_by_user_id) | ||
if created_by_user is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_403_FORBIDDEN, | ||
detail="Unknown user.", | ||
) | ||
agent_id = await repo.create_agent( | ||
database, created_by_user_id, game, agentname, url | ||
) | ||
return agent_id | ||
|
||
|
||
async def get_agent_by_id(database: Database, agent_id: int) -> Agent: | ||
agent = await repo.get_agent_by_id(database, agent_id) | ||
if agent is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Unknown agent.", | ||
) | ||
return agent | ||
|
||
|
||
async def get_agent_by_username_and_agentname( | ||
database: Database, username: str, agentname: str | ||
) -> Agent: | ||
agent = await repo.get_agent_by_username_and_agentname( | ||
database, username, agentname | ||
) | ||
if agent is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Unknown agent.", | ||
) | ||
return agent | ||
|
||
|
||
async def get_agent_id_for_username_and_agentname( | ||
database: Database, username: str, agentname: str | ||
) -> int: | ||
agent_id = await repo.get_agent_id_for_username_and_agentname( | ||
database, username, agentname | ||
) | ||
if agent_id is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Unknown agent.", | ||
) | ||
return agent_id | ||
|
||
|
||
async def get_agent_action( | ||
database: Database, | ||
client: httpx.AsyncClient, | ||
agent_id: int, | ||
match: Match, | ||
) -> Action: | ||
agent = await get_agent_by_id(database, agent_id) | ||
deployment = await repo.get_agent_deployment_by_id(database, agent_id) | ||
if deployment is None: | ||
raise HTTPException( | ||
status_code=status.HTTP_404_NOT_FOUND, | ||
detail="Unknown agent.", | ||
) | ||
if agent.game != match.game: | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail="Wrong game.", | ||
) | ||
# We are sort of assuming right now that we know the agent is the next player and that | ||
# somebody else checked that before calling this. tbd | ||
|
||
action: Connect4Action | None = None | ||
|
||
retries = 0 | ||
while retries < 3: | ||
try: | ||
response = await client.post(deployment.url, json=match.dict()) | ||
response.raise_for_status() | ||
action = Connect4Action(**response.json()) | ||
break | ||
except httpx.HTTPError as e: | ||
print(f"Error: {e}") | ||
retries += 1 | ||
await asyncio.sleep(retries) | ||
|
||
# todo: log errors | ||
if action is None: | ||
print("ERROR: too many agent errors") | ||
raise HTTPException( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="Too many agent errors.", | ||
) | ||
|
||
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,17 @@ | ||
from . import tables | ||
from .schemas import Game, BasePlayer, BaseAction, BaseState | ||
from .schemas import ALogic | ||
from .service import ( | ||
serialize_state, | ||
serialize_action, | ||
deserialize_state, | ||
deserialize_action, | ||
) | ||
|
||
__all__ = [ | ||
"tables", | ||
"Game", | ||
"BasePlayer", | ||
"BaseAction", | ||
"BaseState", | ||
"ALogic", | ||
"serialize_state", | ||
"serialize_action", | ||
"deserialize_state", | ||
"deserialize_action", | ||
] |
Oops, something went wrong.