diff --git a/base_api.py b/base_api.py index 73b50327..612b9392 100644 --- a/base_api.py +++ b/base_api.py @@ -43,7 +43,7 @@ ) from aioredlock import Aioredlock, LockError from database.database import ( - PollWinners, BaseModel, UserID, PollMetadata, ChatWhitelist + PollWinners, UserID, PollMetadata, ChatWhitelist ) logger = logging.getLogger(__name__) @@ -57,6 +57,7 @@ class CallbackCommands(StrEnum): UNDO_OPTION = 'UNDO' RESET_VOTE = 'RESET' SUBMIT_VOTE = 'SUBMIT_VOTE' + REGISTER_OR_SUBMIT = 'REGISTER_OR_SUBMIT' VIEW_VOTE = 'VIEW_VOTE' @@ -366,16 +367,6 @@ def _determine_poll_winner(cls, poll_id: int) -> Optional[int]: winning_option_id = votes_aggregator.determine_winner() return winning_option_id - @staticmethod - def get_poll_voter( - poll_id: int, user_id: UserID - ) -> Result[PollVoters, Optional[BaseModel.DoesNotExist]]: - # check if voter is part of the poll - return PollVoters.safe_get( - (PollVoters.poll == poll_id) & - (PollVoters.user == user_id) - ) - @classmethod def verify_voter( cls, poll_id: int, user_id: UserID, username: Optional[str] = None, @@ -389,7 +380,7 @@ def verify_voter( and whether user was newly whitelisted from chat whitelist """ - poll_voter_res = cls.get_poll_voter(poll_id, user_id) + poll_voter_res = PollVoters.get_poll_voter(poll_id, user_id) if poll_voter_res.is_ok(): poll_voter = poll_voter_res.unwrap() return Ok((poll_voter, False)) @@ -606,14 +597,11 @@ def check_has_voted(poll_id: int, user_id: UserID) -> bool: user_id=user_id, poll_id=poll_id, voted=True ).safe_get().is_ok() - @classmethod - def is_poll_voter(cls, poll_id: int, user_id: UserID) -> bool: - return cls.get_poll_voter(poll_id=poll_id, user_id=user_id).is_ok() - @classmethod def get_poll_message( cls, poll_id: int, user_id: UserID, bot_username: str, - username: Optional[str], add_webapp_link: bool = True + username: Optional[str], add_webapp_link: bool = False, + add_instructions: bool = False ) -> Result[PollMessage, MessageBuilder]: if not cls.has_access_to_poll_id( poll_id=poll_id, user_id=user_id, username=username @@ -624,24 +612,28 @@ def get_poll_message( return Ok(cls._get_poll_message( poll_id=poll_id, bot_username=bot_username, - add_webapp_link=add_webapp_link + add_webapp_link=add_webapp_link, + add_instructions=add_instructions )) @classmethod def _get_poll_message( cls, poll_id: int, bot_username: str, - add_webapp_link: bool = True + add_webapp_link: bool = False, + add_instructions: bool = False ) -> PollMessage: poll_info = cls.unverified_read_poll_info(poll_id=poll_id) return cls.generate_poll_message( poll_info=poll_info, bot_username=bot_username, - add_webapp_link=add_webapp_link + add_webapp_link=add_webapp_link, + add_instructions=add_instructions ) @classmethod def generate_poll_message( cls, poll_info: PollInfo, bot_username: str, - add_webapp_link: bool = True + add_webapp_link: bool = False, + add_instructions: bool = False ) -> PollMessage: poll_metadata = poll_info.metadata poll_message = cls.generate_poll_info( @@ -650,7 +642,8 @@ def generate_poll_message( bot_username=bot_username, max_voters=poll_metadata.max_voters, num_voters=poll_metadata.num_active_voters, num_votes=poll_metadata.num_votes, - add_webapp_link=add_webapp_link + add_webapp_link=add_webapp_link, + add_instructions=add_instructions ) reply_markup = None @@ -703,7 +696,8 @@ def build_private_vote_markup( logger.warning(f'POLL_URL = {poll_url}') # create vote button for reply message markup_layout = [[KeyboardButton( - text=f'Vote for Poll #{poll_id}', web_app=WebAppInfo(url=poll_url) + text=f'Vote for Poll #{poll_id} Online', + web_app=WebAppInfo(url=poll_url) )]] return markup_layout @@ -714,20 +708,11 @@ def build_group_vote_markup( ) -> List[List[InlineKeyboardButton]]: """ TODO: implement button vote context - < poll registration button > < vote option rows > - < undo, abstain, withhold, reset > - < submit / check button > + < undo, view, reset > + < register / submit button > """ markup_rows, current_row = [], [] - - # create first row with just registration button - markup_rows.append([cls.spawn_inline_keyboard_button( - text='Register for Poll', - command=CallbackCommands.REGISTER_FOR_POLL, - callback_data=dict(poll_id=poll_id) - )]) - # fill in rows containing poll option numbers for ranking in range(1, num_options+1): current_row.append(cls.spawn_inline_keyboard_button( @@ -745,26 +730,16 @@ def build_group_vote_markup( markup_rows.append(current_row) current_row = [] - # add row with undo, abstain, withhold, reset buttons + # add row with undo, view, reset buttons markup_rows.append([ cls.spawn_inline_keyboard_button( text='undo', command=CallbackCommands.UNDO_OPTION, callback_data=dict(poll_id=poll_id) ), cls.spawn_inline_keyboard_button( - text='abstain', - command=CallbackCommands.ADD_VOTE_OPTION, - callback_data=dict( - poll_id=poll_id, - option=SpecialVotes.ABSTAIN_VOTE.value - ) - ), cls.spawn_inline_keyboard_button( - text='withhold', - command=CallbackCommands.ADD_VOTE_OPTION, - callback_data=dict( - poll_id=poll_id, - option=SpecialVotes.WITHHOLD_VOTE.value - ) + text='view', + command=CallbackCommands.VIEW_VOTE, + callback_data=dict(poll_id=poll_id) ), cls.spawn_inline_keyboard_button( text='reset', command=CallbackCommands.RESET_VOTE, @@ -775,13 +750,8 @@ def build_group_vote_markup( # add final row with view vote, submit vote buttons markup_rows.append([ cls.spawn_inline_keyboard_button( - text='View Vote', - command=CallbackCommands.VIEW_VOTE, - callback_data=dict(poll_id=poll_id) - ), - cls.spawn_inline_keyboard_button( - text='Submit Vote', - command=CallbackCommands.SUBMIT_VOTE, + text='Register / Submit Vote', + command=CallbackCommands.REGISTER_OR_SUBMIT, callback_data=dict(poll_id=poll_id) ) ]) @@ -853,7 +823,7 @@ def has_access_to_poll( if creator_id == user_id: return True - if cls.is_poll_voter(poll_id, user_id): + if PollVoters.is_poll_voter(poll_id, user_id): return True if username is not None: @@ -912,7 +882,8 @@ def generate_poll_info( poll_id, poll_question, poll_options: list[str], bot_username: str, max_voters: int, num_votes: int = 0, num_voters: int = 0, closed: bool = False, - add_webapp_link: bool = True + add_webapp_link: bool = False, + add_instructions: bool = False ): close_tag = '(closed)' if closed else '' numbered_poll_options = [ @@ -927,14 +898,24 @@ def generate_poll_info( ) webapp_link_footer = '' + instructions_footer = '' + if add_webapp_link: webapp_link_footer = ( f'\n——————————————————' - f'\nvote on the webapp at {deep_link_url}' + f'\nvote on the webapp at {deep_link_url}\n' + ) + if add_instructions: + instructions_footer = ( + '\n——————————————————\n' + 'How to vote:\n' + '- press the register button, then ' + 'start the bot via chat DM\n' + '- alternatively, press the number buttons in order of most ' + 'to least favourite option, then press submit' ) - return ( - textwrap.dedent(f""" + return (textwrap.dedent(f""" Poll #{poll_id} {close_tag} {poll_question} —————————————————— @@ -942,8 +923,9 @@ def generate_poll_info( —————————————————— """) + f'\n'.join(numbered_poll_options) + - webapp_link_footer - ) + webapp_link_footer + + instructions_footer + ).strip() @staticmethod def make_data_check_string( diff --git a/bot.py b/bot.py index 392f4b6a..18de1f90 100644 --- a/bot.py +++ b/bot.py @@ -5,6 +5,7 @@ import os import time import textwrap +import telegram import asyncio import re @@ -41,7 +42,7 @@ from helpers.strings import ( POLL_OPTIONS_LIMIT_REACHED_TEXT, READ_SUBSCRIPTION_TIER_FAILED, - INCREASE_MAX_VOTERS_TEXT + generate_poll_created_message ) from helpers.chat_contexts import ( PollCreationChatContext, PollCreatorTemplate, POLL_MAX_OPTIONS, @@ -150,17 +151,19 @@ def start_bot(self): Command.DELETE_ACCOUNT: self.delete_account, Command.HELP: self.show_help, Command.DONE: context_handlers.complete_chat_context, + Command.SET_MAX_VOTERS: self.payment_handlers.set_max_voters, + Command.PAY_SUPPORT: self.payment_support_handler, + Command.VOTE_ADMIN: self.vote_for_poll_admin, Command.CLOSE_POLL_ADMIN: self.close_poll_admin, Command.UNCLOSE_POLL_ADMIN: self.unclose_poll_admin, Command.LOOKUP_FROM_USERNAME_ADMIN: self.lookup_from_username_admin, Command.INSERT_USER_ADMIN: self.insert_user_admin, - Command.SET_MAX_VOTERS: self.payment_handlers.set_max_voters, - Command.PAY_SUPPORT: self.payment_support_handler, Command.REFUND_ADMIN: self.refund_payment_support_handler, Command.ENTER_MAINTENANCE_ADMIN: self.enter_maintenance_admin, - Command.EXIT_MAINTENANCE_ADMIN: self.exit_maintenance_admin + Command.EXIT_MAINTENANCE_ADMIN: self.exit_maintenance_admin, + Command.SEND_MSG_ADMIN: self.send_msg_admin } # on different commands - answer in Telegram @@ -199,10 +202,9 @@ def start_bot(self): self.app.run_polling(allowed_updates=BaseTeleUpdate.ALL_TYPES) print('<<< BOT POLLING LOOP ENDED >>>') - @staticmethod - async def post_init(application: Application): + async def post_init(self, _: Application): # print('SET COMMANDS') - await application.bot.set_my_commands([( + await self.get_bot().set_my_commands([( Command.START, 'start bot' ), ( Command.USER_DETAILS, 'shows your username and user id' @@ -346,7 +348,7 @@ async def has_voted(self, update: ModifiedTeleUpdate, *_, **__): user_id = user.get_user_id() poll_id = extract_poll_id_result.unwrap() - is_voter = self.is_poll_voter( + is_voter = PollVoters.is_poll_voter( poll_id=poll_id, user_id=user_id ) @@ -551,7 +553,8 @@ async def create_poll( new_poll_id, poll_question, poll_options, bot_username=bot_username, closed=False, num_voters=poll_creator.initial_num_voters, - max_voters=new_poll.max_voters + max_voters=new_poll.max_voters, + add_instructions=update.is_group_chat() ) chat_type = update.message.chat.type @@ -570,10 +573,8 @@ async def create_poll( ) reply_markup = InlineKeyboardMarkup(vote_markup_data) - await message.reply_text( - poll_message, reply_markup=reply_markup - ) - await message.reply_text(INCREASE_MAX_VOTERS_TEXT) + await message.reply_text(poll_message, reply_markup=reply_markup) + await message.reply_text(generate_poll_created_message(new_poll_id)) @classmethod async def whitelist_chat_registration( @@ -1079,6 +1080,14 @@ def read_vote_count(cls, poll_id: int) -> Result[int, MessageBuilder]: async def close_poll(self, update, *_, **__): message = update.message + message_text = TelegramHelpers.read_raw_command_args(update) + + if constants.ID_PATTERN.match(message_text) is None: + return await message.reply_text(textwrap.dedent(f""" + Input format is invalid, try: + /{Command.CLOSE_POLL} {{poll_id}} + """)) + extract_result = TelegramHelpers.extract_poll_id(update) if extract_result.is_err(): @@ -1675,20 +1684,40 @@ async def refund_payment_support_handler( payment.save() @admin_only - def enter_maintenance_admin( + async def enter_maintenance_admin( self, update: ModifiedTeleUpdate, _: ContextTypes.DEFAULT_TYPE ): message = update.message self.payment_handlers.enter_maintenance_mode() - return message.reply_text('Maintenance mode entered') + return await message.reply_text('Maintenance mode entered') @admin_only - def exit_maintenance_admin( + async def exit_maintenance_admin( self, update: ModifiedTeleUpdate, _: ContextTypes.DEFAULT_TYPE ): message = update.message self.payment_handlers.exit_maintenance_mode() - return message.reply_text('Maintenance mode exited') + return await message.reply_text('Maintenance mode exited') + + @admin_only + async def send_msg_admin( + self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + ): + raw_args = TelegramHelpers.read_raw_command_args(update) + split_index = raw_args.index(' ') + chat_id = int(raw_args[:split_index]) + payload = raw_args[split_index+1:] + reply_text = update.message.reply_text + + try: + response = await context.bot.send_message( + chat_id=chat_id, text=payload + ) + except telegram.error.BadRequest: + return await reply_text("Failed to send message") + + logger.info(f"SEND_RESP {response}") + return await reply_text("Message sent") if __name__ == '__main__': diff --git a/database/database.py b/database/database.py index 90d663cd..24c11b2a 100644 --- a/database/database.py +++ b/database/database.py @@ -234,6 +234,19 @@ def get_voter_user(self) -> Users: assert isinstance(self.user, Users) return self.user + @classmethod + def get_poll_voter( + cls, poll_id: int, user_id: UserID + ) -> Result[PollVoters, Optional[BaseModel.DoesNotExist]]: + # check if voter is part of the poll + return cls.safe_get( + (cls.poll == poll_id) & (cls.user == user_id) + ) + + @classmethod + def is_poll_voter(cls, poll_id: int, user_id: UserID) -> bool: + return cls.get_poll_voter(poll_id=poll_id, user_id=user_id).is_ok() + # whitelists voters for a poll by their username # assigns their user_id to the corresponding username diff --git a/handlers/chat_context_handlers.py b/handlers/chat_context_handlers.py index cf0a8796..e84e3efe 100644 --- a/handlers/chat_context_handlers.py +++ b/handlers/chat_context_handlers.py @@ -7,7 +7,7 @@ from base_api import BaseAPI from bot_middleware import track_errors from handlers.payment_handlers import IncMaxVotersChatContext, PaymentHandlers -from handlers.start_handlers import StartGetParams +from helpers.start_get_params import StartGetParams from helpers import strings from helpers.commands import Command from helpers.constants import BLANK_POLL_ID @@ -18,7 +18,7 @@ extract_chat_context ) from helpers.strings import ( - READ_SUBSCRIPTION_TIER_FAILED, INCREASE_MAX_VOTERS_TEXT + READ_SUBSCRIPTION_TIER_FAILED, generate_poll_created_message ) from database import ( Users, CallbackContextState, ChatContextStateTypes, Polls, SupportTickets @@ -36,15 +36,50 @@ async def complete_chat_context( @abstractmethod async def handle_messages( self, extracted_context: ExtractedChatContext, - update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool ): + """ + :param extracted_context: + :param update: + :param context: + :param is_from_start: + whether the chat just got initiated from the start command + """ ... -class PollCreationContextHandler(BaseContextHandler): +class ClosePollContextHandler(BaseContextHandler): async def handle_messages( self, extracted_context: ExtractedChatContext, + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool + ): + message = update.message + raw_poll_id = message.text + + try: + poll_id = int(raw_poll_id) + except ValueError: + return await message.reply_text( + f"Invalid poll id: {raw_poll_id}" + ) + + # TODO: implement poll closing here + + async def complete_chat_context( + self, chat_context: CallbackContextState, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + ): + return await update.message.reply_text( + f"/{Command.DONE} not supported for closing polls" + ) + +class PollCreationContextHandler(BaseContextHandler): + async def handle_messages( + self, extracted_context: ExtractedChatContext, + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool ): message: Message = update.message chat_context = extracted_context.chat_context @@ -88,6 +123,7 @@ async def complete_chat_context( ): user_entry: Users = update.user message: Message = update.message + reply_text = message.reply_text tele_user: TeleUser | None = message.from_user chat_type = message.chat.type user_id = user_entry.get_user_id() @@ -95,14 +131,14 @@ async def complete_chat_context( poll_creation_context_res = PollCreationChatContext.load(chat_context) if poll_creation_context_res.is_err(): chat_context.delete() - return await message.reply_text( + return await reply_text( "Unexpected error loading poll creation context" ) poll_creation_context = poll_creation_context_res.unwrap() subscription_tier_res = user_entry.get_subscription_tier() if subscription_tier_res.is_err(): - return await message.reply_text(READ_SUBSCRIPTION_TIER_FAILED) + return await reply_text(READ_SUBSCRIPTION_TIER_FAILED) subscription_tier = subscription_tier_res.unwrap() poll_creator = poll_creation_context.to_template( @@ -112,7 +148,7 @@ async def complete_chat_context( create_poll_res = poll_creator.save_poll_to_db() if create_poll_res.is_err(): error_message = create_poll_res.err() - return await error_message.call(message.reply_text) + return await error_message.call(reply_text) new_poll: Polls = create_poll_res.unwrap() poll_id = int(new_poll.id) @@ -125,11 +161,12 @@ async def complete_chat_context( username=user_entry.username, # set to false here to discourage sending webapp # link before group chat has been whitelisted - add_webapp_link=False + add_webapp_link=False, + add_instructions=update.is_group_chat() ) if view_poll_result.is_err(): error_message = view_poll_result.err() - return await error_message.call(message.reply_text) + return await error_message.call(reply_text) poll_message = view_poll_result.unwrap() reply_markup = BaseAPI.generate_vote_markup( @@ -138,7 +175,6 @@ async def complete_chat_context( num_options=poll_message.poll_info.max_options ) - reply_text = message.reply_text bot_username = context.bot.username deep_link_url = ( f'https://t.me/{bot_username}?startgroup=' @@ -155,25 +191,24 @@ async def complete_chat_context( "Alternatively, click the following link to share the " "poll to the group chat of your choice:" ) - # https://stackoverflow.com/questions/76538913/ - return await message.reply_markdown_v2(textwrap.dedent(f""" - {strings.escape_markdown(INCREASE_MAX_VOTERS_TEXT)} - - Run the following command: - `/{Command.WHITELIST_CHAT_REGISTRATION} {poll_id}` - {group_chat_text}\\. - - {share_link_text} - [{escaped_deep_link_url}]({escaped_deep_link_url}) - """)) + return await message.reply_markdown_v2( + strings.escape_markdown(generate_poll_created_message(poll_id)) + + f'\n\n' + + f'Run the following command:\n' + f"`/{Command.WHITELIST_CHAT_REGISTRATION} {poll_id}` " + f"{group_chat_text}\\.\n" + + f'\n' + + share_link_text + + f" [{escaped_deep_link_url}]({escaped_deep_link_url})" + ) class VoteContextHandler(BaseContextHandler): - @track_errors async def handle_messages( self, extracted_context: ExtractedChatContext, - update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool ): message: Message = update.message chat_context = extracted_context.chat_context @@ -186,7 +221,44 @@ async def handle_messages( "Unexpected error loading vote context" ) + user = update.user + tele_user: TeleUser = update.message.from_user + bot_username = context.bot.username vote_context = vote_context_res.unwrap() + + if is_from_start: + """ + if called from /start command, we send all the information + about the poll in the chat context and prompt them + to choose poll options interactively + """ + if not vote_context.has_poll_id: + return await message.reply_text("Invalid poll ID") + + poll_id = vote_context.poll_id + poll_info_res = BaseAPI.read_poll_info( + poll_id=poll_id, user_id=user.get_user_id(), + username=tele_user.username, chat_id=message.chat_id + ) + if poll_info_res.is_err(): + error_message = poll_info_res.err() + return await error_message.call(message.reply_text) + + poll_info = poll_info_res.unwrap() + poll_message = BaseAPI.generate_poll_message( + poll_info=poll_info, bot_username=bot_username + ) + poll = poll_message.poll_info.metadata + reply_markup = BaseAPI.generate_vote_markup( + tele_user=tele_user, poll_id=poll_id, chat_type='private', + open_registration=poll.open_registration, + num_options=poll_message.poll_info.max_options + ) + poll_contents = poll_message.text + await message.reply_text(poll_contents, reply_markup=reply_markup) + prompt = vote_context.generate_vote_option_prompt() + return await message.reply_text(prompt) + if not vote_context.has_poll_id: # accept the current text message as the poll_id and set it try: @@ -194,9 +266,8 @@ async def handle_messages( except ValueError: return await message.reply_text("Invalid poll ID") - tele_user: TeleUser = update.message.from_user poll_info_res = BaseAPI.read_poll_info( - poll_id=poll_id, user_id=update.user.get_user_id(), + poll_id=poll_id, user_id=user.get_user_id(), username=tele_user.username, chat_id=message.chat_id ) @@ -310,7 +381,8 @@ async def complete_chat_context( async def handle_messages( self, extracted_context: ExtractedChatContext, - update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool ): msg: Message = update.message chat_context = extracted_context.chat_context @@ -371,7 +443,8 @@ async def complete_chat_context( async def handle_messages( self, extracted_context: ExtractedChatContext, - update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool ): chat_context = extracted_context.chat_context raw_args = TelegramHelpers.read_raw_command_args(update) @@ -427,7 +500,8 @@ def __init__(self): @track_errors async def handle_other_messages( - self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE + self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, + is_from_start: bool = False ): message: Message = update.message chat_context_res = extract_chat_context(update) @@ -445,7 +519,8 @@ async def handle_other_messages( context_handler_cls = self.context_handlers[context_type] context_handler = context_handler_cls() return await context_handler.handle_messages( - extracted_context, update, context + extracted_context, update, context, + is_from_start=is_from_start ) @track_errors diff --git a/handlers/inline_keyboard_handlers.py b/handlers/inline_keyboard_handlers.py index 5a3001bc..32d7fba0 100644 --- a/handlers/inline_keyboard_handlers.py +++ b/handlers/inline_keyboard_handlers.py @@ -2,6 +2,10 @@ import json import logging import time + +import telegram +from telegram._utils.types import ReplyMarkup + import base_api from abc import ABCMeta, abstractmethod @@ -10,9 +14,10 @@ from telegram.ext import CallbackContext from bot_middleware import track_errors from database.db_helpers import UserID -from helpers import constants +from helpers import constants, strings +from helpers.chat_contexts import VoteChatContext from helpers.locks_manager import PollsLockManager -from helpers.strings import generate_poll_closed_message +from helpers.strings import generate_poll_closed_message, generate_poll_deleted_message from tele_helpers import ModifiedTeleUpdate, TelegramHelpers from telegram import User as TeleUser, Message from json import JSONDecodeError @@ -290,7 +295,7 @@ async def handle_queries( poll_closed_res = Polls.get_is_closed(poll_id) if poll_closed_res.is_err(): - return await query.answer('FAILED TO CHECK IF POLL CLOSED') + return await query.answer(generate_poll_deleted_message(poll_id)) elif poll_closed_res.unwrap(): return await query.answer(generate_poll_closed_message(poll_id)) @@ -353,7 +358,7 @@ async def handle_queries( poll_closed_res = Polls.get_is_closed(poll_id) if poll_closed_res.is_err(): - return await query.answer('FAILED TO CHECK IF POLL CLOSED') + return await query.answer(generate_poll_deleted_message(poll_id)) elif poll_closed_res.unwrap(): return await query.answer(generate_poll_closed_message(poll_id)) @@ -386,7 +391,7 @@ async def handle_queries( poll_closed_res = Polls.get_is_closed(poll_id) if poll_closed_res.is_err(): - return await query.answer('FAILED TO CHECK IF POLL CLOSED') + return await query.answer(generate_poll_deleted_message(poll_id)) elif poll_closed_res.unwrap(): return await query.answer(generate_poll_closed_message(poll_id)) else: @@ -421,7 +426,7 @@ async def handle_queries( poll_closed_res = Polls.get_is_closed(poll_id) if poll_closed_res.is_err(): - return await query.answer('FAILED TO CHECK IF POLL CLOSED') + return await query.answer(generate_poll_deleted_message(poll_id)) elif poll_closed_res.unwrap(): return await query.answer(generate_poll_closed_message(poll_id)) @@ -445,16 +450,18 @@ async def handle_queries( chat_id = message.chat_id message_id = query.message.message_id - extracted_message_context_res = extract_message_context(update) poll_id = int(callback_data['poll_id']) poll_closed_res = Polls.get_is_closed(poll_id) if poll_closed_res.is_err(): - return await query.answer('FAILED TO CHECK IF POLL CLOSED') + return await query.answer(generate_poll_deleted_message(poll_id)) elif poll_closed_res.unwrap(): return await query.answer(generate_poll_closed_message(poll_id)) + extracted_message_context_res = extract_message_context(update) if extracted_message_context_res.is_err(): + # message chat context is empty + # (i.e. number buttons weren't pressed) has_voted = BaseAPI.check_has_voted( poll_id=poll_id, user_id=update.user.id ) @@ -496,6 +503,154 @@ async def handle_queries( ) +class RegisterSubmitMessageHandler(BaseMessageHandler): + async def handle_queries( + self, update: ModifiedTeleUpdate, context: CallbackContext, + callback_data: dict[str, any] + ): + user = update.user + user_id = user.get_user_id() + query = update.callback_query + message: Message = query.message + tele_user: TeleUser = query.from_user + message_id = query.message.message_id + poll_id = int(callback_data['poll_id']) + poll_closed_res = Polls.get_is_closed(poll_id) + chat_id = message.chat_id + coroutines = [] + + if poll_closed_res.is_err(): + return await query.answer(generate_poll_deleted_message(poll_id)) + elif poll_closed_res.unwrap(): + return await query.answer(generate_poll_closed_message(poll_id)) + + extracted_message_context_res = extract_message_context(update) + poll_voter_res = PollVoters.get_poll_voter( + poll_id=poll_id, user_id=user_id + ) + registered = poll_voter_res.is_ok() + has_message_context = extracted_message_context_res.is_ok() + + if has_message_context: + # message context vote info exists, + # therefore we just submit the vote in the message vote context + extracted_message_context = extracted_message_context_res.unwrap() + vote_context_res = VoteMessageContext.load( + extracted_message_context.message_context + ) + if vote_context_res.is_err(): + return await query.answer("Failed to load context") + + vote_context = vote_context_res.unwrap() + # print('TELE_USER_ID:', tele_user.id) + register_vote_result = BaseAPI.register_vote( + chat_id=chat_id, rankings=vote_context.rankings, + poll_id=vote_context.poll_id, + username=tele_user.username, user_tele_id=tele_user.id + ) + + if register_vote_result.is_err(): + error_message = register_vote_result.unwrap_err() + return await error_message.call(query.answer) + + # whether the voter was registered for the poll during the vote itself + _, newly_registered = register_vote_result.unwrap() + extracted_message_context.message_context.delete_instance() + + if newly_registered: + poll_info = BaseAPI.unverified_read_poll_info(poll_id=poll_id) + await TelegramHelpers.update_poll_message( + poll_info=poll_info, chat_id=chat_id, + message_id=message_id, context=context, + poll_locks_manager=_poll_locks_manager + ) + + return await query.answer("Vote Submitted") + + assert not has_message_context + newly_registered = False + + if not registered: + # not registered, no message context vote found + if not ChatWhitelist.is_whitelisted(poll_id, chat_id): + return await query.answer( + "Not allowed to register from this chat" + ) + + register_status = _register_voter( + poll_id=poll_id, user_id=user_id, + username=tele_user.username + ) + if register_status == UserRegistrationStatus.REGISTERED: + newly_registered = True + poll_info = BaseAPI.unverified_read_poll_info(poll_id=poll_id) + coroutines.append(TelegramHelpers.update_poll_message( + poll_info=poll_info, chat_id=chat_id, + message_id=message_id, context=context, + poll_locks_manager=_poll_locks_manager + )) + else: + return await query.answer(BaseAPI.reg_status_to_msg( + register_status, poll_id + )) + + # create vote chat DM context and try to send a message to the user + poll_info_res = BaseAPI.read_poll_info( + poll_id=poll_id, user_id=user_id, + username=tele_user.username, chat_id=message.chat_id + ) + if poll_info_res.is_err(): + error_message = poll_info_res.err() + return await error_message.call(query.answer) + + poll_info = poll_info_res.unwrap() + vote_context = VoteChatContext( + user_id=user_id, chat_id=tele_user.id, + max_options=poll_info.max_options, poll_id=poll_id + ) + vote_context.save_state() + bot_username = context.bot.username + async def send_dm(text, markup: Optional[ReplyMarkup] = None): + await context.bot.send_message( + text=text, chat_id=tele_user.id, reply_markup=markup + ) + + if newly_registered: + resp_header = "Registered for poll" + else: + resp_header = "Registered already" + + try: + # check that we can send a message to user directly + # i.e. check that bot DM with user has been opened + await send_dm(strings.BOT_STARTED) + # raise telegram.error.BadRequest("") + except telegram.error.BadRequest: + resp = f"{resp_header} - start the bot to cast your vote" + return await query.answer(resp) + + coroutine = query.answer(resp_header) + coroutines.append(coroutine) + poll_message = BaseAPI.generate_poll_message( + poll_info=poll_info, bot_username=bot_username, + add_instructions=False + ) + poll = poll_message.poll_info.metadata + reply_markup = BaseAPI.generate_vote_markup( + tele_user=tele_user, poll_id=poll_id, chat_type='private', + open_registration=poll.open_registration, + num_options=poll_message.poll_info.max_options + ) + # display poll info in chat DMs at the start + poll_contents = poll_message.text + async def dm_poll_info(): + await send_dm(poll_contents, markup=reply_markup) + await send_dm(vote_context.generate_vote_option_prompt()) + + coroutines.append(dm_poll_info()) + await asyncio.gather(*coroutines) + + class InlineKeyboardHandlers(object): def __init__(self, logger: logging.Logger): self.logger = logger @@ -508,8 +663,11 @@ def __init__(self, logger: logging.Logger): CallbackCommands.UNDO_OPTION: UndoVoteRankingMessageHandler, CallbackCommands.RESET_VOTE: ResetVoteMessageHandler, CallbackCommands.VIEW_VOTE: ViewVoteMessageHandler, - CallbackCommands.SUBMIT_VOTE: SubmitVoteMessageHandler + CallbackCommands.SUBMIT_VOTE: SubmitVoteMessageHandler, + CallbackCommands.REGISTER_OR_SUBMIT: RegisterSubmitMessageHandler } + for callback_command in CallbackCommands: + assert callback_command in self.handlers, callback_command @track_errors async def route( diff --git a/handlers/payment_handlers.py b/handlers/payment_handlers.py index f796dc2b..bfa70093 100644 --- a/handlers/payment_handlers.py +++ b/handlers/payment_handlers.py @@ -319,7 +319,6 @@ async def set_max_voters( user = update.user if raw_args == '': - # TODO: implement callback context behavior IncMaxVotersChatContext( user_id=user.get_user_id(), chat_id=msg.chat_id ).save_state() diff --git a/handlers/start_handlers.py b/handlers/start_handlers.py index efbc2a8a..aef0d137 100644 --- a/handlers/start_handlers.py +++ b/handlers/start_handlers.py @@ -3,13 +3,17 @@ from base_api import BaseAPI from database import Users, Payments, Polls -from handlers.payment_handlers import BasePaymentParams, InvoiceTypes, IncreaseVoterLimitParams, PaymentHandlers from helpers import strings +from helpers.chat_contexts import extract_chat_context from tele_helpers import ModifiedTeleUpdate, TelegramHelpers from telegram import User as TeleUser, ReplyKeyboardMarkup from telegram.ext import ContextTypes from helpers.start_get_params import StartGetParams - +from handlers.chat_context_handlers import context_handlers +from handlers.payment_handlers import ( + BasePaymentParams, InvoiceTypes, IncreaseVoterLimitParams, + PaymentHandlers +) class BaseMessageHandler(object, metaclass=ABCMeta): @abstractmethod @@ -47,7 +51,8 @@ async def handle_messages( view_poll_result = BaseAPI.get_poll_message( poll_id=poll_id, user_id=user_id, bot_username=context.bot.username, - username=tele_user.username + username=tele_user.username, + add_instructions=update.is_group_chat() ) if view_poll_result.is_err(): @@ -158,8 +163,27 @@ async def start_handler( args = context.args if len(args) == 0: - await update.message.reply_text('Bot started') - return True + await update.message.reply_text(strings.BOT_STARTED) + # check for existing chat context and process it if it exists + chat_context_res = extract_chat_context(update) + if chat_context_res.is_err(): + return + + extracted_context = chat_context_res.unwrap() + context_type = extracted_context.context_type + chat_handlers = context_handlers.context_handlers + + if context_type not in chat_handlers: + return await message.reply_text( + f"{context_type} context unsupported" + ) + + context_handler_cls = chat_handlers[context_type] + context_handler = context_handler_cls() + return await context_handler.handle_messages( + extracted_context, update, context, + is_from_start=True + ) command_params: str = args[0] assert isinstance(command_params, str) diff --git a/helpers/commands.py b/helpers/commands.py index 1ca53924..84d7cde7 100644 --- a/helpers/commands.py +++ b/helpers/commands.py @@ -37,3 +37,4 @@ class Command(StrEnum): REFUND_ADMIN = "refund_admin" ENTER_MAINTENANCE_ADMIN = "enter_maintenance_admin" EXIT_MAINTENANCE_ADMIN = "exit_maintenance_admin" + SEND_MSG_ADMIN = "send_msg_admin" diff --git a/helpers/contexts.py b/helpers/contexts.py index b2ca3fdb..6365adec 100644 --- a/helpers/contexts.py +++ b/helpers/contexts.py @@ -61,7 +61,10 @@ def num_options(self) -> int: return len(self.rankings) def to_vote_message(self) -> str: - return f'{self.poll_id}: ' + self.rankings_to_str() + return ( + f'Current vote for Poll #{self.poll_id}: \n' + + self.rankings_to_str() + ) def rankings_to_str(self): return ' > '.join([ diff --git a/helpers/strings.py b/helpers/strings.py index a65cea6a..041f3cdc 100644 --- a/helpers/strings.py +++ b/helpers/strings.py @@ -3,7 +3,7 @@ from helpers.commands import Command -__VERSION__ = '1.3.0' +__VERSION__ = '1.3.2' READ_SUBSCRIPTION_TIER_FAILED = "Unexpected error reading subscription tier" POLL_OPTIONS_LIMIT_REACHED_TEXT = textwrap.dedent(f""" @@ -124,11 +124,6 @@ def generate_delete_text(deletion_token: str) -> str: /{Command.DELETE_ACCOUNT} {deletion_token} """) - -INCREASE_MAX_VOTERS_TEXT = ( - f"Poll created. Use /{Command.SET_MAX_VOTERS} to change " - f"the maximum number of voters who can vote for the poll. " -) MAX_VOTERS_NOT_EDITABLE = ( "Invalid poll ID - note that only the poll's creator is allowed to " "change the max number of voters" @@ -141,6 +136,7 @@ def generate_delete_text(deletion_token: str) -> str: "New poll max voter limit must be greater " "than the existing limit" ) +BOT_STARTED = 'Bot started' def escape_markdown(string: str) -> str: @@ -151,12 +147,21 @@ def escape_markdown(string: str) -> str: ) +def generate_poll_created_message(poll_id: int): + return ( + f"Poll #{poll_id} created.\n" + f"Run /close_poll to close the poll.\n" + f"Use /{Command.SET_MAX_VOTERS} to change " + f"the maximum number of voters who can vote for the poll. " + ) + + def generate_vote_option_prompt(rank: int) -> str: if rank == 1: - return f"Enter the poll option you want to rank #{rank}:" + return f"Enter the poll option no. you want to rank #{rank}:" else: return ( - f"Enter the poll option you want to rank #{rank}, " + f"Enter the poll option no. you want to rank #{rank}, " f"or use /done if you're done:" ) @@ -170,3 +175,6 @@ def generate_max_voters_prompt(poll_id: int, current_max: int): def generate_poll_closed_message(poll_id: int): return f"Poll #{poll_id} has been closed already" + +def generate_poll_deleted_message(poll_id: int): + return f"Poll #{poll_id} has been deleted already" diff --git a/tele_helpers.py b/tele_helpers.py index a3c01249..a0bb70d7 100644 --- a/tele_helpers.py +++ b/tele_helpers.py @@ -56,6 +56,9 @@ def effective_message(self): def pre_checkout_query(self): return self.update.pre_checkout_query + def is_group_chat(self) -> bool: + return self.update.message.chat.type != 'private' + class TelegramHelpers(object): @classmethod @@ -289,7 +292,7 @@ def extract_poll_id( @classmethod async def set_chat_registration_status( cls, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, - whitelist: bool, poll_id: int, add_webapp_link: bool = True + whitelist: bool, poll_id: int, add_webapp_link: bool = False ) -> bool: message = update.message tele_user: TeleUser | None = message.from_user @@ -352,7 +355,7 @@ async def set_chat_registration_status( @classmethod async def view_poll_by_id( cls, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE, - poll_id: int, add_webapp_link: bool = True + poll_id: int, add_webapp_link: bool = False ) -> bool: user = update.user message = update.message @@ -363,7 +366,8 @@ async def view_poll_by_id( poll_id=poll_id, user_id=user_id, bot_username=context.bot.username, username=user.username, - add_webapp_link=add_webapp_link + add_webapp_link=add_webapp_link, + add_instructions=update.is_group_chat() ) if view_poll_result.is_err(): @@ -388,7 +392,7 @@ async def view_poll_by_id( async def update_poll_message( cls, poll_info: PollInfo, chat_id: int, message_id: int, context: CallbackContext, poll_locks_manager: PollsLockManager, - verbose: bool = False + verbose: bool = False, add_instructions: bool = True ): """ attempts to update the poll info message such that in @@ -411,7 +415,8 @@ async def update_poll_message( if await poll_locks.has_correct_voter_count(voter_count): try: poll_display_message = BaseAPI.generate_poll_message( - poll_info=poll_info, bot_username=bot_username + poll_info=poll_info, bot_username=bot_username, + add_instructions=add_instructions ) await context.bot.edit_message_text( chat_id=chat_id, message_id=message_id,