Skip to content

Commit

Permalink
Merge pull request #55 from milselarch/feat/close_poll_context
Browse files Browse the repository at this point in the history
Feat/close poll context
  • Loading branch information
milselarch authored Jan 3, 2025
2 parents 4d2dd2e + a13ee50 commit 4172201
Show file tree
Hide file tree
Showing 12 changed files with 355 additions and 273 deletions.
196 changes: 5 additions & 191 deletions base_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import hmac
import json
import logging
Expand All @@ -10,26 +9,24 @@
import hashlib
import textwrap
import dataclasses

import database
import aioredlock

from enum import IntEnum
from typing_extensions import Any
from strenum import StrEnum
from requests import PreparedRequest

from helpers.constants import BLANK_ID
from helpers.rcv_tally import RCVTally
from helpers.redis_cache_manager import RedisCacheManager
from helpers.start_get_params import StartGetParams
from helpers import constants, strings
from helpers.strings import generate_poll_closed_message
from load_config import TELEGRAM_BOT_TOKEN
from telegram.ext import ApplicationBuilder
from py_rcv import VotesCounter as PyVotesCounter

from typing import List, Dict, Optional, Tuple
from result import Ok, Err, Result
from concurrent.futures import ThreadPoolExecutor
from helpers.message_buillder import MessageBuilder
from helpers.special_votes import SpecialVotes
from load_config import WEBHOOK_URL
Expand All @@ -42,9 +39,8 @@
InlineKeyboardButton, InlineKeyboardMarkup, User as TeleUser,
ReplyKeyboardMarkup, KeyboardButton, WebAppInfo, Bot as TelegramBot
)
from aioredlock import Aioredlock, LockError
from database.database import (
PollWinners, UserID, PollMetadata, ChatWhitelist
UserID, PollMetadata, ChatWhitelist
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -109,33 +105,14 @@ class PollMessage(object):
poll_info: PollInfo


class GetPollWinnerStatus(IntEnum):
CACHED = 0
NEWLY_COMPUTED = 1
COMPUTING = 2
FAILED = 3


class BaseAPI(object):
POLL_WINNER_KEY = "POLL_WINNER"
POLL_WINNER_LOCK_KEY = "POLL_WINNER_LOCK"
# CACHE_LOCK_NAME = "REDIS_CACHE_LOCK"
POLL_CACHE_EXPIRY = 60
DELETION_TOKEN_EXPIRY = 60 * 5
SHORT_HASH_LENGTH = 6

def __init__(self):
self.cache = RedisCacheManager()
self.rcv_tally = RCVTally()
database.initialize_db()
self.redis_lock_manager = self.create_redis_lock_manager()

@staticmethod
def create_redis_lock_manager(
connections: list[dict[str, str | int]] | None = None
):
if connections is not None:
return Aioredlock(connections)
else:
return Aioredlock()

@staticmethod
def __get_telegram_token():
Expand Down Expand Up @@ -179,16 +156,6 @@ def create_application_builder(cls):
builder.token(cls.__get_telegram_token())
return builder

@staticmethod
def _build_cache_key(header: str, key: str):
return f"{header}:{key}"

def _build_poll_winner_lock_cache_key(self, poll_id: int) -> str:
assert isinstance(poll_id, int)
return self._build_cache_key(
self.__class__.POLL_WINNER_LOCK_KEY, str(poll_id)
)

@staticmethod
def get_poll_closed(poll_id: int) -> Result[int, MessageBuilder]:
error_message = MessageBuilder()
Expand Down Expand Up @@ -216,159 +183,6 @@ def spawn_inline_keyboard_button(
))
)

@staticmethod
def fetch_poll(poll_id: int) -> Result[Polls, MessageBuilder]:
error_message = MessageBuilder()

try:
poll = Polls.select().where(Polls.id == poll_id).get()
except Polls.DoesNotExist:
error_message.add(f'poll {poll_id} does not exist')
return Err(error_message)

return Ok(poll)

@classmethod
def get_num_active_poll_voters(
cls, poll_id: int
) -> Result[int, MessageBuilder]:
result = cls.fetch_poll(poll_id)
if result.is_err():
return result

poll = result.unwrap()
return Ok(poll.num_active_voters)

@staticmethod
async def refresh_lock(lock: aioredlock.Lock, interval: float):
try:
while True:
print('WAIT')
await asyncio.sleep(interval)
await lock.extend()
except asyncio.CancelledError:
pass

