From 62dae82e59b9ac9c7b32a7879484e2128b5c0248 Mon Sep 17 00:00:00 2001 From: milselarch Date: Sat, 30 Nov 2024 21:15:02 +0800 Subject: [PATCH] feat: add whitelist user by username feature --- base_api.py | 31 ++++++++++++ bot.py | 72 +++++++++++++++++++++++----- handlers/inline_keyboard_handlers.py | 4 +- handlers/payment_handlers.py | 49 +++++++++++++++---- helpers/commands.py | 1 + helpers/strings.py | 5 +- 6 files changed, 139 insertions(+), 23 deletions(-) diff --git a/base_api.py b/base_api.py index e7b3b7cc..b9b0e007 100644 --- a/base_api.py +++ b/base_api.py @@ -885,6 +885,37 @@ def get_whitelist_entry( return Ok(whitelist_entry) + @classmethod + def _whitelist_username_for_poll( + cls, poll: Polls, target_username: str + ) -> Result[None, Exception]: + poll_id = poll.id + + with db.atomic(): + whitelist_entry = UsernameWhitelist.build_from_fields( + poll_id=poll_id, username=target_username + ).safe_get() + + if whitelist_entry.is_ok(): + # TODO: move await out of atomic block + return Err(ValueError( + f'{target_username} already whitelisted' + )) + if poll.max_voters <= poll.num_active_voters: + # TODO: move await out of atomic block + return Err(ValueError( + "Maximum number of voters reached" + )) + + UsernameWhitelist.build_from_fields( + poll_id=poll_id, username=target_username + ).insert().execute() + + poll.num_voters += 1 + poll.save() + + return Ok(None) + @staticmethod def generate_poll_info( poll_id, poll_question, poll_options: list[str], diff --git a/bot.py b/bot.py index 18de1f90..b1ecc555 100644 --- a/bot.py +++ b/bot.py @@ -152,6 +152,7 @@ def start_bot(self): Command.HELP: self.show_help, Command.DONE: context_handlers.complete_chat_context, Command.SET_MAX_VOTERS: self.payment_handlers.set_max_voters, + Command.WHITELIST_USERNAME: self.whitelist_username, Command.PAY_SUPPORT: self.payment_support_handler, Command.VOTE_ADMIN: self.vote_for_poll_admin, @@ -223,6 +224,9 @@ async def post_init(self, _: Application): ), ( Command.REGISTER_USER_ID, 'registers a user by user_id for a poll' + ), ( + Command.WHITELIST_USERNAME, + 'whitelist a username for a poll' ), ( Command.WHITELIST_CHAT_REGISTRATION, 'whitelist a chat for self registration' @@ -625,17 +629,21 @@ async def register_user_by_tele_id( """ pattern = re.compile(r'^\S+\s+([1-9]\d*)\s+([1-9]\d*)$') matches = pattern.match(raw_text) + format_invalid_message = textwrap.dedent(f""" + Format invalid. + Use /{Command.REGISTER_USER_ID} {{poll_id}} {{user_tele_id}} + """) if tele_user is None: await message.reply_text(f'user not found') return False if matches is None: - await message.reply_text(f'Format invalid') + await message.reply_text(format_invalid_message) return False capture_groups = matches.groups() if len(capture_groups) != 2: - await message.reply_text(f'Format invalid') + await message.reply_text(format_invalid_message) return False poll_id = int(capture_groups[0]) @@ -655,9 +663,10 @@ async def register_user_by_tele_id( await message.reply_text(f'poll {poll_id} does not exist') return False - user_id = target_user.get_user_id() + target_user_id = target_user.get_user_id() + current_user_id = update.user.get_user_id() creator_id: UserID = poll.get_creator().get_user_id() - if creator_id != user_id: + if creator_id != current_user_id: await message.reply_text( 'only poll creator is allowed to whitelist chats ' 'for open user registration' @@ -665,14 +674,14 @@ async def register_user_by_tele_id( return False try: - PollVoters.get(poll_id=poll_id, user_id=user_id) - await message.reply_text(f'User #{user_id} already registered') + PollVoters.get(poll_id=poll_id, user_id=target_user_id) + await message.reply_text(f'User #{target_user_id} already registered') return False except PollVoters.DoesNotExist: pass register_result = self.register_user_id( - poll_id=poll_id, user_id=user_id, + poll_id=poll_id, user_id=target_user_id, ignore_voter_limit=False, from_whitelist=False ) @@ -1139,15 +1148,56 @@ async def close_poll(self, update, *_, **__): else: return await message.reply_text('Poll has no winner') + @classmethod + async def whitelist_username(cls, update: ModifiedTeleUpdate, *_, **__): + """ + Command usage: + /whitelist_username {poll_id} {username} + """ + message = update.message + message_text = TelegramHelpers.read_raw_command_args(update) + invalid_format_text = textwrap.dedent(f""" + Input format is invalid, try: + /{Command.WHITELIST_USERNAME} {{poll_id}} {{username}} + """) + + if ' ' not in message_text: + return await message.reply_text(invalid_format_text) + + raw_poll_id = message_text[:message_text.index(' ')].strip() + target_username = message_text[message_text.index(' '):].strip() + if constants.ID_PATTERN.match(raw_poll_id) is None: + return await message.reply_text('Invalid poll id') + + poll_id = int(raw_poll_id) + user_id = update.user.get_user_id() + poll_res = Polls.get_as_creator(poll_id, user_id) + if poll_res.is_err(): + return await message.reply_text( + "You're not the creator of this poll" + ) + + poll = poll_res.unwrap() + whitelist_res = BaseAPI._whitelist_username_for_poll( + poll=poll, target_username=target_username + ) + if whitelist_res.is_err(): + err_msg = str(whitelist_res.unwrap_err()) + return await message.reply_text(err_msg) + + return await update.message.reply_text( + f"Username {target_username} has been whitelisted" + ) + @admin_only async def vote_for_poll_admin(self, update: ModifiedTeleUpdate, *_, **__): """ telegram command formats: - /vote_admin {username} {poll_id}: {option_1} > ... > {option_n} - /vote_admin {username} {poll_id} {option_1} > ... > {option_n} + /vote_admin {@username or #tele_id} {poll_id}: {opt_1} > ... > {opt_n} + /vote_admin {@username or #tele_id} {poll_id} {opt_1} > ... > {opt_n} examples: - /vote 3: 1 > 2 > 3 - /vote 3 1 > 2 > 3 + /vote #100 3: 1 > 2 > 3 + /vote #100 3 1 > 2 > 3 """ # vote for someone else message: Message = update.message diff --git a/handlers/inline_keyboard_handlers.py b/handlers/inline_keyboard_handlers.py index c3e06b1b..385e7f38 100644 --- a/handlers/inline_keyboard_handlers.py +++ b/handlers/inline_keyboard_handlers.py @@ -139,8 +139,8 @@ def _register_voter( if whitelist_user_result.is_ok(): whitelist_entry = whitelist_user_result.unwrap() assert ( - (whitelist_entry.user is None) or - (whitelist_entry.user == user_id) + (whitelist_entry.user is None) or + (whitelist_entry.user == user_id) ) if whitelist_entry.user == user_id: diff --git a/handlers/payment_handlers.py b/handlers/payment_handlers.py index 5b184bd8..eb92c5dc 100644 --- a/handlers/payment_handlers.py +++ b/handlers/payment_handlers.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import json import logging @@ -123,6 +124,15 @@ async def successful_payment_callback( ... + +@dataclasses.dataclass +class ProcessPaymentsResult(object): + poll: Polls + initial_max_voters: int + new_max_voters: int + + + class IncreaseVoteLimitHandler(BasePaymentHandler): async def pre_checkout_callback( self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE @@ -176,14 +186,37 @@ async def successful_payment_callback( receipt.paid = True receipt.save() + process_res = self._process_payment( + invoice=invoice, receipt=receipt, + voters_increase=voters_increase + ) + if process_res.is_err(): + err_message = process_res.unwrap_err() + return await message.reply_text( + ok=False, error_message=err_message + ) + + post_payment_info = process_res.unwrap() + poll = post_payment_info.poll + initial_max_voters = post_payment_info.initial_max_voters + new_max_voters = post_payment_info.new_max_voters + reply_message = ( + f"The maximum number of voters for poll #{poll.id} " + f"has been raised from {initial_max_voters} to {new_max_voters}" + ) + self.logger.warning(reply_message) + return await message.reply_text(reply_message) + + def _process_payment( + self, invoice: IncreaseVoterLimitParams, + receipt: Payments, voters_increase: int + ) -> Result[ProcessPaymentsResult, Exception]: with db.atomic(): poll_id = invoice.poll_id poll_res = Polls.build_from_fields(poll_id=poll_id).safe_get() if poll_res.is_err(): self.logger.error(f"FAILED TO GET POLL {poll_id}") - return await message.reply_text( - f"Failed to get poll #{poll_id}" - ) + return Err(ValueError(f"Failed to get poll #{poll_id}")) poll = poll_res.unwrap() initial_max_voters = poll.max_voters @@ -194,12 +227,10 @@ async def successful_payment_callback( receipt.processed = True receipt.save() - reply_message = ( - f"The maximum number of voters for poll #{poll.id} " - f"has been raised from {initial_max_voters} to {new_max_voters}" - ) - self.logger.warning(reply_message) - return await message.reply_text(reply_message) + return Ok(ProcessPaymentsResult( + poll=poll, initial_max_voters=initial_max_voters, + new_max_voters=new_max_voters + )) class PaymentHandlers(object): diff --git a/helpers/commands.py b/helpers/commands.py index 84d7cde7..d7030851 100644 --- a/helpers/commands.py +++ b/helpers/commands.py @@ -11,6 +11,7 @@ class Command(StrEnum): CREATE_PRIVATE_POLL = "create_private_poll" CREATE_GROUP_POLL = "create_poll" REGISTER_USER_ID = "register_user_id" + WHITELIST_USERNAME = "whitelist_username" WHITELIST_CHAT_REGISTRATION = "whitelist_chat_registration" BLACKLIST_CHAT_REGISTRATION = "blacklist_chat_registration" DELETE_POLL = "delete_poll" diff --git a/helpers/strings.py b/helpers/strings.py index e9b5e25d..38379d77 100644 --- a/helpers/strings.py +++ b/helpers/strings.py @@ -3,7 +3,7 @@ from helpers.commands import Command -__VERSION__ = '1.3.5' +__VERSION__ = '1.3.6' READ_SUBSCRIPTION_TIER_FAILED = "Unexpected error reading subscription tier" POLL_OPTIONS_LIMIT_REACHED_TEXT = textwrap.dedent(f""" @@ -44,6 +44,9 @@ /{Command.REGISTER_USER_ID} {{poll_id}} {{user_id}} Registers a user by user_id for a poll —————————————————— + /{Command.WHITELIST_USERNAME} {{poll_id}} {{username}} + Whitelists a user by username for a poll + —————————————————— /{Command.WHITELIST_CHAT_REGISTRATION} {{poll_id}} Whitelists the current chat so that chat members can self-register for the poll specified by poll_id within the chat group