Skip to content

Commit

Permalink
Merge pull request #47 from milselarch/feat/prompt_dm_votes
Browse files Browse the repository at this point in the history
Feat/prompt dm votes
  • Loading branch information
milselarch authored Nov 27, 2024
2 parents 38df98a + 698268e commit 4030b9c
Show file tree
Hide file tree
Showing 11 changed files with 434 additions and 137 deletions.
106 changes: 44 additions & 62 deletions base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from aioredlock import Aioredlock, LockError
from database.database import (
PollWinners, BaseModel, UserID, PollMetadata, ChatWhitelist
PollWinners, UserID, PollMetadata, ChatWhitelist
)

logger = logging.getLogger(__name__)
Expand All @@ -57,6 +57,7 @@ class CallbackCommands(StrEnum):
UNDO_OPTION = 'UNDO'
RESET_VOTE = 'RESET'
SUBMIT_VOTE = 'SUBMIT_VOTE'
REGISTER_OR_SUBMIT = 'REGISTER_OR_SUBMIT'
VIEW_VOTE = 'VIEW_VOTE'


Expand Down Expand Up @@ -366,16 +367,6 @@ def _determine_poll_winner(cls, poll_id: int) -> Optional[int]:
winning_option_id = votes_aggregator.determine_winner()
return winning_option_id

@staticmethod
def get_poll_voter(
poll_id: int, user_id: UserID
) -> Result[PollVoters, Optional[BaseModel.DoesNotExist]]:
# check if voter is part of the poll
return PollVoters.safe_get(
(PollVoters.poll == poll_id) &
(PollVoters.user == user_id)
)

