Skip to content

Commit

Permalink
Introduces config generation script to simplify server setup.
Browse files Browse the repository at this point in the history
Also, modifies all paths to point to package root.
  • Loading branch information
jsharf committed Jul 4, 2023
1 parent eff76c7 commit f6e303d
Show file tree
Hide file tree
Showing 13 changed files with 116 additions and 75 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "cb2game"
version = "0.9.1"
version = "0.9.2"
authors = [
{ name="Jacob Sharf", email="[email protected]" },
{ name="Mustafa Omer Gul", email="[email protected]" },
Expand Down Expand Up @@ -53,6 +53,7 @@ dependencies = [
"gitpython==3.1.31",
"aiohttp_session==2.12.0",
"aiohttp_session[secure]==2.12.0",
"cryptography==41.0.1",
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion src/cb2game/agents/local_agent_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from cb2game.server.config.config import Config, ReadConfigOrDie
from cb2game.server.lobbies.open_lobby import OpenLobby
from cb2game.server.lobby import LobbyInfo, LobbyType
from cb2game.server.util import PackageRoot

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,7 +95,7 @@ def PlayNGames(


def main(
config_filepath="server/config/local-covers-config.yaml",
config_filepath=(PackageRoot() / "server/config/local-covers-config.yaml"),
event_uuid="",
profile=False,
num_games=10,
Expand Down
6 changes: 5 additions & 1 deletion src/cb2game/envs/demo_self_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from cb2game.server.hex import HecsCoord
from cb2game.server.messages.map_update import MapUpdate
from cb2game.server.messages.prop import PropUpdate
from cb2game.server.util import PackageRoot

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -277,7 +278,10 @@ def PlayGame(coordinator, log_to_db: bool = True):
return turn_state["score"], duration


def main(config_filepath="server/config/local-covers-config.yaml", instruction_uuid=""):
def main(
config_filepath=(PackageRoot() / "server/config/local-covers-config.yaml"),
instruction_uuid="",
):
nest_asyncio.apply()
# Disabling most logs improves performance by about 50ms per game.
logging.basicConfig(level=logging.INFO)
Expand Down
9 changes: 4 additions & 5 deletions src/cb2game/server/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ def ValidateConfig(config):
def ReadConfigOrDie(config_path):
"""Reads a config file and returns a Config object."""
with open(config_path, "r") as cfg_file:
data = yaml.load(cfg_file, Loader=yaml.CLoader)
config = Config.from_dict(data)
config = yaml.load(cfg_file, Loader=yaml.CLoader)
# If the type of the resulting data isn't already a config, convert it.
if not isinstance(config, Config):
config = Config.from_dict(config)
valid_config, reason = ValidateConfig(config)
if not valid_config:
raise ValueError(f"Config file is invalid: {reason}")
Expand Down Expand Up @@ -125,9 +127,6 @@ class Config(DataClassJSONMixin):

http_port: int = 8080

# Optional feature configurations.
gui: bool = False

map_cache_size: int = 500

comment: str = ""
Expand Down
5 changes: 4 additions & 1 deletion src/cb2game/server/config/map_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Tuple
from typing import List, Optional, Tuple

from mashumaro.mixins.json import DataClassJSONMixin

Expand All @@ -19,6 +19,9 @@ class MapConfig(DataClassJSONMixin):
Assets are categorized by equivalence class. See the asset class enum in assets.py.
"""

# RNG seed. If unspecified, a random seed is chosen.
rng_seed: Optional[int] = None

map_width: int = 25
map_height: int = 25
# These are specified as tuples (min, max) (integers only)
Expand Down
3 changes: 2 additions & 1 deletion src/cb2game/server/db_tools/gen_usernames.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cb2game.server.schemas.defaults as defaults_db
from cb2game.server.schemas import base
from cb2game.server.schemas.mturk import Worker
from cb2game.server.util import PackageRoot


def InitWorkerDefaultUsernameIfNotExists(worker):
Expand All @@ -19,7 +20,7 @@ def InitWorkerDefaultUsernameIfNotExists(worker):
print(f"Username exists for {worker.hashed_id} ({username})")


def main(config_filepath="server/config/server-config.yaml"):
def main(config_filepath=(PackageRoot() / "server/config/server-config.yaml")):
cfg = config.ReadConfigOrDie(config_filepath)

print(f"Reading database from {cfg.database_path()}")
Expand Down
3 changes: 2 additions & 1 deletion src/cb2game/server/db_tools/ldrbrd_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cb2game.server.schemas.google_user import GoogleUser
from cb2game.server.schemas.leaderboard import Leaderboard, Username
from cb2game.server.schemas.mturk import Worker, WorkerExperience, WorkerQualLevel
from cb2game.server.util import PackageRoot

COMMANDS = [
"list",
Expand Down Expand Up @@ -460,7 +461,7 @@ def main(
role="noop",
nosparklines=False,
threshold=3,
config_filepath="server/config/server-config.yaml",
config_filepath=(PackageRoot() / "server/config/server-config.yaml"),
):
if command == "help":
PrintUsage()
Expand Down
72 changes: 43 additions & 29 deletions src/cb2game/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import zipfile
from datetime import datetime, timedelta, timezone

from cb2game.server.util import SafePasswordCompare
from cb2game.server.util import PackageRoot, SafePasswordCompare

os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "" # Hide pygame welcome message

Expand Down Expand Up @@ -105,7 +105,7 @@ async def transmit_bytes(ws, message):

@routes.get("/")
async def Index(request):
return web.FileResponse("server/www/index.html")
return web.FileResponse(PackageRoot() / "server/www/index.html")


# Login form for password-protected backend URLs.
Expand All @@ -122,7 +122,7 @@ async def Login(request):
if len(config.server_password_sha512) == 0:
next_url = request.query.get("next", "/")
return web.HTTPFound(next_url)
return web.FileResponse("server/www/login.html")
return web.FileResponse(PackageRoot() / "server/www/login.html")


# Authentication endpoint for password-protected backend URLs.
Expand Down Expand Up @@ -163,99 +163,105 @@ async def Auth(request):

@routes.get("/play")
async def GamePage(request):
return web.FileResponse("server/www/WebGL/index.html")
return web.FileResponse(PackageRoot() / "server/www/WebGL/index.html")


@routes.get("/consent-form")
async def ConsentForm(request):
return web.FileResponse("server/www/pdfs/consent-form.pdf")
return web.FileResponse(PackageRoot() / "server/www/pdfs/consent-form.pdf")


@routes.get("/rules")
async def Rules(request):
return web.FileResponse("server/www/rules.html")
return web.FileResponse(PackageRoot() / "server/www/rules.html")


@routes.get("/payout")
async def Rules(request):
return web.FileResponse("server/www/payout.html")
return web.FileResponse(PackageRoot() / "server/www/payout.html")


@routes.get("/example_sets")
async def Rules(request):
return web.FileResponse("server/www/example_sets.html")
return web.FileResponse(PackageRoot() / "server/www/example_sets.html")


@routes.get("/oneoff")
async def OneoffComp(request):
return web.FileResponse("server/www/oneoff.html")
return web.FileResponse(PackageRoot() / "server/www/oneoff.html")


@routes.get("/follower-model-study")
async def TaskPage(request):
return web.FileResponse("server/www/follower-model-study.html")
return web.FileResponse(PackageRoot() / "server/www/follower-model-study.html")


@routes.get("/leader-model-study")
async def TaskPage(request):
return web.FileResponse("server/www/leader-model-study.html")
return web.FileResponse(PackageRoot() / "server/www/leader-model-study.html")


@routes.get("/main-study")
async def TaskPage(request):
return web.FileResponse("server/www/main-study.html")
return web.FileResponse(PackageRoot() / "server/www/main-study.html")


@routes.get("/mturk-task")
async def TaskPage(request):
return web.FileResponse("server/www/mturk-task.html")
return web.FileResponse(PackageRoot() / "server/www/mturk-task.html")


@routes.get("/follower-qual")
async def TaskPage(request):
return web.FileResponse("server/www/follower-qual.html")
return web.FileResponse(PackageRoot() / "server/www/follower-qual.html")


@routes.get("/leader-qual")
async def TaskPage(request):
return web.FileResponse("server/www/leader-qual.html")
return web.FileResponse(PackageRoot() / "server/www/leader-qual.html")


@routes.get("/changelist")
async def Changelist(request):
return web.FileResponse("server/www/changelist.html")
return web.FileResponse(PackageRoot() / "server/www/changelist.html")


@routes.get("/privacy")
async def Privacy(request):
return web.FileResponse("server/www/privacy-policy.html")
return web.FileResponse(PackageRoot() / "server/www/privacy-policy.html")


@routes.get("/view/dashboard")
@password_protected
async def Dashboard(request):
return web.FileResponse("server/www/dashboard.html")
return web.FileResponse(PackageRoot() / "server/www/dashboard.html")


@routes.get("/images/{filename}")
async def Images(request):
if not request.match_info.get("filename"):
return web.HTTPNotFound()
return web.FileResponse(f"server/www/images/{request.match_info['filename']}")
return web.FileResponse(
PackageRoot() / f"server/www/images/{request.match_info['filename']}"
)


@routes.get("/css/{filename}")
async def css(request):
if not request.match_info.get("filename"):
return web.HTTPNotFound()
return web.FileResponse(f"server/www/css/{request.match_info['filename']}")
return web.FileResponse(
PackageRoot() / f"server/www/css/{request.match_info['filename']}"
)


@routes.get("/js/{filename}")
async def Js(request):
if not request.match_info.get("filename"):
return web.HTTPNotFound()
return web.FileResponse(f"server/www/js/{request.match_info['filename']}")
return web.FileResponse(
PackageRoot() / f"server/www/js/{request.match_info['filename']}"
)


def JsonFromEvent(event: event_db.Event):
Expand Down Expand Up @@ -632,7 +638,7 @@ async def RetrieveData(request):
async def DataDownloadStart(request):
global download_requested
download_requested = True
return web.FileResponse("server/www/download.html")
return web.FileResponse(PackageRoot() / "server/www/download.html")


@routes.get("/data/game-list")
Expand Down Expand Up @@ -748,20 +754,20 @@ async def ClientExceptionList(request):
@routes.get("/view/client-exceptions")
@password_protected
async def ClientExceptionViewer(request):
return web.FileResponse("server/www/exceptions_viewer.html")
return web.FileResponse(PackageRoot() / "server/www/exceptions_viewer.html")


@routes.get("/view/games")
@password_protected
async def GamesViewer(request):
return web.FileResponse("server/www/games_viewer.html")
return web.FileResponse(PackageRoot() / "server/www/games_viewer.html")


@routes.get("/view/game/{game_id}")
@password_protected
async def GameViewer(request):
# Extract the game_id from the request.
return web.FileResponse("server/www/game_viewer.html")
return web.FileResponse(PackageRoot() / "server/www/game_viewer.html")


@routes.get("/data/config")
Expand All @@ -775,7 +781,7 @@ async def GetConfig(request):
@routes.get("/view/stats")
@password_protected
async def Stats(request):
return web.FileResponse("server/www/stats.html")
return web.FileResponse(PackageRoot() / "server/www/stats.html")


@routes.get("/data/turns/{game_id}")
Expand Down Expand Up @@ -1352,14 +1358,14 @@ async def asset(request):

async def serve(config):
# Check if the server/www/WebGL directory exists.
if not os.path.isdir("server/www/WebGL"):
if not os.path.isdir(os.path.join(PackageRoot() / "server/www/WebGL")):
logger.warning(
"WebGL directory not found. This directory contains the compiled Unity front-end. You can download it by running `python3 -m cb2game.server.fetch_client` or manually here https://github.com/lil-lab/cb2/releases. You can also compile from source, but this requires installing Unity and getting a license. See game/ for client code and build_client.sh for instructions building the client from headless mode in Unity."
)
return

# Add a route for serving web frontend files on /.
routes.static("/", "server/www/WebGL")
routes.static("/", os.path.join(PackageRoot() / "server/www/WebGL"))

app = web.Application()
fernet_key = cryptography.fernet.Fernet.generate_key()
Expand Down Expand Up @@ -1443,14 +1449,22 @@ def CreateExceptionDirectory(config):
exception_dir.mkdir(parents=False, exist_ok=True)


def main(config_filepath="server/config/server-config.yaml"):
def main(config_filepath=""):
global assets_map
global lobby

# On exit, deletes temporary download files.
atexit.register(CleanupDownloadFiles)
atexit.register(SaveClientExceptionsToDB)

# If the config filepath doesn't exist, log an error and tell the user to
# try running `python3 -m cb2game.server.generate_config`.
if not os.path.isfile(config_filepath):
logger.error(
f"Config file not found at {config_filepath}. Try running `python3 -m cb2game.server.generate_config`."
)
return

InitPythonLogging()
InitGlobalConfig(config_filepath)

Expand Down
12 changes: 11 additions & 1 deletion src/cb2game/server/map_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import math
import random
import sys
from dataclasses import dataclass
from enum import Enum
from queue import Queue
Expand Down Expand Up @@ -580,6 +581,15 @@ def place_outpost(map, outpost, map_config: MapConfig):

def RandomMap(map_config: MapConfig):
"""Random map of Tile objects, each with HECS coordinates and locations."""
# First, set the RNG seed if specified.
if map_config.rng_seed is not None:
start_seed = map_config.rng_seed
else:
start_seed = random.randint(0, sys.maxsize)

# Set the RNG seed.
random.seed(start_seed)

map = []
for r in range(0, map_config.map_height):
row = []
Expand All @@ -588,7 +598,7 @@ def RandomMap(map_config: MapConfig):
row.append(tile)
map.append(row)

map_metadata = MapMetadata([], [], [], [], 0)
map_metadata = MapMetadata([], [], [], [], 0, start_seed)

# Generate candidates for feature centers.
rows = list(range(1, map_config.map_height - 2, 6))
Expand Down
Loading

0 comments on commit f6e303d

Please sign in to comment.