diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 816336d..e87def4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: check-case-conflict - id: check-merge-conflict @@ -12,3 +12,4 @@ repos: - id: mixed-line-ending args: [ --fix=lf ] - id: end-of-file-fixer + exclude: .devcontainer/devcontainer-lock.json diff --git a/Dockerfile b/Dockerfile index e43bbf6..21df9a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.11-slim@sha256:edaf703dce209d774af3ff768fc92b1e3b60261e7602126276f9ceb0e3a96874 +FROM python:3.12-slim@sha256:541d45d3d675fb8197f534525a671e2f8d66c882b89491f9dda271f4f94dcd06 # Define Git SHA build argument for Sentry ARG git_sha="development" diff --git a/docs/source/conf.py b/docs/source/conf.py index cc722dc..4255f0a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -53,16 +53,16 @@ releases_release_uri = f"{REPO_LINK}/releases/tag/v%s" -def linkcode_resolve(domain: str, info: dict) -> str: +def linkcode_resolve(domain: str, info: dict) -> str | None: """linkcode_resolve.""" if domain != "py": return None if not info["module"]: return None - import importlib - import inspect - import types + import importlib # noqa: PLC0415 + import inspect # noqa: PLC0415 + import types # noqa: PLC0415 mod = importlib.import_module(info["module"]) diff --git a/make.ps1 b/make.ps1 index 19cb79c..6a2d898 100644 --- a/make.ps1 +++ b/make.ps1 @@ -60,7 +60,7 @@ function Invoke-Upgrade-Deps function Invoke-Lint { pre-commit run --all-files - python -m ruff --fix . + python -m ruff check --fix . python -m ruff format . python -m mypy --strict src/ } diff --git a/pyproject.toml b/pyproject.toml index 736c9c5..b3a1620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,13 @@ docs = { file = ["requirements/requirements-docs.txt"] } [tool.ruff] preview = true unsafe-fixes = true -target-version = "py311" +target-version = "py312" +line-length = 120 [tool.ruff.lint] -select = ["ALL"] +select = [ + "ALL", +] ignore = [ "CPY001", # (Missing copyright notice at top of file) "ERA001", # (Found commented-out code) - Porting features a piece at a time diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 763d55f..366c886 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,10 +1,10 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/requirements-dev.in # -build==1.0.3 +build==1.2.1 # via pip-tools cfgv==3.4.0 # via pre-commit @@ -12,11 +12,11 @@ click==8.1.7 # via pip-tools distlib==0.3.8 # via virtualenv -filelock==3.13.1 +filelock==3.13.4 # via # -c requirements/requirements.txt # virtualenv -identify==2.5.34 +identify==2.5.35 # via pre-commit mypy==1.9.0 # via -r requirements/requirements-dev.in @@ -24,7 +24,7 @@ mypy-extensions==1.0.0 # via mypy nodeenv==1.8.0 # via pre-commit -packaging==23.2 +packaging==24.0 # via build pip-tools==7.4.1 # via -r requirements/requirements-dev.in @@ -40,13 +40,13 @@ pyyaml==6.0.1 # via pre-commit ruff==0.3.7 # via -r requirements/requirements-dev.in -typing-extensions==4.9.0 +typing-extensions==4.11.0 # via # -c requirements/requirements.txt # mypy -virtualenv==20.25.0 +virtualenv==20.25.1 # via pre-commit -wheel==0.42.0 +wheel==0.43.0 # via pip-tools # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt index 6e5a0bc..73b0504 100644 --- a/requirements/requirements-docs.txt +++ b/requirements/requirements-docs.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/requirements-docs.in @@ -8,7 +8,7 @@ alabaster==0.7.16 # via sphinx anyascii==0.3.2 # via sphinx-autoapi -astroid==3.0.3 +astroid==3.1.0 # via sphinx-autoapi babel==2.14.0 # via sphinx @@ -26,7 +26,7 @@ docutils==0.20.1 # via sphinx furo==2024.1.29 # via -r requirements/requirements-docs.in -idna==3.6 +idna==3.7 # via # -c requirements/requirements.txt # requests @@ -38,7 +38,7 @@ jinja2==3.1.3 # sphinx-autoapi markupsafe==2.1.5 # via jinja2 -packaging==23.2 +packaging==24.0 # via sphinx pygments==2.17.2 # via diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index 1ac8734..b4bf403 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -1,12 +1,12 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/requirements-tests.in # iniconfig==2.0.0 # via pytest -packaging==23.2 +packaging==24.0 # via pytest pluggy==1.4.0 # via pytest diff --git a/requirements/requirements.in b/requirements/requirements.in index e4195ba..09e17a3 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -13,10 +13,6 @@ sentry-sdk rapidfuzz coloredlogs -# Database -psycopg[binary] -SQLAlchemy - # Utilities # utils/helpers tldextract @@ -26,7 +22,3 @@ tldextract arrow # exts/utilities/snekbox regex -# exts/fun/typeracer -wonderwords -# exts/fun/uwu -imsosorry diff --git a/requirements/requirements.txt b/requirements/requirements.txt index c53521d..498fe74 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,10 +1,10 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.12 # by the following command: # # pip-compile requirements/requirements.in # -aiodns==3.1.1 +aiodns==3.2.0 # via pydis-core aiohttp==3.9.4 # via @@ -32,46 +32,38 @@ discord-py==2.3.2 # via # -r requirements/requirements.in # pydis-core -filelock==3.13.1 +filelock==3.13.4 # via tldextract frozenlist==1.4.1 # via # aiohttp # aiosignal -greenlet==3.0.3 - # via sqlalchemy humanfriendly==10.0 # via coloredlogs -idna==3.6 +idna==3.7 # via # requests # tldextract # yarl -imsosorry==1.2.1 - # via -r requirements/requirements.in multidict==6.0.5 # via # aiohttp # yarl -psycopg[binary]==3.1.18 - # via -r requirements/requirements.in -psycopg-binary==3.1.18 - # via psycopg pycares==4.4.0 # via aiodns -pycparser==2.21 +pycparser==2.22 # via cffi -pydantic==2.6.1 +pydantic==2.7.0 # via # pydantic-settings # pydis-core -pydantic-core==2.16.2 +pydantic-core==2.18.1 # via pydantic pydantic-settings==2.2.1 # via -r requirements/requirements.in pydis-core==11.1.0 # via -r requirements/requirements.in -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via arrow python-dotenv==1.0.1 # via pydantic-settings @@ -89,25 +81,19 @@ sentry-sdk==1.45.0 # via -r requirements/requirements.in six==1.16.0 # via python-dateutil -sqlalchemy==2.0.29 - # via -r requirements/requirements.in statsd==4.0.1 # via pydis-core tldextract==5.1.2 # via -r requirements/requirements.in -types-python-dateutil==2.8.19.20240106 +types-python-dateutil==2.9.0.20240316 # via arrow -typing-extensions==4.9.0 +typing-extensions==4.11.0 # via - # psycopg # pydantic # pydantic-core - # sqlalchemy urllib3==2.2.1 # via # requests # sentry-sdk -wonderwords==2.2.0 - # via -r requirements/requirements.in yarl==1.9.4 # via aiohttp diff --git a/src/bot/__init__.py b/src/bot/__init__.py index 90cff79..9f219f2 100644 --- a/src/bot/__init__.py +++ b/src/bot/__init__.py @@ -1,23 +1,9 @@ """Anubis, a fancy Discord bot.""" -import logging - -import sentry_sdk -from sentry_sdk.integrations.logging import LoggingIntegration +from pydis_core.utils import apply_monkey_patches from bot import log -from bot.constants import GIT_SHA, Sentry - -sentry_logging = LoggingIntegration(level=logging.DEBUG, event_level=logging.WARNING) - -sentry_sdk.init( - dsn=Sentry.dsn, - integrations=[ - sentry_logging, - ], - release=f"{Sentry.release_prefix}@{GIT_SHA}", - traces_sample_rate=0.5, - profiles_sample_rate=0.5, -) log.setup() + +apply_monkey_patches() diff --git a/src/bot/__main__.py b/src/bot/__main__.py index 64dc029..c2a5944 100644 --- a/src/bot/__main__.py +++ b/src/bot/__main__.py @@ -8,18 +8,22 @@ from bot import constants from bot.bot import Bot - -intents = discord.Intents.default() -intents.message_content = True +from bot.log import setup_sentry async def main() -> None: """Run the bot.""" + setup_sentry() + + allowed_roles = list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}) + intents = discord.Intents.default() + intents.message_content = True + bot = Bot( guild_id=constants.Guild.id, http_session=aiohttp.ClientSession(), - allowed_roles=list({discord.Object(id_) for id_ in constants.MODERATION_ROLES}), - command_prefix=commands.when_mentioned, + allowed_roles=allowed_roles, + command_prefix=commands.when_mentioned_or(constants.Bot.prefix), intents=intents, ) diff --git a/src/bot/bot.py b/src/bot/bot.py index a6d02be..4c7492a 100644 --- a/src/bot/bot.py +++ b/src/bot/bot.py @@ -3,7 +3,7 @@ from typing import Self from pydis_core import BotBase -from sentry_sdk import push_scope +from sentry_sdk import push_scope, start_transaction from bot import exts from bot.log import get_logger @@ -25,6 +25,11 @@ class Bot(BotBase): def __init__(self: Self, *args: list, **kwargs: dict) -> None: super().__init__(*args, **kwargs) + async def load_extension(self, name: str, *args: list, **kwargs: dict) -> None: + """Extend D.py's load_extension function to also record sentry performance stats.""" + with start_transaction(op="cog-load", name=name): + await super().load_extension(name, *args, **kwargs) + async def setup_hook(self: Self) -> None: """Default async initialisation method for discord.py.""" # noqa: D401 await super().setup_hook() diff --git a/src/bot/constants.py b/src/bot/constants.py index 07fe936..1711864 100644 --- a/src/bot/constants.py +++ b/src/bot/constants.py @@ -7,9 +7,7 @@ """ from os import getenv -from typing import Self -from pydantic import root_validator from pydantic_settings import BaseSettings, SettingsConfigDict @@ -41,7 +39,6 @@ class _Bot(EnvConfig, env_prefix="bot_"): """Bot data.""" prefix: str = "!" - database_dsn: str = "postgresql+psycopg://postgres:postgres@localhost:5432/anubis" token: str = "" trace_loggers: str = "*" @@ -94,7 +91,7 @@ class _Roles(EnvConfig, env_prefix="roles_"): class _Guild(EnvConfig, env_prefix="guild_"): """Guild constants.""" - id: int = 1033456860864466995 # - variable is nested + id: int = 1033456860864466995 moderation_roles: tuple[int, ...] = (Roles.administrators, Roles.moderators) staff_roles: tuple[int, ...] = (Roles.administrators, Roles.moderators, Roles.staff) @@ -117,7 +114,7 @@ class _BaseURLs(EnvConfig, env_prefix="urls_"): github_bot_repo: str = "https://github.com/letsbuilda/anubis" - paste: str = "https://paste.pythondiscord.com" + paste_url: str = "https://paste.pythondiscord.com" BaseURLs = _BaseURLs() @@ -131,7 +128,6 @@ class _URLs(_BaseURLs): connect_max_retries: int = 3 connect_cooldown: int = 5 - paste_service: str = f"{BaseURLs.paste}/{{key}}" site_logs_view: str = "https://pythondiscord.com/staff/bot/logs" @@ -236,17 +232,19 @@ class _Icons(EnvConfig, env_prefix="icons_"): filtering: str = "https://cdn.discordapp.com/emojis/472472638594482195.png" - green_checkmark: str = "https://raw.githubusercontent.com/python-discord/branding/main/icons/checkmark/green-checkmark-dist.png" - green_questionmark: str = "https://raw.githubusercontent.com/python-discord/branding/main/icons/checkmark/green-question-mark-dist.png" + green_checkmark: str = ( + "https://raw.githubusercontent.com/python-discord/branding/main/icons/checkmark/green-checkmark-dist.png" + ) + green_questionmark: str = ( + "https://raw.githubusercontent.com/python-discord/branding/main/icons/checkmark/green-question-mark-dist.png" + ) guild_update: str = "https://cdn.discordapp.com/emojis/469954765141442561.png" hash_blurple: str = "https://cdn.discordapp.com/emojis/469950142942806017.png" hash_green: str = "https://cdn.discordapp.com/emojis/469950144918585344.png" hash_red: str = "https://cdn.discordapp.com/emojis/469950145413251072.png" - message_bulk_delete: str = ( - "https://cdn.discordapp.com/emojis/469952898994929668.png" - ) + message_bulk_delete: str = "https://cdn.discordapp.com/emojis/469952898994929668.png" message_delete: str = "https://cdn.discordapp.com/emojis/472472641320648704.png" message_edit: str = "https://cdn.discordapp.com/emojis/472472638976163870.png" @@ -264,9 +262,7 @@ class _Icons(EnvConfig, env_prefix="icons_"): superstarify: str = "https://cdn.discordapp.com/emojis/636288153044516874.png" unsuperstarify: str = "https://cdn.discordapp.com/emojis/636288201258172446.png" - token_removed: str = ( - "https://cdn.discordapp.com/emojis/470326273298792469.png" # - false positive - ) + token_removed: str = "https://cdn.discordapp.com/emojis/470326273298792469.png" # - false positive user_ban: str = "https://cdn.discordapp.com/emojis/469952898026045441.png" user_timeout: str = "https://cdn.discordapp.com/emojis/472472640100106250.png" @@ -303,13 +299,6 @@ class _Colours(EnvConfig, env_prefix="colours_"): grass_green: int = 0x66FF00 gold: int = 0xE6C200 - @root_validator(pre=True) - def parse_hex_values(cls: type[Self], values: dict[str, int]) -> dict[str, int]: # noqa: N805 - check this - """Verify that colors are valid hex.""" - for key, value in values.items(): - values[key] = int(value, 16) - return values - Colours = _Colours() diff --git a/src/bot/exts/core/error_handler.py b/src/bot/exts/core/error_handler.py index ce26d17..ce52376 100644 --- a/src/bot/exts/core/error_handler.py +++ b/src/bot/exts/core/error_handler.py @@ -33,7 +33,7 @@ def revert_cooldown_counter(command: commands.Command, message: Message) -> None bucket = command._buckets.get_bucket(message) bucket._tokens = min(bucket.rate, bucket._tokens + 1) logging.debug( - "Cooldown counter reverted as the command was not used correctly." + "Cooldown counter reverted as the command was not used correctly.", ) @staticmethod @@ -56,7 +56,7 @@ async def on_command_error( # noqa: PLR0911 - uh... """Activates when a command raises an error.""" if getattr(error, "handled", False): logging.debug( - f"Command {ctx.command} had its error already handled locally; ignoring." + f"Command {ctx.command} had its error already handled locally; ignoring.", ) return @@ -81,7 +81,7 @@ async def on_command_error( # noqa: PLR0911 - uh... self.revert_cooldown_counter(ctx.command, ctx.message) usage = f"```\n{ctx.prefix}{parent_command}{ctx.command} {ctx.command.signature}\n```" embed = self.error_embed( - f"Your input was invalid: {error}\n\nUsage:{usage}" + f"Your input was invalid: {error}\n\nUsage:{usage}", ) await ctx.send(embed=embed) return @@ -98,16 +98,18 @@ async def on_command_error( # noqa: PLR0911 - uh... if isinstance(error, commands.DisabledCommand): await ctx.send( embed=self.error_embed( - "This command has been disabled.", NEGATIVE_REPLIES - ) + "This command has been disabled.", + NEGATIVE_REPLIES, + ), ) return if isinstance(error, commands.NoPrivateMessage): await ctx.send( embed=self.error_embed( - "This command can only be used in the server. ", NEGATIVE_REPLIES - ) + "This command can only be used in the server. ", + NEGATIVE_REPLIES, + ), ) return @@ -123,8 +125,9 @@ async def on_command_error( # noqa: PLR0911 - uh... if isinstance(error, commands.CheckFailure): await ctx.send( embed=self.error_embed( - "You are not authorized to use this command.", NEGATIVE_REPLIES - ) + "You are not authorized to use this command.", + NEGATIVE_REPLIES, + ), ) return @@ -147,7 +150,7 @@ async def on_command_error( # noqa: PLR0911 - uh... if isinstance(error, commands.MaxConcurrencyReached): embed = self.error_embed( - "This command can only be used 1 time per channel concurrently." + "This command can only be used 1 time per channel concurrently.", ) await ctx.send(embed=embed) return @@ -167,12 +170,15 @@ async def on_command_error( # noqa: PLR0911 - uh... log.exception(f"Unhandled command error: {error!s}", exc_info=error) async def send_command_suggestion( - self: Self, ctx: commands.Context, command_name: str + self: Self, + ctx: commands.Context, + command_name: str, ) -> None: """Send user similar commands if any can be found.""" command_suggestions = [] if similar_command_names := get_command_suggestions( - list(self.bot.all_commands.keys()), command_name + list(self.bot.all_commands.keys()), + command_name, ): for similar_command_name in similar_command_names: similar_command = self.bot.get_command(similar_command_name) @@ -180,9 +186,7 @@ async def send_command_suggestion( if not similar_command: continue - log_msg = ( - "Cancelling attempt to suggest a command due to failed checks." - ) + log_msg = "Cancelling attempt to suggest a command due to failed checks." try: if not await similar_command.can_run(ctx): log.debug(log_msg) @@ -197,8 +201,7 @@ async def send_command_suggestion( embed = Embed() embed.set_author(name="Did you mean:", icon_url=QUESTION_MARK_ICON) embed.description = "\n".join( - misspelled_content.replace(command_name, cmd, 1) - for cmd in command_suggestions + misspelled_content.replace(command_name, cmd, 1) for cmd in command_suggestions ) await ctx.send(embed=embed, delete_after=7.5) diff --git a/src/bot/exts/core/log.py b/src/bot/exts/core/log.py index 43e5fa4..5f53187 100644 --- a/src/bot/exts/core/log.py +++ b/src/bot/exts/core/log.py @@ -37,7 +37,7 @@ async def send_log_message( """Generate log embed and send to logging channel.""" # Truncate string directly here to avoid removing newlines embed = discord.Embed( - description=text[:4093] + "..." if len(text) > 4096 else text + description=text[:4093] + "..." if len(text) > 4096 else text, ) if title and icon_url: @@ -53,11 +53,7 @@ async def send_log_message( embed.set_thumbnail(url=thumbnail) if ping_mods: - content = ( - f"<@&{Roles.moderators}> {content}" - if content - else f"<@&{Roles.moderators}>" - ) + content = f"<@&{Roles.moderators}> {content}" if content else f"<@&{Roles.moderators}>" # Truncate content to 2000 characters and append an ellipsis. if content and len(content) > 2000: @@ -71,7 +67,7 @@ async def send_log_message( await channel.send(embed=additional_embed) return await self.bot.get_context( - log_message + log_message, ) # Optionally return for use with antispam diff --git a/src/bot/exts/core/logging.py b/src/bot/exts/core/logging.py new file mode 100644 index 0000000..a5bbc94 --- /dev/null +++ b/src/bot/exts/core/logging.py @@ -0,0 +1,43 @@ +"""Logging.""" + +from discord import Embed +from discord.ext.commands import Cog +from pydis_core.utils import scheduling + +from bot.bot import Bot +from bot.constants import DEBUG_MODE, Channels +from bot.log import get_logger + +log = get_logger(__name__) + + +class Logging(Cog): + """Debug logging module.""" + + def __init__(self, bot: Bot) -> None: + self.bot = bot + + scheduling.create_task(self.startup_greeting()) + + async def startup_greeting(self) -> None: + """Announce our presence to the configured devlog channel.""" + await self.bot.wait_until_guild_available() + log.info("Bot connected!") + + embed = Embed(description="Connected!") + embed.set_author( + name="Anubis", + url="https://github.com/letsbuilda/anubis", + # icon_url=( + # "https://raw.githubusercontent.com/" + # "python-discord/branding/main/logos/logo_circle/logo_circle_large.png" + # ) + ) + + if not DEBUG_MODE: + await self.bot.get_channel(Channels.dev_log).send(embed=embed) + + +async def setup(bot: Bot) -> None: + """Load the Logging cog.""" + await bot.add_cog(Logging(bot)) diff --git a/src/bot/exts/filters/webhook_remover.py b/src/bot/exts/filters/webhook_remover.py index eefdc36..f783bfa 100644 --- a/src/bot/exts/filters/webhook_remover.py +++ b/src/bot/exts/filters/webhook_remover.py @@ -42,7 +42,9 @@ def log(self: Self) -> Log | None: return self.bot.get_cog("Log") async def delete_and_respond( - self: Self, message: Message, matches: Match[str] + self: Self, + message: Message, + matches: Match[str], ) -> None: """Delete `message` and send a warning that it contained a Discord webhook.""" webhook_url = matches[0] @@ -71,13 +73,13 @@ async def delete_and_respond( await message.delete() except NotFound: log.debug( - f"Failed to remove webhook in message {message.id}: message already deleted." + f"Failed to remove webhook in message {message.id}: message already deleted.", ) return # Log to user await message.channel.send( - ALERT_MESSAGE_TEMPLATE.format(user=message.author.mention) + ALERT_MESSAGE_TEMPLATE.format(user=message.author.mention), ) if deleted_successfully: @@ -87,7 +89,7 @@ async def delete_and_respond( # Display Bot icon as thumbnail if webhook_metadata.get("avatar") is not None: - thumb = f"https://cdn.discordapp.com/avatars/{webhook_metadata['id']}/{webhook_metadata['avatar']}.webp" + thumb = f"https://cdn.discordapp.com/avatars/{webhook_metadata["id"]}/{webhook_metadata["avatar"]}.webp" else: thumb = message.author.display_avatar.url diff --git a/src/bot/exts/fun/__init__.py b/src/bot/exts/fun/__init__.py deleted file mode 100644 index ddb76ef..0000000 --- a/src/bot/exts/fun/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Fun extensions.""" diff --git a/src/bot/exts/fun/dice.py b/src/bot/exts/fun/dice.py deleted file mode 100644 index b828696..0000000 --- a/src/bot/exts/fun/dice.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Dice rolling.""" - -from random import randint -from typing import Self - -import discord -from discord import app_commands -from discord.ext import commands - -from bot.bot import Bot - - -class Dice(commands.Cog): - """Roll dice.""" - - def __init__(self: Self, bot: Bot) -> None: - self.bot = bot - - @app_commands.command(name="roll") - async def roll( - self: Self, - interaction: discord.Interaction, - number_of_dice: int, - number_of_sides: int, - ) -> None: - """Roll dice.""" - rolls = ", ".join([ - str(randint(1, number_of_sides)) for _ in range(number_of_dice) - ]) - await interaction.response.send_message(rolls) - - -async def setup(bot: Bot) -> None: - """Load the Ping cog.""" - await bot.add_cog(Dice(bot)) diff --git a/src/bot/exts/fun/typeracer.py b/src/bot/exts/fun/typeracer.py deleted file mode 100644 index 68c21eb..0000000 --- a/src/bot/exts/fun/typeracer.py +++ /dev/null @@ -1,163 +0,0 @@ -"""Typeracer game.""" - -from asyncio import TimeoutError -from collections import defaultdict -from typing import Self - -from discord import Colour, Embed, Message -from discord.ext import commands -from discord.ext.commands import BucketType, Context -from wonderwords import RandomWord - -from bot.bot import Bot - - -class Race: - """Class to manage each typerace.""" - - def __init__(self: Self, ctx: Context, number_of_words: int) -> None: - self.ctx = ctx - self.word_generator = RandomWord() - self.word_list = self.word_generator.random_words(number_of_words) - self.i = 0 - self.scores = defaultdict(int) - self.players = defaultdict(str) - - def _next_word(self: Self) -> None: - """Get the next word.""" - self.i += 1 - - def _word_display(self: Self) -> str: - """Return word to display with Zero Width Space inserted to prevent copy pasting.""" - return "\u200b".join(self.word_list[self.i]) - - def _update_scoreboard(self: Self, user_id: int, username: str) -> None: - """Update the scoreboard.""" - self.players[user_id] = username - self.scores[user_id] += 1 - - def game_over(self: Self) -> bool: - """Check if the game is over.""" - return self.i >= len(self.word_list) - - def word_embed(self: Self) -> Embed: - """Build the word embed.""" - embed = Embed( - title="The word is:", - description=f"`{self._word_display()}`", - colour=Colour.yellow(), - ) - embed.set_footer(text=f"{self.i + 1}/{len(self.word_list)}") - return embed - - def is_correct(self: Self, answer: str) -> bool: - """Check if the answer is correct.""" - return answer == self.word_list[self.i] - - def is_copypaste(self: Self, answer: str) -> bool: - """Check if answer is copypasted.""" - return answer == self._word_display() - - def check_message(self: Self, message: Message) -> bool: - """Check function for processing valid inputs.""" - return ( - message.content == self.word_list[self.i] - or message.content == self._word_display() - ) and message.channel == self.ctx.channel - - def scoreboard_embed(self: Self) -> Embed: - """Build the scoreboard embed.""" - embed = Embed(title="Final Scoreboard", colour=Colour.blue()) - scoreboard_list = [ - (self.players[user_id], self.scores[user_id]) for user_id in self.players - ] - scoreboard_list = sorted( - scoreboard_list, key=lambda pair: pair[1], reverse=True - ) - prev = None - offset = 0 - for i, pair in enumerate(scoreboard_list): - username = pair[0] - score = pair[1] - if score == prev: - offset += 1 - else: - offset = 0 - prev = score - embed.add_field( - name=f"{i + 1 - offset}. {username}", - value=f"**{score} words**", - inline=False, - ) - return embed - - def process_correct_answer(self: Self, message: Message) -> Embed: - """Process if the answer is correct.""" - user_id = message.author.id - username = message.author.name - icon_url = message.author.avatar.url - embed = Embed( - title="The word was:", - description=self.word_list[self.i], - colour=Colour.green(), - ) - embed.set_author(name=f"{username} got it right!", icon_url=icon_url) - embed.set_footer(text=f"{self.i + 1}/{len(self.word_list)}") - self._update_scoreboard(user_id, username) - self._next_word() - return embed - - -class Typeracer(commands.Cog): - """Play typeracer.""" - - def __init__(self: Self, bot: Bot) -> None: - self.bot = bot - - @commands.command(name="typeracer") - @commands.max_concurrency(1, per=BucketType.channel) - async def typeracer(self: Self, ctx: Context, number_of_words: int = 5) -> None: - """Play Typeracer. - - Number of words defaults to 5 if not specified. - Max number of words is 30. - """ - if number_of_words < 1 or number_of_words > 30: - msg = "Number of words must be between 1 and 30." - raise commands.UserInputError(msg) - - race = Race(ctx, number_of_words) - await ctx.send("*The race has started!\nThe word to type is...*") - - while not race.game_over(): - # sending typeracer word - embed = race.word_embed() - message = await ctx.send(embed=embed) - try: - # only process messages that satisfy the check function - answer = await self.bot.wait_for( - "message", check=race.check_message, timeout=30 - ) - except TimeoutError: - await message.edit( - content=message.content + "\n\nThe game has timed out!" - ) - # display scoreboard if game times out - if race.scores: - await ctx.send(embed=race.scoreboard_embed()) - return - else: - # check if answer is correct OR copy pasted - if race.is_correct(answer.content): - embed = race.process_correct_answer(answer) - await message.edit(embed=embed) - elif race.is_copypaste(answer.content): - await ctx.send(f"{answer.author.mention} No copy & pasting!") - - # send final scoreboard - await ctx.send(embed=race.scoreboard_embed()) - - -async def setup(bot: Bot) -> None: - """Load the Typeracer cog.""" - await bot.add_cog(Typeracer(bot)) diff --git a/src/bot/exts/fun/uwu.py b/src/bot/exts/fun/uwu.py deleted file mode 100644 index 3151695..0000000 --- a/src/bot/exts/fun/uwu.py +++ /dev/null @@ -1,77 +0,0 @@ -"""The ancient arts of Uwuification.""" - -from functools import partial -from typing import Self - -import discord -from discord.ext import commands -from discord.ext.commands import Cog, Context, clean_content -from imsosorry.uwuification import uwuify - -from bot.bot import Bot -from bot.utils import helpers, messages - - -class Uwu(Cog): - """Cog for the uwu command.""" - - def __init__(self: Self, bot: Bot) -> None: - self.bot = bot - - @commands.command( - name="uwu", - aliases=( - "uwuwize", - "uwuify", - ), - ) - async def uwu_command(self: Self, ctx: Context, *, text: str | None = None) -> None: - """ - Echo an uwuified version the passed text. - - Example: - '.uwu Hello, my name is John' returns something like - 'hewwo, m-my name is j-john nyaa~'. - """ - # If `text` isn't provided then we try to get message content of a replied message - text = text or getattr(ctx.message.reference, "resolved", None) - if isinstance(text, discord.Message): - embeds = text.embeds - text = text.content - else: - embeds = None - - if text is None: - # If we weren't able to get the content of a replied message - msg = "Your message must have content or you must reply to a message." - raise commands.UserInputError(msg) - - await clean_content(fix_channel_mentions=True).convert(ctx, text) - - newufy = partial(uwuify, emoji_strength=1.0) - - # Grabs the text from the embed for uwuification - if embeds: - embed = messages.convert_embed(newufy, embeds[0]) - else: - # Parse potential message links in text - text, embed = await messages.get_text_and_embed(ctx, text) - - # If an embed is found, grab and uwuify its text - if embed: - embed = messages.convert_embed(newufy, embed) - - # Adds the text harvested from an embed to be put into another quote block. - if text: - converted_text = newufy(text) - converted_text = helpers.suppress_links(converted_text) - converted_text = f">>> {converted_text.lstrip('> ')}" - else: - converted_text = None - - await ctx.send(content=converted_text, embed=embed) - - -async def setup(bot: Bot) -> None: - """Load the uwu cog.""" - await bot.add_cog(Uwu(bot)) diff --git a/src/bot/exts/info/code_snippets.py b/src/bot/exts/info/code_snippets.py index d55d16f..593e3c5 100644 --- a/src/bot/exts/info/code_snippets.py +++ b/src/bot/exts/info/code_snippets.py @@ -62,11 +62,16 @@ def __init__(self: Self, bot: Bot) -> None: ] async def _fetch_response( - self: Self, url: str, response_format: str, **kwargs: dict + self: Self, + url: str, + response_format: str, + **kwargs: dict, ) -> str | dict | None: """Make http requests using aiohttp.""" async with self.bot.http_session.get( - url, raise_for_status=True, **kwargs + url, + raise_for_status=True, + **kwargs, ) as response: if response_format == "text": return await response.text() @@ -87,7 +92,11 @@ def _find_ref(self: Self, path: str, refs: tuple) -> tuple: return ref, file_path async def _fetch_github_snippet( - self: Self, repo: str, path: str, start_line: str, end_line: str + self: Self, + repo: str, + path: str, + start_line: str, + end_line: str, ) -> str: """Fetch a snippet from a GitHub repo.""" # Search the GitHub API for the specified branch @@ -97,7 +106,9 @@ async def _fetch_github_snippet( headers=GITHUB_HEADERS, ) tags = await self._fetch_response( - f"https://api.github.com/repos/{repo}/tags", "json", headers=GITHUB_HEADERS + f"https://api.github.com/repos/{repo}/tags", + "json", + headers=GITHUB_HEADERS, ) refs = branches + tags ref, file_path = self._find_ref(path, refs) @@ -108,7 +119,10 @@ async def _fetch_github_snippet( headers=GITHUB_HEADERS, ) return self._snippet_to_codeblock( - file_contents, file_path, start_line, end_line + file_contents, + file_path, + start_line, + end_line, ) async def _fetch_github_gist_snippet( @@ -134,12 +148,19 @@ async def _fetch_github_gist_snippet( "text", ) return self._snippet_to_codeblock( - file_contents, gist_file, start_line, end_line + file_contents, + gist_file, + start_line, + end_line, ) return "" async def _fetch_gitlab_snippet( - self: Self, repo: str, path: str, start_line: str, end_line: str + self: Self, + repo: str, + path: str, + start_line: str, + end_line: str, ) -> str: """Fetch a snippet from a GitLab repo.""" enc_repo = quote_plus(repo) @@ -150,7 +171,8 @@ async def _fetch_gitlab_snippet( "json", ) tags = await self._fetch_response( - f"https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags", "json" + f"https://gitlab.com/api/v4/projects/{enc_repo}/repository/tags", + "json", ) refs = branches + tags ref, file_path = self._find_ref(path, refs) @@ -162,7 +184,10 @@ async def _fetch_gitlab_snippet( "text", ) return self._snippet_to_codeblock( - file_contents, file_path, start_line, end_line + file_contents, + file_path, + start_line, + end_line, ) async def _fetch_bitbucket_snippet( @@ -179,11 +204,18 @@ async def _fetch_bitbucket_snippet( "text", ) return self._snippet_to_codeblock( - file_contents, file_path, start_line, end_line + file_contents, + file_path, + start_line, + end_line, ) def _snippet_to_codeblock( - self: Self, file_contents: str, file_path: str, start_line: str, end_line: str + self: Self, + file_contents: str, + file_path: str, + start_line: str, + end_line: str, ) -> str: """ Given the entire file contents and target lines, creates a code block. @@ -247,9 +279,7 @@ async def _parse_snippets(self: Self, content: str) -> str: except ClientResponseError as error: error_message = error.message log.log( - logging.DEBUG - if error.status == HTTPStatus.NOT_FOUND - else logging.ERROR, + logging.DEBUG if error.status == HTTPStatus.NOT_FOUND else logging.ERROR, f"Failed to fetch code snippet from {match[0]!r}: {error.status} " f"{error_message} for GET {error.request_info.real_url.human_repr()}", ) diff --git a/src/bot/exts/info/github.py b/src/bot/exts/info/github.py index 5539503..3347e77 100644 --- a/src/bot/exts/info/github.py +++ b/src/bot/exts/info/github.py @@ -82,7 +82,10 @@ def remove_codeblocks(message: str) -> str: return CODE_BLOCK_RE.sub("", message) async def fetch_issue( - self: Self, number: int, repository: str, user: str + self: Self, + number: int, + repository: str, + user: str, ) -> IssueState | FetchError: """ Retrieve an issue from a GitHub repository. @@ -112,11 +115,7 @@ async def fetch_issue( # from issues: if the 'issues' key is present in the response then we can pull the data we # need from the initial API call. if "issues" in json_data["html_url"]: - emoji = ( - Emojis.issue_open - if json_data.get("state") == "open" - else Emojis.issue_closed - ) + emoji = Emojis.issue_open if json_data.get("state") == "open" else Emojis.issue_closed # If the 'issues' key is not contained in the API response and there is no error code, then # we know that a PR has been requested and a call to the pulls API endpoint is necessary @@ -136,7 +135,11 @@ async def fetch_issue( issue_url = json_data.get("html_url") return IssueState( - repository, number, issue_url, json_data.get("title", ""), emoji + repository, + number, + issue_url, + json_data.get("title", ""), + emoji, ) @staticmethod @@ -153,7 +156,8 @@ def format_embed(results: list[IssueState | FetchError]) -> discord.Embed: description_list.append(f":x: [{result.return_code}] {result.message}") resp = discord.Embed( - colour=Colours.bright_green, description="\n".join(description_list) + colour=Colours.bright_green, + description="\n".join(description_list), ) resp.set_author(name="GitHub") @@ -180,7 +184,7 @@ async def on_message(self: Self, message: discord.Message) -> None: issues = [ FoundIssue(*match.group("org", "repo", "number")) for match in AUTOMATIC_REGEX.finditer( - self.remove_codeblocks(message.content) + self.remove_codeblocks(message.content), ) ] links = [] @@ -226,7 +230,9 @@ async def fetch_data(self: Self, url: str) -> tuple[dict[str], ClientResponse]: @github_group.command(name="user", aliases=("userinfo",)) async def github_user_info( - self: Self, ctx: commands.Context, username: str + self: Self, + ctx: commands.Context, + username: str, ) -> None: """Fetch a user's GitHub information.""" async with ctx.typing(): @@ -244,10 +250,7 @@ async def github_user_info( return org_data, _ = await self.fetch_data(user_data["organizations_url"]) - orgs = [ - f"[{org['login']}](https://github.com/{org['login']})" - for org in org_data - ] + orgs = [f"[{org["login"]}](https://github.com/{org["login"]})" for org in org_data] orgs_to_add = " | ".join(orgs) gists = user_data["public_gists"] @@ -256,19 +259,18 @@ async def github_user_info( if user_data["blog"].startswith("http"): # Blog link is complete blog = user_data["blog"] elif user_data["blog"]: # Blog exists but the link is not complete - blog = f"https://{user_data['blog']}" + blog = f"https://{user_data["blog"]}" else: blog = "No website link available" embed = discord.Embed( - title=f"`{user_data['login']}`'s GitHub profile info", - description=f"```\n{user_data['bio']}\n```\n" - if user_data["bio"] - else "", + title=f"`{user_data["login"]}`'s GitHub profile info", + description=f"```\n{user_data["bio"]}\n```\n" if user_data["bio"] else "", colour=discord.Colour.og_blurple(), url=user_data["html_url"], timestamp=datetime.strptime( - user_data["created_at"], "%Y-%m-%dT%H:%M:%SZ" + user_data["created_at"], + "%Y-%m-%dT%H:%M:%SZ", ), ) embed.set_thumbnail(url=user_data["avatar_url"]) @@ -277,26 +279,26 @@ async def github_user_info( if user_data["type"] == "User": embed.add_field( name="Followers", - value=f"[{user_data['followers']}]({user_data['html_url']}?tab=followers)", + value=f"[{user_data["followers"]}]({user_data["html_url"]}?tab=followers)", ) embed.add_field( name="Following", - value=f"[{user_data['following']}]({user_data['html_url']}?tab=following)", + value=f"[{user_data["following"]}]({user_data["html_url"]}?tab=following)", ) embed.add_field( name="Public repos", - value=f"[{user_data['public_repos']}]({user_data['html_url']}?tab=repositories)", + value=f"[{user_data["public_repos"]}]({user_data["html_url"]}?tab=repositories)", ) if user_data["type"] == "User": embed.add_field( name="Gists", - value=f"[{gists}](https://gist.github.com/{quote(username, safe='')})", + value=f"[{gists}](https://gist.github.com/{quote(username, safe="")})", ) embed.add_field( - name=f"Organization{'s' if len(orgs) != 1 else ''}", + name=f"Organization{"s" if len(orgs) != 1 else ""}", value=orgs_to_add if orgs else "No organizations.", ) embed.add_field(name="Website", value=blog) @@ -323,7 +325,7 @@ async def github_repo_info(self: Self, ctx: commands.Context, *repo: str) -> Non async with ctx.typing(): repo_data, _ = await self.fetch_data( - f"{GITHUB_API_URL}/repos/{quote(repo)}" + f"{GITHUB_API_URL}/repos/{quote(repo)}", ) # There won't be a message key if this repo exists @@ -347,9 +349,7 @@ async def github_repo_info(self: Self, ctx: commands.Context, *repo: str) -> Non # If it's a fork, then it will have a parent key try: parent = repo_data["parent"] - embed.description += ( - f"\n\nForked from [{parent['full_name']}]({parent['html_url']})" - ) + embed.description += f"\n\nForked from [{parent["full_name"]}]({parent["html_url"]})" except KeyError: log.debug("Repository is not a fork.") @@ -362,16 +362,18 @@ async def github_repo_info(self: Self, ctx: commands.Context, *repo: str) -> Non ) repo_created_at = datetime.strptime( - repo_data["created_at"], "%Y-%m-%dT%H:%M:%SZ" + repo_data["created_at"], + "%Y-%m-%dT%H:%M:%SZ", ).strftime("%d/%m/%Y") last_pushed = datetime.strptime( - repo_data["pushed_at"], "%Y-%m-%dT%H:%M:%SZ" + repo_data["pushed_at"], + "%Y-%m-%dT%H:%M:%SZ", ).strftime("%d/%m/%Y at %H:%M") embed.set_footer( text=( - f"{repo_data['forks_count']} ⑂ " - f"• {repo_data['stargazers_count']} ⭐ " + f"{repo_data["forks_count"]} ⑂ " + f"• {repo_data["stargazers_count"]} ⭐ " f"• Created At {repo_created_at} " f"• Last Commit {last_pushed}" ), diff --git a/src/bot/exts/utilities/internal.py b/src/bot/exts/utilities/internal.py index 1990618..5f7cb32 100644 --- a/src/bot/exts/utilities/internal.py +++ b/src/bot/exts/utilities/internal.py @@ -8,21 +8,17 @@ import traceback from collections import Counter from io import StringIO -from typing import Any, Self +from typing import Any import arrow import discord from discord.ext.commands import Cog, Context, group, has_any_role, is_owner +from pydis_core.utils.paste_service import PasteFile, PasteTooLongError, PasteUploadError, send_to_paste_service from bot.bot import Bot -from bot.constants import DEBUG_MODE, Roles +from bot.constants import DEBUG_MODE, BaseURLs, Roles from bot.log import get_logger -from bot.utils import ( - PasteTooLongError, - PasteUploadError, - find_nth_occurrence, - send_to_paste_service, -) +from bot.utils import find_nth_occurrence log = get_logger(__name__) @@ -30,7 +26,7 @@ class Internal(Cog): """Administrator and Core Developer commands.""" - def __init__(self: Self, bot: Bot) -> None: + def __init__(self, bot: Bot) -> None: self.bot = bot self.env = {} self.ln = 0 @@ -44,16 +40,12 @@ def __init__(self: Self, bot: Bot) -> None: self.eval.add_check(is_owner().predicate) @Cog.listener() - async def on_socket_event_type(self: Self, event_type: str) -> None: + async def on_socket_event_type(self, event_type: str) -> None: """When a websocket event is received, increase our counters.""" self.socket_event_total += 1 self.socket_events[event_type] += 1 - def _format( # noqa: PLR0912 - double check this - self: Self, - inp: str, - out: Any, # noqa: ANN401 - double check this - ) -> tuple[str, discord.Embed | None]: + def _format(self, inp: str, out: Any) -> tuple[str, discord.Embed | None]: """Format the eval output into a string & attempt to format it into an Embed.""" self._ = out @@ -70,7 +62,7 @@ def _format( # noqa: PLR0912 - double check this # Create the input dialog for i, line in enumerate(lines): - if i == 0: # - spread out for docs + if i == 0: # Start dialog start = f"In [{self.ln}]: " @@ -94,7 +86,7 @@ def _format( # noqa: PLR0912 - double check this start = "...: ".rjust(len(str(self.ln)) + 7) if i == len(lines) - 2 and line.startswith("return"): - line = line[6:].strip() # noqa: PLW2901 - is literally fine + line = line[6:].strip() # Combine everything res += start + line + "\n" @@ -119,17 +111,11 @@ def _format( # noqa: PLR0912 - double check this res = (res, out) else: - if isinstance(out, str) and out.startswith( - "Traceback (most recent call last):\n" - ): + if isinstance(out, str) and out.startswith("Traceback (most recent call last):\n"): # Leave out the traceback message out = "\n" + "\n".join(out.split("\n")[1:]) - pretty = ( - out - if isinstance(out, str) - else pprint.pformat(out, compact=True, width=60) - ) + pretty = out if isinstance(out, str) else pprint.pformat(out, compact=True, width=60) if pretty != str(out): # We're using the pretty version, start on the next line @@ -151,7 +137,7 @@ def _format( # noqa: PLR0912 - double check this return res # Return (text, embed) - async def _eval(self: Self, ctx: Context, code: str) -> discord.Message | None: + async def _eval(self, ctx: Context, code: str) -> discord.Message | None: """Eval the input code string & send an embed to the invoking context.""" self.ln += 1 @@ -166,7 +152,7 @@ async def _eval(self: Self, ctx: Context, code: str) -> discord.Message | None: "channel": ctx.channel, "guild": ctx.guild, "ctx": ctx, - "self": Self, + "self": self, "bot": self.bot, "inspect": inspect, "discord": discord, @@ -187,16 +173,14 @@ async def func(): # (None,) -> Any return _ finally: self.env.update(locals()) -""".format( - textwrap.indent(code, " "), - ) +""".format(textwrap.indent(code, " ")) try: exec(code_, self.env) # noqa: S102 func = self.env["func"] res = await func() - except Exception: # noqa: BLE001 - eh... + except Exception: res = traceback.format_exc() out, embed = self._format(code, res) @@ -211,16 +195,19 @@ async def func(): # (None,) -> Any truncate_index = newline_truncate_index if len(out) > truncate_index: + file = PasteFile(content=out) try: - paste_link = await send_to_paste_service( - self.bot.http_session, out, extension="py" + resp = await send_to_paste_service( + files=[file], + http_session=self.bot.http_session, + paste_url=BaseURLs.paste_url, ) except PasteTooLongError: paste_text = "too long to upload to paste service." except PasteUploadError: paste_text = "failed to upload contents to paste service." else: - paste_text = f"full contents at {paste_link}" + paste_text = f"full contents at {resp.link}" await ctx.send( f"```py\n{out[:truncate_index]}\n```... response truncated; {paste_text}", @@ -233,16 +220,14 @@ async def func(): # (None,) -> Any @group(name="internal", aliases=("int",)) @has_any_role(Roles.administrators, Roles.core_developers) - async def internal_group(self: Self, ctx: Context) -> None: - """Internal commands. Top secret!.""" # noqa: D401 - formatting + async def internal_group(self, ctx: Context) -> None: + """Internal commands. Top secret!.""" if not ctx.invoked_subcommand: await ctx.send_help(ctx.command) @internal_group.command(name="eval", aliases=("e",)) @has_any_role(Roles.administrators) - async def eval( - self: Self, ctx: Context, *, code: str - ) -> None: # - uh... good point + async def eval(self, ctx: Context, *, code: str) -> None: """Run eval in a REPL-like format.""" code = code.strip("`") if re.match("py(thon)?\n", code): @@ -256,13 +241,13 @@ async def eval( ) and len(code.split("\n")) == 1 ): - code = "_ = " + code + code += "_ = " await self._eval(ctx, code) @internal_group.command(name="socketstats", aliases=("socket", "stats")) @has_any_role(Roles.administrators, Roles.core_developers) - async def socketstats(self: Self, ctx: Context) -> None: + async def socketstats(self, ctx: Context) -> None: """Fetch information on the socket events received from Discord.""" running_s = (arrow.utcnow() - self.socket_since).total_seconds() diff --git a/src/bot/exts/utilities/snekbox/_cog.py b/src/bot/exts/utilities/snekbox/_cog.py index b4ef526..817035a 100644 --- a/src/bot/exts/utilities/snekbox/_cog.py +++ b/src/bot/exts/utilities/snekbox/_cog.py @@ -5,7 +5,6 @@ from textwrap import dedent from typing import Literal, NamedTuple, Self -from aiohttp import ClientSession from discord import ( AllowedMentions, HTTPException, @@ -18,15 +17,14 @@ ui, ) from discord.ext.commands import Cog, Command, Context, Converter, command, guild_only -from pydis_core.utils import interactions +from pydis_core.utils import interactions, paste_service +from pydis_core.utils.paste_service import PasteFile, send_to_paste_service from pydis_core.utils.regex import FORMATTED_CODE_REGEX, RAW_CODE_REGEX from bot.bot import Bot -from bot.constants import MODERATION_ROLES, TXT_LIKE_FILES, Emojis, URLs +from bot.constants import MODERATION_ROLES, TXT_LIKE_FILES, BaseURLs, Emojis, URLs from bot.log import get_logger -from bot.utils import send_to_paste_service from bot.utils.lock import LockedResourceError, lock_arg -from bot.utils.services import PasteTooLongError, PasteUploadError from ._eval import EvalJob, EvalResult from ._io import FileAttachment @@ -128,9 +126,7 @@ async def convert( code, block, lang, delim = match.group("code", "block", "lang", "delim") codeblocks = [dedent(code)] if block: - info = ( - f"'{lang}' highlighted" if lang else "plain" - ) + " code block" + info = (f"'{lang}' highlighted" if lang else "plain") + " code block" else: info = f"{delim}-enclosed inline code" else: @@ -154,7 +150,8 @@ def __init__( ) -> None: self.version_to_switch_to = version_to_switch_to super().__init__( - label=f"Run in {self.version_to_switch_to}", style=enums.ButtonStyle.primary + label=f"Run in {self.version_to_switch_to}", + style=enums.ButtonStyle.primary, ) self.snekbox_cog = snekbox_cog @@ -177,7 +174,8 @@ async def callback(self: Self, interaction: Interaction) -> None: await interaction.message.delete() await self.snekbox_cog.run_job( - self.ctx, self.job.as_version(self.version_to_switch_to) + self.ctx, + self.job.as_version(self.version_to_switch_to), ) @@ -217,22 +215,27 @@ async def post_job(self: Self, job: EvalJob) -> EvalResult: data = job.to_dict() async with self.bot.http_session.post( - URLs.snekbox_eval_api, json=data, raise_for_status=True + URLs.snekbox_eval_api, + json=data, + raise_for_status=True, ) as resp: return EvalResult.from_dict(await resp.json()) - @staticmethod - async def upload_output(http_session: ClientSession, output: str) -> str | None: + async def upload_output(self, output: str) -> str | None: """Upload the job's output to a paste service and return a URL to it if successful.""" log.trace("Uploading full output to paste service...") + file = PasteFile(content=output, lexer="text") try: - return await send_to_paste_service( - http_session, output, extension="txt", max_length=MAX_PASTE_LENGTH + paste_response = await send_to_paste_service( + files=[file], + http_session=self.bot.http_session, + paste_url=BaseURLs.paste_url, ) - except PasteTooLongError: + return paste_response.link + except paste_service.PasteTooLongError: return "too long to upload" - except PasteUploadError: + except paste_service.PasteUploadError: return "unable to upload" @staticmethod @@ -274,9 +277,7 @@ async def format_output( output = output.replace(" max_lines: truncated = True if len(output) >= max_chars: - output = ( - f"{output[:max_chars]}\n... (truncated - too long, too many lines)" - ) + output = f"{output[:max_chars]}\n... (truncated - too long, too many lines)" else: output = f"{output}\n... (truncated - too many lines)" elif len(output) >= max_chars: @@ -304,9 +303,7 @@ async def format_output( output = f"{output[:max_chars]}\n... (truncated - too long)" if truncated: - paste_link = await self.upload_output( - self.bot.http_session, original_output - ) + paste_link = await self.upload_output(original_output) if output_default and not output: output = output_default @@ -314,7 +311,10 @@ async def format_output( return output, paste_link def _filter_files( - self: Self, ctx: Context, files: list[FileAttachment], blocked_exts: set[str] + self: Self, + ctx: Context, + files: list[FileAttachment], + blocked_exts: set[str], ) -> FilteredFiles: """Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists.""" # Filter files into allowed and blocked @@ -357,13 +357,8 @@ async def send_job(self: Self, ctx: Context, job: EvalJob) -> Message: # noqa: # This is done to make sure the last line of output contains the error # and the error is not manually printed by the author with a syntax error. - if ( - result.stdout.rstrip().endswith("EOFError: EOF when reading a line") - and result.returncode == 1 - ): - msg += ( - "\n:warning: Note: `input` is not supported by the bot :warning:\n" - ) + if result.stdout.rstrip().endswith("EOFError: EOF when reading a line") and result.returncode == 1: + msg += "\n:warning: Note: `input` is not supported by the bot :warning:\n" # Skip output if it's empty and there are file uploads if result.stdout or not result.has_files: @@ -411,11 +406,13 @@ async def send_job(self: Self, ctx: Context, job: EvalJob) -> Message: # noqa: total_files = result.files + failed_files if filter_cog: block_output, blocked_exts = await filter_cog.filter_snekbox_output( - msg, total_files, ctx.message + msg, + total_files, + ctx.message, ) if block_output: return await ctx.send( - "Attempt to circumvent filter detected. Moderator team has been alerted." + "Attempt to circumvent filter detected. Moderator team has been alerted.", ) # Filter file extensions @@ -430,7 +427,9 @@ async def send_job(self: Self, ctx: Context, job: EvalJob) -> Message: # noqa: # Both elif "" in blocked_sorted: blocked_str = ", ".join(ext for ext in blocked_sorted if ext) - blocked_msg = f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + blocked_msg = ( + f"Files with no extension or disallowed extensions can't be uploaded: **{blocked_str}**" + ) else: blocked_str = ", ".join(blocked_sorted) blocked_msg = f"Files with disallowed extensions can't be uploaded: **{blocked_str}**" @@ -440,21 +439,29 @@ async def send_job(self: Self, ctx: Context, job: EvalJob) -> Message: # noqa: # Upload remaining non-text files files = [f.to_file() for f in allowed if f not in text_files] allowed_mentions = AllowedMentions( - everyone=False, roles=False, users=[ctx.author] + everyone=False, + roles=False, + users=[ctx.author], ) view = self.build_python_version_switcher_view(job.version, ctx, job) response = await ctx.send( - msg, allowed_mentions=allowed_mentions, view=view, files=files + msg, + allowed_mentions=allowed_mentions, + view=view, + files=files, ) view.message = response log.info( - f"{ctx.author}'s {job.name} job had a return code of {result.returncode}" + f"{ctx.author}'s {job.name} job had a return code of {result.returncode}", ) return response async def continue_job( - self: Self, ctx: Context, response: Message, job_name: str + self: Self, + ctx: Context, + response: Message, + job_name: str, ) -> EvalJob | None: """ Check if the job's session should continue. @@ -474,7 +481,9 @@ async def continue_job( ) await ctx.message.add_reaction(REDO_EMOJI) await self.bot.wait_for( - "reaction_add", check=_predicate_emoji_reaction, timeout=10 + "reaction_add", + check=_predicate_emoji_reaction, + timeout=10, ) # Ensure the response that's about to be edited is still the most recent. @@ -625,8 +634,4 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) -> def predicate_emoji_reaction(ctx: Context, reaction: Reaction, user: User) -> bool: """Return True if the reaction REDO_EMOJI was added by the context message author on this message.""" - return ( - reaction.message.id == ctx.message.id - and user.id == ctx.author.id - and str(reaction) == REDO_EMOJI - ) + return reaction.message.id == ctx.message.id and user.id == ctx.author.id and str(reaction) == REDO_EMOJI diff --git a/src/bot/exts/utilities/snekbox/_eval.py b/src/bot/exts/utilities/snekbox/_eval.py index af8c83c..a0bb665 100644 --- a/src/bot/exts/utilities/snekbox/_eval.py +++ b/src/bot/exts/utilities/snekbox/_eval.py @@ -160,7 +160,8 @@ def get_message(self: Self, job: EvalJob) -> str: @classmethod def from_dict( - cls: type[Self], data: dict[str, str | int | list[dict[str, str]]] + cls: type[Self], + data: dict[str, str | int | list[dict[str, str]]], ) -> Self: """Create an EvalResult from a dict.""" res = cls( diff --git a/src/bot/exts/utilities/snekbox/_io.py b/src/bot/exts/utilities/snekbox/_io.py index 90836ca..a83be2b 100644 --- a/src/bot/exts/utilities/snekbox/_io.py +++ b/src/bot/exts/utilities/snekbox/_io.py @@ -72,7 +72,9 @@ def name(self: Self) -> str: @classmethod def from_dict( - cls: type[Self], data: dict, size_limit: int = FILE_SIZE_LIMIT + cls: type[Self], + data: dict, + size_limit: int = FILE_SIZE_LIMIT, ) -> Self: """Create a FileAttachment from a dict response.""" size = data.get("size") diff --git a/src/bot/log.py b/src/bot/log.py index 1d2c42d..105f278 100644 --- a/src/bot/log.py +++ b/src/bot/log.py @@ -3,61 +3,29 @@ import logging import os import sys -from logging import Logger, handlers +from logging import handlers from pathlib import Path -from typing import TYPE_CHECKING, Self, cast import coloredlogs import sentry_sdk +from pydis_core.utils import logging as core_logging +from sentry_sdk.integrations.asyncio import AsyncioIntegration from sentry_sdk.integrations.logging import LoggingIntegration from bot import constants -TRACE_LEVEL = 5 - - -LoggerClass = Logger if TYPE_CHECKING else logging.getLoggerClass() - - -class CustomLogger(LoggerClass): - """Custom implementation of the `Logger` class with an added `trace` method.""" - - def trace(self: Self, msg: str, *args: list, **kwargs: dict) -> None: - """ - Log 'msg % args' with severity 'TRACE'. - - To pass exception information, use the keyword argument exc_info with - a true value, e.g. - - logger.trace("Houston, we have an %s", "interesting problem", exc_info=1) - """ - if self.isEnabledFor(TRACE_LEVEL): - self.log(TRACE_LEVEL, msg, *args, **kwargs) - - -def get_logger(name: str | None = None) -> CustomLogger: - """Make mypy recognise that logger is of type `CustomLogger`.""" - return cast(CustomLogger, logging.getLogger(name)) +get_logger = core_logging.get_logger def setup() -> None: """Set up loggers.""" - logging.TRACE = TRACE_LEVEL - logging.addLevelName(TRACE_LEVEL, "TRACE") - logging.setLoggerClass(CustomLogger) - root_log = get_logger() - format_string = "%(asctime)s | %(name)s | %(levelname)s | %(message)s" - log_format = logging.Formatter(format_string) - if constants.FILE_LOGS: log_file = Path("logs", "bot.log") log_file.parent.mkdir(exist_ok=True) - file_handler = handlers.RotatingFileHandler( - log_file, maxBytes=5242880, backupCount=7, encoding="utf8" - ) - file_handler.setFormatter(log_format) + file_handler = handlers.RotatingFileHandler(log_file, maxBytes=5242880, backupCount=7, encoding="utf8") + file_handler.setFormatter(core_logging.log_format) root_log.addHandler(file_handler) if "COLOREDLOGS_LEVEL_STYLES" not in os.environ: @@ -69,18 +37,11 @@ def setup() -> None: } if "COLOREDLOGS_LOG_FORMAT" not in os.environ: - coloredlogs.DEFAULT_LOG_FORMAT = format_string + coloredlogs.DEFAULT_LOG_FORMAT = core_logging.log_format._fmt - coloredlogs.install(level=TRACE_LEVEL, logger=root_log, stream=sys.stdout) + coloredlogs.install(level=core_logging.TRACE_LEVEL, logger=root_log, stream=sys.stdout) root_log.setLevel(logging.DEBUG if constants.DEBUG_MODE else logging.INFO) - get_logger("discord").setLevel(logging.WARNING) - get_logger("websockets").setLevel(logging.WARNING) - get_logger("chardet").setLevel(logging.WARNING) - get_logger("async_rediscache").setLevel(logging.WARNING) - - # Set back to the default of INFO even if asyncio's debug mode is enabled. - get_logger("asyncio").setLevel(logging.INFO) _set_trace_loggers() @@ -88,15 +49,17 @@ def setup() -> None: def setup_sentry() -> None: """Set up the Sentry logging integrations.""" sentry_logging = LoggingIntegration( - level=logging.DEBUG, event_level=logging.WARNING + level=logging.DEBUG, + event_level=logging.WARNING, ) sentry_sdk.init( - dsn=constants.Bot.sentry_dsn, + dsn=constants.Sentry.dsn, integrations=[ sentry_logging, + AsyncioIntegration(), ], - release=f"bot@{constants.GIT_SHA}", + release=f"{constants.Sentry.release_prefix}@{constants.GIT_SHA}", traces_sample_rate=0.5, profiles_sample_rate=0.5, ) @@ -117,13 +80,13 @@ def _set_trace_loggers() -> None: level_filter = constants.Bot.trace_loggers if level_filter: if level_filter.startswith("*"): - get_logger().setLevel(TRACE_LEVEL) + get_logger().setLevel(core_logging.TRACE_LEVEL) elif level_filter.startswith("!"): - get_logger().setLevel(TRACE_LEVEL) + get_logger().setLevel(core_logging.TRACE_LEVEL) for logger_name in level_filter.strip("!,").split(","): get_logger(logger_name).setLevel(logging.DEBUG) else: for logger_name in level_filter.strip(",").split(","): - get_logger(logger_name).setLevel(TRACE_LEVEL) + get_logger(logger_name).setLevel(core_logging.TRACE_LEVEL) diff --git a/src/bot/orm/__init__.py b/src/bot/orm/__init__.py deleted file mode 100644 index f9b72f0..0000000 --- a/src/bot/orm/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Database functions.""" - -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -from bot.constants import Bot - -engine = create_engine(Bot.database_dsn) -Session = sessionmaker(engine) diff --git a/src/bot/orm/models.py b/src/bot/orm/models.py deleted file mode 100644 index 952a6af..0000000 --- a/src/bot/orm/models.py +++ /dev/null @@ -1,16 +0,0 @@ -"""Database models.""" - -from sqlalchemy import BigInteger -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column - - -class Base(DeclarativeBase): - """DeclarativeBase.""" - - -class Guild(Base): - """A Discord guild.""" - - __tablename__ = "guilds" - - guild_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) diff --git a/src/bot/utils/__init__.py b/src/bot/utils/__init__.py index 721eee6..1a7c848 100644 --- a/src/bot/utils/__init__.py +++ b/src/bot/utils/__init__.py @@ -1,18 +1,10 @@ """Internal utilites.""" from bot.utils.helpers import CogABCMeta, find_nth_occurrence, has_lines, pad_base64 -from bot.utils.services import ( - PasteTooLongError, - PasteUploadError, - send_to_paste_service, -) __all__ = [ "CogABCMeta", - "PasteTooLongError", - "PasteUploadError", "find_nth_occurrence", "has_lines", "pad_base64", - "send_to_paste_service", ] diff --git a/src/bot/utils/commands.py b/src/bot/utils/commands.py index 6f3dd1c..8e675a6 100644 --- a/src/bot/utils/commands.py +++ b/src/bot/utils/commands.py @@ -4,7 +4,11 @@ def get_command_suggestions( - all_commands: list[str], query: str, *, cutoff: int = 60, limit: int = 3 + all_commands: list[str], + query: str, + *, + cutoff: int = 60, + limit: int = 3, ) -> list[str]: """Get similar command names.""" results = process.extract(query, all_commands, score_cutoff=cutoff, limit=limit) diff --git a/src/bot/utils/exceptions.py b/src/bot/utils/exceptions.py index efd2d8a..adb3ea9 100644 --- a/src/bot/utils/exceptions.py +++ b/src/bot/utils/exceptions.py @@ -8,7 +8,10 @@ class APIError(Exception): """Raised when an external API (eg. Wikipedia) returns an error response.""" def __init__( - self: Self, api: str, status_code: int, error_msg: str | None = None + self: Self, + api: str, + status_code: int, + error_msg: str | None = None, ) -> None: super().__init__() self.api = api diff --git a/src/bot/utils/extensions.py b/src/bot/utils/extensions.py index ce5ce2f..15c559e 100644 --- a/src/bot/utils/extensions.py +++ b/src/bot/utils/extensions.py @@ -45,7 +45,9 @@ def on_error(name: str) -> NoReturn: modules = set() for module_info in pkgutil.walk_packages( - module.__path__, f"{module.__name__}.", onerror=on_error + module.__path__, + f"{module.__name__}.", + onerror=on_error, ): if ignore_module(module_info): # Ignore modules/packages that have a name starting with an underscore anywhere in their trees. diff --git a/src/bot/utils/function.py b/src/bot/utils/function.py index 1ceb9ec..082ad97 100644 --- a/src/bot/utils/function.py +++ b/src/bot/utils/function.py @@ -112,28 +112,20 @@ def update_wrapper_globals( as this can cause incorrect objects being used by discordpy's converters. """ annotation_global_names = ( - ann.split(".", maxsplit=1)[0] - for ann in wrapped.__annotations__.values() - if isinstance(ann, str) + ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str) ) # Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations. shared_globals = set(wrapper.__code__.co_names) & set(annotation_global_names) - shared_globals &= ( - set(wrapped.__globals__) & set(wrapper.__globals__) - ignored_conflict_names - ) + shared_globals &= set(wrapped.__globals__) & set(wrapper.__globals__) - ignored_conflict_names if shared_globals: msg = ( - f"wrapper and the wrapped function share the following global names used by annotations: {', '.join(shared_globals)}." # noqa: E501 - strings + f"wrapper and the wrapped function share the following global names used by annotations: {", ".join(shared_globals)}." # noqa: E501 - strings "Resolve the conflicts or add the name to the `ignored_conflict_names` set to suppress this error if this is intentional." # noqa: E501 - strings ) raise GlobalNameConflictError(msg) new_globals = wrapper.__globals__.copy() - new_globals.update( - (k, v) - for k, v in wrapped.__globals__.items() - if k not in wrapper.__code__.co_names - ) + new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names) return types.FunctionType( code=wrapper.__code__, globals=new_globals, @@ -155,7 +147,9 @@ def command_wraps( def decorator(wrapper: types.FunctionType) -> types.FunctionType: return functools.update_wrapper( update_wrapper_globals( - wrapper, wrapped, ignored_conflict_names=ignored_conflict_names + wrapper, + wrapped, + ignored_conflict_names=ignored_conflict_names, ), wrapped, assigned, diff --git a/src/bot/utils/lock.py b/src/bot/utils/lock.py index 843f39e..0a23f15 100644 --- a/src/bot/utils/lock.py +++ b/src/bot/utils/lock.py @@ -98,7 +98,7 @@ async def wrapper(*args: list, **kwargs: dict) -> Callable | None: id_ = resource_id log.trace( - f"{name}: getting the lock object for resource {namespace!r}:{id_!r}" + f"{name}: getting the lock object for resource {namespace!r}:{id_!r}", ) # Get the lock for the ID. Create a lock if one doesn't exist yet. @@ -111,13 +111,13 @@ async def wrapper(*args: list, **kwargs: dict) -> Callable | None: # 3. awaits only yield execution to the event loop at actual I/O boundaries if wait or not lock_.locked(): log.debug( - f"{name}: acquiring lock for resource {namespace!r}:{id_!r}..." + f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...", ) async with lock_: return await func(*args, **kwargs) else: log.info( - f"{name}: aborted because resource {namespace!r}:{id_!r} is locked" + f"{name}: aborted because resource {namespace!r}:{id_!r} is locked", ) if raise_error: raise LockedResourceError(str(namespace), id_) diff --git a/src/bot/utils/services.py b/src/bot/utils/services.py deleted file mode 100644 index e805946..0000000 --- a/src/bot/utils/services.py +++ /dev/null @@ -1,95 +0,0 @@ -"""Services.""" - -from aiohttp import ClientConnectorError, ClientSession - -from bot.constants import URLs -from bot.log import get_logger - -log = get_logger(__name__) - -FAILED_REQUEST_ATTEMPTS = 3 -MAX_PASTE_LENGTH = 100_000 - - -class PasteUploadError(Exception): - """Raised when an error is encountered uploading to the paste service.""" - - -class PasteTooLongError(Exception): - """Raised when content is too large to upload to the paste service.""" - - -async def send_to_paste_service( - http_session: ClientSession, - contents: str, - *, - extension: str = "", - max_length: int = MAX_PASTE_LENGTH, -) -> str: - """ - Upload `contents` to the paste service. - - Add `extension` to the output URL. Use `max_length` to limit the allowed contents length - to lower than the maximum allowed by the paste service. - - Raise `ValueError` if `max_length` is greater than the maximum allowed by the paste service. - Raise `PasteTooLongError` if `contents` is too long to upload, and `PasteUploadError` if uploading fails. - - Return the generated URL with the extension. - """ - if max_length > MAX_PASTE_LENGTH: - msg = f"`max_length` must not be greater than {MAX_PASTE_LENGTH}" - raise ValueError(msg) - - extension = extension and f".{extension}" - - contents_size = len(contents.encode()) - if contents_size > max_length: - log.info("Contents too large to send to paste service.") - msg = f"Contents of size {contents_size} greater than maximum size {max_length}" - raise PasteTooLongError(msg) - - log.debug(f"Sending contents of size {contents_size} bytes to paste service.") - paste_url = URLs.paste_service.format(key="documents") - for attempt in range(1, FAILED_REQUEST_ATTEMPTS + 1): - try: - async with http_session.post(paste_url, data=contents) as response: - response_json = await response.json() - except ClientConnectorError: - log.warning( - f"Failed to connect to paste service at url {paste_url}, " - f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS}).", - ) - continue - except Exception: - log.exception( - f"An unexpected error has occurred during handling of the request, " - f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS}).", - ) - continue - - if "message" in response_json: - log.warning( - f"Paste service returned error {response_json['message']} with status code {response.status}, " - f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS}).", - ) - continue - if "key" in response_json: - log.info( - f"Successfully uploaded contents to paste service behind key {response_json['key']}." - ) - - paste_link = URLs.paste_service.format(key=response_json["key"]) + extension - - if extension == ".py": - return paste_link - - return paste_link + "?noredirect" - - log.warning( - f"Got unexpected JSON response from paste service: {response_json}\n" - f"trying again ({attempt}/{FAILED_REQUEST_ATTEMPTS}).", - ) - - msg = "Failed to upload contents to paste service" - raise PasteUploadError(msg)