async def get_poll_winner(
self, poll_id: int
) -> Tuple[Optional[int], GetPollWinnerStatus]:
"""
Returns poll winner for specified poll
Attempts to get poll winner from cache if it exists,
otherwise will run the ranked choice voting computation
and write to the redis cache before returning
# TODO: test that redis lock refresh works
:param poll_id:
:return:
poll winner, status of poll winner computation
"""
assert isinstance(poll_id, int)

cache_result = PollWinners.read_poll_winner_id(poll_id)
if cache_result.is_ok():
# print('CACHE_HIT', cache_result)
return cache_result.unwrap(), GetPollWinnerStatus.CACHED

redis_lock_key = self._build_poll_winner_lock_cache_key(poll_id)
# print('CACHE_KEY', redis_cache_key)
if await self.redis_lock_manager.is_locked(redis_lock_key):
# print('PRE_LOCKED')
return None, GetPollWinnerStatus.COMPUTING

try:
# prevents race conditions where multiple computations
# are run concurrently for the same poll
async with await self.redis_lock_manager.lock(
redis_lock_key, lock_timeout=self.POLL_CACHE_EXPIRY
) as lock:
# Start a task to refresh the lock periodically
refresh_task = asyncio.create_task(self.refresh_lock(
lock, self.POLL_CACHE_EXPIRY / 2
))

try:
cache_result = PollWinners.read_poll_winner_id(poll_id)
if cache_result.is_ok():
# print('INNER_CACHE_HIT', cache_result)
return cache_result.unwrap(), GetPollWinnerStatus.CACHED

# compute the winner in a separate thread to not block
# the async event loop
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as executor:
poll_winner_id = await loop.run_in_executor(
executor, self._determine_poll_winner, poll_id
)

# Store computed winner in the db
PollWinners.build_from_fields(
poll_id=poll_id, option_id=poll_winner_id
).get_or_create()
finally:
# Cancel the refresh task
refresh_task.cancel()
await refresh_task

except LockError:
# print('LOCK_ERROR')
return None, GetPollWinnerStatus.COMPUTING

# print('CACHE_MISS', poll_winner_id)
return poll_winner_id, GetPollWinnerStatus.NEWLY_COMPUTED

@classmethod
def _determine_poll_winner(cls, poll_id: int) -> Optional[int]:
"""
Runs the ranked choice voting algorithm to determine
the winner of the poll
:param poll_id:
:return:
ID of winning option, or None if there's no winner
"""
num_poll_voters_result = cls.get_num_active_poll_voters(poll_id)
if num_poll_voters_result.is_err():
return None

num_poll_voters: int = num_poll_voters_result.unwrap()
# get votes for the poll sorted by PollVoter and from
# the lowest ranking option (most favored)
# to the highest ranking option (least favored)
votes = VoteRankings.select().join(
PollVoters, on=(PollVoters.id == VoteRankings.poll_voter)
).where(
PollVoters.poll == poll_id
).order_by(
PollVoters.id, VoteRankings.ranking.asc() # TODO: test ordering
)

prev_voter_id, num_votes_cast = None, 0
votes_aggregator = PyVotesCounter()

for vote_ranking in votes:
option_row = vote_ranking.option
voter_id = vote_ranking.poll_voter.id

if prev_voter_id != voter_id:
votes_aggregator.flush_votes()
prev_voter_id = voter_id
num_votes_cast += 1

if option_row is None:
vote_value = vote_ranking.special_value
else:
vote_value = option_row.id

# print('VOTE_VAL', vote_value, int(vote_value))
votes_aggregator.insert_vote_ranking(voter_id, vote_value)

votes_aggregator.flush_votes()
voters_without_votes = num_poll_voters - num_votes_cast
assert voters_without_votes >= 0
votes_aggregator.insert_empty_votes(voters_without_votes)
winning_option_id = votes_aggregator.determine_winner()
return winning_option_id