@classmethod
def verify_voter(
cls, poll_id: int, user_id: UserID, username: Optional[str] = None,
Expand All @@ -389,7 +380,7 @@ def verify_voter(
and whether user was newly whitelisted from chat whitelist
"""

poll_voter_res = cls.get_poll_voter(poll_id, user_id)
poll_voter_res = PollVoters.get_poll_voter(poll_id, user_id)
if poll_voter_res.is_ok():
poll_voter = poll_voter_res.unwrap()
return Ok((poll_voter, False))
Expand Down Expand Up @@ -606,14 +597,11 @@ def check_has_voted(poll_id: int, user_id: UserID) -> bool:
user_id=user_id, poll_id=poll_id, voted=True
).safe_get().is_ok()

@classmethod
def is_poll_voter(cls, poll_id: int, user_id: UserID) -> bool:
return cls.get_poll_voter(poll_id=poll_id, user_id=user_id).is_ok()

@classmethod
def get_poll_message(
cls, poll_id: int, user_id: UserID, bot_username: str,
username: Optional[str], add_webapp_link: bool = True
username: Optional[str], add_webapp_link: bool = False,
add_instructions: bool = False
) -> Result[PollMessage, MessageBuilder]:
if not cls.has_access_to_poll_id(
poll_id=poll_id, user_id=user_id, username=username
Expand All @@ -624,24 +612,28 @@ def get_poll_message(

return Ok(cls._get_poll_message(
poll_id=poll_id, bot_username=bot_username,
add_webapp_link=add_webapp_link
add_webapp_link=add_webapp_link,
add_instructions=add_instructions
))

@classmethod
def _get_poll_message(
cls, poll_id: int, bot_username: str,
add_webapp_link: bool = True
add_webapp_link: bool = False,
add_instructions: bool = False
) -> PollMessage:
poll_info = cls.unverified_read_poll_info(poll_id=poll_id)
return cls.generate_poll_message(
poll_info=poll_info, bot_username=bot_username,
add_webapp_link=add_webapp_link
add_webapp_link=add_webapp_link,
add_instructions=add_instructions
)

@classmethod
def generate_poll_message(
cls, poll_info: PollInfo, bot_username: str,
add_webapp_link: bool = True
add_webapp_link: bool = False,
add_instructions: bool = False
) -> PollMessage:
poll_metadata = poll_info.metadata
poll_message = cls.generate_poll_info(
Expand All @@ -650,7 +642,8 @@ def generate_poll_message(
bot_username=bot_username, max_voters=poll_metadata.max_voters,
num_voters=poll_metadata.num_active_voters,
num_votes=poll_metadata.num_votes,
add_webapp_link=add_webapp_link
add_webapp_link=add_webapp_link,
add_instructions=add_instructions
)

reply_markup = None
Expand Down Expand Up @@ -703,7 +696,8 @@ def build_private_vote_markup(
logger.warning(f'POLL_URL = {poll_url}')
# create vote button for reply message
markup_layout = [[KeyboardButton(
text=f'Vote for Poll #{poll_id}', web_app=WebAppInfo(url=poll_url)
text=f'Vote for Poll #{poll_id} Online',
web_app=WebAppInfo(url=poll_url)
)]]

return markup_layout
Expand All @@ -714,20 +708,11 @@ def build_group_vote_markup(
) -> List[List[InlineKeyboardButton]]:
"""
TODO: implement button vote context
< poll registration button >
< vote option rows >
< undo, abstain, withhold, reset >
< submit / check button >
< undo, view, reset >
< register / submit button >
"""
markup_rows, current_row = [], []

# create first row with just registration button
markup_rows.append([cls.spawn_inline_keyboard_button(
text='Register for Poll',
command=CallbackCommands.REGISTER_FOR_POLL,
callback_data=dict(poll_id=poll_id)
)])

# fill in rows containing poll option numbers
for ranking in range(1, num_options+1):
current_row.append(cls.spawn_inline_keyboard_button(
Expand All @@ -745,26 +730,16 @@ def build_group_vote_markup(
markup_rows.append(current_row)
current_row = []

# add row with undo, abstain, withhold, reset buttons
# add row with undo, view, reset buttons
markup_rows.append([
cls.spawn_inline_keyboard_button(
text='undo',
command=CallbackCommands.UNDO_OPTION,
callback_data=dict(poll_id=poll_id)
), cls.spawn_inline_keyboard_button(
text='abstain',
command=CallbackCommands.ADD_VOTE_OPTION,
callback_data=dict(
poll_id=poll_id,
option=SpecialVotes.ABSTAIN_VOTE.value
)
), cls.spawn_inline_keyboard_button(
text='withhold',
command=CallbackCommands.ADD_VOTE_OPTION,
callback_data=dict(
poll_id=poll_id,
option=SpecialVotes.WITHHOLD_VOTE.value
)
text='view',
command=CallbackCommands.VIEW_VOTE,
callback_data=dict(poll_id=poll_id)
), cls.spawn_inline_keyboard_button(
text='reset',
command=CallbackCommands.RESET_VOTE,
Expand All @@ -775,13 +750,8 @@ def build_group_vote_markup(
# add final row with view vote, submit vote buttons
markup_rows.append([
cls.spawn_inline_keyboard_button(
text='View Vote',
command=CallbackCommands.VIEW_VOTE,
callback_data=dict(poll_id=poll_id)
),
cls.spawn_inline_keyboard_button(
text='Submit Vote',
command=CallbackCommands.SUBMIT_VOTE,
text='Register / Submit Vote',
command=CallbackCommands.REGISTER_OR_SUBMIT,
callback_data=dict(poll_id=poll_id)
)
])
Expand Down Expand Up @@ -853,7 +823,7 @@ def has_access_to_poll(

if creator_id == user_id:
return True
if cls.is_poll_voter(poll_id, user_id):
if PollVoters.is_poll_voter(poll_id, user_id):
return True

if username is not None:
Expand Down Expand Up @@ -912,7 +882,8 @@ def generate_poll_info(
poll_id, poll_question, poll_options: list[str],
bot_username: str, max_voters: int, num_votes: int = 0,
num_voters: int = 0, closed: bool = False,
add_webapp_link: bool = True
add_webapp_link: bool = False,
add_instructions: bool = False
):
close_tag = '(closed)' if closed else ''
numbered_poll_options = [
Expand All @@ -927,23 +898,34 @@ def generate_poll_info(
)

webapp_link_footer = ''
instructions_footer = ''

if add_webapp_link:
webapp_link_footer = (
f'\n——————————————————'
f'\nvote on the webapp at {deep_link_url}'
f'\nvote on the webapp at {deep_link_url}\n'
)
if add_instructions:
instructions_footer = (
'\n——————————————————\n'
'How to vote:\n'
'- press the register button, then '
'start the bot via chat DM\n'
'- alternatively, press the number buttons in order of most '
'to least favourite option, then press submit'
)

return (
textwrap.dedent(f"""
return (textwrap.dedent(f"""
Poll #{poll_id} {close_tag}
{poll_question}
——————————————————
{num_votes} / {num_voters} voted (max {max_voters})
——————————————————
""") +
f'\n'.join(numbered_poll_options) +
webapp_link_footer
)
webapp_link_footer +
instructions_footer
).strip()

@staticmethod
def make_data_check_string(
Expand Down
63 changes: 46 additions & 17 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import textwrap
import telegram
import asyncio
import re

Expand Down Expand Up @@ -41,7 +42,7 @@

from helpers.strings import (
POLL_OPTIONS_LIMIT_REACHED_TEXT, READ_SUBSCRIPTION_TIER_FAILED,
INCREASE_MAX_VOTERS_TEXT
generate_poll_created_message
)
from helpers.chat_contexts import (
PollCreationChatContext, PollCreatorTemplate, POLL_MAX_OPTIONS,
Expand Down Expand Up @@ -150,17 +151,19 @@ def start_bot(self):
Command.DELETE_ACCOUNT: self.delete_account,
Command.HELP: self.show_help,
Command.DONE: context_handlers.complete_chat_context,
Command.SET_MAX_VOTERS: self.payment_handlers.set_max_voters,
Command.PAY_SUPPORT: self.payment_support_handler,

Command.VOTE_ADMIN: self.vote_for_poll_admin,
Command.CLOSE_POLL_ADMIN: self.close_poll_admin,
Command.UNCLOSE_POLL_ADMIN: self.unclose_poll_admin,
Command.LOOKUP_FROM_USERNAME_ADMIN:
self.lookup_from_username_admin,
Command.INSERT_USER_ADMIN: self.insert_user_admin,
Command.SET_MAX_VOTERS: self.payment_handlers.set_max_voters,
Command.PAY_SUPPORT: self.payment_support_handler,
Command.REFUND_ADMIN: self.refund_payment_support_handler,
Command.ENTER_MAINTENANCE_ADMIN: self.enter_maintenance_admin,
Command.EXIT_MAINTENANCE_ADMIN: self.exit_maintenance_admin
Command.EXIT_MAINTENANCE_ADMIN: self.exit_maintenance_admin,
Command.SEND_MSG_ADMIN: self.send_msg_admin
}

# on different commands - answer in Telegram
Expand Down Expand Up @@ -199,10 +202,9 @@ def start_bot(self):
self.app.run_polling(allowed_updates=BaseTeleUpdate.ALL_TYPES)
print('<<< BOT POLLING LOOP ENDED >>>')

@staticmethod
async def post_init(application: Application):
async def post_init(self, _: Application):
# print('SET COMMANDS')
await application.bot.set_my_commands([(
await self.get_bot().set_my_commands([(
Command.START, 'start bot'
), (
Command.USER_DETAILS, 'shows your username and user id'
Expand Down Expand Up @@ -346,7 +348,7 @@ async def has_voted(self, update: ModifiedTeleUpdate, *_, **__):

user_id = user.get_user_id()
poll_id = extract_poll_id_result.unwrap()
is_voter = self.is_poll_voter(
is_voter = PollVoters.is_poll_voter(
poll_id=poll_id, user_id=user_id
)

Expand Down Expand Up @@ -551,7 +553,8 @@ async def create_poll(
new_poll_id, poll_question, poll_options,
bot_username=bot_username, closed=False,
num_voters=poll_creator.initial_num_voters,
max_voters=new_poll.max_voters
max_voters=new_poll.max_voters,
add_instructions=update.is_group_chat()
)

chat_type = update.message.chat.type
Expand All @@ -570,10 +573,8 @@ async def create_poll(
)
reply_markup = InlineKeyboardMarkup(vote_markup_data)

await message.reply_text(
poll_message, reply_markup=reply_markup
)
await message.reply_text(INCREASE_MAX_VOTERS_TEXT)
await message.reply_text(poll_message, reply_markup=reply_markup)
await message.reply_text(generate_poll_created_message(new_poll_id))

@classmethod
async def whitelist_chat_registration(
Expand Down Expand Up @@ -1079,6 +1080,14 @@ def read_vote_count(cls, poll_id: int) -> Result[int, MessageBuilder]:

async def close_poll(self, update, *_, **__):
message = update.message
message_text = TelegramHelpers.read_raw_command_args(update)

if constants.ID_PATTERN.match(message_text) is None:
return await message.reply_text(textwrap.dedent(f"""
Input format is invalid, try:
/{Command.CLOSE_POLL} {{poll_id}}
"""))

extract_result = TelegramHelpers.extract_poll_id(update)

if extract_result.is_err():
Expand Down Expand Up @@ -1675,20 +1684,40 @@ async def refund_payment_support_handler(
payment.save()

@admin_only
def enter_maintenance_admin(
async def enter_maintenance_admin(
self, update: ModifiedTeleUpdate, _: ContextTypes.DEFAULT_TYPE
):
message = update.message
self.payment_handlers.enter_maintenance_mode()
return message.reply_text('Maintenance mode entered')
return await message.reply_text('Maintenance mode entered')

@admin_only
def exit_maintenance_admin(
async def exit_maintenance_admin(
self, update: ModifiedTeleUpdate, _: ContextTypes.DEFAULT_TYPE
):
message = update.message
self.payment_handlers.exit_maintenance_mode()
return message.reply_text('Maintenance mode exited')
return await message.reply_text('Maintenance mode exited')

@admin_only
async def send_msg_admin(
self, update: ModifiedTeleUpdate, context: ContextTypes.DEFAULT_TYPE
):
raw_args = TelegramHelpers.read_raw_command_args(update)
split_index = raw_args.index(' ')
chat_id = int(raw_args[:split_index])
payload = raw_args[split_index+1:]
reply_text = update.message.reply_text

try:
response = await context.bot.send_message(
chat_id=chat_id, text=payload
)
except telegram.error.BadRequest:
return await reply_text("Failed to send message")

logger.info(f"SEND_RESP {response}")
return await reply_text("Message sent")


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit 4030b9c

Please sign in to comment.