Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Khakers feature/thread find refactor #27

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ async def on_ready(self):
)
logger.line()

await self.threads.populate_cache()
await self.threads.quick_populate_cache()

# closures
closures = self.config["closures"]
Expand Down Expand Up @@ -621,21 +621,10 @@ async def on_ready(self):
for log in await self.api.get_open_logs():
if self.get_channel(int(log["channel_id"])) is None:
logger.debug("Unable to resolve thread with channel %s.", log["channel_id"])
log_data = await self.api.post_log(
log["channel_id"],
{
"open": False,
"title": None,
"closed_at": str(discord.utils.utcnow()),
"close_message": "Channel has been deleted, no closer found.",
"closer": {
"id": str(self.user.id),
"name": self.user.name,
"discriminator": self.user.discriminator,
"avatar_url": self.user.display_avatar.url,
"mod": True,
},
},
log_data = await self.api.close_log(
channel_id=log["channel_id"],
close_message="Channel has been deleted, no closer found.",
closer=self.user,
)
if log_data:
logger.debug("Successfully closed thread with channel %s.", log["channel_id"])
Expand Down
62 changes: 61 additions & 1 deletion core/clients.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import secrets
import sys
from json import JSONDecodeError
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

import discord
import pymongo.results
from aiohttp import ClientResponse, ClientResponseError
from discord import DMChannel, Member, Message, TextChannel
from discord.ext import commands
Expand Down Expand Up @@ -429,6 +430,12 @@ async def update_repository(self) -> dict:
async def get_user_info(self) -> Optional[dict]:
return NotImplemented

async def add_recipients(self, channel_id: int, recipient: List[discord.User]):
return NotImplemented

async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict:
return NotImplemented

async def update_title(self, title: str, channel_id: Union[str, int]):
return NotImplemented

Expand Down Expand Up @@ -566,6 +573,9 @@ async def get_log(self, channel_id: Union[str, int]) -> dict:
logger.debug("Retrieving channel %s logs.", channel_id)
return await self.logs.find_one({"channel_id": str(channel_id)})

async def get_logs(self, channel_id: List[Union[str, int]]) -> dict:
return await self.logs.find({"channel_id": {"$in": [str(i) for i in channel_id]}}).to_list(None)

async def get_log_link(self, channel_id: Union[str, int]) -> str:
doc = await self.get_log(channel_id)
logger.debug("Retrieving log link for channel %s.", channel_id)
Expand Down Expand Up @@ -593,13 +603,15 @@ async def create_log_entry(self, recipient: Member, channel: TextChannel, creato
"recipient": {
"id": str(recipient.id),
"name": recipient.name,
"global_name": recipient.global_name,
"discriminator": recipient.discriminator,
"avatar_url": recipient.display_avatar.url,
"mod": False,
},
"creator": {
"id": str(creator.id),
"name": creator.name,
"global_name": creator.global_name,
"discriminator": creator.discriminator,
"avatar_url": creator.display_avatar.url,
"mod": isinstance(creator, Member),
Expand Down Expand Up @@ -662,6 +674,7 @@ async def append_log(
"author": {
"id": str(message.author.id),
"name": message.author.name,
"global_name": message.author.global_name,
"discriminator": message.author.discriminator,
"avatar_url": message.author.display_avatar.url,
"mod": not isinstance(message.channel, DMChannel),
Expand Down Expand Up @@ -714,6 +727,7 @@ async def create_note(self, recipient: Member, message: Message, message_id: Uni
"author": {
"id": str(message.author.id),
"name": message.author.name,
"global_name": message.author.global_name,
"discriminator": message.author.discriminator,
"avatar_url": message.author.display_avatar.url,
},
Expand All @@ -735,6 +749,52 @@ async def delete_note(self, message_id: Union[int, str]):
async def edit_note(self, message_id: Union[int, str], message: str):
await self.db.notes.update_one({"message_id": str(message_id)}, {"$set": {"message": message}})

async def add_recipients(self, channel_id: int, recipient: List[discord.User]):
results: pymongo.results.UpdateResult = await self.bot.db.logs.update_one(
{"channel_id": str(channel_id)},
{
"$addToSet": {
"other_recipients": {
"$each": [
{
"id": r.id,
"name": r.name,
"global_name": r.global_name,
"discriminator": r.discriminator,
"avatar_url": r.display_avatar.url,
}
for r in recipient
]
}
}
},
)
if results.matched_count == 0:
raise ValueError(f"Channel id {channel_id} not found in mongodb")
return

async def close_log(self, channel_id: int, title: str, close_message: str, closer: discord.User) -> dict:
# TODO doesn't set title yet
return await self.bot.db.logs.find_one_and_update(
{"channel_id": str(channel_id)},
{
"$set": {
"open": False,
"closed_at": str(discord.utils.utcnow()),
"title": title,
"close_message": close_message,
"closer": {
"id": str(closer.id),
"name": closer.name,
"global_name": closer.name,
"discriminator": closer.discriminator,
"avatar_url": closer.display_avatar.url,
"mod": True,
},
},
},
)

def get_plugin_partition(self, cog):
cls_name = cog.__class__.__name__
return self.db.plugins[cls_name]
Expand Down
95 changes: 52 additions & 43 deletions core/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing
import warnings
from datetime import timedelta
from time import perf_counter

import discord
import isodate
Expand Down Expand Up @@ -441,22 +442,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,

# Logging
if self.channel:
log_data = await self.bot.api.post_log(
self.channel.id,
{
"open": False,
"title": match_title(self.channel.topic),
"closed_at": str(discord.utils.utcnow()),
"nsfw": self.channel.nsfw,
"close_message": message,
"closer": {
"id": str(closer.id),
"name": closer.name,
"discriminator": closer.discriminator,
"avatar_url": closer.display_avatar.url,
"mod": True,
},
},
log_data = await self.bot.api.close_log(
channel_id=self.channel.id,
title=match_title(self.channel.topic),
closer=closer,
close_message=message,
)
else:
log_data = None
Expand Down Expand Up @@ -1217,6 +1207,9 @@ async def add_users(self, users: typing.List[typing.Union[discord.Member, discor

topic += f"\nOther Recipients: {ids}"

# Add recipients to database
await self.bot.api.add_recipients(self._channel.id, users)

await self.channel.edit(topic=topic)
await self._update_users_genesis()

Expand Down Expand Up @@ -1244,11 +1237,14 @@ class ThreadManager:

def __init__(self, bot):
self.bot = bot
self.cache = {}
self.cache: typing.Dict[int, Thread] = {}

async def populate_cache(self) -> None:
# time method runtime
start = perf_counter()
for channel in self.bot.modmail_guild.text_channels:
await self.find(channel=channel)
logger.info("Cache populated in %fs.", time.perf_counter() - start)

def __len__(self):
return len(self.cache)
Expand All @@ -1259,6 +1255,27 @@ def __iter__(self):
def __getitem__(self, item: str) -> Thread:
return self.cache[item]

async def quick_populate_cache(self) -> None:
start = perf_counter()

# create a list containing the id of every text channel in the modmail guild
channel_ids = [channel.id for channel in self.bot.modmail_guild.text_channels]
logs = await self.bot.api.get_logs(channel_ids)

for log in logs:
recipients = log["other_recipients"]

tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients]
recipient_users: list[discord.Member] = await asyncio.gather(*tasks)

self.cache[log["recipient"]["id"]] = Thread(
self,
recipient=log["creator"]["id"],
channel=log["channel_id"],
other_recipients=recipient_users,
)
logger.debug("Cache populated in %fs.", perf_counter() - start)

async def find(
self,
*,
Expand Down Expand Up @@ -1322,44 +1339,36 @@ def check(topic):

return thread

async def _find_from_channel(self, channel):
async def _find_from_channel(self, channel) -> typing.Optional[Thread]:
"""
Tries to find a thread from a channel channel topic,
if channel topic doesnt exist for some reason, falls back to
Tries to find a thread from a channel topic,
if channel topic doesn't exist for some reason, falls back to
searching channel history for genesis embed and
extracts user_id from that.
"""

if not channel.topic:
return None
logger.debug("_find_from_channel")
logger.debug(f"channel: {channel}")

_, user_id, other_ids = parse_channel_topic(channel.topic)
# TODO cache thread for channel ID

if user_id == -1:
return None
log = await self.bot.api.get_log(channel.id)

if user_id in self.cache:
return self.cache[user_id]
if log is None:
return None

try:
recipient = await self.bot.get_or_fetch_user(user_id)
except discord.NotFound:
recipient = None
logger.debug("This is a thread channel")

other_recipients = []
for uid in other_ids:
try:
other_recipient = await self.bot.get_or_fetch_user(uid)
except discord.NotFound:
continue
other_recipients.append(other_recipient)
recipients = log["other_recipients"]
# Create a list of tasks to fetch the users
tasks = [self.bot.get_or_fetch_user(user_data["id"]) for user_data in recipients]
# Fetch the users
recipient_users: list[discord.Member] = await asyncio.gather(*tasks)

if recipient is None:
thread = Thread(self, user_id, channel, other_recipients)
else:
self.cache[user_id] = thread = Thread(self, recipient, channel, other_recipients)
thread = Thread(
self, recipient=log["creator"]["id"], channel=channel, other_recipients=recipient_users
)
thread.ready = True

return thread

async def create(
Expand Down