@classmethod
def verify_voter(
cls, poll_id: int, user_id: UserID, username: Optional[str] = None,
Expand Down
78 changes: 25 additions & 53 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from datetime import datetime

from helpers import strings
from helpers.rcv_tally import RCVTally
from helpers.redis_cache_manager import GetPollWinnerStatus
from tele_helpers import ModifiedTeleUpdate
from helpers.special_votes import SpecialVotes
from bot_middleware import track_errors, admin_only
from database.database import UserID, CallbackContextState
from database.db_helpers import EmptyField, Empty
from handlers.chat_context_handlers import context_handlers
from handlers.chat_context_handlers import context_handlers, ClosePollContextHandler
from helpers import constants

from telegram import (
Expand All @@ -48,18 +50,14 @@
)
from helpers.chat_contexts import (
PollCreationChatContext, PollCreatorTemplate, POLL_MAX_OPTIONS,
VoteChatContext, PaySupportChatContext
VoteChatContext, PaySupportChatContext, ClosePollChatContext
)
from database import (
Users, Polls, PollVoters, UsernameWhitelist,
PollOptions, VoteRankings, db, ChatWhitelist, PollWinners,
MessageContextState, Payments
)
from base_api import (
BaseAPI, UserRegistrationStatus,
CallbackCommands, GetPollWinnerStatus
)

from base_api import BaseAPI, UserRegistrationStatus, CallbackCommands
from tele_helpers import TelegramHelpers

# https://stackoverflow.com/questions/15892946/
Expand Down Expand Up @@ -145,7 +143,7 @@ def start_bot(self):
Command.VOTE: self.vote_for_poll_handler,
Command.POLL_RESULTS: self.fetch_poll_results,
Command.HAS_VOTED: self.has_voted,
Command.CLOSE_POLL: self.close_poll,
Command.CLOSE_POLL: self.close_poll_handler,
Command.VIEW_VOTES: self.view_votes,
Command.VIEW_VOTERS: self.view_poll_voters,
Command.ABOUT: self.show_about,
Expand Down Expand Up @@ -1108,17 +1106,27 @@ async def view_all_polls(update: ModifiedTeleUpdate, *_, **__):
@classmethod
def read_vote_count(cls, poll_id: int) -> Result[int, MessageBuilder]:
# returns all registered voters who have cast a vote
fetch_poll_result = cls.fetch_poll(poll_id)
fetch_poll_result = RCVTally.fetch_poll(poll_id)

if fetch_poll_result.is_err():
return fetch_poll_result

poll = fetch_poll_result.unwrap()
return Ok(poll.num_votes)

async def close_poll(self, update, *_, **__):
@classmethod
async def close_poll_handler(cls, update, *_, **__):
message = update.message
message_text = TelegramHelpers.read_raw_command_args(update)
user = update.user

if message_text == '':
ClosePollChatContext(
user_id=user.get_user_id(), chat_id=message.chat.id
).save_state()
return await message.reply_text(
'Enter the poll ID for the poll you want to close'
)

if constants.ID_PATTERN.match(message_text) is None:
return await message.reply_text(textwrap.dedent(f"""
Expand All @@ -1134,48 +1142,11 @@ async def close_poll(self, update, *_, **__):
return False

poll_id = extract_result.unwrap()
tele_user: TeleUser | None = message.from_user

try:
poll = Polls.select().where(Polls.id == poll_id).get()
except Polls.DoesNotExist:
await message.reply_text(f'poll {poll_id} does not exist')
return False

try:
user = Users.build_from_fields(tele_id=tele_user.id).get()
except Users.DoesNotExist:
await message.reply_text(f'UNEXPECTED ERROR: USER DOES NOT EXIST')
return False

user_id = user.get_user_id()
creator_id: UserID = poll.get_creator().get_user_id()
if creator_id != user_id:
await message.reply_text(
"Only the creator of this poll is allowed to close it"
)
return False

poll.closed = True
poll.save()

await message.reply_text(f'poll {poll_id} closed')
winning_option_id, get_status = await self.get_poll_winner(poll_id)

if get_status == GetPollWinnerStatus.COMPUTING:
return await message.reply_text(textwrap.dedent(f"""
Poll winner computation in progress
Please check again later
"""))
elif winning_option_id is not None:
winning_options = PollOptions.select().where(
PollOptions.id == winning_option_id
)

option_name = winning_options[0].option_name
return await message.reply_text(f'Poll winner is: {option_name}')
else:
return await message.reply_text('Poll has no winner')
user_id = update.user.get_user_id()
await ClosePollContextHandler.close_poll(
poll_id=poll_id, user_id=user_id,
update=update
)

@classmethod
async def whitelist_username(cls, update: ModifiedTeleUpdate, *_, **__):
Expand Down Expand Up @@ -1677,7 +1648,8 @@ async def fetch_poll_results(self, update, *_, **__):
error_message = get_poll_closed_result.err()
await error_message.call(message.reply_text)

winning_option_id, get_status = await self.get_poll_winner(poll_id)
get_winner_result = self.rcv_tally.get_poll_winner(poll_id)
winning_option_id, get_status = get_winner_result

if get_status == GetPollWinnerStatus.COMPUTING:
await message.reply_text(textwrap.dedent(f"""
Expand Down
Loading

0 comments on commit 4172201

Please sign in to comment.