From 395e62be813f7fb04e48ac1a78b0313340b16a1b Mon Sep 17 00:00:00 2001 From: milselarch Date: Tue, 1 Oct 2024 23:00:50 +0800 Subject: [PATCH 01/23] refactor users middlware, add deleted_at field for users --- bot.py | 74 +++++++++++----- database/database.py | 1 + migrations/0002_migration_202410012257.py | 103 ++++++++++++++++++++++ 3 files changed, 156 insertions(+), 22 deletions(-) create mode 100644 migrations/0002_migration_202410012257.py diff --git a/bot.py b/bot.py index 24ae9c93..b4b0acd1 100644 --- a/bot.py +++ b/bot.py @@ -18,7 +18,7 @@ from database.database import UserID from database.db_helpers import EmptyField, Empty, BoundRowFields from load_config import TELEGRAM_BOT_TOKEN, WEBHOOK_URL -from typing import List, Tuple, Dict, Optional, Sequence, Iterable +from typing import List, Tuple, Dict, Optional, Sequence, Iterable, Coroutine, Awaitable, Callable from LocksManager import PollsLockManager from database import ( @@ -69,11 +69,8 @@ def __init__(self, config_path='config.yml'): self.webhook_url = None @staticmethod - def record_username_wrapper(func, include_self=True): - """ - updates user id to username mapping - """ - def caller(self, update: Update, *args, **kwargs): + def users_middleware(func: Callable[..., Awaitable], include_self=True): + async def caller(self, update: Update, *args, **kwargs): # print("SELF", self) if update.callback_query is not None: query = update.callback_query @@ -84,23 +81,37 @@ def caller(self, update: Update, *args, **kwargs): else: tele_user = None - if tele_user is not None: - assert isinstance(tele_user, TeleUser) - chat_username: str = tele_user.username - tele_id = tele_user.id - # print('UPDATE_USER', user.id, chat_username) + if tele_user is None: + if update.message is not None: + respond_callback = update.message.reply_text + elif update.callback_query is not None: + query = update.callback_query + respond_callback = query.answer + else: + logger.error(f'NO USER FOUND FOR ENDPOINT {func}') + return False + + await respond_callback("User not found") + + tele_id = tele_user.id + chat_username: str = tele_user.username + assert isinstance(tele_user, TeleUser) + user, _ = Users.build_from_fields(tele_id=tele_id).get_or_create() + # don't allow deleted users to interact with the bot + if user.deleted_at is not None: + await tele_user.send_message("User has been deleted") + return False - Users.build_from_fields( - tele_id=tele_id, username=chat_username - ).insert().on_conflict( - preserve=[Users.tele_id], - update={Users.username: chat_username} - ).execute() + # update user tele id to username mapping + if user.username != chat_username: + user.username = chat_username + user.save() + # TODO: get user db entry and pass to inner func if include_self: - return func(self, update, *args, **kwargs) + return await func(self, update, *args, **kwargs) else: - return func(update, *args, **kwargs) + return await func(update, *args, **kwargs) if not include_self: return lambda *args, **kwargs: caller(None, *args, **kwargs) @@ -109,7 +120,7 @@ def caller(self, update: Update, *args, **kwargs): @classmethod def wrap_command_handler(cls, handler): - return track_errors(cls.record_username_wrapper( + return track_errors(cls.users_middleware( handler, include_self=False )) @@ -143,6 +154,7 @@ def start_bot(self): view_voters=self.view_poll_voters, about=self.show_about, delete_poll=self.delete_poll, + delete_account=self.delete_account, help=self.show_help, vote_admin=self.vote_for_poll_admin, @@ -168,6 +180,11 @@ def start_bot(self): self.app.add_handler(CallbackQueryHandler( self.inline_keyboard_handler )) + # catch-all to handle all other messages + self.app.add_handler(MessageHandler( + filters.Regex(r'.*') & filters.TEXT, + self.handle_other_messages + )) # self.app.add_error_handler(self.error_handler) self.app.run_polling(allowed_updates=Update.ALL_TYPES) @@ -326,10 +343,18 @@ async def send_post_vote_reply(self, message: Message, poll_id: int): """)) @track_errors - @record_username_wrapper + @users_middleware async def handle_unknown_command(self, update: Update, _): await update.message.reply_text("Command not found") + @track_errors + @users_middleware + async def handle_other_messages(self, update: Update, _): + # TODO: implement callback contexts voting and poll creation + await update.message.reply_text( + "Message support is still in development" + ) + def generate_poll_url(self, poll_id: int, tele_user: TeleUser) -> str: req = PreparedRequest() auth_date = str(int(time.time())) @@ -371,7 +396,7 @@ def build_private_vote_markup( return markup_layout @track_errors - @record_username_wrapper + @users_middleware async def inline_keyboard_handler( self, update: Update, context: CallbackContext ): @@ -1867,6 +1892,11 @@ async def delete_poll(cls, update: Update, *_, **__): ) return True + @classmethod + def delete_account(cls, update: Update, *_, **__): + # TODO: implement this + raise NotImplementedError + @staticmethod async def show_help(update: Update, *_, **__): message: Message = update.message diff --git a/database/database.py b/database/database.py index e2710e53..9577eda5 100644 --- a/database/database.py +++ b/database/database.py @@ -66,6 +66,7 @@ class Users(BaseModel): username = CharField(max_length=255, default=None, null=True) credits = IntegerField(default=0) subscription_tier = IntegerField(default=0) + deleted_at = DateTimeField(default=None, null=True) class Meta: database = database_proxy diff --git a/migrations/0002_migration_202410012257.py b/migrations/0002_migration_202410012257.py new file mode 100644 index 00000000..52a95f05 --- /dev/null +++ b/migrations/0002_migration_202410012257.py @@ -0,0 +1,103 @@ +# auto-generated snapshot +from peewee import * +import datetime +import peewee + + +snapshot = Snapshot() + + +@snapshot.append +class Users(peewee.Model): + id = BigAutoField(primary_key=True) + tele_id = BigIntegerField(index=True, unique=True) + username = CharField(max_length=255, null=True) + credits = IntegerField(default=0) + subscription_tier = IntegerField(default=0) + deleted_at = DateTimeField(null=True) + class Meta: + table_name = "users" + indexes = ( + (('username',), False), + ) + + +@snapshot.append +class Polls(peewee.Model): + desc = TextField(default='') + close_time = DateTimeField() + open_time = DateTimeField(default=datetime.datetime.now) + closed = BooleanField(default=False) + open_registration = BooleanField(default=False) + auto_refill = BooleanField(default=False) + creator = snapshot.ForeignKeyField(index=True, model='users', on_delete='CASCADE') + max_voters = IntegerField(default=10) + num_voters = IntegerField(default=0) + num_votes = IntegerField(default=0) + class Meta: + table_name = "polls" + + +@snapshot.append +class ChatWhitelist(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + chat_id = BigIntegerField() + broadcasted = BooleanField(default=False) + class Meta: + table_name = "chatwhitelist" + indexes = ( + (('poll', 'chat_id'), True), + ) + + +@snapshot.append +class PollOptions(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + option_name = CharField(max_length=255) + option_number = IntegerField() + class Meta: + table_name = "polloptions" + + +@snapshot.append +class PollVoters(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + user = snapshot.ForeignKeyField(index=True, model='users', on_delete='CASCADE') + voted = BooleanField(default=False) + class Meta: + table_name = "pollvoters" + indexes = ( + (('poll', 'user'), True), + ) + + +@snapshot.append +class PollWinners(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + option = snapshot.ForeignKeyField(index=True, model='polloptions', null=True, on_delete='CASCADE') + class Meta: + table_name = "pollwinners" + + +@snapshot.append +class UsernameWhitelist(peewee.Model): + username = CharField(max_length=255) + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + user = snapshot.ForeignKeyField(index=True, model='users', null=True, on_delete='CASCADE') + class Meta: + table_name = "usernamewhitelist" + indexes = ( + (('poll', 'username'), True), + ) + + +@snapshot.append +class VoteRankings(peewee.Model): + poll_voter = snapshot.ForeignKeyField(index=True, model='pollvoters', on_delete='CASCADE') + option = snapshot.ForeignKeyField(index=True, model='polloptions', null=True, on_delete='CASCADE') + special_value = IntegerField(constraints=[SQL('CHECK (special_value < 0)')], null=True) + ranking = IntegerField() + class Meta: + table_name = "voterankings" + + From 66849f745397cb135d99bf890a44df5a5ec07297 Mon Sep 17 00:00:00 2001 From: milselarch Date: Tue, 1 Oct 2024 23:02:16 +0800 Subject: [PATCH 02/23] refactor imports --- bot.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bot.py b/bot.py index b4b0acd1..5f06df75 100644 --- a/bot.py +++ b/bot.py @@ -18,9 +18,11 @@ from database.database import UserID from database.db_helpers import EmptyField, Empty, BoundRowFields from load_config import TELEGRAM_BOT_TOKEN, WEBHOOK_URL -from typing import List, Tuple, Dict, Optional, Sequence, Iterable, Coroutine, Awaitable, Callable from LocksManager import PollsLockManager +from typing import ( + List, Tuple, Dict, Optional, Sequence, Iterable, Awaitable, Callable +) from database import ( Users, Polls, PollVoters, UsernameWhitelist, PollOptions, VoteRankings, db, ChatWhitelist, PollWinners @@ -85,8 +87,7 @@ async def caller(self, update: Update, *args, **kwargs): if update.message is not None: respond_callback = update.message.reply_text elif update.callback_query is not None: - query = update.callback_query - respond_callback = query.answer + respond_callback = update.callback_query.answer else: logger.error(f'NO USER FOUND FOR ENDPOINT {func}') return False From f2813da829136861f6114d2bf30fb783cc6219d4 Mon Sep 17 00:00:00 2001 From: milselarch Date: Wed, 2 Oct 2024 19:04:52 +0800 Subject: [PATCH 03/23] added ModifiedTeleUpdate --- BaseAPI.py | 28 ++++++- ModifiedTeleUpdate.py | 22 ++++++ bot.py | 170 ++++++++++++++++++++++++++---------------- 3 files changed, 152 insertions(+), 68 deletions(-) create mode 100644 ModifiedTeleUpdate.py diff --git a/BaseAPI.py b/BaseAPI.py index f7f0826a..95d62d7f 100644 --- a/BaseAPI.py +++ b/BaseAPI.py @@ -3,19 +3,25 @@ import json import secrets import string + +import telegram import time import hashlib import textwrap import dataclasses + +from telegram.ext import ApplicationBuilder + import database import aioredlock +import ranked_choice_vote import redis from enum import IntEnum from typing_extensions import Any from collections import defaultdict -from ranked_choice_vote import ranked_choice_vote from strenum import StrEnum +from load_config import TELEGRAM_BOT_TOKEN from typing import List, Dict, Optional, Tuple from result import Ok, Err, Result @@ -127,6 +133,21 @@ def __init__(self): self.redis_cache = redis.Redis() self.redis_lock_manager = Aioredlock() + @staticmethod + def __get_telegram_token(): + # TODO: move methods using tele token to a separate class + return TELEGRAM_BOT_TOKEN + + @classmethod + def create_tele_bot(cls): + return telegram.Bot(token=cls.__get_telegram_token()) + + @classmethod + def create_application_builder(cls): + builder = ApplicationBuilder() + builder.token(cls.__get_telegram_token()) + return builder + @staticmethod def _build_cache_key(header: str, key: str): return f"{header}:{key}" @@ -865,10 +886,11 @@ def make_data_check_string( return data_check_string - @staticmethod + @classmethod def sign_data_check_string( - data_check_string: str, bot_token: str + cls, data_check_string: str ) -> str: + bot_token = cls.__get_telegram_token() secret_key = hmac.new( key=b"WebAppData", msg=bot_token.encode(), digestmod=hashlib.sha256 diff --git a/ModifiedTeleUpdate.py b/ModifiedTeleUpdate.py new file mode 100644 index 00000000..cca340f7 --- /dev/null +++ b/ModifiedTeleUpdate.py @@ -0,0 +1,22 @@ +from database import Users +from telegram import Update as BaseTeleUpdate + + +class ModifiedTeleUpdate(object): + def __init__( + self, update: BaseTeleUpdate, user: Users + ): + self.update: BaseTeleUpdate = update + self.user: Users = user + + @property + def callback_query(self): + return self.update.callback_query + + @property + def message(self): + return self.update.message + + @property + def effective_message(self): + return self.update.effective_message \ No newline at end of file diff --git a/bot.py b/bot.py index 5f06df75..8b145109 100644 --- a/bot.py +++ b/bot.py @@ -11,17 +11,22 @@ from json import JSONDecodeError from result import Ok, Err, Result +# noinspection PyProtectedMember +from telegram.ext._utils.types import CCT, RT +from telegram.ext.filters import BaseFilter from MessageBuilder import MessageBuilder from requests.models import PreparedRequest + +from ModifiedTeleUpdate import ModifiedTeleUpdate from SpecialVotes import SpecialVotes from bot_middleware import track_errors, admin_only from database.database import UserID from database.db_helpers import EmptyField, Empty, BoundRowFields -from load_config import TELEGRAM_BOT_TOKEN, WEBHOOK_URL +from load_config import WEBHOOK_URL from LocksManager import PollsLockManager from typing import ( - List, Tuple, Dict, Optional, Sequence, Iterable, Awaitable, Callable + List, Tuple, Dict, Optional, Sequence, Iterable, Callable, Coroutine, Any ) from database import ( Users, Polls, PollVoters, UsernameWhitelist, @@ -32,12 +37,12 @@ CallbackCommands, GetPollWinnerStatus ) from telegram import ( - Update, Message, WebAppInfo, ReplyKeyboardMarkup, + Message, WebAppInfo, ReplyKeyboardMarkup, KeyboardButton, InlineKeyboardMarkup, InlineKeyboardButton, - User as TeleUser + User as TeleUser, Update as BaseTeleUpdate ) from telegram.ext import ( - CommandHandler, ApplicationBuilder, ContextTypes, + CommandHandler, ContextTypes, MessageHandler, filters, CallbackContext, CallbackQueryHandler, Application ) @@ -71,15 +76,23 @@ def __init__(self, config_path='config.yml'): self.webhook_url = None @staticmethod - def users_middleware(func: Callable[..., Awaitable], include_self=True): - async def caller(self, update: Update, *args, **kwargs): + def users_middleware( + func: Callable[..., Coroutine], include_self=True + ) -> Callable[[BaseTeleUpdate, ...], Coroutine]: + async def caller( + self, update: BaseTeleUpdate | CallbackContext, + *args, **kwargs + ): # print("SELF", self) - if update.callback_query is not None: - query = update.callback_query - tele_user = query.from_user - elif update.message is not None: + # print('UPDATE', update, args, kwargs) + is_tele_update = isinstance(update, BaseTeleUpdate) + + if update.message is not None: message: Message = update.message tele_user = message.from_user + elif is_tele_update and update.callback_query is not None: + query = update.callback_query + tele_user = query.from_user else: tele_user = None @@ -108,16 +121,19 @@ async def caller(self, update: Update, *args, **kwargs): user.username = chat_username user.save() - # TODO: get user db entry and pass to inner func + modified_tele_update = ModifiedTeleUpdate( + update=update, user=user + ) + if include_self: - return await func(self, update, *args, **kwargs) + return await func(self, modified_tele_update, *args, **kwargs) else: - return await func(update, *args, **kwargs) + return await func(modified_tele_update, *args, **kwargs) - if not include_self: - return lambda *args, **kwargs: caller(None, *args, **kwargs) + def caller_without_self(update: BaseTeleUpdate, *args, **kwargs): + return caller(None, update, *args, **kwargs) - return caller + return caller if include_self else caller_without_self @classmethod def wrap_command_handler(cls, handler): @@ -126,11 +142,10 @@ def wrap_command_handler(cls, handler): )) def start_bot(self): - self.bot = telegram.Bot(token=TELEGRAM_BOT_TOKEN) + self.bot = self.create_tele_bot() self.webhook_url = WEBHOOK_URL - builder = ApplicationBuilder() - builder.token(TELEGRAM_BOT_TOKEN) + builder = self.create_application_builder() builder.concurrent_updates(MAX_CONCURRENT_UPDATES) builder.post_init(self.post_init) self.app = builder.build() @@ -171,24 +186,26 @@ def start_bot(self): wrap_func=self.wrap_command_handler ) # catch-all to handle responses to unknown commands - self.app.add_handler(MessageHandler( - filters.Regex(r'^/') & filters.COMMAND, + self.register_message_handler( + self.app, filters.Regex(r'^/') & filters.COMMAND, self.handle_unknown_command - )) - self.app.add_handler(MessageHandler( - filters.StatusUpdate.WEB_APP_DATA, self.web_app_handler - )) - self.app.add_handler(CallbackQueryHandler( - self.inline_keyboard_handler - )) + ) + # handle web app updates + self.register_message_handler( + self.app, filters.StatusUpdate.WEB_APP_DATA, + self.web_app_handler + ) # catch-all to handle all other messages - self.app.add_handler(MessageHandler( - filters.Regex(r'.*') & filters.TEXT, + self.register_message_handler( + self.app, filters.Regex(r'.*') & filters.TEXT, self.handle_other_messages - )) + ) + self.register_callback_handler( + self.app, self.inline_keyboard_handler + ) # self.app.add_error_handler(self.error_handler) - self.app.run_polling(allowed_updates=Update.ALL_TYPES) + self.app.run_polling(allowed_updates=BaseTeleUpdate.ALL_TYPES) @staticmethod async def post_init(application: Application): @@ -295,7 +312,7 @@ async def start_handler( ) @track_errors - async def web_app_handler(self, update: Update, _): + async def web_app_handler(self, update: ModifiedTeleUpdate, _): message: Message = update.message payload = json.loads(update.effective_message.web_app_data.data) @@ -344,13 +361,11 @@ async def send_post_vote_reply(self, message: Message, poll_id: int): """)) @track_errors - @users_middleware - async def handle_unknown_command(self, update: Update, _): + async def handle_unknown_command(self, update: ModifiedTeleUpdate, _): await update.message.reply_text("Command not found") @track_errors - @users_middleware - async def handle_other_messages(self, update: Update, _): + async def handle_other_messages(self, update: ModifiedTeleUpdate, _): # TODO: implement callback contexts voting and poll creation await update.message.reply_text( "Message support is still in development" @@ -369,7 +384,7 @@ def generate_poll_url(self, poll_id: int, tele_user: TeleUser) -> str: auth_date=auth_date, query_id=query_id, user=user_info ) validation_hash = self.sign_data_check_string( - data_check_string=data_check_string, bot_token=TELEGRAM_BOT_TOKEN + data_check_string=data_check_string ) params = { @@ -397,9 +412,8 @@ def build_private_vote_markup( return markup_layout @track_errors - @users_middleware async def inline_keyboard_handler( - self, update: Update, context: CallbackContext + self, update: ModifiedTeleUpdate, context: CallbackContext ): """ callback method for buttons in chat group messages @@ -554,7 +568,7 @@ def is_whitelisted_chat(poll_id: int, chat_id: int): return query.exists() @staticmethod - async def user_details_handler(update: Update, *_): + async def user_details_handler(update: ModifiedTeleUpdate, *_): """ returns current user id and username """ @@ -566,14 +580,14 @@ async def user_details_handler(update: Update, *_): """)) @staticmethod - async def chat_details_handler(update: Update, *_): + async def chat_details_handler(update: ModifiedTeleUpdate, *_): """ returns current chat id """ chat_id = update.message.chat.id await update.message.reply_text(f"chat id: {chat_id}") - async def has_voted(self, update: Update, *_, **__): + async def has_voted(self, update: ModifiedTeleUpdate, *_, **__): """ usage: /has_voted {poll_id} @@ -884,7 +898,9 @@ async def create_poll( poll_message, reply_markup=reply_markup ) - async def register_user_by_tele_id(self, update: Update, *_, **__): + async def register_user_by_tele_id( + self, update: ModifiedTeleUpdate, *_, **__ + ): """ registers a user by user_tele_id for a poll /whitelist_user_id {poll_id} {user_tele_id} @@ -954,21 +970,21 @@ async def register_user_by_tele_id(self, update: Update, *_, **__): return True async def whitelist_chat_registration( - self, update: Update, *_, **__ + self, update: ModifiedTeleUpdate, *_, **__ ): return await self.set_chat_registration_status( update=update, whitelist=True ) async def blacklist_chat_registration( - self, update: Update, *_, **__ + self, update: ModifiedTeleUpdate, *_, **__ ): return await self.set_chat_registration_status( update=update, whitelist=False ) async def set_chat_registration_status( - self, update: Update, whitelist: bool + self, update: ModifiedTeleUpdate, whitelist: bool ) -> bool: message = update.message tele_user: TeleUser | None = message.from_user @@ -1030,7 +1046,7 @@ async def set_chat_registration_status( ) return True - async def view_votes(self, update: Update, *_, **__): + async def view_votes(self, update: ModifiedTeleUpdate, *_, **__): message: Message = update.message extract_result = self.extract_poll_id(update) @@ -1161,8 +1177,8 @@ async def close_poll_admin(self, update, *_, **__): await self._set_poll_status(update, True) @admin_only - async def _set_poll_status(self, update: Update, closed=True): - assert isinstance(update, Update) + async def _set_poll_status(self, update: ModifiedTeleUpdate, closed=True): + assert isinstance(update, ModifiedTeleUpdate) message = update.message extract_result = self.extract_poll_id(update) @@ -1185,12 +1201,14 @@ async def _set_poll_status(self, update: Update, closed=True): await message.reply_text(f'poll {poll_id} has been unclosed') @admin_only - async def lookup_from_username_admin(self, update: Update, *_, **__): + async def lookup_from_username_admin( + self, update: ModifiedTeleUpdate, *_, **__ + ): """ /lookup_from_username_admin {username} Looks up user_ids for users with a matching username """ - assert isinstance(update, Update) + assert isinstance(update, ModifiedTeleUpdate) message = update.message raw_text = message.text.strip() @@ -1208,7 +1226,7 @@ async def lookup_from_username_admin(self, update: Update, *_, **__): """)) @admin_only - async def insert_user_admin(self, update: Update, *_, **__): + async def insert_user_admin(self, update: ModifiedTeleUpdate, *_, **__): """ Inserts a user with the given user_id and username into the Users table @@ -1266,11 +1284,13 @@ async def insert_user_admin(self, update: Update, *_, **__): ) @admin_only - async def lookup_from_username_admin(self, update: Update, *_, **__): + async def lookup_from_username_admin( + self, update: ModifiedTeleUpdate, *_, **__ + ): """ Looks up user_ids for users with a matching username """ - assert isinstance(update, Update) + assert isinstance(update, ModifiedTeleUpdate) message = update.message raw_text = message.text.strip() @@ -1288,7 +1308,7 @@ async def lookup_from_username_admin(self, update: Update, *_, **__): """)) @admin_only - async def insert_user_admin(self, update: Update, *_, **__): + async def insert_user_admin(self, update: ModifiedTeleUpdate, *_, **__): """ Inserts a user with the given user_id and username into the Users table @@ -1410,7 +1430,7 @@ async def view_poll(self, update, context: ContextTypes.DEFAULT_TYPE): return True @staticmethod - async def view_all_polls(update: Update, *_, **__): + async def view_all_polls(update: ModifiedTeleUpdate, *_, **__): # TODO: show voted / voters count + open / close status for each poll message: Message = update.message tele_user: TeleUser = update.message.from_user @@ -1525,7 +1545,7 @@ async def close_poll(self, update, *_, **__): return await message.reply_text('Poll has no winner') @admin_only - async def vote_for_poll_admin(self, update: Update, *_, **__): + async def vote_for_poll_admin(self, update: ModifiedTeleUpdate, *_, **__): """ telegram command formats: /vote_admin {username} {poll_id}: {option_1} > ... > {option_n} @@ -1807,7 +1827,7 @@ def unpack_rankings_and_poll_id( return Ok((poll_id, rankings)) @staticmethod - async def show_about(update: Update, *_, **__): + async def show_about(update: ModifiedTeleUpdate, *_, **__): message: Message = update.message await message.reply_text(textwrap.dedent(f""" Version {__VERSION__} @@ -1818,7 +1838,7 @@ async def show_about(update: Update, *_, **__): """)) @classmethod - async def delete_poll(cls, update: Update, *_, **__): + async def delete_poll(cls, update: ModifiedTeleUpdate, *_, **__): message: Message = update.message tele_user = message.from_user user_tele_id = tele_user.id @@ -1894,12 +1914,12 @@ async def delete_poll(cls, update: Update, *_, **__): return True @classmethod - def delete_account(cls, update: Update, *_, **__): + def delete_account(cls, update: ModifiedTeleUpdate, *_, **__): # TODO: implement this raise NotImplementedError @staticmethod - async def show_help(update: Update, *_, **__): + async def show_help(update: ModifiedTeleUpdate, *_, **__): message: Message = update.message await message.reply_text(textwrap.dedent(""" /start - start bot @@ -2136,9 +2156,29 @@ async def fetch_poll_results(self, update, *_, **__): else: return await message.reply_text('Poll has no winner') + @classmethod + def register_message_handler( + cls, dispatcher: Application, message_filter: BaseFilter, + callback: Callable[[ModifiedTeleUpdate, CCT], Coroutine[Any, Any, RT]] + ): + dispatcher.add_handler(MessageHandler( + message_filter, cls.users_middleware(callback, include_self=False) + )) + + @classmethod + def register_callback_handler( + cls, dispatcher: Application, + callback: Callable[[ModifiedTeleUpdate, CCT], Coroutine[Any, Any, RT]] + ): + dispatcher.add_handler(CallbackQueryHandler( + cls.users_middleware(callback, include_self=False) + )) + @staticmethod def register_commands( - dispatcher, commands_mapping, wrap_func=lambda func: func + dispatcher: Application, + commands_mapping, + wrap_func=lambda func: func ): for command_name in commands_mapping: handler = commands_mapping[command_name] @@ -2149,7 +2189,7 @@ def register_commands( @staticmethod def extract_poll_id( - update: telegram.Update + update: ModifiedTeleUpdate ) -> Result[int, MessageBuilder]: message: telegram.Message = update.message error_message = MessageBuilder() From 631e4b73c0bab56bccb4b2c98a5c64138f9eb5a4 Mon Sep 17 00:00:00 2001 From: milselarch Date: Thu, 3 Oct 2024 21:31:16 +0800 Subject: [PATCH 04/23] chore: type command registration better --- bot.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bot.py b/bot.py index 8b145109..215f64aa 100644 --- a/bot.py +++ b/bot.py @@ -150,7 +150,7 @@ def start_bot(self): builder.post_init(self.post_init) self.app = builder.build() - commands_mapping = self.kwargify( + commands_mapping = dict( start=self.start_handler, user_details=self.user_details_handler, chat_details=self.chat_details_handler, @@ -182,8 +182,7 @@ def start_bot(self): # on different commands - answer in Telegram self.register_commands( - self.app, commands_mapping=commands_mapping, - wrap_func=self.wrap_command_handler + self.app, commands_mapping=commands_mapping ) # catch-all to handle responses to unknown commands self.register_message_handler( @@ -2174,15 +2173,16 @@ def register_callback_handler( cls.users_middleware(callback, include_self=False) )) - @staticmethod + @classmethod def register_commands( - dispatcher: Application, - commands_mapping, - wrap_func=lambda func: func + cls, dispatcher: Application, + commands_mapping: Dict[ + str, Callable[[ModifiedTeleUpdate, ...], Coroutine] + ], ): for command_name in commands_mapping: handler = commands_mapping[command_name] - wrapped_handler = wrap_func(handler) + wrapped_handler = cls.wrap_command_handler(handler) dispatcher.add_handler(CommandHandler( command_name, wrapped_handler )) From 74fcc7654443a0a9561f3c7eda0d7755e2afc4c4 Mon Sep 17 00:00:00 2001 From: milselarch Date: Fri, 4 Oct 2024 22:56:33 +0800 Subject: [PATCH 05/23] chore: refactor some endpoints to use passed in db user entry --- bot.py | 40 +++++++++++++++------------------------- database/database.py | 14 ++++++++++++-- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/bot.py b/bot.py index 215f64aa..9d3550a0 100644 --- a/bot.py +++ b/bot.py @@ -259,7 +259,7 @@ async def post_init(application: Application): )]) async def start_handler( - self, update, context: ContextTypes.DEFAULT_TYPE + self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE ): # Send a message when the command /start is issued. message = update.message @@ -281,14 +281,8 @@ async def start_handler( return False poll_id = int(pattern_match.group(1)) - tele_user: TeleUser = update.message.from_user - user_tele_id = tele_user.id - - try: - user = Users.build_from_fields(tele_id=user_tele_id).get() - except Users.DoesNotExist: - await message.reply_text(f'UNEXPECTED ERROR: USER DOES NOT EXIST') - return False + tele_user: TeleUser = message.from_user + user: Users = update.user user_id = user.get_user_id() view_poll_result = self.get_poll_message( @@ -572,7 +566,7 @@ async def user_details_handler(update: ModifiedTeleUpdate, *_): returns current user id and username """ # when command /user_details is invoked - user = update.message.from_user + user: TeleUser = update.message.from_user await update.message.reply_text(textwrap.dedent(f""" user id: {user.id} username: {user.username} @@ -626,7 +620,7 @@ async def has_voted(self, update: ModifiedTeleUpdate, *_, **__): await message.reply_text("you haven't voted") async def create_group_poll( - self, update, context: ContextTypes.DEFAULT_TYPE + self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE ): """ /create_group_poll @username_1 @username_2 ... @username_n: @@ -650,7 +644,7 @@ async def create_group_poll( ) async def create_poll( - self, update, context: ContextTypes.DEFAULT_TYPE, + self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, open_registration: bool = False, whitelisted_chat_ids: Sequence[int] = () ): @@ -673,14 +667,8 @@ async def create_poll( creator_tele_id = creator_user.id assert isinstance(creator_tele_id, int) raw_text = message.text.strip() - # print('CHAT_IDS', whitelisted_chat_ids) - - user_res = Users.get_from_tele_id(creator_tele_id) - if user_res.is_err(): - await message.reply_text("Creator user does not exist") - return False + user_entry: Users = update.user - user_entry: Users = user_res.unwrap() try: subscription_tier = SubscriptionTiers( user_entry.subscription_tier @@ -905,12 +893,12 @@ async def register_user_by_tele_id( /whitelist_user_id {poll_id} {user_tele_id} """ message: Message = update.message - user = message.from_user + tele_user = message.from_user raw_text = message.text.strip() pattern = re.compile(r'^\S+\s+([1-9]\d*)\s+([1-9]\d*)$') matches = pattern.match(raw_text) - if user is None: + if tele_user is None: await message.reply_text(f'user not found') return False if matches is None: @@ -923,10 +911,12 @@ async def register_user_by_tele_id( return False poll_id = int(capture_groups[0]) - user_tele_id = int(capture_groups[1]) + target_user_tele_id = int(capture_groups[1]) try: - user = Users.build_from_fields(tele_id=user_tele_id).get() + target_user = Users.build_from_fields( + tele_id=target_user_tele_id + ).get() except Users.DoesNotExist: await message.reply_text(f'UNEXPECTED ERROR: USER DOES NOT EXIST') return False @@ -937,7 +927,7 @@ async def register_user_by_tele_id( await message.reply_text(f'poll {poll_id} does not exist') return False - user_id = user.get_user_id() + user_id = target_user.get_user_id() creator_id: UserID = poll.get_creator().get_user_id() if creator_id != user_id: await message.reply_text( @@ -965,7 +955,7 @@ async def register_user_by_tele_id( await message.reply_text(response_text) return False - await message.reply_text(f'User #{user_tele_id} registered') + await message.reply_text(f'User #{target_user_tele_id} registered') return True async def whitelist_chat_registration( diff --git a/database/database.py b/database/database.py index 9577eda5..de95c1cf 100644 --- a/database/database.py +++ b/database/database.py @@ -134,19 +134,29 @@ def get_creator_id(self) -> UserID: @classmethod def build_from_fields( - cls, desc: str | EmptyField = Empty, + cls, poll_id: int | EmptyField = Empty, + desc: str | EmptyField = Empty, creator_id: UserID | EmptyField = Empty, num_voters: int | EmptyField = Empty, open_registration: bool | EmptyField = Empty, max_voters: int | EmptyField = Empty ) -> BoundRowFields[Self]: return BoundRowFields(cls, { - cls.desc: desc, cls.creator: creator_id, + cls.id: poll_id, cls.desc: desc, cls.creator: creator_id, cls.num_voters: num_voters, cls.open_registration: open_registration, cls.max_voters: max_voters }) + @classmethod + def get_as_creator(cls, poll_id: int, user_id: UserID) -> Polls: + # TODO: wrap this in a Result with an enum error type + # (not found, unauthorized, etc) and use this in + # register_user_by_tele_id + return cls.build_from_fields( + poll_id=poll_id, creator_id=user_id + ).get() + # whitelisted group chats from which users are # allowed to register as voters for a poll From 2e10e7053a050a10d802574df383b3592695036e Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 6 Oct 2024 13:50:44 +0800 Subject: [PATCH 06/23] feat: verify deletion token --- BaseAPI.py | 43 +++++++++++++++++++++++++++++- bot.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 110 insertions(+), 10 deletions(-) diff --git a/BaseAPI.py b/BaseAPI.py index 95d62d7f..3aa2c591 100644 --- a/BaseAPI.py +++ b/BaseAPI.py @@ -23,7 +23,7 @@ from strenum import StrEnum from load_config import TELEGRAM_BOT_TOKEN -from typing import List, Dict, Optional, Tuple +from typing import List, Dict, Optional, Tuple, Literal from result import Ok, Err, Result from concurrent.futures import ThreadPoolExecutor from telegram import InlineKeyboardButton, InlineKeyboardMarkup @@ -127,6 +127,8 @@ class BaseAPI(object): POLL_WINNER_LOCK_KEY = "POLL_WINNER_LOCK" # CACHE_LOCK_NAME = "REDIS_CACHE_LOCK" POLL_CACHE_EXPIRY = 60 + DELETION_TOKEN_EXPIRY = 60 * 5 + SHORT_HASH_LENGTH = 6 def __init__(self): database.initialize_db() @@ -138,6 +140,33 @@ def __get_telegram_token(): # TODO: move methods using tele token to a separate class return TELEGRAM_BOT_TOKEN + def generate_delete_token(self, user: Users): + stamp = int(time.time()) + hex_stamp = hex(stamp)[2:].upper() + user_id = user.get_user_id() + hash_input = f'{user_id}:{stamp}' + + signed_message = self.sign_message(hash_input).upper() + short_signed_message = signed_message[:self.SHORT_HASH_LENGTH] + return f'{hex_stamp}:{short_signed_message}' + + def validate_delete_token( + self, user: Users, stamp: int, short_hash: str + ) -> Result[bool, str]: + current_stamp = int(time.time()) + if abs(current_stamp - stamp) > self.DELETION_TOKEN_EXPIRY: + return Err('Token expired') + + user_id = user.get_user_id() + hash_input = f'{user_id}:{stamp}' + signed_message = self.sign_message(hash_input).upper() + short_signed_message = signed_message[:self.SHORT_HASH_LENGTH] + + if short_signed_message != short_hash: + return Err('Invalid token') + + return Ok(True) + @classmethod def create_tele_bot(cls): return telegram.Bot(token=cls.__get_telegram_token()) @@ -899,7 +928,19 @@ def sign_data_check_string( validation_hash = hmac.new( secret_key, data_check_string.encode(), hashlib.sha256 ).hexdigest() + return validation_hash + @classmethod + def sign_message(cls, message: str) -> str: + bot_token = cls.__get_telegram_token() + secret_key = hmac.new( + key=b"SIGN_MESSAGE", msg=bot_token.encode(), + digestmod=hashlib.sha256 + ).digest() + + validation_hash = hmac.new( + secret_key, message.encode(), hashlib.sha256 + ).hexdigest() return validation_hash @staticmethod diff --git a/bot.py b/bot.py index 9d3550a0..f6972269 100644 --- a/bot.py +++ b/bot.py @@ -9,13 +9,14 @@ import asyncio import re -from json import JSONDecodeError +from peewee import JOIN from result import Ok, Err, Result # noinspection PyProtectedMember from telegram.ext._utils.types import CCT, RT from telegram.ext.filters import BaseFilter from MessageBuilder import MessageBuilder from requests.models import PreparedRequest +from json import JSONDecodeError from ModifiedTeleUpdate import ModifiedTeleUpdate from SpecialVotes import SpecialVotes @@ -254,6 +255,8 @@ async def post_init(application: Application): 'about', 'miscellaneous info about the bot' ), ( 'delete_poll', 'delete a poll' + ), ( + 'delete_account', 'delete your user account' ), ( 'help', 'view commands available to the bot' )]) @@ -1902,10 +1905,52 @@ async def delete_poll(cls, update: ModifiedTeleUpdate, *_, **__): ) return True - @classmethod - def delete_account(cls, update: ModifiedTeleUpdate, *_, **__): - # TODO: implement this - raise NotImplementedError + @track_errors + async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): + deletion_token_res = self.get_raw_command_args(update) + + if deletion_token_res.is_err(): + # deletion token not provided, send deletion instructions + delete_token = self.generate_delete_token(update.user) + return await update.message.reply_text(textwrap.dedent(f""" + Deleting your account will accomplish the following: + - all polls you've created will be deleted + - all votes you've cast for any ongoing polls + will be replaced with an abstain vote + - all votes you've cast for any closed polls will + be decoupled from your user account + - your username will be removed from your user account + - your user account will be marked as deleted and you + will not be able to create new polls or vote using + your account moving forward + - your user account will be removed from the database + 28 days after being marked for deletion + + Confirm account deletion by running the delete command + with the provided deletion token: + —————————————————— + /delete_account {delete_token} + """)) + + deletion_token = deletion_token_res.unwrap() + match_pattern = f'^[0-9A-F]+:[0-9A-F]+$' + if re.match(match_pattern, deletion_token) is None: + return await update.message.reply_text('Invalid deletion token') + + hex_stamp, short_hash = deletion_token.split(':') + deletion_stamp = int(hex_stamp, 16) + validation_result = self.validate_delete_token( + user=update.user, stamp=deletion_stamp, short_hash=short_hash + ) + + # TODO: actually implement deletion + if validation_result.is_err(): + err_message = validation_result.err() + return await update.message.reply_text(err_message) + else: + return await update.message.reply_text( + 'Account deleted successfully' + ) @staticmethod async def show_help(update: ModifiedTeleUpdate, *_, **__): @@ -2035,7 +2080,8 @@ async def view_poll_voters(self, update, *_, **__): vote_count = read_vote_count_result.unwrap() poll_voters: Iterable[PollVoters] = PollVoters.select().join( - Users, on=(PollVoters.user == Users.id) + Users, on=(PollVoters.user == Users.id), + join_type=JOIN.LEFT_OUTER ).where( PollVoters.poll == poll_id ) @@ -2051,6 +2097,7 @@ async def view_poll_voters(self, update, *_, **__): recorded_user_ids: set[int] = set() for voter in poll_voters: + # TODO: check if user has been deleted username: Optional[str] = voter.user.username voter_user = voter.get_voter_user() user_tele_id = voter_user.get_tele_id() @@ -2178,9 +2225,9 @@ def register_commands( )) @staticmethod - def extract_poll_id( + def get_raw_command_args( update: ModifiedTeleUpdate - ) -> Result[int, MessageBuilder]: + ) -> Result[str, MessageBuilder]: message: telegram.Message = update.message error_message = MessageBuilder() @@ -2193,7 +2240,19 @@ def extract_poll_id( error_message.add('no poll id specified') return Err(error_message) - raw_poll_id = raw_text[raw_text.index(' '):].strip() + raw_command_args = raw_text[raw_text.index(' '):].strip() + return Ok(raw_command_args) + + @classmethod + def extract_poll_id( + cls, update: ModifiedTeleUpdate + ) -> Result[int, MessageBuilder]: + raw_args_res = cls.get_raw_command_args(update) + if raw_args_res.is_err(): + return raw_args_res + + raw_poll_id = raw_args_res.unwrap() + error_message = MessageBuilder() try: poll_id = int(raw_poll_id) From 8e96c24549f8cd1676e0d97ec5ec427eed427837 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 6 Oct 2024 14:26:13 +0800 Subject: [PATCH 07/23] migrate: add deleted_voters to Polls, make PollVoters.user nullable --- database/database.py | 5 +- migrations/0003_migration_202410061424.py | 117 ++++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 migrations/0003_migration_202410061424.py diff --git a/database/database.py b/database/database.py index de95c1cf..c49cab7b 100644 --- a/database/database.py +++ b/database/database.py @@ -123,6 +123,7 @@ class Polls(BaseModel): num_voters = IntegerField(default=0) # number of registered votes in the poll num_votes = IntegerField(default=0) + deleted_voters = IntegerField(default=0) def get_creator(self) -> Users: # TODO: do a unit test for this @@ -188,7 +189,9 @@ class PollVoters(BaseModel): # poll that voter is eligible to vote for poll = ForeignKeyField(Polls, to_field='id', on_delete='CASCADE') # telegram user id of voter - user = ForeignKeyField(Users, to_field='id', on_delete='CASCADE') + user = ForeignKeyField( + Users, to_field='id', null=True, on_delete='CASCADE' + ) voted = BooleanField(default=False) class Meta: diff --git a/migrations/0003_migration_202410061424.py b/migrations/0003_migration_202410061424.py new file mode 100644 index 00000000..c505886e --- /dev/null +++ b/migrations/0003_migration_202410061424.py @@ -0,0 +1,117 @@ +# auto-generated snapshot +from peewee import * +import datetime +import peewee + + +snapshot = Snapshot() + + +@snapshot.append +class Users(peewee.Model): + id = BigAutoField(primary_key=True) + tele_id = BigIntegerField(index=True, unique=True) + username = CharField(max_length=255, null=True) + credits = IntegerField(default=0) + subscription_tier = IntegerField(default=0) + deleted_at = DateTimeField(null=True) + class Meta: + table_name = "users" + indexes = ( + (('username',), False), + ) + + +@snapshot.append +class Polls(peewee.Model): + desc = TextField(default='') + close_time = DateTimeField() + open_time = DateTimeField(default=datetime.datetime.now) + closed = BooleanField(default=False) + open_registration = BooleanField(default=False) + auto_refill = BooleanField(default=False) + creator = snapshot.ForeignKeyField(index=True, model='users', on_delete='CASCADE') + max_voters = IntegerField(default=10) + num_voters = IntegerField(default=0) + num_votes = IntegerField(default=0) + deleted_voters = IntegerField(default=0) + class Meta: + table_name = "polls" + + +@snapshot.append +class ChatWhitelist(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + chat_id = BigIntegerField() + broadcasted = BooleanField(default=False) + class Meta: + table_name = "chatwhitelist" + indexes = ( + (('poll', 'chat_id'), True), + ) + + +@snapshot.append +class PollOptions(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + option_name = CharField(max_length=255) + option_number = IntegerField() + class Meta: + table_name = "polloptions" + + +@snapshot.append +class PollVoters(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + user = snapshot.ForeignKeyField(index=True, model='users', null=True, on_delete='CASCADE') + voted = BooleanField(default=False) + class Meta: + table_name = "pollvoters" + indexes = ( + (('poll', 'user'), True), + ) + + +@snapshot.append +class PollWinners(peewee.Model): + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + option = snapshot.ForeignKeyField(index=True, model='polloptions', null=True, on_delete='CASCADE') + class Meta: + table_name = "pollwinners" + + +@snapshot.append +class UsernameWhitelist(peewee.Model): + username = CharField(max_length=255) + poll = snapshot.ForeignKeyField(index=True, model='polls', on_delete='CASCADE') + user = snapshot.ForeignKeyField(index=True, model='users', null=True, on_delete='CASCADE') + class Meta: + table_name = "usernamewhitelist" + indexes = ( + (('poll', 'username'), True), + ) + + +@snapshot.append +class VoteRankings(peewee.Model): + poll_voter = snapshot.ForeignKeyField(index=True, model='pollvoters', on_delete='CASCADE') + option = snapshot.ForeignKeyField(index=True, model='polloptions', null=True, on_delete='CASCADE') + special_value = IntegerField(constraints=[SQL('CHECK (special_value < 0)')], null=True) + ranking = IntegerField() + class Meta: + table_name = "voterankings" + + +def forward(old_orm, new_orm): + polls = new_orm['polls'] + return [ + # Apply default value 0 to the field polls.deleted_voters, + polls.update({polls.deleted_voters: 0}).where(polls.deleted_voters.is_null(True)), + ] + + +def backward(old_orm, new_orm): + pollvoters = new_orm['pollvoters'] + return [ + # Check the field `pollvoters.user` does not contain null values, + ] From e9663886426e89c68fad9d97e1b2da098ba6ca9a Mon Sep 17 00:00:00 2001 From: milselarch Date: Wed, 9 Oct 2024 23:04:25 +0800 Subject: [PATCH 08/23] feat: add background process to flush deleted users --- BaseAPI.py | 2 +- bot.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/BaseAPI.py b/BaseAPI.py index 3aa2c591..836d5a93 100644 --- a/BaseAPI.py +++ b/BaseAPI.py @@ -23,7 +23,7 @@ from strenum import StrEnum from load_config import TELEGRAM_BOT_TOKEN -from typing import List, Dict, Optional, Tuple, Literal +from typing import List, Dict, Optional, Tuple from result import Ok, Err, Result from concurrent.futures import ThreadPoolExecutor from telegram import InlineKeyboardButton, InlineKeyboardMarkup diff --git a/bot.py b/bot.py index f6972269..d87fde8d 100644 --- a/bot.py +++ b/bot.py @@ -2,11 +2,13 @@ import json import logging +import multiprocessing import time import telegram import textwrap import asyncio +import datetime import re from peewee import JOIN @@ -17,6 +19,7 @@ from MessageBuilder import MessageBuilder from requests.models import PreparedRequest from json import JSONDecodeError +from datetime import datetime as Datetime from ModifiedTeleUpdate import ModifiedTeleUpdate from SpecialVotes import SpecialVotes @@ -64,10 +67,13 @@ class RankedChoiceBot(BaseAPI): # how long before the delete poll button expires DELETE_POLL_BUTTON_EXPIRY = 60 + DELETE_USERS_BACKLOG = datetime.timedelta(days=28) + FLUSH_USERS_INTERVAL = 600 def __init__(self, config_path='config.yml'): super().__init__() self.config_path = config_path + self.scheduled_processes = [] self.bot = None self.app = None @@ -76,6 +82,21 @@ def __init__(self, config_path='config.yml'): self.poll_locks_manager = PollsLockManager() self.webhook_url = None + @classmethod + def run_flush_deleted_users(cls): + asyncio.run(cls.flush_deleted_users()) + + @classmethod + async def flush_deleted_users(cls): + # TODO: write tests for this + while True: + deletion_cutoff = Datetime.now() - cls.DELETE_USERS_BACKLOG + Users.delete().where( + Users.deleted_at < deletion_cutoff + ).execute() + + await asyncio.sleep(cls.FLUSH_USERS_INTERVAL) + @staticmethod def users_middleware( func: Callable[..., Coroutine], include_self=True @@ -142,9 +163,21 @@ def wrap_command_handler(cls, handler): handler, include_self=False )) + def schedule_tasks(self, tasks: List[Callable[[], None]]): + assert len(self.scheduled_processes) == 0 + + for task in tasks: + process = multiprocessing.Process(target=task) + self.scheduled_processes.append(process) + process.start() + def start_bot(self): + assert self.bot is None self.bot = self.create_tele_bot() self.webhook_url = WEBHOOK_URL + self.schedule_tasks([ + self.run_flush_deleted_users + ]) builder = self.create_application_builder() builder.concurrent_updates(MAX_CONCURRENT_UPDATES) @@ -206,6 +239,10 @@ def start_bot(self): # self.app.add_error_handler(self.error_handler) self.app.run_polling(allowed_updates=BaseTeleUpdate.ALL_TYPES) + print('<<< BOT POLLING LOOP ENDED >>>') + for process in self.scheduled_processes: + process.terminate() + process.join() @staticmethod async def post_init(application: Application): @@ -1444,13 +1481,17 @@ async def view_all_polls(update: ModifiedTeleUpdate, *_, **__): polls = Polls.select().where(Polls.creator == user_id) poll_descriptions = [] + if len(polls) == 0: + await message.reply_text("No polls found") + return False + for poll in polls: poll_descriptions.append( f'#{poll.id}: {poll.desc}' ) await message.reply_text( - 'Polls created:\n' + '\n'.join(poll_descriptions) + 'Polls found:\n' + '\n'.join(poll_descriptions) ) async def vote_and_report( @@ -1937,21 +1978,35 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): if re.match(match_pattern, deletion_token) is None: return await update.message.reply_text('Invalid deletion token') + user: Users = update.user hex_stamp, short_hash = deletion_token.split(':') deletion_stamp = int(hex_stamp, 16) validation_result = self.validate_delete_token( - user=update.user, stamp=deletion_stamp, short_hash=short_hash + user=user, stamp=deletion_stamp, short_hash=short_hash ) - # TODO: actually implement deletion if validation_result.is_err(): err_message = validation_result.err() return await update.message.reply_text(err_message) - else: - return await update.message.reply_text( - 'Account deleted successfully' + + with db.atomic(): + poll_registrations = PollVoters.select().where( + PollVoters.user == user.id + ).join( + VoteRankings, on=(PollVoters.id == VoteRankings.poll_voter), + join_type=JOIN.LEFT_OUTER ) + for poll_voter in poll_registrations: + print('REGISTRATIONS', poll_voter) + + # user.deleted_at = Datetime.now() + # user.save() + + return await update.message.reply_text( + 'Account deleted successfully' + ) + @staticmethod async def show_help(update: ModifiedTeleUpdate, *_, **__): message: Message = update.message From b150848831604da7a0f9f4cd41968d79e8c94275 Mon Sep 17 00:00:00 2001 From: milselarch Date: Thu, 10 Oct 2024 23:24:00 +0800 Subject: [PATCH 09/23] feat: do user deletion actions --- bot.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/bot.py b/bot.py index d87fde8d..59b57336 100644 --- a/bot.py +++ b/bot.py @@ -1957,10 +1957,10 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): Deleting your account will accomplish the following: - all polls you've created will be deleted - all votes you've cast for any ongoing polls - will be replaced with an abstain vote + will be deleted, and you will be deregistered + as a voter from said ongoing polls - all votes you've cast for any closed polls will be decoupled from your user account - - your username will be removed from your user account - your user account will be marked as deleted and you will not be able to create new polls or vote using your account moving forward @@ -1979,6 +1979,7 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): return await update.message.reply_text('Invalid deletion token') user: Users = update.user + user_id = user.get_user_id() hex_stamp, short_hash = deletion_token.split(':') deletion_stamp = int(hex_stamp, 16) validation_result = self.validate_delete_token( @@ -1990,18 +1991,26 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): return await update.message.reply_text(err_message) with db.atomic(): - poll_registrations = PollVoters.select().where( - PollVoters.user == user.id - ).join( - VoteRankings, on=(PollVoters.id == VoteRankings.poll_voter), - join_type=JOIN.LEFT_OUTER - ) + # delete all polls created by the user + Polls.delete().where(Polls.creator == user_id).execute() + user.deleted_at = Datetime.now() # mark as deleted + user.save() - for poll_voter in poll_registrations: - print('REGISTRATIONS', poll_voter) + poll_registrations: Iterable[PollVoters] = ( + PollVoters.select().where(PollVoters.user == user_id) + ) + for poll_registration in poll_registrations: + poll: Polls = poll_registration.poll - # user.deleted_at = Datetime.now() - # user.save() + if poll.closed: + # decouple poll voter from user + poll_registration.user = None + poll_registration.save() + else: + # delete poll voter and increment deleted voters count + poll.deleted_voters += 1 + poll_registration.delete_instance() + poll.save() return await update.message.reply_text( 'Account deleted successfully' From a18e2e68dbbb35c26559b4b9219d8a15d57b5e5e Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 12 Oct 2024 22:11:29 +0800 Subject: [PATCH 10/23] fix: factor in deleted voters when getting num voters for a poll --- BaseAPI.py | 37 +++++++++---------------------------- bot.py | 23 +++++++++++++---------- database/database.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 38 deletions(-) diff --git a/BaseAPI.py b/BaseAPI.py index 836d5a93..2c8969c6 100644 --- a/BaseAPI.py +++ b/BaseAPI.py @@ -33,7 +33,7 @@ Polls, PollVoters, UsernameWhitelist, PollOptions, VoteRankings, db, Users ) -from database.database import PollWinners, BaseModel, UserID +from database.database import PollWinners, BaseModel, UserID, PollMetadata from aioredlock import Aioredlock, LockError @@ -95,17 +95,6 @@ class PollMessage(object): reply_markup: Optional[InlineKeyboardMarkup] -@dataclasses.dataclass -class PollMetadata(object): - id: int - question: str - num_voters: int - num_votes: int - - open_registration: bool - closed: bool - - class GetPollWinnerStatus(IntEnum): CACHED = 0 NEWLY_COMPUTED = 1 @@ -216,13 +205,15 @@ def fetch_poll(poll_id: int) -> Result[Polls, MessageBuilder]: return Ok(poll) @classmethod - def get_num_poll_voters(cls, poll_id: int) -> Result[int, MessageBuilder]: + def get_num_active_poll_voters( + cls, poll_id: int + ) -> Result[int, MessageBuilder]: result = cls.fetch_poll(poll_id) if result.is_err(): return result poll = result.unwrap() - return Ok(poll.num_voters) + return Ok(poll.num_active_voters) @staticmethod async def refresh_lock(lock: aioredlock.Lock, interval: float): @@ -311,7 +302,7 @@ def _determine_poll_winner(cls, poll_id: int) -> Optional[int]: :return: ID of winning option, or None if there's no winner """ - num_poll_voters_result = cls.get_num_poll_voters(poll_id) + num_poll_voters_result = cls.get_num_active_poll_voters(poll_id) if num_poll_voters_result.is_err(): return None @@ -533,7 +524,7 @@ def _register_user_id( return Err(UserRegistrationStatus.POLL_NOT_FOUND) # print('NUM_VOTES', poll.num_voters, poll.max_voters) - voter_limit_reached = (poll.num_voters >= poll.max_voters) + voter_limit_reached = (poll.num_active_voters >= poll.max_voters) if ignore_voter_limit: voter_limit_reached = False @@ -719,7 +710,7 @@ def _generate_poll_message( poll_metadata.id, poll_metadata.question, poll_info.poll_options, closed=poll_metadata.closed, bot_username=bot_username, - num_voters=poll_metadata.num_voters, + num_voters=poll_metadata.num_active_voters, num_votes=poll_metadata.num_votes ) @@ -764,19 +755,9 @@ def read_poll_info( return Ok(cls._read_poll_info(poll_id=poll_id)) - @classmethod - def _read_poll_metadata(cls, poll_id: int) -> PollMetadata: - poll = Polls.select().where(Polls.id == poll_id).get() - return PollMetadata( - id=poll.id, question=poll.desc, - num_voters=poll.num_voters, num_votes=poll.num_votes, - open_registration=poll.open_registration, - closed=poll.closed - ) - @classmethod def _read_poll_info(cls, poll_id: int) -> PollInfo: - poll_metadata = cls._read_poll_metadata(poll_id) + poll_metadata = Polls.read_poll_metadata(poll_id) poll_option_rows = PollOptions.select().where( PollOptions.poll == poll_id ).order_by(PollOptions.option_number) diff --git a/bot.py b/bot.py index 59b57336..27b3eab6 100644 --- a/bot.py +++ b/bot.py @@ -383,9 +383,10 @@ async def web_app_handler(self, update: ModifiedTeleUpdate, _): message=message, poll_id=poll_id ) - async def send_post_vote_reply(self, message: Message, poll_id: int): - poll_metadata = self._read_poll_metadata(poll_id) - num_voters = poll_metadata.num_voters + @classmethod + async def send_post_vote_reply(cls, message: Message, poll_id: int): + poll_metadata = Polls.read_poll_metadata(poll_id) + num_voters = poll_metadata.num_active_voters num_votes = poll_metadata.num_votes await message.reply_text(textwrap.dedent(f""" @@ -561,7 +562,7 @@ async def update_poll_message( """ poll_id = poll_info.metadata.id bot_username = context.bot.username - voter_count = poll_info.metadata.num_voters + voter_count = poll_info.metadata.num_active_voters poll_locks = await self.poll_locks_manager.get_poll_locks( poll_id=poll_id ) @@ -812,9 +813,11 @@ async def create_poll( return False assert len(set(duplicate_tele_ids)) == len(duplicate_tele_ids) - num_voters = len(poll_user_tele_ids) + len(whitelisted_usernames) + initial_num_voters = ( + len(poll_user_tele_ids) + len(whitelisted_usernames) + ) max_voters = subscription_tier.get_max_voters() - if num_voters > max_voters: + if initial_num_voters > max_voters: await message.reply_text(f'Whitelisted voters exceeds limit') return False @@ -825,7 +828,7 @@ async def create_poll( return False creator_id = user.get_user_id() - assert num_voters <= max_voters + assert initial_num_voters <= max_voters num_user_created_polls = self.count_polls_created(creator_id) poll_creation_limit = subscription_tier.get_max_polls() limit_reached_text = textwrap.dedent(f""" @@ -859,7 +862,7 @@ async def create_poll( new_poll = Polls.build_from_fields( desc=poll_question, creator_id=creator_id, - num_voters=num_voters, open_registration=open_registration, + num_voters=initial_num_voters, open_registration=open_registration, max_voters=subscription_tier.get_max_voters() ).create() @@ -903,7 +906,7 @@ async def create_poll( poll_message = self.generate_poll_info( new_poll_id, poll_question, poll_options, bot_username=bot_username, closed=False, - num_voters=num_voters + num_voters=initial_num_voters ) chat_type = update.message.chat.type @@ -1958,7 +1961,7 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): - all polls you've created will be deleted - all votes you've cast for any ongoing polls will be deleted, and you will be deregistered - as a voter from said ongoing polls + from these ongoing polls - all votes you've cast for any closed polls will be decoupled from your user account - your user account will be marked as deleted and you diff --git a/database/database.py b/database/database.py index c49cab7b..52dabae1 100644 --- a/database/database.py +++ b/database/database.py @@ -1,5 +1,6 @@ from __future__ import annotations +import dataclasses import datetime import os import sys @@ -104,6 +105,22 @@ def get_from_tele_id( return cls.build_from_fields(tele_id=tele_id).safe_get() +@dataclasses.dataclass +class PollMetadata(object): + id: int + question: str + num_voters: int + num_deleted: int + num_votes: int + + open_registration: bool + closed: bool + + @property + def num_active_voters(self) -> int: + return self.num_voters - self.num_deleted + + # stores poll metadata (description, open time, etc etc) class Polls(BaseModel): id = AutoField(primary_key=True) @@ -120,11 +137,20 @@ class Polls(BaseModel): creator = ForeignKeyField(Users, to_field='id', on_delete='CASCADE') max_voters = IntegerField(default=10) # number of registered voters in the poll + # TODO: rename to raw_num_voters or _num_voters to make it clear + # that this number includes deleted voters as well num_voters = IntegerField(default=0) # number of registered votes in the poll num_votes = IntegerField(default=0) + # TODO: rename to num_deleted_voters deleted_voters = IntegerField(default=0) + @property + def num_active_voters(self) -> int: + assert isinstance(self.num_voters, int) + assert isinstance(self.deleted_voters, int) + return self.num_voters - self.deleted_voters + def get_creator(self) -> Users: # TODO: do a unit test for this assert isinstance(self.creator, Users) @@ -133,6 +159,16 @@ def get_creator(self) -> Users: def get_creator_id(self) -> UserID: return self.get_creator().get_user_id() + @classmethod + def read_poll_metadata(cls, poll_id: int) -> PollMetadata: + poll = cls.select().where(cls.id == poll_id).get() + return PollMetadata( + id=poll.id, question=poll.desc, + num_voters=poll.num_voters, num_votes=poll.num_votes, + open_registration=poll.open_registration, + closed=poll.closed, num_deleted=poll.deleted_voters + ) + @classmethod def build_from_fields( cls, poll_id: int | EmptyField = Empty, From 50e711a23743b83666486099a70fc8e195114d44 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 12 Oct 2024 22:13:35 +0800 Subject: [PATCH 11/23] chore: rename num_voters to _num_voters in PollMetadata --- database/database.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/database/database.py b/database/database.py index 52dabae1..40f6579c 100644 --- a/database/database.py +++ b/database/database.py @@ -109,7 +109,7 @@ def get_from_tele_id( class PollMetadata(object): id: int question: str - num_voters: int + _num_voters: int num_deleted: int num_votes: int @@ -118,7 +118,7 @@ class PollMetadata(object): @property def num_active_voters(self) -> int: - return self.num_voters - self.num_deleted + return self._num_voters - self.num_deleted # stores poll metadata (description, open time, etc etc) @@ -164,7 +164,7 @@ def read_poll_metadata(cls, poll_id: int) -> PollMetadata: poll = cls.select().where(cls.id == poll_id).get() return PollMetadata( id=poll.id, question=poll.desc, - num_voters=poll.num_voters, num_votes=poll.num_votes, + _num_voters=poll.num_voters, num_votes=poll.num_votes, open_registration=poll.open_registration, closed=poll.closed, num_deleted=poll.deleted_voters ) From fb6ab0ca6027ededb4c820ad1f1189ec99e350f7 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 12 Oct 2024 22:56:46 +0800 Subject: [PATCH 12/23] feat: block deleted users in webapp --- .gitignore | 1 + database/database.py | 3 +++ webapp.py | 11 +++++++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 755ae185..4da05670 100644 --- a/.gitignore +++ b/.gitignore @@ -125,6 +125,7 @@ celerybeat.pid # config files *.yml !config.example.yml +!.github/**/*.yml # Environments .env diff --git a/database/database.py b/database/database.py index 40f6579c..59fd623c 100644 --- a/database/database.py +++ b/database/database.py @@ -79,6 +79,9 @@ class Meta: (('username',), False), ) + def is_deleted(self) -> bool: + return self.deleted_at is not None + @classmethod def build_from_fields( cls, user_id: int | EmptyField = Empty, diff --git a/webapp.py b/webapp.py index 567afbe2..86bd9175 100644 --- a/webapp.py +++ b/webapp.py @@ -45,7 +45,10 @@ async def dispatch(self, request: Request, call_next): content = {'detail': 'Missing telegram-data header'} return JSONResponse(content=content, status_code=401) + # TODO: expire old auth tokens as new ones are created """ + # This is commented out cause its not very intuitive for + # the webapp button to just expire after 24 hours if PRODUCTION_MODE: # only allow auth headers that were created in the last 24 hours parsed_query = parse_qs(telegram_data_header) @@ -129,8 +132,12 @@ def fetch_poll_endpoint( return JSONResponse( status_code=400, content={'error': 'User not found'} ) - user = user_res.unwrap() + if user.is_deleted(): + return JSONResponse( + status_code=403, content={'error': 'User is deleted'} + ) + user_id = user.get_user_id() username = user_info['username'] read_poll_result = self.read_poll_info( @@ -144,7 +151,7 @@ def fetch_poll_endpoint( status_code=500, content={'error': error.get_content()} ) - poll_info = read_poll_result.ok() + poll_info = read_poll_result.unwrap() return dataclasses.asdict(poll_info) From 56bc4eded56cb7d808caaf25223db3942f9cdc58 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 12 Oct 2024 23:04:10 +0800 Subject: [PATCH 13/23] feat: refactor webapp to do sign_data_check_string without passing bot token --- webapp.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/webapp.py b/webapp.py index 86bd9175..4c69b860 100644 --- a/webapp.py +++ b/webapp.py @@ -65,9 +65,7 @@ async def dispatch(self, request: Request, call_next): return JSONResponse(content=content, status_code=401) """ - user_params = self.check_authorization( - telegram_data_header, TELEGRAM_BOT_TOKEN - ) + user_params = self.check_authorization(telegram_data_header) if user_params is None: content = {'detail': 'Unauthorized'} @@ -92,14 +90,12 @@ def parse_auth_string(cls, init_data: str): return data_check_string, signature, params @classmethod - def check_authorization( - cls, init_data: str, bot_token: str - ) -> Optional[dict]: + def check_authorization(cls, init_data: str) -> Optional[dict]: parse_result = cls.parse_auth_string(init_data) # print('PARSE_RESULT', parse_result) data_check_string, signature, params = parse_result validation_hash = BaseAPI.sign_data_check_string( - data_check_string=data_check_string, bot_token=bot_token + data_check_string=data_check_string ) # print('VALIDATION_HASH', validation_hash, signature) From d5a99f8e56df43c0a56fc629db8c89dde6d65d9d Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 12 Oct 2024 23:05:00 +0800 Subject: [PATCH 14/23] chore: bump version 1.1.1 --- bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bot.py b/bot.py index 27b3eab6..7d7499aa 100644 --- a/bot.py +++ b/bot.py @@ -51,7 +51,7 @@ Application ) -__VERSION__ = '1.1.0' +__VERSION__ = '1.1.1' ID_PATTERN = re.compile(r"^[1-9]\d*$") MAX_DISPLAY_VOTE_COUNT = 30 MAX_CONCURRENT_UPDATES = 256 From 601a2df42958d990c068060ad4f18c9972f38d37 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 16:43:34 +0800 Subject: [PATCH 15/23] add workflow actions --- .github/algo_tests.yml | 25 +++++++ config.example.yml | 2 +- tests/enum_test.py | 16 +++-- tests/test_ranked_vote.py | 142 ++++++++++++++++++++++++++------------ 4 files changed, 131 insertions(+), 54 deletions(-) create mode 100644 .github/algo_tests.yml diff --git a/.github/algo_tests.yml b/.github/algo_tests.yml new file mode 100644 index 00000000..1c39a291 --- /dev/null +++ b/.github/algo_tests.yml @@ -0,0 +1,25 @@ +name: Core Algorithm Tests + +on: [push, pull_request] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' # Specify the Python version you need + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run tests + run: | + pytest tests/enum_test.py tests/test_ranked_vote.py \ No newline at end of file diff --git a/config.example.yml b/config.example.yml index 16c34fbd..ec2bcce2 100644 --- a/config.example.yml +++ b/config.example.yml @@ -4,7 +4,7 @@ settings: database: name: ranked_choice_voting user: rcv_user - password: your_mysql_password + password: YOUR_DATABASE_PASSWORD host: localhost telegram: bot_token: YOUR_BOT_TOKEN diff --git a/tests/enum_test.py b/tests/enum_test.py index 80e6cbe4..a13b0a52 100644 --- a/tests/enum_test.py +++ b/tests/enum_test.py @@ -1,5 +1,6 @@ import unittest -import ParentImport +# noinspection PyUnresolvedReferences +# import ParentImport from SpecialVotes import SpecialVotes as SpecV @@ -15,13 +16,13 @@ def test_enum_values(self): def test_to_string(self): # Test the to_string method - self.assertEqual(SpecV.WITHHOLD_VOTE.to_string(), '0') - self.assertEqual(SpecV.ABSTAIN_VOTE.to_string(), 'nil') + self.assertEqual(SpecV.WITHHOLD_VOTE.to_string(), 'withhold') + self.assertEqual(SpecV.ABSTAIN_VOTE.to_string(), 'abstain') def test_from_string(self): # Test the from_string method - self.assertEqual(SpecV.from_string('0'), SpecV.WITHHOLD_VOTE) - self.assertEqual(SpecV.from_string('nil'), SpecV.ABSTAIN_VOTE) + self.assertEqual(SpecV.from_string('withhold'), SpecV.WITHHOLD_VOTE) + self.assertEqual(SpecV.from_string('abstain'), SpecV.ABSTAIN_VOTE) with self.assertRaises(ValueError): SpecV.from_string('invalid') @@ -34,8 +35,9 @@ def test_is_valid(self): def test_string_maps(self): # Test if string maps are correctly set self.assertEqual( - SpecV.get_string_map(), - {SpecV.WITHHOLD_VOTE: '0', SpecV.ABSTAIN_VOTE: 'nil'} + SpecV.get_string_map(), { + SpecV.WITHHOLD_VOTE: 'withhold', SpecV.ABSTAIN_VOTE: 'abstain' + } ) diff --git a/tests/test_ranked_vote.py b/tests/test_ranked_vote.py index 23545dc4..0687ac1a 100644 --- a/tests/test_ranked_vote.py +++ b/tests/test_ranked_vote.py @@ -1,8 +1,8 @@ import unittest -import ParentImport +# noinspection PyUnresolvedReferences +# import ParentImport +import ranked_choice_vote -from RankedChoice import ranked_choice_vote -from RankedVote import RankedVote from SpecialVotes import SpecialVotes @@ -16,98 +16,148 @@ def __init__(self, *args, verbose=False, **kwargs): """ def test_basic_scenario(self): # Basic test with predefined votes + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, 2, 3, 4]), - RankedVote([1, 2, 3]), - RankedVote([3]), - RankedVote([3, 2, 4]), - RankedVote([4, 1]) + [1, 2, 3, 4], + [1, 2, 3], + [3], + [3, 2, 4], + [4, 1] ] - result = ranked_choice_vote(votes, verbose=self.verbose) + + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, 1, + winner, 1, "The winner should be candidate 1" ) def test_simple_majority(self): # Basic test where there is a winner in round 1 + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, 2, 3, 4]), - RankedVote([1, 2, 3]), - RankedVote([3]), - RankedVote([3, 2, 4]), - RankedVote([1, 2]) + [1, 2, 3, 4], + [1, 2, 3], + [3], + [3, 2, 4], + [1, 2] ] - result = ranked_choice_vote(votes, verbose=self.verbose) + + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, 1, + winner, 1, "The winner should be candidate 1" ) def test_tie_scenario(self): # Test for a tie + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, 2]), - RankedVote([2, 1]) + [1, 2], + [2, 1] ] - result = ranked_choice_vote(votes) - self.assertIsNone(result, "There should be a tie") + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() + self.assertIsNone(winner, "There should be a tie") def test_zero_vote_end(self): # Test that a zero vote ends with no one winning + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, SpecialVotes.WITHHOLD_VOTE]), - RankedVote([2, 1]), - RankedVote([3, 2]), - RankedVote([3]) + [1, SpecialVotes.WITHHOLD_VOTE], + [2, 1], + [3, 2], + [3] ] - result = ranked_choice_vote(votes, verbose=self.verbose) + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, None, + winner, None, "Candidate 1's vote should not count, no one should win" ) def test_zero_nil_votes_only(self): # Test that having only zero and nil votes ends with no one winning, # and also that there are no errors in computing the poll result + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([SpecialVotes.WITHHOLD_VOTE]), - RankedVote([SpecialVotes.WITHHOLD_VOTE]), - RankedVote([SpecialVotes.WITHHOLD_VOTE]), - RankedVote([SpecialVotes.ABSTAIN_VOTE]) + [SpecialVotes.WITHHOLD_VOTE], + [SpecialVotes.WITHHOLD_VOTE], + [SpecialVotes.WITHHOLD_VOTE], + [SpecialVotes.ABSTAIN_VOTE] ] - result = ranked_choice_vote(votes, verbose=self.verbose) + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, None, + winner, None, "No one should win if all votes were 0 or nil" ) def test_null_vote_end(self): # Test that a null vote ends with someone winning + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, SpecialVotes.ABSTAIN_VOTE]), - RankedVote([2, 1]), - RankedVote([3, 2]), - RankedVote([3]) + [1, SpecialVotes.ABSTAIN_VOTE], + [2, 1], + [3, 2], + [3] ] - result = ranked_choice_vote(votes, verbose=self.verbose) + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) + + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, 3, + winner, 3, "Candidate 3's vote should not count, no one should win" ) def test_majoritarian_rule(self): + votes_aggregator = ranked_choice_vote.VotesAggregator() votes = [ - RankedVote([1, 6, 15]), - RankedVote([1, 2, 6, 15, 5, 4, 7, 3, 11]), - RankedVote([6, 15, 1, 11, 10, 16, 17, 8, 2, 3, 5, 7]), - RankedVote([9, 8, 6, 11, 13, 3, 1]), - RankedVote([13, 14, 16, 6, 3, 4, 5, 2, 1, 8, 9]) + [1, 6, 15], + [1, 2, 6, 15, 5, 4, 7, 3, 11], + [6, 15, 1, 11, 10, 16, 17, 8, 2, 3, 5, 7], + [9, 8, 6, 11, 13, 3, 1], + [13, 14, 16, 6, 3, 4, 5, 2, 1, 8, 9] ] + for vote_idx in range(len(votes)): + vote_rankings = votes[vote_idx] + for vote_ranking in vote_rankings: + votes_aggregator.insert_vote_ranking(vote_idx, vote_ranking) - result = ranked_choice_vote(votes, verbose=self.verbose) + votes_aggregator.flush_votes() + winner = votes_aggregator.determine_winner() self.assertEqual( - result, 6, + winner, 6, "Candidate 6 should be the majoritarian winner" ) From 25be0726d0fc7df473a90d153d60da6e28e18503 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 19:53:03 +0800 Subject: [PATCH 16/23] add initial github actions --- .github/{ => workflows}/algo_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) rename .github/{ => workflows}/algo_tests.yml (87%) diff --git a/.github/algo_tests.yml b/.github/workflows/algo_tests.yml similarity index 87% rename from .github/algo_tests.yml rename to .github/workflows/algo_tests.yml index 1c39a291..7b46acbf 100644 --- a/.github/algo_tests.yml +++ b/.github/workflows/algo_tests.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.9' # Specify the Python version you need + python-version: '3.12' - name: Install dependencies run: | From 8549ddde6128e8d9fd83cb1714d736091565ab4c Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 19:54:07 +0800 Subject: [PATCH 17/23] fix: only run github action on pull request --- .github/workflows/algo_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/algo_tests.yml b/.github/workflows/algo_tests.yml index 7b46acbf..df3e59b3 100644 --- a/.github/workflows/algo_tests.yml +++ b/.github/workflows/algo_tests.yml @@ -1,6 +1,6 @@ name: Core Algorithm Tests -on: [push, pull_request] +on: [pull_request] jobs: test: From 213439a96ca307adb667c125da0029c3c7d7f348 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 19:55:53 +0800 Subject: [PATCH 18/23] chore: run python -m pytest --- .github/workflows/algo_tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/algo_tests.yml b/.github/workflows/algo_tests.yml index df3e59b3..7d7dcd7b 100644 --- a/.github/workflows/algo_tests.yml +++ b/.github/workflows/algo_tests.yml @@ -22,4 +22,4 @@ jobs: - name: Run tests run: | - pytest tests/enum_test.py tests/test_ranked_vote.py \ No newline at end of file + python -m pytest tests/enum_test.py tests/test_ranked_vote.py \ No newline at end of file From 26782aa41cd2d095ffd8d066cf742a6b22d229d4 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 19:59:56 +0800 Subject: [PATCH 19/23] chore: add rust install to github actions --- .github/workflows/algo_tests.yml | 9 +++++++++ requirements.txt | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/algo_tests.yml b/.github/workflows/algo_tests.yml index 7d7dcd7b..85f057ca 100644 --- a/.github/workflows/algo_tests.yml +++ b/.github/workflows/algo_tests.yml @@ -19,6 +19,15 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + + - name: Install Rust + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + source $HOME/.cargo/env + + - name: Build Maturin crate + run: | + maturin develop --bindings pyo3 --release - name: Run tests run: | diff --git a/requirements.txt b/requirements.txt index ae8e474a..312ad504 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ starlette~=0.27.0 setuptools==72.2.0 requests==2.31.0 aioredlock==0.7.3 -redis==5.0.3 \ No newline at end of file +redis==5.0.3 +pytest==8.3.3 \ No newline at end of file From 0e23f1ecdb24b61f928bd3f67a01b6a601f588fc Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 20:02:06 +0800 Subject: [PATCH 20/23] fix: add virtualenv setup to github actions --- .github/workflows/algo_tests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/algo_tests.yml b/.github/workflows/algo_tests.yml index 85f057ca..c354597e 100644 --- a/.github/workflows/algo_tests.yml +++ b/.github/workflows/algo_tests.yml @@ -18,6 +18,8 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + python -m venv venv + source venv/bin/activate pip install -r requirements.txt - name: Install Rust @@ -27,8 +29,10 @@ jobs: - name: Build Maturin crate run: | + source venv/bin/activate maturin develop --bindings pyo3 --release - name: Run tests run: | + source venv/bin/activate python -m pytest tests/enum_test.py tests/test_ranked_vote.py \ No newline at end of file From dacb7dd78933e7b21e981fb952da0472d7c14918 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 20:10:00 +0800 Subject: [PATCH 21/23] chore: add /delete_account command info to /help --- bot.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/bot.py b/bot.py index 7d7499aa..b8514853 100644 --- a/bot.py +++ b/bot.py @@ -2084,8 +2084,8 @@ async def show_help(update: ModifiedTeleUpdate, *_, **__): —————————————————— /view_votes {poll_id} View all the votes entered for the poll - with the specified poll_id. This can only be done - after the poll has been closed first + with the specified poll_id. + This can only be done after the poll has been closed first —————————————————— /view_voters {poll_id} Show which voters have voted and which have not @@ -2094,9 +2094,13 @@ async def show_help(update: ModifiedTeleUpdate, *_, **__): /view_polls - view all polls created by you —————————————————— /delete_poll {poll_id} - delete poll by poll_id - use /delete_poll --force to force delete the poll without + Use /delete_poll --force to force delete the poll without confirmation, regardless of whether poll is open or closed —————————————————— + /delete_account + /delete_account {deletion_token} + Delete your user account (this cannot be undone) + —————————————————— /help - view commands available to the bot """)) From a89df50c86f1d6ef91166d86d9be7cf50c1a3bde Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 20:26:51 +0800 Subject: [PATCH 22/23] chore: wrap account deletion logic in try-catch --- bot.py | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/bot.py b/bot.py index b8514853..4def00aa 100644 --- a/bot.py +++ b/bot.py @@ -1993,27 +1993,34 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): err_message = validation_result.err() return await update.message.reply_text(err_message) - with db.atomic(): - # delete all polls created by the user - Polls.delete().where(Polls.creator == user_id).execute() - user.deleted_at = Datetime.now() # mark as deleted - user.save() + try: + with db.atomic(): + # delete all polls created by the user + Polls.delete().where(Polls.creator == user_id).execute() + user.deleted_at = Datetime.now() # mark as deleted + user.save() - poll_registrations: Iterable[PollVoters] = ( - PollVoters.select().where(PollVoters.user == user_id) + poll_registrations: Iterable[PollVoters] = ( + PollVoters.select().where(PollVoters.user == user_id) + ) + for poll_registration in poll_registrations: + poll: Polls = poll_registration.poll + + if poll.closed: + # decouple poll voter from user + poll_registration.user = None + poll_registration.save() + else: + # delete poll voter and increment deleted voters count + poll.deleted_voters += 1 + poll_registration.delete_instance() + poll.save() + + except Exception as e: + await update.message.reply_text( + 'Unexpected error occurred during account deletion' ) - for poll_registration in poll_registrations: - poll: Polls = poll_registration.poll - - if poll.closed: - # decouple poll voter from user - poll_registration.user = None - poll_registration.save() - else: - # delete poll voter and increment deleted voters count - poll.deleted_voters += 1 - poll_registration.delete_instance() - poll.save() + raise e return await update.message.reply_text( 'Account deleted successfully' From 78d90200e2184e50dfa45f4e2a1cb2d51e649cd1 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sun, 13 Oct 2024 20:27:40 +0800 Subject: [PATCH 23/23] chore: remove deletion todo (done) --- bot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bot.py b/bot.py index 4def00aa..bdf2e9c6 100644 --- a/bot.py +++ b/bot.py @@ -1988,7 +1988,6 @@ async def delete_account(self, update: ModifiedTeleUpdate, *_, **__): validation_result = self.validate_delete_token( user=user, stamp=deletion_stamp, short_hash=short_hash ) - # TODO: actually implement deletion if validation_result.is_err(): err_message = validation_result.err() return await update.message.reply_text(err_message)