From 0d4e6a0865802b3d01292f730f11d099c3354c5d Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Fri, 18 Jan 2019 11:26:33 +1100 Subject: [PATCH 1/9] Fix MongoDB to JSON migration and warn about Mongo driver (#2373) Resolves #2372. Signed-off-by: Toby Harradine --- redbot/setup.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/redbot/setup.py b/redbot/setup.py index 3f93c6da4..ed9bf7eba 100644 --- a/redbot/setup.py +++ b/redbot/setup.py @@ -114,7 +114,7 @@ def get_storage_type(): print() print("Please choose your storage backend (if you're unsure, choose 1).") print("1. JSON (file storage, requires no database).") - print("2. MongoDB") + print("2. MongoDB (not recommended, currently unstable)") storage = input("> ") try: storage = int(storage) @@ -198,22 +198,23 @@ async def mongo_to_json(current_data_dir: Path, storage_details: dict): m = Mongo("Core", "0", **storage_details) db = m.db - collection_names = await db.collection_names(include_system_collections=False) - for c_name in collection_names: - if c_name == "Core": + collection_names = await db.list_collection_names() + for collection_name in collection_names: + if collection_name == "Core": c_data_path = current_data_dir / "core" else: - c_data_path = current_data_dir / "cogs/{}".format(c_name) - output = {} - docs = await db[c_name].find().to_list(None) - c_id = None - for item in docs: - item_id = item.pop("_id") - if not c_id: - c_id = str(hash(item_id)) - output[item_id] = item - target = JSON(c_name, c_id, data_path_override=c_data_path) - await target.jsonIO._threadsafe_save_json(output) + c_data_path = current_data_dir / "cogs" / collection_name + c_data_path.mkdir(parents=True, exist_ok=True) + # Every cog name has its own collection + collection = db[collection_name] + async for document in collection.find(): + # Every cog has its own document. + # This means if two cogs have the same name but different identifiers, they will + # be two separate documents in the same collection + cog_id = document.pop("_id") + driver = JSON(collection_name, cog_id, data_path_override=c_data_path) + for key, value in document.items(): + await driver.set(key, value=value) async def edit_instance(): From 849ade6e585572f4dd7c9a8210fa98df32d461f9 Mon Sep 17 00:00:00 2001 From: Michael H Date: Thu, 17 Jan 2019 22:45:34 -0500 Subject: [PATCH 2/9] Reconcile permission hooks with ctx.permission_state (#2375) Resolves #2374. See mod.py's voice mute for an example of why this may be necessary. --- redbot/core/bot.py | 7 ++++++- redbot/core/commands/requires.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/redbot/core/bot.py b/redbot/core/bot.py index e0778f53c..4f5de7bee 100644 --- a/redbot/core/bot.py +++ b/redbot/core/bot.py @@ -500,7 +500,12 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): if result is not None: hook_results.append(result) if hook_results: - return all(hook_results) + if all(hook_results): + ctx.permission_state = commands.PermState.ALLOWED_BY_HOOK + return True + else: + ctx.permission_state = commands.PermState.DENIED_BY_HOOK + return False class Red(RedBase, discord.AutoShardedClient): diff --git a/redbot/core/commands/requires.py b/redbot/core/commands/requires.py index 2b98d69c7..7c546b4e8 100644 --- a/redbot/core/commands/requires.py +++ b/redbot/core/commands/requires.py @@ -168,6 +168,17 @@ class PermState(enum.Enum): chain. """ + # The below are valid states, but should not be transitioned to + # They should be set if they apply. + + ALLOWED_BY_HOOK = enum.auto() + """This command has been actively allowed by a permission hook. + check validation doesn't need this, but is useful to developers""" + + DENIED_BY_HOOK = enum.auto() + """This command has been actively denied by a permission hook + check validation doesn't need this, but is useful to developers""" + def transition_to( self, next_state: "PermState" ) -> Tuple[Optional[bool], Union["PermState", Dict[bool, "PermState"]]]: From 1c4193cce24ac4891446425f5625602e97c1d826 Mon Sep 17 00:00:00 2001 From: Michael H Date: Thu, 17 Jan 2019 22:48:00 -0500 Subject: [PATCH 3/9] [Permissions] Quick extra comment of importance (#2379) --- redbot/core/commands/requires.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/redbot/core/commands/requires.py b/redbot/core/commands/requires.py index 7c546b4e8..9759919ba 100644 --- a/redbot/core/commands/requires.py +++ b/redbot/core/commands/requires.py @@ -93,6 +93,10 @@ DM_PERMS.update( class PrivilegeLevel(enum.IntEnum): """Enumeration for special privileges.""" + # Maintainer Note: do NOT re-order these. + # Each privelege level also implies access to the ones before it. + # Inserting new privelege levels at a later point is fine if that is considered. + NONE = enum.auto() """No special privilege level.""" From 158c4f741b761aa69aae06d5f17bfc986637250a Mon Sep 17 00:00:00 2001 From: Iangit1 <43935737+Iangit1@users.noreply.github.com> Date: Sun, 20 Jan 2019 22:07:55 +0000 Subject: [PATCH 4/9] Grammar in ask_sentry and interactive_config (#2383) --- redbot/core/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redbot/core/cli.py b/redbot/core/cli.py index 8d31b8f0a..55e0ba3c0 100644 --- a/redbot/core/cli.py +++ b/redbot/core/cli.py @@ -30,7 +30,7 @@ def interactive_config(red, token_set, prefix_set): "\nPick a prefix. A prefix is what you type before a " "command. Example:\n" "!help\n^ The exclamation mark is the prefix in this case.\n" - "Can be multiple characters. You will be able to change it " + "The prefix can be multiple characters. You will be able to change it " "later and add more of them.\nChoose your prefix:\n" ) while not prefix: @@ -51,7 +51,7 @@ def ask_sentry(red: Red): loop = asyncio.get_event_loop() print( "\nThank you for installing Red V3! Red is constantly undergoing\n" - " improvements, and we would like ask if you are comfortable with\n" + " improvements, and we would like to ask if you are comfortable with\n" " the bot automatically submitting fatal error logs to the development\n" ' team. If you wish to opt into the process please type "yes":\n' ) From 348277bcbde542ebb39c17d791be266724f55071 Mon Sep 17 00:00:00 2001 From: Caleb Johnson Date: Sun, 27 Jan 2019 19:43:21 -0600 Subject: [PATCH 5/9] [Audio] Lavalink 3.0/3.1 compatibility updates (#2272) - Update to red-lavalink v0.2.0 (blocked by Cog-Creators/Red-Lavalink#41) - Force lavalink to use TLSv1.2 on java 11+ (blocked by #2270) I would add equalizer support, but there's no way to know the full Lavalink version and thus whether it's supported ahead of time. --- redbot/cogs/audio/manager.py | 5 ++++- setup.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/redbot/cogs/audio/manager.py b/redbot/cogs/audio/manager.py index ae10bb3be..355749242 100644 --- a/redbot/cogs/audio/manager.py +++ b/redbot/cogs/audio/manager.py @@ -96,9 +96,12 @@ async def start_lavalink_server(loop): if not java_available: raise RuntimeError("You must install Java 1.8+ for Lavalink to run.") - extra_flags = "" if java_version == (1, 8): extra_flags = "-Dsun.zip.disableMemoryMapping=true" + elif java_version >= (11, 0): + extra_flags = "-Djdk.tls.client.protocols=TLSv1.2" + else: + extra_flags = "" from . import LAVALINK_DOWNLOAD_DIR, LAVALINK_JAR_FILE diff --git a/setup.py b/setup.py index 0c463f01a..bdb05ec3a 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ extras_require = { "sphinxcontrib-websupport==1.1.0", "urllib3==1.24.1", ], - "voice": ["red-lavalink==0.1.2"], + "voice": ["red-lavalink==0.2.0"], "style": ["black==18.9b0", "click==7.0", "toml==0.10.0"], } From 05bef917aeca69eda253606c2a1d77f282733896 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Mon, 28 Jan 2019 14:14:36 +1100 Subject: [PATCH 6/9] Vendor discord.py (#2387) Signed-off-by: Toby Harradine --- .readthedocs.yml | 2 - MANIFEST.in | 3 - Makefile | 6 + Pipfile | 1 - Pipfile.lock | 18 +- README.md | 5 + dependency_links.txt | 1 - discord/__init__.py | 64 ++ discord/__main__.py | 337 +++++++ discord/abc.py | 1030 ++++++++++++++++++++ discord/activity.py | 613 ++++++++++++ discord/audit_logs.py | 366 +++++++ discord/backoff.py | 86 ++ discord/calls.py | 157 +++ discord/channel.py | 986 +++++++++++++++++++ discord/client.py | 1074 ++++++++++++++++++++ discord/colour.py | 234 +++++ discord/context_managers.py | 69 ++ discord/embeds.py | 492 ++++++++++ discord/emoji.py | 269 +++++ discord/enums.py | 274 ++++++ discord/errors.py | 183 ++++ discord/ext/commands/__init__.py | 19 + discord/ext/commands/bot.py | 1049 ++++++++++++++++++++ discord/ext/commands/context.py | 225 +++++ discord/ext/commands/converter.py | 560 +++++++++++ discord/ext/commands/cooldowns.py | 148 +++ discord/ext/commands/core.py | 1517 +++++++++++++++++++++++++++++ discord/ext/commands/errors.py | 279 ++++++ discord/ext/commands/formatter.py | 365 +++++++ discord/ext/commands/view.py | 201 ++++ discord/file.py | 81 ++ discord/gateway.py | 701 +++++++++++++ discord/guild.py | 1419 +++++++++++++++++++++++++++ discord/http.py | 911 +++++++++++++++++ discord/invite.py | 176 ++++ discord/iterators.py | 489 ++++++++++ discord/member.py | 621 ++++++++++++ discord/message.py | 799 +++++++++++++++ discord/mixins.py | 44 + discord/object.py | 71 ++ discord/opus.py | 286 ++++++ discord/permissions.py | 636 ++++++++++++ discord/player.py | 356 +++++++ discord/raw_models.py | 151 +++ discord/reaction.py | 151 +++ discord/relationship.py | 79 ++ discord/role.py | 297 ++++++ discord/shard.py | 370 +++++++ discord/state.py | 1048 ++++++++++++++++++++ discord/user.py | 699 +++++++++++++ discord/utils.py | 353 +++++++ discord/voice_client.py | 438 +++++++++ discord/webhook.py | 703 +++++++++++++ docs/Makefile | 3 - docs/guide_cog_creation.rst | 2 +- docs/install_linux_mac.rst | 6 +- docs/install_windows.rst | 6 +- make.bat | 15 +- redbot/launcher.py | 13 +- setup.py | 16 +- tox.ini | 2 - 62 files changed, 21516 insertions(+), 59 deletions(-) delete mode 100644 MANIFEST.in delete mode 100644 dependency_links.txt create mode 100644 discord/__init__.py create mode 100644 discord/__main__.py create mode 100644 discord/abc.py create mode 100644 discord/activity.py create mode 100644 discord/audit_logs.py create mode 100644 discord/backoff.py create mode 100644 discord/calls.py create mode 100644 discord/channel.py create mode 100644 discord/client.py create mode 100644 discord/colour.py create mode 100644 discord/context_managers.py create mode 100644 discord/embeds.py create mode 100644 discord/emoji.py create mode 100644 discord/enums.py create mode 100644 discord/errors.py create mode 100644 discord/ext/commands/__init__.py create mode 100644 discord/ext/commands/bot.py create mode 100644 discord/ext/commands/context.py create mode 100644 discord/ext/commands/converter.py create mode 100644 discord/ext/commands/cooldowns.py create mode 100644 discord/ext/commands/core.py create mode 100644 discord/ext/commands/errors.py create mode 100644 discord/ext/commands/formatter.py create mode 100644 discord/ext/commands/view.py create mode 100644 discord/file.py create mode 100644 discord/gateway.py create mode 100644 discord/guild.py create mode 100644 discord/http.py create mode 100644 discord/invite.py create mode 100644 discord/iterators.py create mode 100644 discord/member.py create mode 100644 discord/message.py create mode 100644 discord/mixins.py create mode 100644 discord/object.py create mode 100644 discord/opus.py create mode 100644 discord/permissions.py create mode 100644 discord/player.py create mode 100644 discord/raw_models.py create mode 100644 discord/reaction.py create mode 100644 discord/relationship.py create mode 100644 discord/role.py create mode 100644 discord/shard.py create mode 100644 discord/state.py create mode 100644 discord/user.py create mode 100644 discord/utils.py create mode 100644 discord/voice_client.py create mode 100644 discord/webhook.py diff --git a/.readthedocs.yml b/.readthedocs.yml index fc3d63898..2512e080e 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -4,8 +4,6 @@ formats: build: image: latest -requirements_file: dependency_links.txt - python: version: 3.6 pip_install: true diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 0f59a790e..000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -include README.md -include LICENSE -include dependency_links.txt diff --git a/Makefile b/Makefile index f2702b0b9..beecb6d28 100644 --- a/Makefile +++ b/Makefile @@ -5,3 +5,9 @@ stylecheck: gettext: redgettext --command-docstrings --verbose --recursive redbot --exclude-files "redbot/pytest/**/*" crowdin upload + +REF?=rewrite +update_vendor: + pip install --upgrade --no-deps -t . https://github.com/Rapptz/discord.py/archive/$(REF).tar.gz#egg=discord.py + rm -r discord.py*.egg-info + $(MAKE) reformat diff --git a/Pipfile b/Pipfile index 3c0b9e0ee..a4cd61513 100644 --- a/Pipfile +++ b/Pipfile @@ -4,7 +4,6 @@ verify_ssl = true name = "pypi" [packages] -"discord.py" = { git = 'git://github.com/Rapptz/discord.py', ref = 'rewrite', editable = true } "e1839a8" = { path = ".", editable = true, extras = ['mongo', 'voice'] } [dev-packages] diff --git a/Pipfile.lock b/Pipfile.lock index 69d5be9a3..a85d89025 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "edd35f353e1fadc20094e40de6627db77fd61303da01794214c44d748e99838b" + "sha256": "57184ef83392116db24a1966022ad358f54048bb43d428d47a6e31f1a88fc434" }, "pipfile-spec": 6, "requires": {}, @@ -83,16 +83,6 @@ ], "version": "==0.4.1" }, - "discord-py": { - "editable": true, - "git": "git://github.com/Rapptz/discord.py", - "ref": "7f4c57dd5ad20b7fa10aea485f674a4bc24b9547" - }, - "discord.py": { - "editable": true, - "git": "git://github.com/Rapptz/discord.py", - "ref": "rewrite" - }, "distro": { "hashes": [ "sha256:224041cef9600e72d19ae41ba006e71c05c4dc802516da715d7fda55ba3d8742", @@ -715,11 +705,11 @@ }, "tox": { "hashes": [ - "sha256:2a8d8a63660563e41e64e3b5b677e81ce1ffa5e2a93c2c565d3768c287445800", - "sha256:edfca7809925f49bdc110d0a2d9966bbf35a0c25637216d9586e7a5c5de17bfb" + "sha256:04f8f1aa05de8e76d7a266ccd14e0d665d429977cd42123bc38efa9b59964e9e", + "sha256:25ef928babe88c71e3ed3af0c464d1160b01fca2dd1870a5bb26c2dea61a17fc" ], "index": "pypi", - "version": "==3.6.1" + "version": "==3.7.0" }, "urllib3": { "hashes": [ diff --git a/README.md b/README.md index 303c5decc..99f16e048 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,11 @@ Join us on our [Official Discord Server](https://discord.gg/red)! Released under the [GNU GPL v3](https://www.gnu.org/licenses/gpl-3.0.en.html) license. +This project vendors the +[discord.py library by Rapptz](https://github.com/Rapptz/discord.py/tree/rewrite), which is +licensed under the [MIT License](https://opensource.org/licenses/MIT). This amounts to everything +within the *discord* folder of this repository. + Red is named after the main character of "Transistor", a video game by [Super Giant Games](https://www.supergiantgames.com/games/transistor/). diff --git a/dependency_links.txt b/dependency_links.txt deleted file mode 100644 index c05bff902..000000000 --- a/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ -https://github.com/Rapptz/discord.py/tarball/7f4c57dd5ad20b7fa10aea485f674a4bc24b9547#egg=discord.py-1.0.0a0 diff --git a/discord/__init__.py b/discord/__init__.py new file mode 100644 index 000000000..2f8bdd637 --- /dev/null +++ b/discord/__init__.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +""" +Discord API Wrapper +~~~~~~~~~~~~~~~~~~~ + +A basic wrapper for the Discord API. + +:copyright: (c) 2015-2017 Rapptz +:license: MIT, see LICENSE for more details. + +""" + +__title__ = "discord" +__author__ = "Rapptz" +__license__ = "MIT" +__copyright__ = "Copyright 2015-2017 Rapptz" +__version__ = "1.0.0a" + +from collections import namedtuple +import logging + +from .client import Client, AppInfo +from .user import User, ClientUser, Profile +from .emoji import Emoji, PartialEmoji +from .activity import * +from .channel import * +from .guild import Guild +from .relationship import Relationship +from .member import Member, VoiceState +from .message import Message, Attachment +from .errors import * +from .calls import CallMessage, GroupCall +from .permissions import Permissions, PermissionOverwrite +from .role import Role +from .file import File +from .colour import Color, Colour +from .invite import Invite +from .object import Object +from .reaction import Reaction +from . import utils, opus, abc +from .enums import * +from .embeds import Embed +from .shard import AutoShardedClient +from .player import * +from .webhook import * +from .voice_client import VoiceClient +from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff +from .raw_models import * + +VersionInfo = namedtuple("VersionInfo", "major minor micro releaselevel serial") + +version_info = VersionInfo(major=1, minor=0, micro=0, releaselevel="alpha", serial=0) + +try: + from logging import NullHandler +except ImportError: + + class NullHandler(logging.Handler): + def emit(self, record): + pass + + +logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/discord/__main__.py b/discord/__main__.py new file mode 100644 index 000000000..33bb961d6 --- /dev/null +++ b/discord/__main__.py @@ -0,0 +1,337 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import argparse +import sys +from pathlib import Path + +import discord + + +def core(parser, args): + pass + + +bot_template = """#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from discord.ext import commands +import discord +import config + +class Bot(commands.{base}): + def __init__(self, **kwargs): + super().__init__(command_prefix=commands.when_mentioned_or('{prefix}'), **kwargs) + for cog in config.cogs: + try: + self.load_extension(cog) + except Exception as exc: + print('Could not load extension {{0}} due to {{1.__class__.__name__}}: {{1}}'.format(cog, exc)) + + async def on_ready(self): + print('Logged on as {{0}} (ID: {{0.id}})'.format(self.user)) + + +bot = Bot() + +# write general commands here + +bot.run(config.token) +""" + +gitignore_template = """# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +*.egg-info/ +.installed.cfg +*.egg + +# Our configuration files +config.py +""" + +cog_template = '''# -*- coding: utf-8 -*- + +from discord.ext import commands +import discord + +class {name}: + """The description for {name} goes here.""" + + def __init__(self, bot): + self.bot = bot +{extra} +def setup(bot): + bot.add_cog({name}(bot)) +''' + +cog_extras = """ + def __unload(self): + # clean up logic goes here + pass + + async def __local_check(self, ctx): + # checks that apply to every command in here + return True + + async def __global_check(self, ctx): + # checks that apply to every command to the bot + return True + + async def __global_check_once(self, ctx): + # check that apply to every command but is guaranteed to be called only once + return True + + async def __error(self, ctx, error): + # error handling to every command in here + pass + + async def __before_invoke(self, ctx): + # called before a command is called here + pass + + async def __after_invoke(self, ctx): + # called after a command is called here + pass + +""" + + +# certain file names and directory names are forbidden +# see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx +# although some of this doesn't apply to Linux, we might as well be consistent +_base_table = { + "<": "-", + ">": "-", + ":": "-", + '"': "-", + # '/': '-', these are fine + # '\\': '-', + "|": "-", + "?": "-", + "*": "-", +} + +# +_base_table.update((chr(i), None) for i in range(32)) + +translation_table = str.maketrans(_base_table) + + +def to_path(parser, name, *, replace_spaces=False): + if isinstance(name, Path): + return name + + if sys.platform == "win32": + forbidden = ( + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + ) + if len(name) <= 4 and name.upper() in forbidden: + parser.error("invalid directory name given, use a different one") + + name = name.translate(translation_table) + if replace_spaces: + name = name.replace(" ", "-") + return Path(name) + + +def newbot(parser, args): + if sys.version_info < (3, 5): + parser.error("python version is older than 3.5, consider upgrading.") + + new_directory = to_path(parser, args.directory) / to_path(parser, args.name) + + # as a note exist_ok for Path is a 3.5+ only feature + # since we already checked above that we're >3.5 + try: + new_directory.mkdir(exist_ok=True, parents=True) + except OSError as exc: + parser.error("could not create our bot directory ({})".format(exc)) + + cogs = new_directory / "cogs" + + try: + cogs.mkdir(exist_ok=True) + init = cogs / "__init__.py" + init.touch() + except OSError as exc: + print("warning: could not create cogs directory ({})".format(exc)) + + try: + with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp: + fp.write('token = "place your token here"\ncogs = []\n') + except OSError as exc: + parser.error("could not create config file ({})".format(exc)) + + try: + with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp: + base = "Bot" if not args.sharded else "AutoShardedBot" + fp.write(bot_template.format(base=base, prefix=args.prefix)) + except OSError as exc: + parser.error("could not create bot file ({})".format(exc)) + + if not args.no_git: + try: + with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp: + fp.write(gitignore_template) + except OSError as exc: + print("warning: could not create .gitignore file ({})".format(exc)) + + print("successfully made bot at", new_directory) + + +def newcog(parser, args): + if sys.version_info < (3, 5): + parser.error("python version is older than 3.5, consider upgrading.") + + cog_dir = to_path(parser, args.directory) + try: + cog_dir.mkdir(exist_ok=True) + except OSError as exc: + print("warning: could not create cogs directory ({})".format(exc)) + + directory = cog_dir / to_path(parser, args.name) + directory = directory.with_suffix(".py") + try: + with open(str(directory), "w", encoding="utf-8") as fp: + extra = cog_extras if args.full else "" + if args.class_name: + name = args.class_name + else: + name = str(directory.stem) + if "-" in name: + name = name.replace("-", " ").title().replace(" ", "") + else: + name = name.title() + fp.write(cog_template.format(name=name, extra=extra)) + except OSError as exc: + parser.error("could not create cog file ({})".format(exc)) + else: + print("successfully made cog at", directory) + + +def add_newbot_args(subparser): + parser = subparser.add_parser("newbot", help="creates a command bot project quickly") + parser.set_defaults(func=newbot) + + parser.add_argument("name", help="the bot project name") + parser.add_argument( + "directory", + help="the directory to place it in (default: .)", + nargs="?", + default=Path.cwd(), + ) + parser.add_argument( + "--prefix", help="the bot prefix (default: $)", default="$", metavar="" + ) + parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true") + parser.add_argument( + "--no-git", help="do not create a .gitignore file", action="store_true", dest="no_git" + ) + + +def add_newcog_args(subparser): + parser = subparser.add_parser("newcog", help="creates a new cog template quickly") + parser.set_defaults(func=newcog) + + parser.add_argument("name", help="the cog name") + parser.add_argument( + "directory", + help="the directory to place it in (default: cogs)", + nargs="?", + default=Path("cogs"), + ) + parser.add_argument( + "--class-name", help="the class name of the cog (default: )", dest="class_name" + ) + parser.add_argument("--full", help="add all special methods as well", action="store_true") + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="discord", description="Tools for helping with discord.py" + ) + + version = "discord.py v{0.__version__} for Python {1[0]}.{1[1]}.{1[2]}".format( + discord, sys.version_info + ) + parser.add_argument( + "-v", "--version", action="version", version=version, help="shows the library version" + ) + parser.set_defaults(func=core) + + subparser = parser.add_subparsers(dest="subcommand", title="subcommands") + add_newbot_args(subparser) + add_newcog_args(subparser) + return parser, parser.parse_args() + + +def main(): + parser, args = parse_args() + args.func(parser, args) + + +main() diff --git a/discord/abc.py b/discord/abc.py new file mode 100644 index 000000000..146f80925 --- /dev/null +++ b/discord/abc.py @@ -0,0 +1,1030 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import abc +import copy +import asyncio +from collections import namedtuple + +from .iterators import HistoryIterator +from .context_managers import Typing +from .errors import InvalidArgument, ClientException, HTTPException +from .permissions import PermissionOverwrite, Permissions +from .role import Role +from .invite import Invite +from .file import File +from .voice_client import VoiceClient +from . import utils + + +class _Undefined: + def __repr__(self): + return "see-below" + + +_undefined = _Undefined() + + +class Snowflake(metaclass=abc.ABCMeta): + """An ABC that details the common operations on a Discord model. + + Almost all :ref:`Discord models ` meet this + abstract base class. + + Attributes + ----------- + id: :class:`int` + The model's unique ID. + """ + + __slots__ = () + + @property + @abc.abstractmethod + def created_at(self): + """Returns the model's creation time in UTC.""" + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, C): + if cls is Snowflake: + mro = C.__mro__ + for attr in ("created_at", "id"): + for base in mro: + if attr in base.__dict__: + break + else: + return NotImplemented + return True + return NotImplemented + + +class User(metaclass=abc.ABCMeta): + """An ABC that details the common operations on a Discord user. + + The following implement this ABC: + + - :class:`User` + - :class:`ClientUser` + - :class:`Member` + + This ABC must also implement :class:`abc.Snowflake`. + + Attributes + ----------- + name: :class:`str` + The user's username. + discriminator: :class:`str` + The user's discriminator. + avatar: Optional[:class:`str`] + The avatar hash the user has. + bot: :class:`bool` + If the user is a bot account. + """ + + __slots__ = () + + @property + @abc.abstractmethod + def display_name(self): + """Returns the user's display name.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def mention(self): + """Returns a string that allows you to mention the given user.""" + raise NotImplementedError + + @classmethod + def __subclasshook__(cls, C): + if cls is User: + if Snowflake.__subclasshook__(C) is NotImplemented: + return NotImplemented + + mro = C.__mro__ + for attr in ("display_name", "mention", "name", "avatar", "discriminator", "bot"): + for base in mro: + if attr in base.__dict__: + break + else: + return NotImplemented + return True + return NotImplemented + + +class PrivateChannel(metaclass=abc.ABCMeta): + """An ABC that details the common operations on a private Discord channel. + + The following implement this ABC: + + - :class:`DMChannel` + - :class:`GroupChannel` + + This ABC must also implement :class:`abc.Snowflake`. + + Attributes + ----------- + me: :class:`ClientUser` + The user presenting yourself. + """ + + __slots__ = () + + @classmethod + def __subclasshook__(cls, C): + if cls is PrivateChannel: + if Snowflake.__subclasshook__(C) is NotImplemented: + return NotImplemented + + mro = C.__mro__ + for base in mro: + if "me" in base.__dict__: + return True + return NotImplemented + return NotImplemented + + +_Overwrites = namedtuple("_Overwrites", "id allow deny type") + + +class GuildChannel: + """An ABC that details the common operations on a Discord guild channel. + + The following implement this ABC: + + - :class:`TextChannel` + - :class:`VoiceChannel` + - :class:`CategoryChannel` + + This ABC must also implement :class:`abc.Snowflake`. + + Attributes + ----------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. + e.g. the top channel is position 0. + """ + + __slots__ = () + + def __str__(self): + return self.name + + async def _move(self, position, parent_id=None, lock_permissions=False, *, reason): + if position < 0: + raise InvalidArgument("Channel position cannot be less than 0.") + + http = self._state.http + cls = type(self) + channels = [c for c in self.guild.channels if isinstance(c, cls)] + + if position >= len(channels): + raise InvalidArgument( + "Channel position cannot be greater than {}".format(len(channels) - 1) + ) + + channels.sort(key=lambda c: c.position) + + try: + # remove ourselves from the channel list + channels.remove(self) + except ValueError: + # not there somehow lol + return + else: + # add ourselves at our designated position + channels.insert(position, self) + + payload = [] + for index, c in enumerate(channels): + d = {"id": c.id, "position": index} + if parent_id is not _undefined and c.id == self.id: + d.update(parent_id=parent_id, lock_permissions=lock_permissions) + payload.append(d) + + await http.bulk_channel_update(self.guild.id, payload, reason=reason) + self.position = position + if parent_id is not _undefined: + self.category_id = int(parent_id) if parent_id else None + + async def _edit(self, options, reason): + try: + parent = options.pop("category") + except KeyError: + parent_id = _undefined + else: + parent_id = parent and parent.id + + try: + options["rate_limit_per_user"] = options.pop("slowmode_delay") + except KeyError: + pass + + lock_permissions = options.pop("sync_permissions", False) + + try: + position = options.pop("position") + except KeyError: + if parent_id is not _undefined: + if lock_permissions: + category = self.guild.get_channel(parent_id) + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + options["parent_id"] = parent_id + elif lock_permissions and self.category_id is not None: + # if we're syncing permissions on a pre-existing channel category without changing it + # we need to update the permissions to point to the pre-existing category + category = self.guild.get_channel(self.category_id) + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + else: + await self._move( + position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason + ) + + if options: + data = await self._state.http.edit_channel(self.id, reason=reason, **options) + self._update(self.guild, data) + + def _fill_overwrites(self, data): + self._overwrites = [] + everyone_index = 0 + everyone_id = self.guild.id + + for index, overridden in enumerate(data.get("permission_overwrites", [])): + overridden_id = int(overridden.pop("id")) + self._overwrites.append(_Overwrites(id=overridden_id, **overridden)) + + if overridden["type"] == "member": + continue + + if overridden_id == everyone_id: + # the @everyone role is not guaranteed to be the first one + # in the list of permission overwrites, however the permission + # resolution code kind of requires that it is the first one in + # the list since it is special. So we need the index so we can + # swap it to be the first one. + everyone_index = index + + # do the swap + tmp = self._overwrites + if tmp: + tmp[everyone_index], tmp[0] = tmp[0], tmp[everyone_index] + + @property + def changed_roles(self): + """Returns a :class:`list` of :class:`Roles` that have been overridden from + their default values in the :attr:`Guild.roles` attribute.""" + ret = [] + g = self.guild + for overwrite in filter(lambda o: o.type == "role", self._overwrites): + role = g.get_role(overwrite.id) + if role is None: + continue + + role = copy.copy(role) + role.permissions.handle_overwrite(overwrite.allow, overwrite.deny) + ret.append(role) + return ret + + @property + def mention(self): + """:class:`str` : The string that allows you to mention the channel.""" + return "<#%s>" % self.id + + @property + def created_at(self): + """Returns the channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def overwrites_for(self, obj): + """Returns the channel-specific overwrites for a member or a role. + + Parameters + ----------- + obj + The :class:`Role` or :class:`abc.User` denoting + whose overwrite to get. + + Returns + --------- + :class:`PermissionOverwrite` + The permission overwrites for this object. + """ + + if isinstance(obj, User): + predicate = lambda p: p.type == "member" + elif isinstance(obj, Role): + predicate = lambda p: p.type == "role" + else: + predicate = lambda p: True + + for overwrite in filter(predicate, self._overwrites): + if overwrite.id == obj.id: + allow = Permissions(overwrite.allow) + deny = Permissions(overwrite.deny) + return PermissionOverwrite.from_pair(allow, deny) + + return PermissionOverwrite() + + @property + def overwrites(self): + """Returns all of the channel's overwrites. + + This is returned as a list of two-element tuples containing the target, + which can be either a :class:`Role` or a :class:`Member` and the overwrite + as the second element as a :class:`PermissionOverwrite`. + + Returns + -------- + List[Tuple[Union[:class:`Role`, :class:`Member`], :class:`PermissionOverwrite`]]: + The channel's permission overwrites. + """ + ret = [] + for ow in self._overwrites: + allow = Permissions(ow.allow) + deny = Permissions(ow.deny) + overwrite = PermissionOverwrite.from_pair(allow, deny) + + if ow.type == "role": + target = self.guild.get_role(ow.id) + elif ow.type == "member": + target = self.guild.get_member(ow.id) + + ret.append((target, overwrite)) + return ret + + @property + def category(self): + """Optional[:class:`CategoryChannel`]: The category this channel belongs to. + + If there is no category then this is ``None``. + """ + return self.guild.get_channel(self.category_id) + + def permissions_for(self, member): + """Handles permission resolution for the current :class:`Member`. + + This function takes into consideration the following cases: + + - Guild owner + - Guild roles + - Channel overrides + - Member overrides + + Parameters + ---------- + member : :class:`Member` + The member to resolve permissions for. + + Returns + ------- + :class:`Permissions` + The resolved permissions for the member. + """ + + # The current cases can be explained as: + # Guild owner get all permissions -- no questions asked. Otherwise... + # The @everyone role gets the first application. + # After that, the applied roles that the user has in the channel + # (or otherwise) are then OR'd together. + # After the role permissions are resolved, the member permissions + # have to take into effect. + # After all that is done.. you have to do the following: + + # If manage permissions is True, then all permissions are set to True. + + # The operation first takes into consideration the denied + # and then the allowed. + + o = self.guild.owner + if o is not None and member.id == o.id: + return Permissions.all() + + default = self.guild.default_role + base = Permissions(default.permissions.value) + roles = member.roles + + # Apply guild roles that the member has. + for role in roles: + base.value |= role.permissions.value + + # Guild-wide Administrator -> True for everything + # Bypass all channel-specific overrides + if base.administrator: + return Permissions.all() + + # Apply @everyone allow/deny first since it's special + try: + maybe_everyone = self._overwrites[0] + if maybe_everyone.id == self.guild.id: + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + remaining_overwrites = self._overwrites[1:] + else: + remaining_overwrites = self._overwrites + except IndexError: + remaining_overwrites = self._overwrites + + # not sure if doing member._roles.get(...) is better than the + # set approach. While this is O(N) to re-create into a set for O(1) + # the direct approach would just be O(log n) for searching with no + # extra memory overhead. For now, I'll keep the set cast + # Note that the member.roles accessor up top also creates a + # temporary list + member_role_ids = {r.id for r in roles} + denies = 0 + allows = 0 + + # Apply channel specific role permission overwrites + for overwrite in remaining_overwrites: + if overwrite.type == "role" and overwrite.id in member_role_ids: + denies |= overwrite.deny + allows |= overwrite.allow + + base.handle_overwrite(allow=allows, deny=denies) + + # Apply member specific permission overwrites + for overwrite in remaining_overwrites: + if overwrite.type == "member" and overwrite.id == member.id: + base.handle_overwrite(allow=overwrite.allow, deny=overwrite.deny) + break + + # if you can't send a message in a channel then you can't have certain + # permissions as well + if not base.send_messages: + base.send_tts_messages = False + base.mention_everyone = False + base.embed_links = False + base.attach_files = False + + # if you can't read a channel then you have no permissions there + if not base.read_messages: + denied = Permissions.all_channel() + base.value &= ~denied.value + + return base + + async def delete(self, *, reason=None): + """|coro| + + Deletes the channel. + + You must have :attr:`~.Permissions.manage_channels` permission to use this. + + Parameters + ----------- + reason: Optional[str] + The reason for deleting this channel. + Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have proper permissions to delete the channel. + NotFound + The channel was not found or was already deleted. + HTTPException + Deleting the channel failed. + """ + await self._state.http.delete_channel(self.id, reason=reason) + + async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): + r"""|coro| + + Sets the channel specific permission overwrites for a target in the + channel. + + The ``target`` parameter should either be a :class:`Member` or a + :class:`Role` that belongs to guild. + + The ``overwrite`` parameter, if given, must either be ``None`` or + :class:`PermissionOverwrite`. For convenience, you can pass in + keyword arguments denoting :class:`Permissions` attributes. If this is + done, then you cannot mix the keyword arguments with the ``overwrite`` + parameter. + + If the ``overwrite`` parameter is ``None``, then the permission + overwrites are deleted. + + You must have the :attr:`~Permissions.manage_roles` permission to use this. + + Examples + ---------- + + Setting allow and deny: :: + + await message.channel.set_permissions(message.author, read_messages=True, + send_messages=False) + + Deleting overwrites :: + + await channel.set_permissions(member, overwrite=None) + + Using :class:`PermissionOverwrite` :: + + overwrite = PermissionOverwrite() + overwrite.send_messages = False + overwrite.read_messages = True + await channel.set_permissions(member, overwrite=overwrite) + + Parameters + ----------- + target + The :class:`Member` or :class:`Role` to overwrite permissions for. + overwrite: :class:`PermissionOverwrite` + The permissions to allow and deny to the target. + \*\*permissions + A keyword argument list of permissions to set for ease of use. + Cannot be mixed with ``overwrite``. + reason: Optional[str] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to edit channel specific permissions. + HTTPException + Editing channel specific permissions failed. + NotFound + The role or member being edited is not part of the guild. + InvalidArgument + The overwrite parameter invalid or the target type was not + :class:`Role` or :class:`Member`. + """ + + http = self._state.http + + if isinstance(target, User): + perm_type = "member" + elif isinstance(target, Role): + perm_type = "role" + else: + raise InvalidArgument("target parameter must be either Member or Role") + + if isinstance(overwrite, _Undefined): + if len(permissions) == 0: + raise InvalidArgument("No overwrite provided.") + try: + overwrite = PermissionOverwrite(**permissions) + except (ValueError, TypeError): + raise InvalidArgument("Invalid permissions given to keyword arguments.") + else: + if len(permissions) > 0: + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") + + # TODO: wait for event + + if overwrite is None: + await http.delete_channel_permissions(self.id, target.id, reason=reason) + elif isinstance(overwrite, PermissionOverwrite): + (allow, deny) = overwrite.pair() + await http.edit_channel_permissions( + self.id, target.id, allow.value, deny.value, perm_type, reason=reason + ) + else: + raise InvalidArgument("Invalid overwrite type provided.") + + async def create_invite(self, *, reason=None, **fields): + """|coro| + + Creates an instant invite. + + You must have :attr:`~.Permissions.create_instant_invite` permission to + do this. + + Parameters + ------------ + max_age : int + How long the invite should last. If it's 0 then the invite + doesn't expire. Defaults to 0. + max_uses : int + How many uses the invite could be used for. If it's 0 then there + are unlimited uses. Defaults to 0. + temporary : bool + Denotes that the invite grants temporary membership + (i.e. they get kicked after they disconnect). Defaults to False. + unique: bool + Indicates if a unique invite URL should be created. Defaults to True. + If this is set to False then it will return a previously created + invite. + reason: Optional[str] + The reason for creating this invite. Shows up on the audit log. + + Raises + ------- + HTTPException + Invite creation failed. + + Returns + -------- + :class:`Invite` + The invite that was created. + """ + + data = await self._state.http.create_invite(self.id, reason=reason, **fields) + return Invite.from_incomplete(data=data, state=self._state) + + async def invites(self): + """|coro| + + Returns a list of all active instant invites from this channel. + + You must have :attr:`~.Permissions.manage_guild` to get this information. + + Raises + ------- + Forbidden + You do not have proper permissions to get the information. + HTTPException + An error occurred while fetching the information. + + Returns + ------- + List[:class:`Invite`] + The list of invites that are currently active. + """ + + state = self._state + data = await state.http.invites_from_channel(self.id) + result = [] + + for invite in data: + invite["channel"] = self + invite["guild"] = self.guild + result.append(Invite(state=state, data=invite)) + + return result + + +class Messageable(metaclass=abc.ABCMeta): + """An ABC that details the common operations on a model that can send messages. + + The following implement this ABC: + + - :class:`TextChannel` + - :class:`DMChannel` + - :class:`GroupChannel` + - :class:`User` + - :class:`Member` + - :class:`~ext.commands.Context` + + This ABC must also implement :class:`abc.Snowflake`. + """ + + __slots__ = () + + @abc.abstractmethod + async def _get_channel(self): + raise NotImplementedError + + async def send( + self, + content=None, + *, + tts=False, + embed=None, + file=None, + files=None, + delete_after=None, + nonce=None + ): + """|coro| + + Sends a message to the destination with the content given. + + The content must be a type that can convert to a string through ``str(content)``. + If the content is set to ``None`` (the default), then the ``embed`` parameter must + be provided. + + To upload a single file, the ``file`` parameter should be used with a + single :class:`File` object. To upload multiple files, the ``files`` + parameter should be used with a :class:`list` of :class:`File` objects. + **Specifying both parameters will lead to an exception**. + + If the ``embed`` parameter is provided, it must be of type :class:`Embed` and + it must be a rich embed type. + + Parameters + ------------ + content + The content of the message to send. + tts: bool + Indicates if the message should be sent using text-to-speech. + embed: :class:`Embed` + The rich embed for the content. + file: :class:`File` + The file to upload. + files: List[:class:`File`] + A list of files to upload. Must be a maximum of 10. + nonce: int + The nonce to use for sending this message. If the message was successfully sent, + then the message will have a nonce with this value. + delete_after: float + If provided, the number of seconds to wait in the background + before deleting the message we just sent. If the deletion fails, + then it is silently ignored. + + Raises + -------- + HTTPException + Sending the message failed. + Forbidden + You do not have the proper permissions to send the message. + InvalidArgument + The ``files`` list is not of the appropriate size or + you specified both ``file`` and ``files``. + + Returns + --------- + :class:`Message` + The message that was sent. + """ + + channel = await self._get_channel() + state = self._state + content = str(content) if content is not None else None + if embed is not None: + embed = embed.to_dict() + + if file is not None and files is not None: + raise InvalidArgument("cannot pass both file and files parameter to send()") + + if file is not None: + if not isinstance(file, File): + raise InvalidArgument("file parameter must be File") + + try: + data = await state.http.send_files( + channel.id, + files=[(file.open_file(), file.filename)], + content=content, + tts=tts, + embed=embed, + nonce=nonce, + ) + finally: + file.close() + + elif files is not None: + if len(files) > 10: + raise InvalidArgument("files parameter must be a list of up to 10 elements") + + try: + param = [(f.open_file(), f.filename) for f in files] + data = await state.http.send_files( + channel.id, files=param, content=content, tts=tts, embed=embed, nonce=nonce + ) + finally: + for f in files: + f.close() + else: + data = await state.http.send_message( + channel.id, content, tts=tts, embed=embed, nonce=nonce + ) + + ret = state.create_message(channel=channel, data=data) + if delete_after is not None: + + async def delete(): + await asyncio.sleep(delete_after, loop=state.loop) + try: + await ret.delete() + except HTTPException: + pass + + asyncio.ensure_future(delete(), loop=state.loop) + return ret + + async def trigger_typing(self): + """|coro| + + Triggers a *typing* indicator to the destination. + + *Typing* indicator will go away after 10 seconds, or after a message is sent. + """ + + channel = await self._get_channel() + await self._state.http.send_typing(channel.id) + + def typing(self): + """Returns a context manager that allows you to type for an indefinite period of time. + + This is useful for denoting long computations in your bot. + + .. note:: + + This is both a regular context manager and an async context manager. + This means that both ``with`` and ``async with`` work with this. + + Example Usage: :: + + async with channel.typing(): + # do expensive stuff here + await channel.send('done!') + + """ + return Typing(self) + + async def get_message(self, id): + """|coro| + + Retrieves a single :class:`Message` from the destination. + + This can only be used by bot accounts. + + Parameters + ------------ + id: int + The message ID to look for. + + Returns + -------- + :class:`Message` + The message asked for. + + Raises + -------- + NotFound + The specified message was not found. + Forbidden + You do not have the permissions required to get a message. + HTTPException + Retrieving the message failed. + """ + + channel = await self._get_channel() + data = await self._state.http.get_message(channel.id, id) + return self._state.create_message(channel=channel, data=data) + + async def pins(self): + """|coro| + + Returns a :class:`list` of :class:`Message` that are currently pinned. + + Raises + ------- + HTTPException + Retrieving the pinned messages failed. + """ + + channel = await self._get_channel() + state = self._state + data = await state.http.pins_from(channel.id) + return [state.create_message(channel=channel, data=m) for m in data] + + def history(self, *, limit=100, before=None, after=None, around=None, reverse=None): + """Return an :class:`AsyncIterator` that enables receiving the destination's message history. + + You must have :attr:`~.Permissions.read_message_history` permissions to use this. + + All parameters are optional. + + Parameters + ----------- + limit: Optional[int] + The number of messages to retrieve. + If ``None``, retrieves every message in the channel. Note, however, + that this would make it a slow operation. + before: :class:`Message` or `datetime` + Retrieve messages before this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + after: :class:`Message` or `datetime` + Retrieve messages after this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + around: :class:`Message` or `datetime` + Retrieve messages around this date or message. + If a date is provided it must be a timezone-naive datetime representing UTC time. + When using this argument, the maximum limit is 101. Note that if the limit is an + even number then this will return at most limit + 1 messages. + reverse: bool + If set to true, return messages in oldest->newest order. If unspecified, + this defaults to ``False`` for most cases. However if passing in a + ``after`` parameter then this is set to ``True``. This avoids getting messages + out of order in the ``after`` case. + + Raises + ------ + Forbidden + You do not have permissions to get channel message history. + HTTPException + The request to get message history failed. + + Yields + ------- + :class:`Message` + The message with the message data parsed. + + Examples + --------- + + Usage :: + + counter = 0 + async for message in channel.history(limit=200): + if message.author == client.user: + counter += 1 + + Flattening into a list: :: + + messages = await channel.history(limit=123).flatten() + # messages is now a list of Message... + """ + return HistoryIterator( + self, limit=limit, before=before, after=after, around=around, reverse=reverse + ) + + +class Connectable(metaclass=abc.ABCMeta): + """An ABC that details the common operations on a channel that can + connect to a voice server. + + The following implement this ABC: + + - :class:`VoiceChannel` + """ + + __slots__ = () + + @abc.abstractmethod + def _get_voice_client_key(self): + raise NotImplementedError + + @abc.abstractmethod + def _get_voice_state_pair(self): + raise NotImplementedError + + async def connect(self, *, timeout=60.0, reconnect=True): + """|coro| + + Connects to voice and creates a :class:`VoiceClient` to establish + your connection to the voice server. + + Parameters + ----------- + timeout: float + The timeout in seconds to wait for the voice endpoint. + reconnect: bool + Whether the bot should automatically attempt + a reconnect if a part of the handshake fails + or the gateway goes down. + + Raises + ------- + asyncio.TimeoutError + Could not connect to the voice channel in time. + ClientException + You are already connected to a voice channel. + OpusNotLoaded + The opus library has not been loaded. + + Returns + ------- + :class:`VoiceClient` + A voice client that is fully connected to the voice server. + """ + key_id, _ = self._get_voice_client_key() + state = self._state + + if state._get_voice_client(key_id): + raise ClientException("Already connected to a voice channel.") + + voice = VoiceClient(state=state, timeout=timeout, channel=self) + state._add_voice_client(key_id, voice) + + try: + await voice.connect(reconnect=reconnect) + except asyncio.TimeoutError: + try: + await voice.disconnect(force=True) + except Exception: + # we don't care if disconnect failed because connection failed + pass + raise # re-raise + + return voice diff --git a/discord/activity.py b/discord/activity.py new file mode 100644 index 000000000..edc50e9fb --- /dev/null +++ b/discord/activity.py @@ -0,0 +1,613 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import datetime + +from .enums import ActivityType, try_enum +from .colour import Colour + +__all__ = ["Activity", "Streaming", "Game", "Spotify"] + +"""If curious, this is the current schema for an activity. + +It's fairly long so I will document it here: + +All keys are optional. + +state: str (max: 128), +details: str (max: 128) +timestamps: dict + start: int (min: 1) + end: int (min: 1) +assets: dict + large_image: str (max: 32) + large_text: str (max: 128) + small_image: str (max: 32) + small_text: str (max: 128) +party: dict + id: str (max: 128), + size: List[int] (max-length: 2) + elem: int (min: 1) +secrets: dict + match: str (max: 128) + join: str (max: 128) + spectate: str (max: 128) +instance: bool +application_id: str +name: str (max: 128) +url: str +type: int +sync_id: str +session_id: str +flags: int + +There are also activity flags which are mostly uninteresting for the library atm. + +t.ActivityFlags = { + INSTANCE: 1, + JOIN: 2, + SPECTATE: 4, + JOIN_REQUEST: 8, + SYNC: 16, + PLAY: 32 +} +""" + + +class _ActivityTag: + __slots__ = () + + +class Activity(_ActivityTag): + """Represents an activity in Discord. + + This could be an activity such as streaming, playing, listening + or watching. + + For memory optimisation purposes, some activities are offered in slimmed + down versions: + + - :class:`Game` + - :class:`Streaming` + + Attributes + ------------ + application_id: :class:`str` + The application ID of the game. + name: :class:`str` + The name of the activity. + url: :class:`str` + A stream URL that the activity could be doing. + type: :class:`ActivityType` + The type of activity currently being done. + state: :class:`str` + The user's current state. For example, "In Game". + details: :class:`str` + The detail of the user's current activity. + timestamps: :class:`dict` + A dictionary of timestamps. It contains the following optional keys: + + - ``start``: Corresponds to when the user started doing the + activity in milliseconds since Unix epoch. + - ``end``: Corresponds to when the user will finish doing the + activity in milliseconds since Unix epoch. + + assets: :class:`dict` + A dictionary representing the images and their hover text of an activity. + It contains the following optional keys: + + - ``large_image``: A string representing the ID for the large image asset. + - ``large_text``: A string representing the text when hovering over the large image asset. + - ``small_image``: A string representing the ID for the small image asset. + - ``small_text``: A string representing the text when hovering over the small image asset. + + party: :class:`dict` + A dictionary representing the activity party. It contains the following optional keys: + + - ``id``: A string representing the party ID. + - ``size``: A list of up to two integer elements denoting (current_size, maximum_size). + """ + + __slots__ = ( + "state", + "details", + "timestamps", + "assets", + "party", + "flags", + "sync_id", + "session_id", + "type", + "name", + "url", + "application_id", + ) + + def __init__(self, **kwargs): + self.state = kwargs.pop("state", None) + self.details = kwargs.pop("details", None) + self.timestamps = kwargs.pop("timestamps", {}) + self.assets = kwargs.pop("assets", {}) + self.party = kwargs.pop("party", {}) + self.application_id = kwargs.pop("application_id", None) + self.name = kwargs.pop("name", None) + self.url = kwargs.pop("url", None) + self.flags = kwargs.pop("flags", 0) + self.sync_id = kwargs.pop("sync_id", None) + self.session_id = kwargs.pop("session_id", None) + self.type = try_enum(ActivityType, kwargs.pop("type", -1)) + + def to_dict(self): + ret = {} + for attr in self.__slots__: + value = getattr(self, attr, None) + if value is None: + continue + + if isinstance(value, dict) and len(value) == 0: + continue + + ret[attr] = value + ret["type"] = int(self.type) + return ret + + @property + def start(self): + """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" + try: + return datetime.datetime.utcfromtimestamp(self.timestamps["start"] / 1000) + except KeyError: + return None + + @property + def end(self): + """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" + try: + return datetime.datetime.utcfromtimestamp(self.timestamps["end"] / 1000) + except KeyError: + return None + + @property + def large_image_url(self): + """Optional[:class:`str`]: Returns a URL pointing to the large image asset of this activity if applicable.""" + if self.application_id is None: + return None + + try: + large_image = self.assets["large_image"] + except KeyError: + return None + else: + return "https://cdn.discordapp.com/app-assets/{0}/{1}.png".format( + self.application_id, large_image + ) + + @property + def small_image_url(self): + """Optional[:class:`str`]: Returns a URL pointing to the small image asset of this activity if applicable.""" + if self.application_id is None: + return None + + try: + small_image = self.assets["small_image"] + except KeyError: + return None + else: + return "https://cdn.discordapp.com/app-assets/{0}/{1}.png".format( + self.application_id, small_image + ) + + @property + def large_image_text(self): + """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" + return self.assets.get("large_text", None) + + @property + def small_image_text(self): + """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" + return self.assets.get("small_text", None) + + +class Game(_ActivityTag): + """A slimmed down version of :class:`Activity` that represents a Discord game. + + This is typically displayed via **Playing** on the official Discord client. + + .. container:: operations + + .. describe:: x == y + + Checks if two games are equal. + + .. describe:: x != y + + Checks if two games are not equal. + + .. describe:: hash(x) + + Returns the game's hash. + + .. describe:: str(x) + + Returns the game's name. + + Parameters + ----------- + name: :class:`str` + The game's name. + start: Optional[:class:`datetime.datetime`] + A naive UTC timestamp representing when the game started. Keyword-only parameter. Ignored for bots. + end: Optional[:class:`datetime.datetime`] + A naive UTC timestamp representing when the game ends. Keyword-only parameter. Ignored for bots. + + Attributes + ----------- + name: :class:`str` + The game's name. + """ + + __slots__ = ("name", "_end", "_start") + + def __init__(self, name, **extra): + self.name = name + + try: + timestamps = extra["timestamps"] + except KeyError: + self._extract_timestamp(extra, "start") + self._extract_timestamp(extra, "end") + else: + self._start = timestamps.get("start", 0) + self._end = timestamps.get("end", 0) + + def _extract_timestamp(self, data, key): + try: + dt = data[key] + except KeyError: + setattr(self, "_" + key, 0) + else: + setattr(self, "_" + key, dt.timestamp() * 1000.0) + + @property + def type(self): + """Returns the game's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.playing`. + """ + return ActivityType.playing + + @property + def start(self): + """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" + if self._start: + return datetime.datetime.utcfromtimestamp(self._start / 1000) + return None + + @property + def end(self): + """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" + if self._end: + return datetime.datetime.utcfromtimestamp(self._end / 1000) + return None + + def __str__(self): + return str(self.name) + + def __repr__(self): + return "".format(self) + + def to_dict(self): + timestamps = {} + if self._start: + timestamps["start"] = self._start + + if self._end: + timestamps["end"] = self._end + + return { + "type": ActivityType.playing.value, + "name": str(self.name), + "timestamps": timestamps, + } + + def __eq__(self, other): + return isinstance(other, Game) and other.name == self.name + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.name) + + +class Streaming(_ActivityTag): + """A slimmed down version of :class:`Activity` that represents a Discord streaming status. + + This is typically displayed via **Streaming** on the official Discord client. + + .. container:: operations + + .. describe:: x == y + + Checks if two streams are equal. + + .. describe:: x != y + + Checks if two streams are not equal. + + .. describe:: hash(x) + + Returns the stream's hash. + + .. describe:: str(x) + + Returns the stream's name. + + Attributes + ----------- + name: :class:`str` + The stream's name. + url: :class:`str` + The stream's URL. Currently only twitch.tv URLs are supported. Anything else is silently + discarded. + details: Optional[:class:`str`] + If provided, typically the game the streamer is playing. + assets: :class:`dict` + A dictionary comprising of similar keys than those in :attr:`Activity.assets`. + """ + + __slots__ = ("name", "url", "details", "assets") + + def __init__(self, *, name, url, **extra): + self.name = name + self.url = url + self.details = extra.pop("details", None) + self.assets = extra.pop("assets", {}) + + @property + def type(self): + """Returns the game's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.streaming`. + """ + return ActivityType.streaming + + def __str__(self): + return str(self.name) + + def __repr__(self): + return "".format(self) + + @property + def twitch_name(self): + """Optional[:class:`str`]: If provided, the twitch name of the user streaming. + + This corresponds to the ``large_image`` key of the :attr:`Streaming.assets` + dictionary if it starts with ``twitch:``. Typically set by the Discord client. + """ + + try: + name = self.assets["large_image"] + except KeyError: + return None + else: + return name[7:] if name[:7] == "twitch:" else None + + def to_dict(self): + ret = { + "type": ActivityType.streaming.value, + "name": str(self.name), + "url": str(self.url), + "assets": self.assets, + } + if self.details: + ret["details"] = self.details + return ret + + def __eq__(self, other): + return isinstance(other, Streaming) and other.name == self.name and other.url == self.url + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.name) + + +class Spotify: + """Represents a Spotify listening activity from Discord. This is a special case of + :class:`Activity` that makes it easier to work with the Spotify integration. + + .. container:: operations + + .. describe:: x == y + + Checks if two activities are equal. + + .. describe:: x != y + + Checks if two activities are not equal. + + .. describe:: hash(x) + + Returns the activity's hash. + + .. describe:: str(x) + + Returns the string 'Spotify'. + """ + + __slots__ = ( + "_state", + "_details", + "_timestamps", + "_assets", + "_party", + "_sync_id", + "_session_id", + ) + + def __init__(self, **data): + self._state = data.pop("state", None) + self._details = data.pop("details", None) + self._timestamps = data.pop("timestamps", {}) + self._assets = data.pop("assets", {}) + self._party = data.pop("party", {}) + self._sync_id = data.pop("sync_id") + self._session_id = data.pop("session_id") + + @property + def type(self): + """Returns the activity's type. This is for compatibility with :class:`Activity`. + + It always returns :attr:`ActivityType.listening`. + """ + return ActivityType.listening + + @property + def colour(self): + """Returns the Spotify integration colour, as a :class:`Colour`. + + There is an alias for this named :meth:`color`""" + return Colour(0x1DB954) + + @property + def color(self): + """Returns the Spotify integration colour, as a :class:`Colour`. + + There is an alias for this named :meth:`colour`""" + return self.colour + + def to_dict(self): + return { + "flags": 48, # SYNC | PLAY + "name": "Spotify", + "assets": self._assets, + "party": self._party, + "sync_id": self._sync_id, + "session_id": self._session_id, + "timestamps": self._timestamps, + "details": self._details, + "state": self._state, + } + + @property + def name(self): + """:class:`str`: The activity's name. This will always return "Spotify".""" + return "Spotify" + + def __eq__(self, other): + return isinstance(other, Spotify) and other._session_id == self._session_id + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._session_id) + + def __str__(self): + return "Spotify" + + def __repr__(self): + return "".format( + self + ) + + @property + def title(self): + """:class:`str`: The title of the song being played.""" + return self._details + + @property + def artists(self): + """List[:class:`str`]: The artists of the song being played.""" + return self._state.split("; ") + + @property + def artist(self): + """:class:`str`: The artist of the song being played. + + This does not attempt to split the artist information into + multiple artists. Useful if there's only a single artist. + """ + return self._state + + @property + def album(self): + """:class:`str`: The album that the song being played belongs to.""" + return self._assets.get("large_text", "") + + @property + def album_cover_url(self): + """:class:`str`: The album cover image URL from Spotify's CDN.""" + large_image = self._assets.get("large_image", "") + if large_image[:8] != "spotify:": + return "" + album_image_id = large_image[8:] + return "https://i.scdn.co/image/" + album_image_id + + @property + def track_id(self): + """:class:`str`: The track ID used by Spotify to identify this song.""" + return self._sync_id + + @property + def start(self): + """:class:`datetime.datetime`: When the user started playing this song in UTC.""" + return datetime.datetime.utcfromtimestamp(self._timestamps["start"] / 1000) + + @property + def end(self): + """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" + return datetime.datetime.utcfromtimestamp(self._timestamps["end"] / 1000) + + @property + def duration(self): + """:class:`datetime.timedelta`: The duration of the song being played.""" + return self.end - self.start + + @property + def party_id(self): + """:class:`str`: The party ID of the listening party.""" + return self._party.get("id", "") + + +def create_activity(data): + if not data: + return None + + game_type = try_enum(ActivityType, data.get("type", -1)) + if game_type is ActivityType.playing: + if "application_id" in data or "session_id" in data: + return Activity(**data) + return Game(**data) + elif game_type is ActivityType.streaming: + if "url" in data: + return Streaming(**data) + return Activity(**data) + elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: + return Spotify(**data) + return Activity(**data) diff --git a/discord/audit_logs.py b/discord/audit_logs.py new file mode 100644 index 000000000..c21e0afde --- /dev/null +++ b/discord/audit_logs.py @@ -0,0 +1,366 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from . import utils, enums +from .object import Object +from .permissions import PermissionOverwrite, Permissions +from .colour import Colour +from .invite import Invite + + +def _transform_verification_level(entry, data): + return enums.try_enum(enums.VerificationLevel, data) + + +def _transform_default_notifications(entry, data): + return enums.try_enum(enums.NotificationLevel, data) + + +def _transform_explicit_content_filter(entry, data): + return enums.try_enum(enums.ContentFilter, data) + + +def _transform_permissions(entry, data): + return Permissions(data) + + +def _transform_color(entry, data): + return Colour(data) + + +def _transform_snowflake(entry, data): + return int(data) + + +def _transform_channel(entry, data): + if data is None: + return None + channel = entry.guild.get_channel(int(data)) or Object(id=data) + return channel + + +def _transform_owner_id(entry, data): + if data is None: + return None + return entry._get_member(int(data)) + + +def _transform_inviter_id(entry, data): + if data is None: + return None + return entry._get_member(int(data)) + + +def _transform_overwrites(entry, data): + overwrites = [] + for elem in data: + allow = Permissions(elem["allow"]) + deny = Permissions(elem["deny"]) + ow = PermissionOverwrite.from_pair(allow, deny) + + ow_type = elem["type"] + ow_id = int(elem["id"]) + if ow_type == "role": + target = entry.guild.get_role(ow_id) + else: + target = entry._get_member(ow_id) + + if target is None: + target = Object(id=ow_id) + + overwrites.append((target, ow)) + + return overwrites + + +class AuditLogDiff: + def __len__(self): + return len(self.__dict__) + + def __iter__(self): + return iter(self.__dict__.items()) + + def __repr__(self): + return "".format(tuple(self.__dict__)) + + +class AuditLogChanges: + TRANSFORMERS = { + "verification_level": (None, _transform_verification_level), + "explicit_content_filter": (None, _transform_explicit_content_filter), + "allow": (None, _transform_permissions), + "deny": (None, _transform_permissions), + "permissions": (None, _transform_permissions), + "id": (None, _transform_snowflake), + "color": ("colour", _transform_color), + "owner_id": ("owner", _transform_owner_id), + "inviter_id": ("inviter", _transform_inviter_id), + "channel_id": ("channel", _transform_channel), + "afk_channel_id": ("afk_channel", _transform_channel), + "system_channel_id": ("system_channel", _transform_channel), + "widget_channel_id": ("widget_channel", _transform_channel), + "permission_overwrites": ("overwrites", _transform_overwrites), + "splash_hash": ("splash", None), + "icon_hash": ("icon", None), + "avatar_hash": ("avatar", None), + "rate_limit_per_user": ("slowmode_delay", None), + "default_message_notifications": ( + "default_notifications", + _transform_default_notifications, + ), + } + + def __init__(self, entry, data): + self.before = AuditLogDiff() + self.after = AuditLogDiff() + + for elem in data: + attr = elem["key"] + + # special cases for role add/remove + if attr == "$add": + self._handle_role(self.before, self.after, entry, elem["new_value"]) + continue + elif attr == "$remove": + self._handle_role(self.after, self.before, entry, elem["new_value"]) + continue + + transformer = self.TRANSFORMERS.get(attr) + if transformer: + key, transformer = transformer + if key: + attr = key + + try: + before = elem["old_value"] + except KeyError: + before = None + else: + if transformer: + before = transformer(entry, before) + + setattr(self.before, attr, before) + + try: + after = elem["new_value"] + except KeyError: + after = None + else: + if transformer: + after = transformer(entry, after) + + setattr(self.after, attr, after) + + # add an alias + if hasattr(self.after, "colour"): + self.after.color = self.after.colour + self.before.color = self.before.colour + + def _handle_role(self, first, second, entry, elem): + if not hasattr(first, "roles"): + setattr(first, "roles", []) + + data = [] + g = entry.guild + + for e in elem: + role_id = int(e["id"]) + role = g.get_role(role_id) + + if role is None: + role = Object(id=role_id) + role.name = e["name"] + + data.append(role) + + setattr(second, "roles", data) + + +class AuditLogEntry: + r"""Represents an Audit Log entry. + + You retrieve these via :meth:`Guild.audit_logs`. + + Attributes + ----------- + action: :class:`AuditLogAction` + The action that was done. + user: :class:`abc.User` + The user who initiated this action. Usually a :class:`Member`\, unless gone + then it's a :class:`User`. + id: :class:`int` + The entry ID. + target: Any + The target that got changed. The exact type of this depends on + the action being done. + reason: Optional[:class:`str`] + The reason this action was done. + extra: Any + Extra information that this entry has that might be useful. + For most actions, this is ``None``. However in some cases it + contains extra information. See :class:`AuditLogAction` for + which actions have this field filled out. + """ + + def __init__(self, *, users, data, guild): + self._state = guild._state + self.guild = guild + self._users = users + self._from_data(data) + + def _from_data(self, data): + self.action = enums.AuditLogAction(data["action_type"]) + self.id = int(data["id"]) + + # this key is technically not usually present + self.reason = data.get("reason") + self.extra = data.get("options") + + if self.extra: + if self.action is enums.AuditLogAction.member_prune: + # member prune has two keys with useful information + self.extra = type( + "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} + )() + elif self.action is enums.AuditLogAction.message_delete: + channel_id = int(self.extra["channel_id"]) + elems = { + "count": int(self.extra["count"]), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), + } + self.extra = type("_AuditLogProxy", (), elems)() + elif self.action.name.startswith("overwrite_"): + # the overwrite_ actions have a dict with some information + instance_id = int(self.extra["id"]) + the_type = self.extra.get("type") + if the_type == "member": + self.extra = self._get_member(instance_id) + else: + role = self.guild.get_role(instance_id) + if role is None: + role = Object(id=instance_id) + role.name = self.extra.get("role_name") + self.extra = role + + # this key is not present when the above is present, typically. + # It's a list of { new_value: a, old_value: b, key: c } + # where new_value and old_value are not guaranteed to be there depending + # on the action type, so let's just fetch it for now and only turn it + # into meaningful data when requested + self._changes = data.get("changes", []) + + self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) + self._target_id = utils._get_as_snowflake(data, "target_id") + + def _get_member(self, user_id): + return self.guild.get_member(user_id) or self._users.get(user_id) + + def __repr__(self): + return "".format(self) + + @utils.cached_property + def created_at(self): + """Returns the entry's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @utils.cached_property + def target(self): + try: + converter = getattr(self, "_convert_target_" + self.action.target_type) + except AttributeError: + return Object(id=self._target_id) + else: + return converter(self._target_id) + + @utils.cached_property + def category(self): + """Optional[:class:`AuditLogActionCategory`]: The category of the action, if applicable.""" + return self.action.category + + @utils.cached_property + def changes(self): + """:class:`AuditLogChanges`: The list of changes this entry has.""" + obj = AuditLogChanges(self, self._changes) + del self._changes + return obj + + @utils.cached_property + def before(self): + """:class:`AuditLogDiff`: The target's prior state.""" + return self.changes.before + + @utils.cached_property + def after(self): + """:class:`AuditLogDiff`: The target's subsequent state.""" + return self.changes.after + + def _convert_target_guild(self, target_id): + return self.guild + + def _convert_target_channel(self, target_id): + ch = self.guild.get_channel(target_id) + if ch is None: + return Object(id=target_id) + return ch + + def _convert_target_user(self, target_id): + return self._get_member(target_id) + + def _convert_target_role(self, target_id): + role = self.guild.get_role(target_id) + if role is None: + return Object(id=target_id) + return role + + def _convert_target_invite(self, target_id): + # invites have target_id set to null + # so figure out which change has the full invite data + changeset = ( + self.before if self.action is enums.AuditLogAction.invite_delete else self.after + ) + + fake_payload = { + "max_age": changeset.max_age, + "max_uses": changeset.max_uses, + "code": changeset.code, + "temporary": changeset.temporary, + "channel": changeset.channel, + "uses": changeset.uses, + "guild": self.guild, + } + + obj = Invite(state=self._state, data=fake_payload) + try: + obj.inviter = changeset.inviter + except AttributeError: + pass + return obj + + def _convert_target_emoji(self, target_id): + return self._state.get_emoji(target_id) or Object(id=target_id) + + def _convert_target_message(self, target_id): + return self._get_member(target_id) diff --git a/discord/backoff.py b/discord/backoff.py new file mode 100644 index 000000000..ca289505f --- /dev/null +++ b/discord/backoff.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import time +import random + + +class ExponentialBackoff: + """An implementation of the exponential backoff algorithm + + Provides a convenient interface to implement an exponential backoff + for reconnecting or retrying transmissions in a distributed network. + + Once instantiated, the delay method will return the next interval to + wait for when retrying a connection or transmission. The maximum + delay increases exponentially with each retry up to a maximum of + 2^10 * base, and is reset if no more attempts are needed in a period + of 2^11 * base seconds. + + Parameters + ---------- + base: int + The base delay in seconds. The first retry-delay will be up to + this many seconds. + integral: bool + Set to True if whole periods of base is desirable, otherwise any + number in between may be returned. + """ + + def __init__(self, base=1, *, integral=False): + self._base = base + + self._exp = 0 + self._max = 10 + self._reset_time = base * 2 ** 11 + self._last_invocation = time.monotonic() + + # Use our own random instance to avoid messing with global one + rand = random.Random() + rand.seed() + + self._randfunc = rand.randrange if integral else rand.uniform + + def delay(self): + """Compute the next delay + + Returns the next delay to wait according to the exponential + backoff algorithm. This is a value between 0 and base * 2^exp + where exponent starts off at 1 and is incremented at every + invocation of this method up to a maximum of 10. + + If a period of more than base * 2^11 has passed since the last + retry, the exponent is reset to 1. + """ + invocation = time.monotonic() + interval = invocation - self._last_invocation + self._last_invocation = invocation + + if interval > self._reset_time: + self._exp = 0 + + self._exp = min(self._exp + 1, self._max) + return self._randfunc(0, self._base * 2 ** self._exp) diff --git a/discord/calls.py b/discord/calls.py new file mode 100644 index 000000000..63bc96ed7 --- /dev/null +++ b/discord/calls.py @@ -0,0 +1,157 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import datetime + +from . import utils +from .enums import VoiceRegion, try_enum +from .member import VoiceState + + +class CallMessage: + """Represents a group call message from Discord. + + This is only received in cases where the message type is equivalent to + :attr:`MessageType.call`. + + Attributes + ----------- + ended_timestamp: Optional[datetime.datetime] + A naive UTC datetime object that represents the time that the call has ended. + participants: List[:class:`User`] + The list of users that are participating in this call. + message: :class:`Message` + The message associated with this call message. + """ + + def __init__(self, message, **kwargs): + self.message = message + self.ended_timestamp = utils.parse_time(kwargs.get("ended_timestamp")) + self.participants = kwargs.get("participants") + + @property + def call_ended(self): + """:obj:`bool`: Indicates if the call has ended.""" + return self.ended_timestamp is not None + + @property + def channel(self): + r""":class:`GroupChannel`\: The private channel associated with this message.""" + return self.message.channel + + @property + def duration(self): + """Queries the duration of the call. + + If the call has not ended then the current duration will + be returned. + + Returns + --------- + datetime.timedelta + The timedelta object representing the duration. + """ + if self.ended_timestamp is None: + return datetime.datetime.utcnow() - self.message.created_at + else: + return self.ended_timestamp - self.message.created_at + + +class GroupCall: + """Represents the actual group call from Discord. + + This is accompanied with a :class:`CallMessage` denoting the information. + + Attributes + ----------- + call: :class:`CallMessage` + The call message associated with this group call. + unavailable: :obj:`bool` + Denotes if this group call is unavailable. + ringing: List[:class:`User`] + A list of users that are currently being rung to join the call. + region: :class:`VoiceRegion` + The guild region the group call is being hosted on. + """ + + def __init__(self, **kwargs): + self.call = kwargs.get("call") + self.unavailable = kwargs.get("unavailable") + self._voice_states = {} + + for state in kwargs.get("voice_states", []): + self._update_voice_state(state) + + self._update(**kwargs) + + def _update(self, **kwargs): + self.region = try_enum(VoiceRegion, kwargs.get("region")) + lookup = {u.id: u for u in self.call.channel.recipients} + me = self.call.channel.me + lookup[me.id] = me + self.ringing = list(filter(None, map(lookup.get, kwargs.get("ringing", [])))) + + def _update_voice_state(self, data): + user_id = int(data["user_id"]) + # left the voice channel? + if data["channel_id"] is None: + self._voice_states.pop(user_id, None) + else: + self._voice_states[user_id] = VoiceState(data=data, channel=self.channel) + + @property + def connected(self): + """A property that returns the :obj:`list` of :class:`User` that are currently in this call.""" + ret = [u for u in self.channel.recipients if self.voice_state_for(u) is not None] + me = self.channel.me + if self.voice_state_for(me) is not None: + ret.append(me) + + return ret + + @property + def channel(self): + r""":class:`GroupChannel`\: Returns the channel the group call is in.""" + return self.call.channel + + def voice_state_for(self, user): + """Retrieves the :class:`VoiceState` for a specified :class:`User`. + + If the :class:`User` has no voice state then this function returns + ``None``. + + Parameters + ------------ + user: :class:`User` + The user to retrieve the voice state for. + + Returns + -------- + Optional[:class:`VoiceState`] + The voice state associated with this user. + """ + + return self._voice_states.get(user.id) diff --git a/discord/channel.py b/discord/channel.py new file mode 100644 index 000000000..b1b39ee71 --- /dev/null +++ b/discord/channel.py @@ -0,0 +1,986 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import time +import asyncio + +import discord.abc +from .permissions import Permissions +from .enums import ChannelType, try_enum +from .mixins import Hashable +from . import utils +from .errors import ClientException, NoMoreItems +from .webhook import Webhook + +__all__ = [ + "TextChannel", + "VoiceChannel", + "DMChannel", + "CategoryChannel", + "GroupChannel", + "_channel_factory", +] + + +async def _single_delete_strategy(messages): + for m in messages: + await m.delete() + + +class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): + """Represents a Discord guild text channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ----------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + category_id: :class:`int` + The category channel ID this channel belongs to. + topic: Optional[:class:`str`] + The channel's topic. None if it doesn't exist. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + slowmode_delay: :class:`int` + The number of seconds a member must wait between sending messages + in this channel. A value of `0` denotes that it is disabled. + Bots and users with :attr:`~Permissions.manage_channels` or + :attr:`~Permissions.manage_messages` bypass slowmode. + """ + + __slots__ = ( + "name", + "id", + "guild", + "topic", + "_state", + "nsfw", + "category_id", + "position", + "slowmode_delay", + "_overwrites", + ) + + def __init__(self, *, state, guild, data): + self._state = state + self.id = int(data["id"]) + self._update(guild, data) + + def __repr__(self): + return "".format(self) + + def _update(self, guild, data): + self.guild = guild + self.name = data["name"] + self.category_id = utils._get_as_snowflake(data, "parent_id") + self.topic = data.get("topic") + self.position = data["position"] + self.nsfw = data.get("nsfw", False) + # Does this need coercion into `int`? No idea yet. + self.slowmode_delay = data.get("rate_limit_per_user", 0) + self._fill_overwrites(data) + + async def _get_channel(self): + return self + + def permissions_for(self, member): + base = super().permissions_for(member) + + # text channels do not have voice related permissions + denied = Permissions.voice() + base.value &= ~denied.value + return base + + permissions_for.__doc__ = discord.abc.GuildChannel.permissions_for.__doc__ + + @property + def members(self): + """Returns a :class:`list` of :class:`Member` that can see this channel.""" + return [m for m in self.guild.members if self.permissions_for(m).read_messages] + + def is_nsfw(self): + """Checks if the channel is NSFW.""" + n = self.name + return self.nsfw or n == "nsfw" or n[:5] == "nsfw-" + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + Parameters + ---------- + name: :class:`str` + The new channel name. + topic: :class:`str` + The new channel's topic. + position: :class:`int` + The new channel's position. + nsfw: :class:`bool` + To mark the channel as NSFW or not. + sync_permissions: :class:`bool` + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + slowmode_delay: :class:`int` + Specifies the slowmode rate limit for user in this channel. A value of + `0` disables slowmode. The maximum value possible is `120`. + reason: Optional[:class:`str`] + The reason for editing this channel. Shows up on the audit log. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of channels. + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + await self._edit(options, reason=reason) + + async def delete_messages(self, messages): + """|coro| + + Deletes a list of messages. This is similar to :meth:`Message.delete` + except it bulk deletes multiple messages. + + As a special case, if the number of messages is 0, then nothing + is done. If the number of messages is 1 then single message + delete is done. If it's more than two, then bulk delete is used. + + You cannot bulk delete more than 100 messages or messages that + are older than 14 days old. + + You must have the :attr:`~Permissions.manage_messages` permission to + use this. + + Usable only by bot accounts. + + Parameters + ----------- + messages: Iterable[:class:`abc.Snowflake`] + An iterable of messages denoting which ones to bulk delete. + + Raises + ------ + ClientException + The number of messages to delete was more than 100. + Forbidden + You do not have proper permissions to delete the messages or + you're not using a bot account. + HTTPException + Deleting the messages failed. + """ + if not isinstance(messages, (list, tuple)): + messages = list(messages) + + if len(messages) == 0: + return # do nothing + + if len(messages) == 1: + message_id = messages[0].id + await self._state.http.delete_message(self.id, message_id) + return + + if len(messages) > 100: + raise ClientException("Can only bulk delete messages up to 100 messages") + + message_ids = [m.id for m in messages] + await self._state.http.delete_messages(self.id, message_ids) + + async def purge( + self, + *, + limit=100, + check=None, + before=None, + after=None, + around=None, + reverse=False, + bulk=True + ): + """|coro| + + Purges a list of messages that meet the criteria given by the predicate + ``check``. If a ``check`` is not provided then all messages are deleted + without discrimination. + + You must have the :attr:`~Permissions.manage_messages` permission to + delete messages even if they are your own (unless you are a user + account). The :attr:`~Permissions.read_message_history` permission is + also needed to retrieve message history. + + Internally, this employs a different number of strategies depending + on the conditions met such as if a bulk delete is possible or if + the account is a user bot or not. + + Parameters + ----------- + limit: int + The number of messages to search through. This is not the number + of messages that will be deleted, though it can be. + check: predicate + The function used to check if a message should be deleted. + It must take a :class:`Message` as its sole parameter. + before + Same as ``before`` in :meth:`history`. + after + Same as ``after`` in :meth:`history`. + around + Same as ``around`` in :meth:`history`. + reverse + Same as ``reverse`` in :meth:`history`. + bulk: bool + If True, use bulk delete. bulk=False is useful for mass-deleting + a bot's own messages without manage_messages. When True, will fall + back to single delete if current account is a user bot, or if + messages are older than two weeks. + + Raises + ------- + Forbidden + You do not have proper permissions to do the actions required. + HTTPException + Purging the messages failed. + + Examples + --------- + + Deleting bot's messages :: + + def is_me(m): + return m.author == client.user + + deleted = await channel.purge(limit=100, check=is_me) + await channel.send('Deleted {} message(s)'.format(len(deleted))) + + Returns + -------- + list + The list of messages that were deleted. + """ + + if check is None: + check = lambda m: True + + iterator = self.history( + limit=limit, before=before, after=after, reverse=reverse, around=around + ) + ret = [] + count = 0 + + minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 + strategy = self.delete_messages if self._state.is_bot and bulk else _single_delete_strategy + + while True: + try: + msg = await iterator.next() + except NoMoreItems: + # no more messages to poll + if count >= 2: + # more than 2 messages -> bulk delete + to_delete = ret[-count:] + await strategy(to_delete) + elif count == 1: + # delete a single message + await ret[-1].delete() + + return ret + else: + if count == 100: + # we've reached a full 'queue' + to_delete = ret[-100:] + await strategy(to_delete) + count = 0 + await asyncio.sleep(1) + + if check(msg): + if msg.id < minimum_time: + # older than 14 days old + if count == 1: + await ret[-1].delete() + elif count >= 2: + to_delete = ret[-count:] + await strategy(to_delete) + + count = 0 + strategy = _single_delete_strategy + + count += 1 + ret.append(msg) + + async def webhooks(self): + """|coro| + + Gets the list of webhooks from this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Raises + ------- + Forbidden + You don't have permissions to get the webhooks. + + Returns + -------- + List[:class:`Webhook`] + The webhooks for this channel. + """ + + data = await self._state.http.channel_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def create_webhook(self, *, name, avatar=None): + """|coro| + + Creates a webhook for this channel. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Parameters + ------------- + name: str + The webhook's name. + avatar: Optional[bytes] + A :term:`py:bytes-like object` representing the webhook's default avatar. + This operates similarly to :meth:`~ClientUser.edit`. + + Raises + ------- + HTTPException + Creating the webhook failed. + Forbidden + You do not have permissions to create a webhook. + + Returns + -------- + :class:`Webhook` + The created webhook. + """ + + if avatar is not None: + avatar = utils._bytes_to_base64_data(avatar) + + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar) + return Webhook.from_state(data, state=self._state) + + +class VoiceChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): + """Represents a Discord guild voice channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns the channel's name. + + Attributes + ----------- + name: :class:`str` + The channel name. + guild: :class:`Guild` + The guild the channel belongs to. + id: :class:`int` + The channel ID. + category_id: :class:`int` + The category channel ID this channel belongs to. + position: :class:`int` + The position in the channel list. This is a number that starts at 0. e.g. the + top channel is position 0. + bitrate: :class:`int` + The channel's preferred audio bitrate in bits per second. + user_limit: :class:`int` + The channel's limit for number of members that can be in a voice channel. + """ + + __slots__ = ( + "name", + "id", + "guild", + "bitrate", + "user_limit", + "_state", + "position", + "_overwrites", + "category_id", + ) + + def __init__(self, *, state, guild, data): + self._state = state + self.id = int(data["id"]) + self._update(guild, data) + + def __repr__(self): + return "".format(self) + + def _get_voice_client_key(self): + return self.guild.id, "guild_id" + + def _get_voice_state_pair(self): + return self.guild.id, self.id + + def _update(self, guild, data): + self.guild = guild + self.name = data["name"] + self.category_id = utils._get_as_snowflake(data, "parent_id") + self.position = data["position"] + self.bitrate = data.get("bitrate") + self.user_limit = data.get("user_limit") + self._fill_overwrites(data) + + @property + def members(self): + """Returns a list of :class:`Member` that are currently inside this voice channel.""" + ret = [] + for user_id, state in self.guild._voice_states.items(): + if state.channel.id == self.id: + member = self.guild.get_member(user_id) + if member is not None: + ret.append(member) + return ret + + def permissions_for(self, member): + base = super().permissions_for(member) + + # voice channels cannot be edited by people who can't connect to them + # It also implicitly denies all other voice perms + if not base.connect: + denied = Permissions.voice() + denied.update(manage_channels=True, manage_roles=True) + base.value &= ~denied.value + return base + + permissions_for.__doc__ = discord.abc.GuildChannel.permissions_for.__doc__ + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + Parameters + ---------- + name: str + The new channel's name. + bitrate: int + The new channel's bitrate. + user_limit: int + The new channel's user limit. + position: int + The new channel's position. + sync_permissions: bool + Whether to sync permissions with the channel's new or pre-existing + category. Defaults to ``False``. + category: Optional[:class:`CategoryChannel`] + The new category for this channel. Can be ``None`` to remove the + category. + reason: Optional[str] + The reason for editing this channel. Shows up on the audit log. + + Raises + ------ + Forbidden + You do not have permissions to edit the channel. + HTTPException + Editing the channel failed. + """ + + await self._edit(options, reason=reason) + + +class CategoryChannel(discord.abc.GuildChannel, Hashable): + """Represents a Discord channel category. + + These are useful to group channels to logical compartments. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the category's hash. + + .. describe:: str(x) + + Returns the category's name. + + Attributes + ----------- + name: :class:`str` + The category name. + guild: :class:`Guild` + The guild the category belongs to. + id: :class:`int` + The category channel ID. + position: :class:`int` + The position in the category list. This is a number that starts at 0. e.g. the + top category is position 0. + """ + + __slots__ = ("name", "id", "guild", "nsfw", "_state", "position", "_overwrites", "category_id") + + def __init__(self, *, state, guild, data): + self._state = state + self.id = int(data["id"]) + self._update(guild, data) + + def __repr__(self): + return "".format(self) + + def _update(self, guild, data): + self.guild = guild + self.name = data["name"] + self.category_id = utils._get_as_snowflake(data, "parent_id") + self.nsfw = data.get("nsfw", False) + self.position = data["position"] + self._fill_overwrites(data) + + def is_nsfw(self): + """Checks if the category is NSFW.""" + n = self.name + return self.nsfw or n == "nsfw" or n[:5] == "nsfw-" + + async def edit(self, *, reason=None, **options): + """|coro| + + Edits the channel. + + You must have the :attr:`~Permissions.manage_channels` permission to + use this. + + Parameters + ---------- + name: str + The new category's name. + position: int + The new category's position. + nsfw: bool + To mark the category as NSFW or not. + reason: Optional[str] + The reason for editing this category. Shows up on the audit log. + + Raises + ------ + InvalidArgument + If position is less than 0 or greater than the number of categories. + Forbidden + You do not have permissions to edit the category. + HTTPException + Editing the category failed. + """ + + try: + position = options.pop("position") + except KeyError: + pass + else: + await self._move(position, reason=reason) + self.position = position + + if options: + data = await self._state.http.edit_channel(self.id, reason=reason, **options) + self._update(self.guild, data) + + @property + def channels(self): + """List[:class:`abc.GuildChannel`]: Returns the channels that are under this category. + + These are sorted by the official Discord UI, which places voice channels below the text channels. + """ + + def comparator(channel): + return (not isinstance(channel, TextChannel), channel.position) + + ret = [c for c in self.guild.channels if c.category_id == self.id] + ret.sort(key=comparator) + return ret + + +class DMChannel(discord.abc.Messageable, Hashable): + """Represents a Discord direct message channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns a string representation of the channel + + Attributes + ---------- + recipient: :class:`User` + The user you are participating with in the direct message channel. + me: :class:`ClientUser` + The user presenting yourself. + id: :class:`int` + The direct message channel ID. + """ + + __slots__ = ("id", "recipient", "me", "_state") + + def __init__(self, *, me, state, data): + self._state = state + self.recipient = state.store_user(data["recipients"][0]) + self.me = me + self.id = int(data["id"]) + + async def _get_channel(self): + return self + + def __str__(self): + return "Direct Message with %s" % self.recipient + + def __repr__(self): + return "".format(self) + + @property + def created_at(self): + """Returns the direct message channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def permissions_for(self, user=None): + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Actual direct messages do not really have the concept of permissions. + + This returns all the Text related permissions set to true except: + + - send_tts_messages: You cannot send TTS messages in a DM. + - manage_messages: You cannot delete others messages in a DM. + + Parameters + ----------- + user: :class:`User` + The user to check permissions for. This parameter is ignored + but kept for compatibility. + + Returns + -------- + :class:`Permissions` + The resolved permissions. + """ + + base = Permissions.text() + base.send_tts_messages = False + base.manage_messages = False + return base + + +class GroupChannel(discord.abc.Messageable, Hashable): + """Represents a Discord group channel. + + .. container:: operations + + .. describe:: x == y + + Checks if two channels are equal. + + .. describe:: x != y + + Checks if two channels are not equal. + + .. describe:: hash(x) + + Returns the channel's hash. + + .. describe:: str(x) + + Returns a string representation of the channel + + Attributes + ---------- + recipients: :class:`list` of :class:`User` + The users you are participating with in the group channel. + me: :class:`ClientUser` + The user presenting yourself. + id: :class:`int` + The group channel ID. + owner: :class:`User` + The user that owns the group channel. + icon: Optional[:class:`str`] + The group channel's icon hash if provided. + name: Optional[:class:`str`] + The group channel's name if provided. + """ + + __slots__ = ("id", "recipients", "owner", "icon", "name", "me", "_state") + + def __init__(self, *, me, state, data): + self._state = state + self.id = int(data["id"]) + self.me = me + self._update_group(data) + + def _update_group(self, data): + owner_id = utils._get_as_snowflake(data, "owner_id") + self.icon = data.get("icon") + self.name = data.get("name") + + try: + self.recipients = [self._state.store_user(u) for u in data["recipients"]] + except KeyError: + pass + + if owner_id == self.me.id: + self.owner = self.me + else: + self.owner = utils.find(lambda u: u.id == owner_id, self.recipients) + + async def _get_channel(self): + return self + + def __str__(self): + if self.name: + return self.name + + if len(self.recipients) == 0: + return "Unnamed" + + return ", ".join(map(lambda x: x.name, self.recipients)) + + def __repr__(self): + return "".format(self) + + @property + def icon_url(self): + """Returns the channel's icon URL if available or an empty string otherwise.""" + if self.icon is None: + return "" + + return "https://cdn.discordapp.com/channel-icons/{0.id}/{0.icon}.jpg".format(self) + + @property + def created_at(self): + """Returns the channel's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def permissions_for(self, user): + """Handles permission resolution for a :class:`User`. + + This function is there for compatibility with other channel types. + + Actual direct messages do not really have the concept of permissions. + + This returns all the Text related permissions set to true except: + + - send_tts_messages: You cannot send TTS messages in a DM. + - manage_messages: You cannot delete others messages in a DM. + + This also checks the kick_members permission if the user is the owner. + + Parameters + ----------- + user: :class:`User` + The user to check permissions for. + + Returns + -------- + :class:`Permissions` + The resolved permissions for the user. + """ + + base = Permissions.text() + base.send_tts_messages = False + base.manage_messages = False + base.mention_everyone = True + + if user.id == self.owner.id: + base.kick_members = True + + return base + + async def add_recipients(self, *recipients): + r"""|coro| + + Adds recipients to this group. + + A group can only have a maximum of 10 members. + Attempting to add more ends up in an exception. To + add a recipient to the group, you must have a relationship + with the user of type :attr:`RelationshipType.friend`. + + Parameters + ----------- + \*recipients: :class:`User` + An argument list of users to add to this group. + + Raises + ------- + HTTPException + Adding a recipient to this group failed. + """ + + # TODO: wait for the corresponding WS event + + req = self._state.http.add_group_recipient + for recipient in recipients: + await req(self.id, recipient.id) + + async def remove_recipients(self, *recipients): + r"""|coro| + + Removes recipients from this group. + + Parameters + ----------- + \*recipients: :class:`User` + An argument list of users to remove from this group. + + Raises + ------- + HTTPException + Removing a recipient from this group failed. + """ + + # TODO: wait for the corresponding WS event + + req = self._state.http.remove_group_recipient + for recipient in recipients: + await req(self.id, recipient.id) + + async def edit(self, **fields): + """|coro| + + Edits the group. + + Parameters + ----------- + name: Optional[str] + The new name to change the group to. + Could be ``None`` to remove the name. + icon: Optional[bytes] + A :term:`py:bytes-like object` representing the new icon. + Could be ``None`` to remove the icon. + + Raises + ------- + HTTPException + Editing the group failed. + """ + + try: + icon_bytes = fields["icon"] + except KeyError: + pass + else: + if icon_bytes is not None: + fields["icon"] = utils._bytes_to_base64_data(icon_bytes) + + data = await self._state.http.edit_group(self.id, **fields) + self._update_group(data) + + async def leave(self): + """|coro| + + Leave the group. + + If you are the only one in the group, this deletes it as well. + + Raises + ------- + HTTPException + Leaving the group failed. + """ + + await self._state.http.leave_group(self.id) + + +def _channel_factory(channel_type): + value = try_enum(ChannelType, channel_type) + if value is ChannelType.text: + return TextChannel, value + elif value is ChannelType.voice: + return VoiceChannel, value + elif value is ChannelType.private: + return DMChannel, value + elif value is ChannelType.category: + return CategoryChannel, value + elif value is ChannelType.group: + return GroupChannel, value + else: + return None, value diff --git a/discord/client.py b/discord/client.py new file mode 100644 index 000000000..41d5d6826 --- /dev/null +++ b/discord/client.py @@ -0,0 +1,1074 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +from collections import namedtuple +import logging +import re +import signal +import sys +import traceback + +import aiohttp +import websockets + +from .user import User, Profile +from .invite import Invite +from .object import Object +from .guild import Guild +from .errors import * +from .enums import Status, VoiceRegion +from .gateway import * +from .activity import _ActivityTag, create_activity +from .voice_client import VoiceClient +from .http import HTTPClient +from .state import ConnectionState +from . import utils +from .backoff import ExponentialBackoff +from .webhook import Webhook + +log = logging.getLogger(__name__) + +AppInfo = namedtuple( + "AppInfo", "id name description rpc_origins bot_public bot_require_code_grant icon owner" +) + + +def app_info_icon_url(self): + """Retrieves the application's icon_url if it exists. Empty string otherwise.""" + if not self.icon: + return "" + + return "https://cdn.discordapp.com/app-icons/{0.id}/{0.icon}.jpg".format(self) + + +AppInfo.icon_url = property(app_info_icon_url) + + +class Client: + r"""Represents a client connection that connects to Discord. + This class is used to interact with the Discord WebSocket and API. + + A number of options can be passed to the :class:`Client`. + + .. _event loop: https://docs.python.org/3/library/asyncio-eventloops.html + .. _connector: http://aiohttp.readthedocs.org/en/stable/client_reference.html#connectors + .. _ProxyConnector: http://aiohttp.readthedocs.org/en/stable/client_reference.html#proxyconnector + + Parameters + ----------- + max_messages : Optional[:class:`int`] + The maximum number of messages to store in the internal message cache. + This defaults to 5000. Passing in `None` or a value less than 100 + will use the default instead of the passed in value. + loop : Optional[event loop] + The `event loop`_ to use for asynchronous operations. Defaults to ``None``, + in which case the default event loop is used via ``asyncio.get_event_loop()``. + connector : aiohttp.BaseConnector + The `connector`_ to use for connection pooling. + proxy : Optional[:class:`str`] + Proxy URL. + proxy_auth : Optional[aiohttp.BasicAuth] + An object that represents proxy HTTP Basic Authorization. + shard_id : Optional[:class:`int`] + Integer starting at 0 and less than shard_count. + shard_count : Optional[:class:`int`] + The total number of shards. + fetch_offline_members: :class:`bool` + Indicates if :func:`on_ready` should be delayed to fetch all offline + members from the guilds the bot belongs to. If this is ``False``\, then + no offline members are received and :meth:`request_offline_members` + must be used to fetch the offline members of the guild. + status: Optional[:class:`Status`] + A status to start your presence with upon logging on to Discord. + activity: Optional[Union[:class:`Activity`, :class:`Game`, :class:`Streaming`]] + An activity to start your presence with upon logging on to Discord. + heartbeat_timeout: :class:`float` + The maximum numbers of seconds before timing out and restarting the + WebSocket in the case of not receiving a HEARTBEAT_ACK. Useful if + processing the initial packets take too long to the point of disconnecting + you. The default timeout is 60 seconds. + + Attributes + ----------- + ws + The websocket gateway the client is currently connected to. Could be None. + loop + The `event loop`_ that the client uses for HTTP requests and websocket operations. + """ + + def __init__(self, *, loop=None, **options): + self.ws = None + self.loop = asyncio.get_event_loop() if loop is None else loop + self._listeners = {} + self.shard_id = options.get("shard_id") + self.shard_count = options.get("shard_count") + + connector = options.pop("connector", None) + proxy = options.pop("proxy", None) + proxy_auth = options.pop("proxy_auth", None) + self.http = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, loop=self.loop) + + self._handlers = {"ready": self._handle_ready} + + self._connection = ConnectionState( + dispatch=self.dispatch, + chunker=self._chunker, + handlers=self._handlers, + syncer=self._syncer, + http=self.http, + loop=self.loop, + **options + ) + + self._connection.shard_count = self.shard_count + self._closed = asyncio.Event(loop=self.loop) + self._ready = asyncio.Event(loop=self.loop) + self._connection._get_websocket = lambda g: self.ws + + if VoiceClient.warn_nacl: + VoiceClient.warn_nacl = False + log.warning("PyNaCl is not installed, voice will NOT be supported") + + # internals + + async def _syncer(self, guilds): + await self.ws.request_sync(guilds) + + async def _chunker(self, guild): + try: + guild_id = guild.id + except AttributeError: + guild_id = [s.id for s in guild] + + payload = {"op": 8, "d": {"guild_id": guild_id, "query": "", "limit": 0}} + + await self.ws.send_as_json(payload) + + def _handle_ready(self): + self._ready.set() + + def _resolve_invite(self, invite): + if isinstance(invite, Invite) or isinstance(invite, Object): + return invite.id + else: + rx = r"(?:https?\:\/\/)?discord(?:\.gg|app\.com\/invite)\/(.+)" + m = re.match(rx, invite) + if m: + return m.group(1) + return invite + + @property + def latency(self): + """:obj:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. + + This could be referred to as the Discord WebSocket protocol latency. + """ + ws = self.ws + return float("nan") if not ws else ws.latency + + @property + def user(self): + """Optional[:class:`ClientUser`]: Represents the connected client. None if not logged in.""" + return self._connection.user + + @property + def guilds(self): + """List[:class:`Guild`]: The guilds that the connected client is a member of.""" + return self._connection.guilds + + @property + def emojis(self): + """List[:class:`Emoji`]: The emojis that the connected client has.""" + return self._connection.emojis + + @property + def private_channels(self): + """List[:class:`abc.PrivateChannel`]: The private channels that the connected client is participating on. + + .. note:: + + This returns only up to 128 most recent private channels due to an internal working + on how Discord deals with private channels. + """ + return self._connection.private_channels + + @property + def voice_clients(self): + """List[:class:`VoiceClient`]: Represents a list of voice connections.""" + return self._connection.voice_clients + + def is_ready(self): + """:obj:`bool`: Specifies if the client's internal cache is ready for use.""" + return self._ready.is_set() + + async def _run_event(self, coro, event_name, *args, **kwargs): + try: + await coro(*args, **kwargs) + except asyncio.CancelledError: + pass + except Exception: + try: + await self.on_error(event_name, *args, **kwargs) + except asyncio.CancelledError: + pass + + def dispatch(self, event, *args, **kwargs): + log.debug("Dispatching event %s", event) + method = "on_" + event + + listeners = self._listeners.get(event) + if listeners: + removed = [] + for i, (future, condition) in enumerate(listeners): + if future.cancelled(): + removed.append(i) + continue + + try: + result = condition(*args) + except Exception as exc: + future.set_exception(exc) + removed.append(i) + else: + if result: + if len(args) == 0: + future.set_result(None) + elif len(args) == 1: + future.set_result(args[0]) + else: + future.set_result(args) + removed.append(i) + + if len(removed) == len(listeners): + self._listeners.pop(event) + else: + for idx in reversed(removed): + del listeners[idx] + + try: + coro = getattr(self, method) + except AttributeError: + pass + else: + asyncio.ensure_future(self._run_event(coro, method, *args, **kwargs), loop=self.loop) + + async def on_error(self, event_method, *args, **kwargs): + """|coro| + + The default error handler provided by the client. + + By default this prints to ``sys.stderr`` however it could be + overridden to have a different implementation. + Check :func:`discord.on_error` for more details. + """ + print("Ignoring exception in {}".format(event_method), file=sys.stderr) + traceback.print_exc() + + async def request_offline_members(self, *guilds): + r"""|coro| + + Requests previously offline members from the guild to be filled up + into the :attr:`Guild.members` cache. This function is usually not + called. It should only be used if you have the ``fetch_offline_members`` + parameter set to ``False``. + + When the client logs on and connects to the websocket, Discord does + not provide the library with offline members if the number of members + in the guild is larger than 250. You can check if a guild is large + if :attr:`Guild.large` is ``True``. + + Parameters + ----------- + \*guilds + An argument list of guilds to request offline members for. + + Raises + ------- + InvalidArgument + If any guild is unavailable or not large in the collection. + """ + if any(not g.large or g.unavailable for g in guilds): + raise InvalidArgument("An unavailable or non-large guild was passed.") + + await self._connection.request_offline_members(guilds) + + # login state management + + async def login(self, token, *, bot=True): + """|coro| + + Logs in the client with the specified credentials. + + This function can be used in two different ways. + + .. warning:: + + Logging on with a user token is against the Discord + `Terms of Service `_ + and doing so might potentially get your account banned. + Use this at your own risk. + + Parameters + ----------- + token: str + The authentication token. Do not prefix this token with + anything as the library will do it for you. + bot: bool + Keyword argument that specifies if the account logging on is a bot + token or not. + + Raises + ------ + LoginFailure + The wrong credentials are passed. + HTTPException + An unknown HTTP related error occurred, + usually when it isn't 200 or the known incorrect credentials + passing status code. + """ + + log.info("logging in using static token") + await self.http.static_login(token, bot=bot) + self._connection.is_bot = bot + + async def logout(self): + """|coro| + + Logs out of Discord and closes all connections. + """ + await self.close() + + async def _connect(self): + coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id) + self.ws = await asyncio.wait_for(coro, timeout=180.0, loop=self.loop) + while True: + try: + await self.ws.poll_event() + except ResumeWebSocket: + log.info("Got a request to RESUME the websocket.") + coro = DiscordWebSocket.from_client( + self, + shard_id=self.shard_id, + session=self.ws.session_id, + sequence=self.ws.sequence, + resume=True, + ) + self.ws = await asyncio.wait_for(coro, timeout=180.0, loop=self.loop) + + async def connect(self, *, reconnect=True): + """|coro| + + Creates a websocket connection and lets the websocket listen + to messages from discord. This is a loop that runs the entire + event system and miscellaneous aspects of the library. Control + is not resumed until the WebSocket connection is terminated. + + Parameters + ----------- + reconnect: bool + If we should attempt reconnecting, either due to internet + failure or a specific failure on Discord's part. Certain + disconnects that lead to bad state will not be handled (such as + invalid sharding payloads or bad tokens). + + Raises + ------- + GatewayNotFound + If the gateway to connect to discord is not found. Usually if this + is thrown then there is a discord API outage. + ConnectionClosed + The websocket connection has been terminated. + """ + + backoff = ExponentialBackoff() + while not self.is_closed(): + try: + await self._connect() + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + websockets.InvalidHandshake, + websockets.WebSocketProtocolError, + ) as exc: + + if not reconnect: + await self.close() + if isinstance(exc, ConnectionClosed) and exc.code == 1000: + # clean close, don't re-raise this + return + raise + + if self.is_closed(): + return + + # We should only get this when an unhandled close code happens, + # such as a clean disconnect (1000) or a bad state (bad token, no sharding, etc) + # sometimes, discord sends us 1000 for unknown reasons so we should reconnect + # regardless and rely on is_closed instead + if isinstance(exc, ConnectionClosed): + if exc.code != 1000: + await self.close() + raise + + retry = backoff.delay() + log.exception("Attempting a reconnect in %.2fs", retry) + await asyncio.sleep(retry, loop=self.loop) + + async def close(self): + """|coro| + + Closes the connection to discord. + """ + if self.is_closed(): + return + + self._closed.set() + + for voice in self.voice_clients: + try: + await voice.disconnect() + except Exception: + # if an error happens during disconnects, disregard it. + pass + + if self.ws is not None and self.ws.open: + await self.ws.close() + + await self.http.close() + self._ready.clear() + + def clear(self): + """Clears the internal state of the bot. + + After this, the bot can be considered "re-opened", i.e. :meth:`.is_closed` + and :meth:`.is_ready` both return ``False`` along with the bot's internal + cache cleared. + """ + self._closed.clear() + self._ready.clear() + self._connection.clear() + self.http.recreate() + + async def start(self, *args, **kwargs): + """|coro| + + A shorthand coroutine for :meth:`login` + :meth:`connect`. + """ + + bot = kwargs.pop("bot", True) + reconnect = kwargs.pop("reconnect", True) + await self.login(*args, bot=bot) + await self.connect(reconnect=reconnect) + + def _do_cleanup(self): + log.info("Cleaning up event loop.") + loop = self.loop + if loop.is_closed(): + return # we're already cleaning up + + task = asyncio.ensure_future(self.close(), loop=loop) + + def _silence_gathered(fut): + try: + fut.result() + except Exception: + pass + finally: + loop.stop() + + def when_future_is_done(fut): + pending = asyncio.Task.all_tasks(loop=loop) + if pending: + log.info("Cleaning up after %s tasks", len(pending)) + gathered = asyncio.gather(*pending, loop=loop) + gathered.cancel() + gathered.add_done_callback(_silence_gathered) + else: + loop.stop() + + task.add_done_callback(when_future_is_done) + if not loop.is_running(): + loop.run_forever() + else: + # on Linux, we're still running because we got triggered via + # the signal handler rather than the natural KeyboardInterrupt + # Since that's the case, we're going to return control after + # registering the task for the event loop to handle later + return None + + try: + return task.result() # suppress unused task warning + except Exception: + return None + + def run(self, *args, **kwargs): + """A blocking call that abstracts away the `event loop`_ + initialisation from you. + + If you want more control over the event loop then this + function should not be used. Use :meth:`start` coroutine + or :meth:`connect` + :meth:`login`. + + Roughly Equivalent to: :: + + try: + loop.run_until_complete(start(*args, **kwargs)) + except KeyboardInterrupt: + loop.run_until_complete(logout()) + # cancel all tasks lingering + finally: + loop.close() + + Warning + -------- + This function must be the last function to call due to the fact that it + is blocking. That means that registration of events or anything being + called after this function call will not execute until it returns. + """ + is_windows = sys.platform == "win32" + loop = self.loop + if not is_windows: + loop.add_signal_handler(signal.SIGINT, self._do_cleanup) + loop.add_signal_handler(signal.SIGTERM, self._do_cleanup) + + task = asyncio.ensure_future(self.start(*args, **kwargs), loop=loop) + + def stop_loop_on_finish(fut): + loop.stop() + + task.add_done_callback(stop_loop_on_finish) + + try: + loop.run_forever() + except KeyboardInterrupt: + log.info("Received signal to terminate bot and event loop.") + finally: + task.remove_done_callback(stop_loop_on_finish) + if is_windows: + self._do_cleanup() + + loop.close() + if task.cancelled() or not task.done(): + return None + return task.result() + + # properties + + def is_closed(self): + """:obj:`bool`: Indicates if the websocket connection is closed.""" + return self._closed.is_set() + + @property + def activity(self): + """Optional[Union[:class:`Activity`, :class:`Game`, :class:`Streaming`]]: The activity being used upon logging in.""" + return create_activity(self._connection._activity) + + @activity.setter + def activity(self, value): + if value is None: + self._connection._activity = None + elif isinstance(value, _ActivityTag): + self._connection._activity = value.to_dict() + else: + raise TypeError("activity must be one of Game, Streaming, or Activity.") + + # helpers/getters + + @property + def users(self): + """Returns a :obj:`list` of all the :class:`User` the bot can see.""" + return list(self._connection._users.values()) + + def get_channel(self, id): + """Returns a :class:`abc.GuildChannel` or :class:`abc.PrivateChannel` with the following ID. + + If not found, returns None. + """ + return self._connection.get_channel(id) + + def get_guild(self, id): + """Returns a :class:`Guild` with the given ID. If not found, returns None.""" + return self._connection._get_guild(id) + + def get_user(self, id): + """Returns a :class:`User` with the given ID. If not found, returns None.""" + return self._connection.get_user(id) + + def get_emoji(self, id): + """Returns a :class:`Emoji` with the given ID. If not found, returns None.""" + return self._connection.get_emoji(id) + + def get_all_channels(self): + """A generator that retrieves every :class:`abc.GuildChannel` the client can 'access'. + + This is equivalent to: :: + + for guild in client.guilds: + for channel in guild.channels: + yield channel + + Note + ----- + Just because you receive a :class:`abc.GuildChannel` does not mean that + you can communicate in said channel. :meth:`abc.GuildChannel.permissions_for` should + be used for that. + """ + + for guild in self.guilds: + for channel in guild.channels: + yield channel + + def get_all_members(self): + """Returns a generator with every :class:`Member` the client can see. + + This is equivalent to: :: + + for guild in client.guilds: + for member in guild.members: + yield member + + """ + for guild in self.guilds: + for member in guild.members: + yield member + + # listeners/waiters + + async def wait_until_ready(self): + """|coro| + + Waits until the client's internal cache is all ready. + """ + await self._ready.wait() + + def wait_for(self, event, *, check=None, timeout=None): + """|coro| + + Waits for a WebSocket event to be dispatched. + + This could be used to wait for a user to reply to a message, + or to react to a message, or to edit a message in a self-contained + way. + + The ``timeout`` parameter is passed onto :func:`asyncio.wait_for`. By default, + it does not timeout. Note that this does propagate the + :exc:`asyncio.TimeoutError` for you in case of timeout and is provided for + ease of use. + + In case the event returns multiple arguments, a :obj:`tuple` containing those + arguments is returned instead. Please check the + :ref:`documentation ` for a list of events and their + parameters. + + This function returns the **first event that meets the requirements**. + + Examples + --------- + + Waiting for a user reply: :: + + @client.event + async def on_message(message): + if message.content.startswith('$greet'): + channel = message.channel + await channel.send('Say hello!') + + def check(m): + return m.content == 'hello' and m.channel == channel + + msg = await client.wait_for('message', check=check) + await channel.send('Hello {.author}!'.format(msg)) + + Waiting for a thumbs up reaction from the message author: :: + + @client.event + async def on_message(message): + if message.content.startswith('$thumb'): + channel = message.channel + await channel.send('Send me that \N{THUMBS UP SIGN} reaction, mate') + + def check(reaction, user): + return user == message.author and str(reaction.emoji) == '\N{THUMBS UP SIGN}' + + try: + reaction, user = await client.wait_for('reaction_add', timeout=60.0, check=check) + except asyncio.TimeoutError: + await channel.send('\N{THUMBS DOWN SIGN}') + else: + await channel.send('\N{THUMBS UP SIGN}') + + + Parameters + ------------ + event: str + The event name, similar to the :ref:`event reference `, + but without the ``on_`` prefix, to wait for. + check: Optional[predicate] + A predicate to check what to wait for. The arguments must meet the + parameters of the event being waited for. + timeout: Optional[float] + The number of seconds to wait before timing out and raising + :exc:`asyncio.TimeoutError`. + + Raises + ------- + asyncio.TimeoutError + If a timeout is provided and it was reached. + + Returns + -------- + Any + Returns no arguments, a single argument, or a :obj:`tuple` of multiple + arguments that mirrors the parameters passed in the + :ref:`event reference `. + """ + + future = self.loop.create_future() + if check is None: + + def _check(*args): + return True + + check = _check + + ev = event.lower() + try: + listeners = self._listeners[ev] + except KeyError: + listeners = [] + self._listeners[ev] = listeners + + listeners.append((future, check)) + return asyncio.wait_for(future, timeout, loop=self.loop) + + # event registration + + def event(self, coro): + """A decorator that registers an event to listen to. + + You can find more info about the events on the :ref:`documentation below `. + + The events must be a |corourl|_, if not, :exc:`ClientException` is raised. + + Example + --------- + + :: + @client.event + async def on_ready(): + print('Ready!') + + """ + + if not asyncio.iscoroutinefunction(coro): + raise ClientException("event registered must be a coroutine function") + + setattr(self, coro.__name__, coro) + log.debug("%s has successfully been registered as an event", coro.__name__) + return coro + + async def change_presence(self, *, activity=None, status=None, afk=False): + """|coro| + + Changes the client's presence. + + The activity parameter is a :class:`Activity` object (not a string) that represents + the activity being done currently. This could also be the slimmed down versions, + :class:`Game` and :class:`Streaming`. + + Example: :: + + game = discord.Game("with the API") + await client.change_presence(status=discord.Status.idle, activity=game) + + Parameters + ---------- + activity: Optional[Union[:class:`Game`, :class:`Streaming`, :class:`Activity`]] + The activity being done. ``None`` if no currently active activity is done. + status: Optional[:class:`Status`] + Indicates what status to change to. If None, then + :attr:`Status.online` is used. + afk: bool + Indicates if you are going AFK. This allows the discord + client to know how to handle push notifications better + for you in case you are actually idle and not lying. + + Raises + ------ + InvalidArgument + If the ``activity`` parameter is not the proper type. + """ + + if status is None: + status = "online" + status_enum = Status.online + elif status is Status.offline: + status = "invisible" + status_enum = Status.offline + else: + status_enum = status + status = str(status) + + await self.ws.change_presence(activity=activity, status=status, afk=afk) + + for guild in self._connection.guilds: + me = guild.me + if me is None: + continue + + me.activities = (activity,) + me.status = status_enum + + # Guild stuff + + async def create_guild(self, name, region=None, icon=None): + """|coro| + + Creates a :class:`Guild`. + + Bot accounts in more than 10 guilds are not allowed to create guilds. + + Parameters + ---------- + name: str + The name of the guild. + region: :class:`VoiceRegion` + The region for the voice communication server. + Defaults to :attr:`VoiceRegion.us_west`. + icon: bytes + The :term:`py:bytes-like object` representing the icon. See :meth:`~ClientUser.edit` + for more details on what is expected. + + Raises + ------ + HTTPException + Guild creation failed. + InvalidArgument + Invalid icon image format given. Must be PNG or JPG. + + Returns + ------- + :class:`Guild` + The guild created. This is not the same guild that is + added to cache. + """ + if icon is not None: + icon = utils._bytes_to_base64_data(icon) + + if region is None: + region = VoiceRegion.us_west.value + else: + region = region.value + + data = await self.http.create_guild(name, region, icon) + return Guild(data=data, state=self._connection) + + # Invite management + + async def get_invite(self, url): + """|coro| + + Gets an :class:`Invite` from a discord.gg URL or ID. + + Note + ------ + If the invite is for a guild you have not joined, the guild and channel + attributes of the returned invite will be :class:`Object` with the names + patched in. + + Parameters + ----------- + url : str + The discord invite ID or URL (must be a discord.gg URL). + + Raises + ------- + NotFound + The invite has expired or is invalid. + HTTPException + Getting the invite failed. + + Returns + -------- + :class:`Invite` + The invite from the URL/ID. + """ + + invite_id = self._resolve_invite(url) + data = await self.http.get_invite(invite_id) + return Invite.from_incomplete(state=self._connection, data=data) + + async def delete_invite(self, invite): + """|coro| + + Revokes an :class:`Invite`, URL, or ID to an invite. + + You must have the :attr:`~Permissions.manage_channels` permission in + the associated guild to do this. + + Parameters + ---------- + invite + The invite to revoke. + + Raises + ------- + Forbidden + You do not have permissions to revoke invites. + NotFound + The invite is invalid or expired. + HTTPException + Revoking the invite failed. + """ + + invite_id = self._resolve_invite(invite) + await self.http.delete_invite(invite_id) + + # Miscellaneous stuff + + async def application_info(self): + """|coro| + + Retrieve's the bot's application information. + + Returns + -------- + :class:`AppInfo` + A namedtuple representing the application info. + + Raises + ------- + HTTPException + Retrieving the information failed somehow. + """ + data = await self.http.application_info() + if "rpc_origins" not in data: + data["rpc_origins"] = None + return AppInfo( + id=int(data["id"]), + name=data["name"], + description=data["description"], + icon=data["icon"], + rpc_origins=data["rpc_origins"], + bot_public=data["bot_public"], + bot_require_code_grant=data["bot_require_code_grant"], + owner=User(state=self._connection, data=data["owner"]), + ) + + async def get_user_info(self, user_id): + """|coro| + + Retrieves a :class:`User` based on their ID. This can only + be used by bot accounts. You do not have to share any guilds + with the user to get this information, however many operations + do require that you do. + + Parameters + ----------- + user_id: int + The user's ID to fetch from. + + Returns + -------- + :class:`User` + The user you requested. + + Raises + ------- + NotFound + A user with this ID does not exist. + HTTPException + Fetching the user failed. + """ + data = await self.http.get_user_info(user_id) + return User(state=self._connection, data=data) + + async def get_user_profile(self, user_id): + """|coro| + + Gets an arbitrary user's profile. This can only be used by non-bot accounts. + + Parameters + ------------ + user_id: int + The ID of the user to fetch their profile for. + + Raises + ------- + Forbidden + Not allowed to fetch profiles. + HTTPException + Fetching the profile failed. + + Returns + -------- + :class:`Profile` + The profile of the user. + """ + + state = self._connection + data = await self.http.get_user_profile(user_id) + + def transform(d): + return state._get_guild(int(d["id"])) + + since = data.get("premium_since") + mutual_guilds = list(filter(None, map(transform, data.get("mutual_guilds", [])))) + user = data["user"] + return Profile( + flags=user.get("flags", 0), + premium_since=utils.parse_time(since), + mutual_guilds=mutual_guilds, + user=User(data=user, state=state), + connected_accounts=data["connected_accounts"], + ) + + async def get_webhook_info(self, webhook_id): + """|coro| + + Retrieves a :class:`Webhook` with the specified ID. + + Raises + -------- + HTTPException + Retrieving the webhook failed. + NotFound + Invalid webhook ID. + Forbidden + You do not have permission to fetch this webhook. + + Returns + --------- + :class:`Webhook` + The webhook you requested. + """ + data = await self.http.get_webhook(webhook_id) + return Webhook.from_state(data, state=self._connection) diff --git a/discord/colour.py b/discord/colour.py new file mode 100644 index 000000000..f741dd3ef --- /dev/null +++ b/discord/colour.py @@ -0,0 +1,234 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import colorsys + + +class Colour: + """Represents a Discord role colour. This class is similar + to an (red, green, blue) :class:`tuple`. + + There is an alias for this called Color. + + .. container:: operations + + .. describe:: x == y + + Checks if two colours are equal. + + .. describe:: x != y + + Checks if two colours are not equal. + + .. describe:: hash(x) + + Return the colour's hash. + + .. describe:: str(x) + + Returns the hex format for the colour. + + Attributes + ------------ + value: :class:`int` + The raw integer colour value. + """ + + __slots__ = ("value",) + + def __init__(self, value): + if not isinstance(value, int): + raise TypeError( + "Expected int parameter, received %s instead." % value.__class__.__name__ + ) + + self.value = value + + def _get_byte(self, byte): + return (self.value >> (8 * byte)) & 0xFF + + def __eq__(self, other): + return isinstance(other, Colour) and self.value == other.value + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "#{:0>6x}".format(self.value) + + def __repr__(self): + return "" % self.value + + def __hash__(self): + return hash(self.value) + + @property + def r(self): + """Returns the red component of the colour.""" + return self._get_byte(2) + + @property + def g(self): + """Returns the green component of the colour.""" + return self._get_byte(1) + + @property + def b(self): + """Returns the blue component of the colour.""" + return self._get_byte(0) + + def to_rgb(self): + """Returns an (r, g, b) tuple representing the colour.""" + return (self.r, self.g, self.b) + + @classmethod + def from_rgb(cls, r, g, b): + """Constructs a :class:`Colour` from an RGB tuple.""" + return cls((r << 16) + (g << 8) + b) + + @classmethod + def from_hsv(cls, h, s, v): + """Constructs a :class:`Colour` from an HSV tuple.""" + rgb = colorsys.hsv_to_rgb(h, s, v) + return cls.from_rgb(*(int(x * 255) for x in rgb)) + + @classmethod + def default(cls): + """A factory method that returns a :class:`Colour` with a value of 0.""" + return cls(0) + + @classmethod + def teal(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" + return cls(0x1ABC9C) + + @classmethod + def dark_teal(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" + return cls(0x11806A) + + @classmethod + def green(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" + return cls(0x2ECC71) + + @classmethod + def dark_green(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" + return cls(0x1F8B4C) + + @classmethod + def blue(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" + return cls(0x3498DB) + + @classmethod + def dark_blue(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x206694``.""" + return cls(0x206694) + + @classmethod + def purple(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" + return cls(0x9B59B6) + + @classmethod + def dark_purple(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" + return cls(0x71368A) + + @classmethod + def magenta(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" + return cls(0xE91E63) + + @classmethod + def dark_magenta(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" + return cls(0xAD1457) + + @classmethod + def gold(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" + return cls(0xF1C40F) + + @classmethod + def dark_gold(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" + return cls(0xC27C0E) + + @classmethod + def orange(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" + return cls(0xE67E22) + + @classmethod + def dark_orange(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" + return cls(0xA84300) + + @classmethod + def red(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" + return cls(0xE74C3C) + + @classmethod + def dark_red(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" + return cls(0x992D22) + + @classmethod + def lighter_grey(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" + return cls(0x95A5A6) + + @classmethod + def dark_grey(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" + return cls(0x607D8B) + + @classmethod + def light_grey(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" + return cls(0x979C9F) + + @classmethod + def darker_grey(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" + return cls(0x546E7A) + + @classmethod + def blurple(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" + return cls(0x7289DA) + + @classmethod + def greyple(cls): + """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" + return cls(0x99AAB5) + + +Color = Colour diff --git a/discord/context_managers.py b/discord/context_managers.py new file mode 100644 index 000000000..b7f21825b --- /dev/null +++ b/discord/context_managers.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio + + +def _typing_done_callback(fut): + # just retrieve any exception and call it a day + try: + fut.exception() + except Exception: + pass + + +class Typing: + def __init__(self, messageable): + self.loop = messageable._state.loop + self.messageable = messageable + + async def do_typing(self): + try: + channel = self._channel + except AttributeError: + channel = await self.messageable._get_channel() + + typing = channel._state.http.send_typing + + while True: + await typing(channel.id) + await asyncio.sleep(5) + + def __enter__(self): + self.task = asyncio.ensure_future(self.do_typing(), loop=self.loop) + self.task.add_done_callback(_typing_done_callback) + return self + + def __exit__(self, exc_type, exc, tb): + self.task.cancel() + + async def __aenter__(self): + self._channel = channel = await self.messageable._get_channel() + await channel._state.http.send_typing(channel.id) + return self.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + self.task.cancel() diff --git a/discord/embeds.py b/discord/embeds.py new file mode 100644 index 000000000..2c3fb9c1e --- /dev/null +++ b/discord/embeds.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import datetime + +from . import utils +from .colour import Colour + + +class _EmptyEmbed: + def __bool__(self): + return False + + def __repr__(self): + return "Embed.Empty" + + +EmptyEmbed = _EmptyEmbed() + + +class EmbedProxy: + def __init__(self, layer): + self.__dict__.update(layer) + + def __len__(self): + return len(self.__dict__) + + def __repr__(self): + return "EmbedProxy(%s)" % ", ".join( + ("%s=%r" % (k, v) for k, v in self.__dict__.items() if not k.startswith("_")) + ) + + def __getattr__(self, attr): + return EmptyEmbed + + +class Embed: + """Represents a Discord embed. + + The following attributes can be set during creation + of the object: + + Certain properties return an ``EmbedProxy``. Which is a type + that acts similar to a regular :class:`dict` except access the attributes + via dotted access, e.g. ``embed.author.icon_url``. If the attribute + is invalid or empty, then a special sentinel value is returned, + :attr:`Embed.Empty`. + + For ease of use, all parameters that expect a :class:`str` are implicitly + casted to :class:`str` for you. + + Attributes + ----------- + title: :class:`str` + The title of the embed. + type: :class:`str` + The type of embed. Usually "rich". + description: :class:`str` + The description of the embed. + url: :class:`str` + The URL of the embed. + timestamp: `datetime.datetime` + The timestamp of the embed content. This could be a naive or aware datetime. + colour: :class:`Colour` or :class:`int` + The colour code of the embed. Aliased to ``color`` as well. + Empty + A special sentinel value used by ``EmbedProxy`` and this class + to denote that the value or attribute is empty. + """ + + __slots__ = ( + "title", + "url", + "type", + "_timestamp", + "_colour", + "_footer", + "_image", + "_thumbnail", + "_video", + "_provider", + "_author", + "_fields", + "description", + ) + + Empty = EmptyEmbed + + def __init__(self, **kwargs): + # swap the colour/color aliases + try: + colour = kwargs["colour"] + except KeyError: + colour = kwargs.get("color", EmptyEmbed) + + self.colour = colour + self.title = kwargs.get("title", EmptyEmbed) + self.type = kwargs.get("type", "rich") + self.url = kwargs.get("url", EmptyEmbed) + self.description = kwargs.get("description", EmptyEmbed) + + try: + timestamp = kwargs["timestamp"] + except KeyError: + pass + else: + self.timestamp = timestamp + + @classmethod + def from_data(cls, data): + # we are bypassing __init__ here since it doesn't apply here + self = cls.__new__(cls) + + # fill in the basic fields + + self.title = data.get("title", EmptyEmbed) + self.type = data.get("type", EmptyEmbed) + self.description = data.get("description", EmptyEmbed) + self.url = data.get("url", EmptyEmbed) + + # try to fill in the more rich fields + + try: + self._colour = Colour(value=data["color"]) + except KeyError: + pass + + try: + self._timestamp = utils.parse_time(data["timestamp"]) + except KeyError: + pass + + for attr in ("thumbnail", "video", "provider", "author", "fields", "image", "footer"): + try: + value = data[attr] + except KeyError: + continue + else: + setattr(self, "_" + attr, value) + + return self + + @property + def colour(self): + return getattr(self, "_colour", EmptyEmbed) + + @colour.setter + def colour(self, value): + if isinstance(value, (Colour, _EmptyEmbed)): + self._colour = value + elif isinstance(value, int): + self._colour = Colour(value=value) + else: + raise TypeError( + "Expected discord.Colour, int, or Embed.Empty but received %s instead." + % value.__class__.__name__ + ) + + color = colour + + @property + def timestamp(self): + return getattr(self, "_timestamp", EmptyEmbed) + + @timestamp.setter + def timestamp(self, value): + if isinstance(value, (datetime.datetime, _EmptyEmbed)): + self._timestamp = value + else: + raise TypeError( + "Expected datetime.datetime or Embed.Empty received %s instead" + % value.__class__.__name__ + ) + + @property + def footer(self): + """Returns an ``EmbedProxy`` denoting the footer contents. + + See :meth:`set_footer` for possible values you can access. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_footer", {})) + + def set_footer(self, *, text=EmptyEmbed, icon_url=EmptyEmbed): + """Sets the footer for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + text: str + The footer text. + icon_url: str + The URL of the footer icon. Only HTTP(S) is supported. + """ + + self._footer = {} + if text is not EmptyEmbed: + self._footer["text"] = str(text) + + if icon_url is not EmptyEmbed: + self._footer["icon_url"] = str(icon_url) + + return self + + @property + def image(self): + """Returns an ``EmbedProxy`` denoting the image contents. + + Possible attributes you can access are: + + - ``url`` + - ``proxy_url`` + - ``width`` + - ``height`` + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_image", {})) + + def set_image(self, *, url): + """Sets the image for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + url: str + The source URL for the image. Only HTTP(S) is supported. + """ + + self._image = {"url": str(url)} + + return self + + @property + def thumbnail(self): + """Returns an ``EmbedProxy`` denoting the thumbnail contents. + + Possible attributes you can access are: + + - ``url`` + - ``proxy_url`` + - ``width`` + - ``height`` + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_thumbnail", {})) + + def set_thumbnail(self, *, url): + """Sets the thumbnail for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + url: str + The source URL for the thumbnail. Only HTTP(S) is supported. + """ + + self._thumbnail = {"url": str(url)} + + return self + + @property + def video(self): + """Returns an ``EmbedProxy`` denoting the video contents. + + Possible attributes include: + + - ``url`` for the video URL. + - ``height`` for the video height. + - ``width`` for the video width. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_video", {})) + + @property + def provider(self): + """Returns an ``EmbedProxy`` denoting the provider contents. + + The only attributes that might be accessed are ``name`` and ``url``. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_provider", {})) + + @property + def author(self): + """Returns an ``EmbedProxy`` denoting the author contents. + + See :meth:`set_author` for possible values you can access. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return EmbedProxy(getattr(self, "_author", {})) + + def set_author(self, *, name, url=EmptyEmbed, icon_url=EmptyEmbed): + """Sets the author for the embed content. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + name: str + The name of the author. + url: str + The URL for the author. + icon_url: str + The URL of the author icon. Only HTTP(S) is supported. + """ + + self._author = {"name": str(name)} + + if url is not EmptyEmbed: + self._author["url"] = str(url) + + if icon_url is not EmptyEmbed: + self._author["icon_url"] = str(icon_url) + + return self + + @property + def fields(self): + """Returns a :class:`list` of ``EmbedProxy`` denoting the field contents. + + See :meth:`add_field` for possible values you can access. + + If the attribute has no value then :attr:`Empty` is returned. + """ + return [EmbedProxy(d) for d in getattr(self, "_fields", [])] + + def add_field(self, *, name, value, inline=True): + """Adds a field to the embed object. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + name: str + The name of the field. + value: str + The value of the field. + inline: bool + Whether the field should be displayed inline. + """ + + field = {"inline": inline, "name": str(name), "value": str(value)} + + try: + self._fields.append(field) + except AttributeError: + self._fields = [field] + + return self + + def clear_fields(self): + """Removes all fields from this embed.""" + try: + self._fields.clear() + except AttributeError: + self._fields = [] + + def remove_field(self, index): + """Removes a field at a specified index. + + If the index is invalid or out of bounds then the error is + silently swallowed. + + .. note:: + + When deleting a field by index, the index of the other fields + shift to fill the gap just like a regular list. + + Parameters + ----------- + index: int + The index of the field to remove. + """ + try: + del self._fields[index] + except (AttributeError, IndexError): + pass + + def set_field_at(self, index, *, name, value, inline=True): + """Modifies a field to the embed object. + + The index must point to a valid pre-existing field. + + This function returns the class instance to allow for fluent-style + chaining. + + Parameters + ----------- + index: int + The index of the field to modify. + name: str + The name of the field. + value: str + The value of the field. + inline: bool + Whether the field should be displayed inline. + + Raises + ------- + IndexError + An invalid index was provided. + """ + + try: + field = self._fields[index] + except (TypeError, IndexError, AttributeError): + raise IndexError("field index out of range") + + field["name"] = str(name) + field["value"] = str(value) + field["inline"] = inline + return self + + def to_dict(self): + """Converts this embed object into a dict.""" + + # add in the raw data into the dict + result = { + key[1:]: getattr(self, key) + for key in self.__slots__ + if key[0] == "_" and hasattr(self, key) + } + + # deal with basic convenience wrappers + + try: + colour = result.pop("colour") + except KeyError: + pass + else: + if colour: + result["color"] = colour.value + + try: + timestamp = result.pop("timestamp") + except KeyError: + pass + else: + if timestamp: + result["timestamp"] = timestamp.isoformat() + + # add in the non raw attribute ones + if self.type: + result["type"] = self.type + + if self.description: + result["description"] = self.description + + if self.url: + result["url"] = self.url + + if self.title: + result["title"] = self.title + + return result diff --git a/discord/emoji.py b/discord/emoji.py new file mode 100644 index 000000000..11653db43 --- /dev/null +++ b/discord/emoji.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections import namedtuple + +from . import utils +from .mixins import Hashable + + +class PartialEmoji(namedtuple("PartialEmoji", "animated name id")): + """Represents a "partial" emoji. + + This model will be given in two scenarios: + + - "Raw" data events such as :func:`on_raw_reaction_add` + - Custom emoji that the bot cannot see from e.g. :attr:`Message.reactions` + + .. container:: operations + + .. describe:: x == y + + Checks if two emoji are the same. + + .. describe:: x != y + + Checks if two emoji are not the same. + + .. describe:: hash(x) + + Return the emoji's hash. + + .. describe:: str(x) + + Returns the emoji rendered for discord. + + Attributes + ----------- + name: :class:`str` + The custom emoji name, if applicable, or the unicode codepoint + of the non-custom emoji. + animated: :class:`bool` + Whether the emoji is animated or not. + id: Optional[:class:`int`] + The ID of the custom emoji, if applicable. + """ + + __slots__ = () + + def __str__(self): + if self.id is None: + return self.name + if self.animated: + return "" % (self.name, self.id) + return "<:%s:%s>" % (self.name, self.id) + + def is_custom_emoji(self): + """Checks if this is a custom non-Unicode emoji.""" + return self.id is not None + + def is_unicode_emoji(self): + """Checks if this is a Unicode emoji.""" + return self.id is None + + def _as_reaction(self): + if self.id is None: + return self.name + return "%s:%s" % (self.name, self.id) + + @property + def url(self): + """Returns a URL version of the emoji, if it is custom.""" + if self.is_unicode_emoji(): + return None + + _format = "gif" if self.animated else "png" + return "https://cdn.discordapp.com/emojis/{0.id}.{1}".format(self, _format) + + +class Emoji(Hashable): + """Represents a custom emoji. + + Depending on the way this object was created, some of the attributes can + have a value of ``None``. + + .. container:: operations + + .. describe:: x == y + + Checks if two emoji are the same. + + .. describe:: x != y + + Checks if two emoji are not the same. + + .. describe:: hash(x) + + Return the emoji's hash. + + .. describe:: iter(x) + + Returns an iterator of ``(field, value)`` pairs. This allows this class + to be used as an iterable in list/dict/etc constructions. + + .. describe:: str(x) + + Returns the emoji rendered for discord. + + Attributes + ----------- + name: :class:`str` + The name of the emoji. + id: :class:`int` + The emoji's ID. + require_colons: :class:`bool` + If colons are required to use this emoji in the client (:PJSalt: vs PJSalt). + animated: :class:`bool` + Whether an emoji is animated or not. + managed: :class:`bool` + If this emoji is managed by a Twitch integration. + guild_id: :class:`int` + The guild ID the emoji belongs to. + """ + + __slots__ = ( + "require_colons", + "animated", + "managed", + "id", + "name", + "_roles", + "guild_id", + "_state", + ) + + def __init__(self, *, guild, state, data): + self.guild_id = guild.id + self._state = state + self._from_data(data) + + def _from_data(self, emoji): + self.require_colons = emoji["require_colons"] + self.managed = emoji["managed"] + self.id = int(emoji["id"]) + self.name = emoji["name"] + self.animated = emoji.get("animated", False) + self._roles = utils.SnowflakeList(map(int, emoji.get("roles", []))) + + def _iterator(self): + for attr in self.__slots__: + if attr[0] != "_": + value = getattr(self, attr, None) + if value is not None: + yield (attr, value) + + def __iter__(self): + return self._iterator() + + def __str__(self): + if self.animated: + return "".format(self) + return "<:{0.name}:{0.id}>".format(self) + + def __repr__(self): + return "".format(self) + + @property + def created_at(self): + """Returns the emoji's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def url(self): + """Returns a URL version of the emoji.""" + _format = "gif" if self.animated else "png" + return "https://cdn.discordapp.com/emojis/{0.id}.{1}".format(self, _format) + + @property + def roles(self): + """List[:class:`Role`]: A :class:`list` of roles that is allowed to use this emoji. + + If roles is empty, the emoji is unrestricted. + """ + guild = self.guild + if guild is None: + return [] + + return [role for role in guild.roles if self._roles.has(role.id)] + + @property + def guild(self): + """:class:`Guild`: The guild this emoji belongs to.""" + return self._state._get_guild(self.guild_id) + + async def delete(self, *, reason=None): + """|coro| + + Deletes the custom emoji. + + You must have :attr:`~Permissions.manage_emojis` permission to + do this. + + Parameters + ----------- + reason: Optional[str] + The reason for deleting this emoji. Shows up on the audit log. + + Raises + ------- + Forbidden + You are not allowed to delete emojis. + HTTPException + An error occurred deleting the emoji. + """ + + await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) + + async def edit(self, *, name, roles=None, reason=None): + r"""|coro| + + Edits the custom emoji. + + You must have :attr:`~Permissions.manage_emojis` permission to + do this. + + Parameters + ----------- + name: str + The new emoji name. + roles: Optional[list[:class:`Role`]] + A :class:`list` of :class:`Role`\s that can use this emoji. Leave empty to make it available to everyone. + reason: Optional[str] + The reason for editing this emoji. Shows up on the audit log. + + Raises + ------- + Forbidden + You are not allowed to edit emojis. + HTTPException + An error occurred editing the emoji. + """ + + if roles: + roles = [role.id for role in roles] + await self._state.http.edit_custom_emoji( + self.guild.id, self.id, name=name, roles=roles, reason=reason + ) diff --git a/discord/enums.py b/discord/enums.py new file mode 100644 index 000000000..8262b80c5 --- /dev/null +++ b/discord/enums.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from enum import Enum, IntEnum + +__all__ = [ + "ChannelType", + "MessageType", + "VoiceRegion", + "VerificationLevel", + "ContentFilter", + "Status", + "DefaultAvatar", + "RelationshipType", + "AuditLogAction", + "AuditLogActionCategory", + "UserFlags", + "ActivityType", + "HypeSquadHouse", + "NotificationLevel", +] + + +class ChannelType(Enum): + text = 0 + private = 1 + voice = 2 + group = 3 + category = 4 + + def __str__(self): + return self.name + + +class MessageType(Enum): + default = 0 + recipient_add = 1 + recipient_remove = 2 + call = 3 + channel_name_change = 4 + channel_icon_change = 5 + pins_add = 6 + new_member = 7 + + +class VoiceRegion(Enum): + us_west = "us-west" + us_east = "us-east" + us_south = "us-south" + us_central = "us-central" + eu_west = "eu-west" + eu_central = "eu-central" + singapore = "singapore" + london = "london" + sydney = "sydney" + amsterdam = "amsterdam" + frankfurt = "frankfurt" + brazil = "brazil" + hongkong = "hongkong" + russia = "russia" + japan = "japan" + southafrica = "southafrica" + vip_us_east = "vip-us-east" + vip_us_west = "vip-us-west" + vip_amsterdam = "vip-amsterdam" + + def __str__(self): + return self.value + + +class VerificationLevel(IntEnum): + none = 0 + low = 1 + medium = 2 + high = 3 + table_flip = 3 + extreme = 4 + double_table_flip = 4 + + def __str__(self): + return self.name + + +class ContentFilter(IntEnum): + disabled = 0 + no_role = 1 + all_members = 2 + + def __str__(self): + return self.name + + +class Status(Enum): + online = "online" + offline = "offline" + idle = "idle" + dnd = "dnd" + do_not_disturb = "dnd" + invisible = "invisible" + + def __str__(self): + return self.value + + +class DefaultAvatar(Enum): + blurple = 0 + grey = 1 + gray = 1 + green = 2 + orange = 3 + red = 4 + + def __str__(self): + return self.name + + +class RelationshipType(Enum): + friend = 1 + blocked = 2 + incoming_request = 3 + outgoing_request = 4 + + +class NotificationLevel(IntEnum): + all_messages = 0 + only_mentions = 1 + + +class AuditLogActionCategory(Enum): + create = 1 + delete = 2 + update = 3 + + +class AuditLogAction(Enum): + guild_update = 1 + channel_create = 10 + channel_update = 11 + channel_delete = 12 + overwrite_create = 13 + overwrite_update = 14 + overwrite_delete = 15 + kick = 20 + member_prune = 21 + ban = 22 + unban = 23 + member_update = 24 + member_role_update = 25 + role_create = 30 + role_update = 31 + role_delete = 32 + invite_create = 40 + invite_update = 41 + invite_delete = 42 + webhook_create = 50 + webhook_update = 51 + webhook_delete = 52 + emoji_create = 60 + emoji_update = 61 + emoji_delete = 62 + message_delete = 72 + + @property + def category(self): + lookup = { + AuditLogAction.guild_update: AuditLogActionCategory.update, + AuditLogAction.channel_create: AuditLogActionCategory.create, + AuditLogAction.channel_update: AuditLogActionCategory.update, + AuditLogAction.channel_delete: AuditLogActionCategory.delete, + AuditLogAction.overwrite_create: AuditLogActionCategory.create, + AuditLogAction.overwrite_update: AuditLogActionCategory.update, + AuditLogAction.overwrite_delete: AuditLogActionCategory.delete, + AuditLogAction.kick: None, + AuditLogAction.member_prune: None, + AuditLogAction.ban: None, + AuditLogAction.unban: None, + AuditLogAction.member_update: AuditLogActionCategory.update, + AuditLogAction.member_role_update: AuditLogActionCategory.update, + AuditLogAction.role_create: AuditLogActionCategory.create, + AuditLogAction.role_update: AuditLogActionCategory.update, + AuditLogAction.role_delete: AuditLogActionCategory.delete, + AuditLogAction.invite_create: AuditLogActionCategory.create, + AuditLogAction.invite_update: AuditLogActionCategory.update, + AuditLogAction.invite_delete: AuditLogActionCategory.delete, + AuditLogAction.webhook_create: AuditLogActionCategory.create, + AuditLogAction.webhook_update: AuditLogActionCategory.update, + AuditLogAction.webhook_delete: AuditLogActionCategory.delete, + AuditLogAction.emoji_create: AuditLogActionCategory.create, + AuditLogAction.emoji_update: AuditLogActionCategory.update, + AuditLogAction.emoji_delete: AuditLogActionCategory.delete, + AuditLogAction.message_delete: AuditLogActionCategory.delete, + } + return lookup[self] + + @property + def target_type(self): + v = self.value + if v == -1: + return "all" + elif v < 10: + return "guild" + elif v < 20: + return "channel" + elif v < 30: + return "user" + elif v < 40: + return "role" + elif v < 50: + return "invite" + elif v < 60: + return "webhook" + elif v < 70: + return "emoji" + elif v < 80: + return "message" + + +class UserFlags(Enum): + staff = 1 + partner = 2 + hypesquad = 4 + bug_hunter = 8 + hypesquad_bravery = 64 + hypesquad_brilliance = 128 + hypesquad_balance = 256 + early_supporter = 512 + + +class ActivityType(IntEnum): + unknown = -1 + playing = 0 + streaming = 1 + listening = 2 + watching = 3 + + +class HypeSquadHouse(Enum): + bravery = 1 + brilliance = 2 + balance = 3 + + +def try_enum(cls, val): + """A function that tries to turn the value into enum ``cls``. + + If it fails it returns the value instead. + """ + try: + return cls(val) + except ValueError: + return val diff --git a/discord/errors.py b/discord/errors.py new file mode 100644 index 000000000..8a66d1d0b --- /dev/null +++ b/discord/errors.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +class DiscordException(Exception): + """Base exception class for discord.py + + Ideally speaking, this could be caught to handle any exceptions thrown from this library. + """ + + pass + + +class ClientException(DiscordException): + """Exception that's thrown when an operation in the :class:`Client` fails. + + These are usually for exceptions that happened due to user input. + """ + + pass + + +class NoMoreItems(DiscordException): + """Exception that is thrown when an async iteration operation has no more + items.""" + + pass + + +class GatewayNotFound(DiscordException): + """An exception that is usually thrown when the gateway hub + for the :class:`Client` websocket is not found.""" + + def __init__(self): + message = "The gateway to connect to discord was not found." + super(GatewayNotFound, self).__init__(message) + + +def flatten_error_dict(d, key=""): + items = [] + for k, v in d.items(): + new_key = key + "." + k if key else k + + if isinstance(v, dict): + try: + _errors = v["_errors"] + except KeyError: + items.extend(flatten_error_dict(v, new_key).items()) + else: + items.append((new_key, " ".join(x.get("message", "") for x in _errors))) + else: + items.append((new_key, v)) + + return dict(items) + + +class HTTPException(DiscordException): + """Exception that's thrown when an HTTP request operation fails. + + Attributes + ------------ + response: aiohttp.ClientResponse + The response of the failed HTTP request. This is an + instance of `aiohttp.ClientResponse`__. In some cases + this could also be a ``requests.Response``. + + __ http://aiohttp.readthedocs.org/en/stable/client_reference.html#aiohttp.ClientResponse + + text: :class:`str` + The text of the error. Could be an empty string. + status: :class:`int` + The status code of the HTTP request. + code: :class:`int` + The Discord specific error code for the failure. + """ + + def __init__(self, response, message): + self.response = response + self.status = response.status + if isinstance(message, dict): + self.code = message.get("code", 0) + base = message.get("message", "") + errors = message.get("errors") + if errors: + errors = flatten_error_dict(errors) + helpful = "\n".join("In %s: %s" % t for t in errors.items()) + self.text = base + "\n" + helpful + else: + self.text = base + else: + self.text = message + self.code = 0 + + fmt = "{0.reason} (status code: {0.status})" + if len(self.text): + fmt = fmt + ": {1}" + + super().__init__(fmt.format(self.response, self.text)) + + +class Forbidden(HTTPException): + """Exception that's thrown for when status code 403 occurs. + + Subclass of :exc:`HTTPException` + """ + + pass + + +class NotFound(HTTPException): + """Exception that's thrown for when status code 404 occurs. + + Subclass of :exc:`HTTPException` + """ + + pass + + +class InvalidArgument(ClientException): + """Exception that's thrown when an argument to a function + is invalid some way (e.g. wrong value or wrong type). + + This could be considered the analogous of ``ValueError`` and + ``TypeError`` except derived from :exc:`ClientException` and thus + :exc:`DiscordException`. + """ + + pass + + +class LoginFailure(ClientException): + """Exception that's thrown when the :meth:`Client.login` function + fails to log you in from improper credentials or some other misc. + failure. + """ + + pass + + +class ConnectionClosed(ClientException): + """Exception that's thrown when the gateway connection is + closed for reasons that could not be handled internally. + + Attributes + ----------- + code: :class:`int` + The close code of the websocket. + reason: :class:`str` + The reason provided for the closure. + shard_id: Optional[:class:`int`] + The shard ID that got closed if applicable. + """ + + def __init__(self, original, *, shard_id): + # This exception is just the same exception except + # reconfigured to subclass ClientException for users + self.code = original.code + self.reason = original.reason + self.shard_id = shard_id + super().__init__(str(original)) diff --git a/discord/ext/commands/__init__.py b/discord/ext/commands/__init__.py new file mode 100644 index 000000000..f5f7bb263 --- /dev/null +++ b/discord/ext/commands/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- + +""" +discord.ext.commands +~~~~~~~~~~~~~~~~~~~~~ + +An extension module to facilitate creation of bot commands. + +:copyright: (c) 2017 Rapptz +:license: MIT, see LICENSE for more details. +""" + +from .bot import Bot, AutoShardedBot, when_mentioned, when_mentioned_or +from .context import Context +from .core import * +from .errors import * +from .formatter import HelpFormatter, Paginator +from .converter import * +from .cooldowns import * diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py new file mode 100644 index 000000000..a13989524 --- /dev/null +++ b/discord/ext/commands/bot.py @@ -0,0 +1,1049 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import collections +import inspect +import importlib +import sys +import traceback +import re + +import discord + +from .core import GroupMixin, Command +from .view import StringView +from .context import Context +from .errors import CommandNotFound, CommandError +from .formatter import HelpFormatter + + +def when_mentioned(bot, msg): + """A callable that implements a command prefix equivalent to being mentioned. + + These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. + """ + return [bot.user.mention + " ", "<@!%s> " % bot.user.id] + + +def when_mentioned_or(*prefixes): + """A callable that implements when mentioned or other prefixes provided. + + These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. + + Example + -------- + + .. code-block:: python3 + + bot = commands.Bot(command_prefix=commands.when_mentioned_or('!')) + + + .. note:: + + This callable returns another callable, so if this is done inside a custom + callable, you must call the returned callable, for example: + + .. code-block:: python3 + + async def get_prefix(bot, message): + extras = await prefixes_for(message.guild) # returns a list + return commands.when_mentioned_or(*extras)(bot, message) + + + See Also + ---------- + :func:`.when_mentioned` + """ + + def inner(bot, msg): + r = list(prefixes) + r = when_mentioned(bot, msg) + r + return r + + return inner + + +_mentions_transforms = {"@everyone": "@\u200beveryone", "@here": "@\u200bhere"} + +_mention_pattern = re.compile("|".join(_mentions_transforms.keys())) + + +def _is_submodule(parent, child): + return parent == child or child.startswith(parent + ".") + + +async def _default_help_command(ctx, *commands: str): + """Shows this message.""" + bot = ctx.bot + destination = ctx.message.author if bot.pm_help else ctx.message.channel + + def repl(obj): + return _mentions_transforms.get(obj.group(0), "") + + # help by itself just lists our own commands. + if len(commands) == 0: + pages = await bot.formatter.format_help_for(ctx, bot) + elif len(commands) == 1: + # try to see if it is a cog name + name = _mention_pattern.sub(repl, commands[0]) + command = None + if name in bot.cogs: + command = bot.cogs[name] + else: + command = bot.all_commands.get(name) + if command is None: + await destination.send(bot.command_not_found.format(name)) + return + + pages = await bot.formatter.format_help_for(ctx, command) + else: + name = _mention_pattern.sub(repl, commands[0]) + command = bot.all_commands.get(name) + if command is None: + await destination.send(bot.command_not_found.format(name)) + return + + for key in commands[1:]: + try: + key = _mention_pattern.sub(repl, key) + command = command.all_commands.get(key) + if command is None: + await destination.send(bot.command_not_found.format(key)) + return + except AttributeError: + await destination.send(bot.command_has_no_subcommands.format(command, key)) + return + + pages = await bot.formatter.format_help_for(ctx, command) + + if bot.pm_help is None: + characters = sum(map(len, pages)) + # modify destination based on length of pages. + if characters > 1000: + destination = ctx.message.author + + for page in pages: + await destination.send(page) + + +class BotBase(GroupMixin): + def __init__(self, command_prefix, formatter=None, description=None, pm_help=False, **options): + super().__init__(**options) + self.command_prefix = command_prefix + self.extra_events = {} + self.cogs = {} + self.extensions = {} + self._checks = [] + self._check_once = [] + self._before_invoke = None + self._after_invoke = None + self.description = inspect.cleandoc(description) if description else "" + self.pm_help = pm_help + self.owner_id = options.get("owner_id") + self.command_not_found = options.pop("command_not_found", 'No command called "{}" found.') + self.command_has_no_subcommands = options.pop( + "command_has_no_subcommands", "Command {0.name} has no subcommands." + ) + + if options.pop("self_bot", False): + self._skip_check = lambda x, y: x != y + else: + self._skip_check = lambda x, y: x == y + + self.help_attrs = options.pop("help_attrs", {}) + + if "name" not in self.help_attrs: + self.help_attrs["name"] = "help" + + if formatter is not None: + if not isinstance(formatter, HelpFormatter): + raise discord.ClientException("Formatter must be a subclass of HelpFormatter") + self.formatter = formatter + else: + self.formatter = HelpFormatter() + + # pay no mind to this ugliness. + self.command(**self.help_attrs)(_default_help_command) + + # internal helpers + + def dispatch(self, event_name, *args, **kwargs): + super().dispatch(event_name, *args, **kwargs) + ev = "on_" + event_name + for event in self.extra_events.get(ev, []): + coro = self._run_event(event, event_name, *args, **kwargs) + asyncio.ensure_future(coro, loop=self.loop) + + async def close(self): + for extension in tuple(self.extensions): + try: + self.unload_extension(extension) + except Exception: + pass + + for cog in tuple(self.cogs): + try: + self.remove_cog(cog) + except Exception: + pass + + await super().close() + + async def on_command_error(self, context, exception): + """|coro| + + The default command error handler provided by the bot. + + By default this prints to ``sys.stderr`` however it could be + overridden to have a different implementation. + + This only fires if you do not specify any listeners for command error. + """ + if self.extra_events.get("on_command_error", None): + return + + if hasattr(context.command, "on_error"): + return + + cog = context.cog + if cog: + attr = "_{0.__class__.__name__}__error".format(cog) + if hasattr(cog, attr): + return + + print("Ignoring exception in command {}:".format(context.command), file=sys.stderr) + traceback.print_exception( + type(exception), exception, exception.__traceback__, file=sys.stderr + ) + + # global check registration + + def check(self, func): + r"""A decorator that adds a global check to the bot. + + A global check is similar to a :func:`.check` that is applied + on a per command basis except it is run before any command checks + have been verified and applies to every command the bot has. + + .. note:: + + This function can either be a regular function or a coroutine. + + Similar to a command :func:`.check`\, this takes a single parameter + of type :class:`.Context` and can only raise exceptions derived from + :exc:`.CommandError`. + + Example + --------- + + .. code-block:: python3 + + @bot.check + def check_commands(ctx): + return ctx.command.qualified_name in allowed_commands + + """ + self.add_check(func) + return func + + def add_check(self, func, *, call_once=False): + """Adds a global check to the bot. + + This is the non-decorator interface to :meth:`.check` + and :meth:`.check_once`. + + Parameters + ----------- + func + The function that was used as a global check. + call_once: bool + If the function should only be called once per + :meth:`.Command.invoke` call. + """ + + if call_once: + self._check_once.append(func) + else: + self._checks.append(func) + + def remove_check(self, func, *, call_once=False): + """Removes a global check from the bot. + + This function is idempotent and will not raise an exception + if the function is not in the global checks. + + Parameters + ----------- + func + The function to remove from the global checks. + call_once: bool + If the function was added with ``call_once=True`` in + the :meth:`.Bot.add_check` call or using :meth:`.check_once`. + """ + l = self._check_once if call_once else self._checks + + try: + l.remove(func) + except ValueError: + pass + + def check_once(self, func): + r"""A decorator that adds a "call once" global check to the bot. + + Unlike regular global checks, this one is called only once + per :meth:`.Command.invoke` call. + + Regular global checks are called whenever a command is called + or :meth:`.Command.can_run` is called. This type of check + bypasses that and ensures that it's called only once, even inside + the default help command. + + .. note:: + + This function can either be a regular function or a coroutine. + + Similar to a command :func:`.check`\, this takes a single parameter + of type :class:`.Context` and can only raise exceptions derived from + :exc:`.CommandError`. + + Example + --------- + + .. code-block:: python3 + + @bot.check_once + def whitelist(ctx): + return ctx.message.author.id in my_whitelist + + """ + self.add_check(func, call_once=True) + return func + + async def can_run(self, ctx, *, call_once=False): + data = self._check_once if call_once else self._checks + + if len(data) == 0: + return True + + return await discord.utils.async_all(f(ctx) for f in data) + + async def is_owner(self, user): + """Checks if a :class:`.User` or :class:`.Member` is the owner of + this bot. + + If an :attr:`owner_id` is not set, it is fetched automatically + through the use of :meth:`~.Bot.application_info`. + + Parameters + ----------- + user: :class:`.abc.User` + The user to check for. + """ + + if self.owner_id is None: + app = await self.application_info() + self.owner_id = owner_id = app.owner.id + return user.id == owner_id + return user.id == self.owner_id + + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + .. note:: + + The :meth:`~.Bot.before_invoke` and :meth:`~.Bot.after_invoke` hooks are + only called if all checks and argument parsing procedures pass + without error. If any check or argument parsing procedures fail + then the hooks are not called. + + Parameters + ----------- + coro + The coroutine to register as the pre-invoke hook. + + Raises + ------- + :exc:`.ClientException` + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + r"""A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + .. note:: + + Similar to :meth:`~.Bot.before_invoke`\, this is not called unless + checks and argument parsing procedures succeed. This hook is, + however, **always** called regardless of the internal command + callback raising an error (i.e. :exc:`.CommandInvokeError`\). + This makes it ideal for clean-up scenarios. + + Parameters + ----------- + coro + The coroutine to register as the post-invoke hook. + + Raises + ------- + :exc:`.ClientException` + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + # listener registration + + def add_listener(self, func, name=None): + """The non decorator alternative to :meth:`.listen`. + + Parameters + ----------- + func : :ref:`coroutine ` + The extra event to listen to. + name : Optional[str] + The name of the command to use. Defaults to ``func.__name__``. + + Example + -------- + + .. code-block:: python3 + + async def on_ready(): pass + async def my_message(message): pass + + bot.add_listener(on_ready) + bot.add_listener(my_message, 'on_message') + + """ + name = func.__name__ if name is None else name + + if not asyncio.iscoroutinefunction(func): + raise discord.ClientException("Listeners must be coroutines") + + if name in self.extra_events: + self.extra_events[name].append(func) + else: + self.extra_events[name] = [func] + + def remove_listener(self, func, name=None): + """Removes a listener from the pool of listeners. + + Parameters + ----------- + func + The function that was used as a listener to remove. + name + The name of the event we want to remove. Defaults to + ``func.__name__``. + """ + + name = func.__name__ if name is None else name + + if name in self.extra_events: + try: + self.extra_events[name].remove(func) + except ValueError: + pass + + def listen(self, name=None): + """A decorator that registers another function as an external + event listener. Basically this allows you to listen to multiple + events from different places e.g. such as :func:`.on_ready` + + The functions being listened to must be a coroutine. + + Example + -------- + + .. code-block:: python3 + + @bot.listen() + async def on_message(message): + print('one') + + # in some other file... + + @bot.listen('on_message') + async def my_message(message): + print('two') + + Would print one and two in an unspecified order. + + Raises + ------- + :exc:`.ClientException` + The function being listened to is not a coroutine. + """ + + def decorator(func): + self.add_listener(func, name) + return func + + return decorator + + # cogs + + def add_cog(self, cog): + """Adds a "cog" to the bot. + + A cog is a class that has its own event listeners and commands. + + They are meant as a way to organize multiple relevant commands + into a singular class that shares some state or no state at all. + + The cog can also have a ``__global_check`` member function that allows + you to define a global check. See :meth:`.check` for more info. If + the name is ``__global_check_once`` then it's equivalent to the + :meth:`.check_once` decorator. + + More information will be documented soon. + + Parameters + ----------- + cog + The cog to register to the bot. + """ + + self.cogs[type(cog).__name__] = cog + + try: + check = getattr(cog, "_{.__class__.__name__}__global_check".format(cog)) + except AttributeError: + pass + else: + self.add_check(check) + + try: + check = getattr(cog, "_{.__class__.__name__}__global_check_once".format(cog)) + except AttributeError: + pass + else: + self.add_check(check, call_once=True) + + members = inspect.getmembers(cog) + for name, member in members: + # register commands the cog has + if isinstance(member, Command): + if member.parent is None: + self.add_command(member) + continue + + # register event listeners the cog has + if name.startswith("on_"): + self.add_listener(member, name) + + def get_cog(self, name): + """Gets the cog instance requested. + + If the cog is not found, ``None`` is returned instead. + + Parameters + ----------- + name : str + The name of the cog you are requesting. + """ + return self.cogs.get(name) + + def get_cog_commands(self, name): + """Gets a unique set of the cog's registered commands + without aliases. + + If the cog is not found, an empty set is returned. + + Parameters + ------------ + name: str + The name of the cog whose commands you are requesting. + + Returns + --------- + Set[:class:`.Command`] + A unique set of commands without aliases that belong + to the cog. + """ + + try: + cog = self.cogs[name] + except KeyError: + return set() + + return {c for c in self.all_commands.values() if c.instance is cog} + + def remove_cog(self, name): + """Removes a cog from the bot. + + All registered commands and event listeners that the + cog has registered will be removed as well. + + If no cog is found then this method has no effect. + + If the cog defines a special member function named ``__unload`` + then it is called when removal has completed. This function + **cannot** be a coroutine. It must be a regular function. + + Parameters + ----------- + name : str + The name of the cog to remove. + """ + + cog = self.cogs.pop(name, None) + if cog is None: + return + + members = inspect.getmembers(cog) + for name, member in members: + # remove commands the cog has + if isinstance(member, Command): + if member.parent is None: + self.remove_command(member.name) + continue + + # remove event listeners the cog has + if name.startswith("on_"): + self.remove_listener(member) + + try: + check = getattr(cog, "_{0.__class__.__name__}__global_check".format(cog)) + except AttributeError: + pass + else: + self.remove_check(check) + + try: + check = getattr(cog, "_{0.__class__.__name__}__global_check_once".format(cog)) + except AttributeError: + pass + else: + self.remove_check(check, call_once=True) + + unloader_name = "_{0.__class__.__name__}__unload".format(cog) + try: + unloader = getattr(cog, unloader_name) + except AttributeError: + pass + else: + unloader() + + del cog + + # extensions + + def load_extension(self, name): + """Loads an extension. + + An extension is a python module that contains commands, cogs, or + listeners. + + An extension must have a global function, ``setup`` defined as + the entry point on what to do when the extension is loaded. This entry + point must have a single argument, the ``bot``. + + Parameters + ------------ + name: str + The extension name to load. It must be dot separated like + regular Python imports if accessing a sub-module. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + + Raises + -------- + ClientException + The extension does not have a setup function. + ImportError + The extension could not be imported. + """ + + if name in self.extensions: + return + + lib = importlib.import_module(name) + if not hasattr(lib, "setup"): + del lib + del sys.modules[name] + raise discord.ClientException("extension does not have a setup function") + + lib.setup(self) + self.extensions[name] = lib + + def unload_extension(self, name): + """Unloads an extension. + + When the extension is unloaded, all commands, listeners, and cogs are + removed from the bot and the module is un-imported. + + The extension can provide an optional global function, ``teardown``, + to do miscellaneous clean-up if necessary. This function takes a single + parameter, the ``bot``, similar to ``setup`` from + :func:`~.Bot.load_extension`. + + Parameters + ------------ + name: str + The extension name to unload. It must be dot separated like + regular Python imports if accessing a sub-module. e.g. + ``foo.test`` if you want to import ``foo/test.py``. + """ + + lib = self.extensions.get(name) + if lib is None: + return + + lib_name = lib.__name__ + + # find all references to the module + + # remove the cogs registered from the module + for cogname, cog in self.cogs.copy().items(): + if _is_submodule(lib_name, cog.__module__): + self.remove_cog(cogname) + + # remove all the commands from the module + for cmd in self.all_commands.copy().values(): + if cmd.module is not None and _is_submodule(lib_name, cmd.module): + if isinstance(cmd, GroupMixin): + cmd.recursively_remove_all_commands() + self.remove_command(cmd.name) + + # remove all the listeners from the module + for event_list in self.extra_events.copy().values(): + remove = [] + for index, event in enumerate(event_list): + if event.__module__ is not None and _is_submodule(lib_name, event.__module__): + remove.append(index) + + for index in reversed(remove): + del event_list[index] + + try: + func = getattr(lib, "teardown") + except AttributeError: + pass + else: + try: + func(self) + except Exception: + pass + finally: + # finally remove the import.. + del lib + del self.extensions[name] + del sys.modules[name] + for module in list(sys.modules.keys()): + if _is_submodule(lib_name, module): + del sys.modules[module] + + # command processing + + async def get_prefix(self, message): + """|coro| + + Retrieves the prefix the bot is listening to + with the message as a context. + + Parameters + ----------- + message: :class:`discord.Message` + The message context to get the prefix of. + + Returns + -------- + Union[List[str], str] + A list of prefixes or a single prefix that the bot is + listening for. + """ + prefix = ret = self.command_prefix + if callable(prefix): + ret = await discord.utils.maybe_coroutine(prefix, self, message) + + if not isinstance(ret, str): + try: + ret = list(ret) + except TypeError: + # It's possible that a generator raised this exception. Don't + # replace it with our own error if that's the case. + if isinstance(ret, collections.Iterable): + raise + + raise TypeError( + "command_prefix must be plain string, iterable of strings, or callable " + "returning either of these, not {}".format(ret.__class__.__name__) + ) + + if not ret: + raise ValueError("Iterable command_prefix must contain at least one prefix") + + return ret + + async def get_context(self, message, *, cls=Context): + r"""|coro| + + Returns the invocation context from the message. + + This is a more low-level counter-part for :meth:`.process_commands` + to allow users more fine grained control over the processing. + + The returned context is not guaranteed to be a valid invocation + context, :attr:`.Context.valid` must be checked to make sure it is. + If the context is not valid then it is not a valid candidate to be + invoked under :meth:`~.Bot.invoke`. + + Parameters + ----------- + message: :class:`discord.Message` + The message to get the invocation context from. + cls + The factory class that will be used to create the context. + By default, this is :class:`.Context`. Should a custom + class be provided, it must be similar enough to :class:`.Context`\'s + interface. + + Returns + -------- + :class:`.Context` + The invocation context. The type of this can change via the + ``cls`` parameter. + """ + + view = StringView(message.content) + ctx = cls(prefix=None, view=view, bot=self, message=message) + + if self._skip_check(message.author.id, self.user.id): + return ctx + + prefix = await self.get_prefix(message) + invoked_prefix = prefix + + if isinstance(prefix, str): + if not view.skip_string(prefix): + return ctx + else: + try: + # if the context class' __init__ consumes something from the view this + # will be wrong. That seems unreasonable though. + if message.content.startswith(tuple(prefix)): + invoked_prefix = discord.utils.find(view.skip_string, prefix) + else: + return ctx + + except TypeError: + if not isinstance(prefix, list): + raise TypeError( + "get_prefix must return either a string or a list of string, " + "not {}".format(prefix.__class__.__name__) + ) + + # It's possible a bad command_prefix got us here. + for value in prefix: + if not isinstance(value, str): + raise TypeError( + "Iterable command_prefix or list returned from get_prefix must " + "contain only strings, not {}".format(value.__class__.__name__) + ) + + # Getting here shouldn't happen + raise + + invoker = view.get_word() + ctx.invoked_with = invoker + ctx.prefix = invoked_prefix + ctx.command = self.all_commands.get(invoker) + return ctx + + async def invoke(self, ctx): + """|coro| + + Invokes the command given under the invocation context and + handles all the internal event dispatch mechanisms. + + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to invoke. + """ + if ctx.command is not None: + self.dispatch("command", ctx) + try: + if await self.can_run(ctx, call_once=True): + await ctx.command.invoke(ctx) + except CommandError as exc: + await ctx.command.dispatch_error(ctx, exc) + else: + self.dispatch("command_completion", ctx) + elif ctx.invoked_with: + exc = CommandNotFound('Command "{}" is not found'.format(ctx.invoked_with)) + self.dispatch("command_error", ctx, exc) + + async def process_commands(self, message): + """|coro| + + This function processes the commands that have been registered + to the bot and other groups. Without this coroutine, none of the + commands will be triggered. + + By default, this coroutine is called inside the :func:`.on_message` + event. If you choose to override the :func:`.on_message` event, then + you should invoke this coroutine as well. + + This is built using other low level tools, and is equivalent to a + call to :meth:`~.Bot.get_context` followed by a call to :meth:`~.Bot.invoke`. + + This also checks if the message's author is a bot and doesn't + call :meth:`~.Bot.get_context` or :meth:`~.Bot.invoke` if so. + + Parameters + ----------- + message: :class:`discord.Message` + The message to process commands for. + """ + if message.author.bot: + return + + ctx = await self.get_context(message) + await self.invoke(ctx) + + async def on_message(self, message): + await self.process_commands(message) + + +class Bot(BotBase, discord.Client): + """Represents a discord bot. + + This class is a subclass of :class:`discord.Client` and as a result + anything that you can do with a :class:`discord.Client` you can do with + this bot. + + .. _deque: https://docs.python.org/3.4/library/collections.html#collections.deque + .. _event loop: https://docs.python.org/3/library/asyncio-eventloops.html + + This class also subclasses :class:`.GroupMixin` to provide the functionality + to manage commands. + + Attributes + ----------- + command_prefix + The command prefix is what the message content must contain initially + to have a command invoked. This prefix could either be a string to + indicate what the prefix should be, or a callable that takes in the bot + as its first parameter and :class:`discord.Message` as its second + parameter and returns the prefix. This is to facilitate "dynamic" + command prefixes. This callable can be either a regular function or + a coroutine. + + An empty string as the prefix always matches, enabling prefix-less + command invocation. While this may be useful in DMs it should be avoided + in servers, as it's likely to cause performance issues and unintended + command invocations. + + The command prefix could also be an iterable of strings indicating that + multiple checks for the prefix should be used and the first one to + match will be the invocation prefix. You can get this prefix via + :attr:`.Context.prefix`. To avoid confusion empty iterables are not + allowed. + + .. note:: + + When passing multiple prefixes be careful to not pass a prefix + that matches a longer prefix occuring later in the sequence. For + example, if the command prefix is ``('!', '!?')`` the ``'!?'`` + prefix will never be matched to any message as the previous one + matches messages starting with ``!?``. This is especially important + when passing an empty string, it should always be last as no prefix + after it will be matched. + case_insensitive: :class:`bool` + Whether the commands should be case insensitive. Defaults to ``False``. This + attribute does not carry over to groups. You must set it to every group if + you require group commands to be case insensitive as well. + description : :class:`str` + The content prefixed into the default help message. + self_bot : :class:`bool` + If ``True``, the bot will only listen to commands invoked by itself rather + than ignoring itself. If ``False`` (the default) then the bot will ignore + itself. This cannot be changed once initialised. + formatter : :class:`.HelpFormatter` + The formatter used to format the help message. By default, it uses + the :class:`.HelpFormatter`. Check it for more info on how to override it. + If you want to change the help command completely (add aliases, etc) then + a call to :meth:`~.Bot.remove_command` with 'help' as the argument would do the + trick. + pm_help : Optional[:class:`bool`] + A tribool that indicates if the help command should PM the user instead of + sending it to the channel it received it from. If the boolean is set to + ``True``, then all help output is PM'd. If ``False``, none of the help + output is PM'd. If ``None``, then the bot will only PM when the help + message becomes too long (dictated by more than 1000 characters). + Defaults to ``False``. + help_attrs : :class:`dict` + A dictionary of options to pass in for the construction of the help command. + This allows you to change the command behaviour without actually changing + the implementation of the command. The attributes will be the same as the + ones passed in the :class:`.Command` constructor. Note that ``pass_context`` + will always be set to ``True`` regardless of what you pass in. + command_not_found : :class:`str` + The format string used when the help command is invoked with a command that + is not found. Useful for i18n. Defaults to ``"No command called {} found."``. + The only format argument is the name of the command passed. + command_has_no_subcommands : :class:`str` + The format string used when the help command is invoked with requests for a + subcommand but the command does not have any subcommands. Defaults to + ``"Command {0.name} has no subcommands."``. The first format argument is the + :class:`.Command` attempted to get a subcommand and the second is the name. + owner_id: Optional[:class:`int`] + The ID that owns the bot. If this is not set and is then queried via + :meth:`.is_owner` then it is fetched automatically using + :meth:`~.Bot.application_info`. + """ + + pass + + +class AutoShardedBot(BotBase, discord.AutoShardedClient): + """This is similar to :class:`.Bot` except that it is derived from + :class:`discord.AutoShardedClient` instead. + """ + + pass diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py new file mode 100644 index 000000000..b84a0d655 --- /dev/null +++ b/discord/ext/commands/context.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import discord.abc +import discord.utils + + +class Context(discord.abc.Messageable): + r"""Represents the context in which a command is being invoked under. + + This class contains a lot of meta data to help you understand more about + the invocation context. This class is not created manually and is instead + passed around to commands as the first parameter. + + This class implements the :class:`abc.Messageable` ABC. + + Attributes + ----------- + message: :class:`discord.Message` + The message that triggered the command being executed. + bot: :class:`.Bot` + The bot that contains the command being executed. + args: :class:`list` + The list of transformed arguments that were passed into the command. + If this is accessed during the :func:`on_command_error` event + then this list could be incomplete. + kwargs: :class:`dict` + A dictionary of transformed arguments that were passed into the command. + Similar to :attr:`args`\, if this is accessed in the + :func:`on_command_error` event then this dict could be incomplete. + prefix: :class:`str` + The prefix that was used to invoke the command. + command + The command (i.e. :class:`.Command` or its superclasses) that is being + invoked currently. + invoked_with: :class:`str` + The command name that triggered this invocation. Useful for finding out + which alias called the command. + invoked_subcommand + The subcommand (i.e. :class:`.Command` or its superclasses) that was + invoked. If no valid subcommand was invoked then this is equal to + `None`. + subcommand_passed: Optional[:class:`str`] + The string that was attempted to call a subcommand. This does not have + to point to a valid registered subcommand and could just point to a + nonsense string. If nothing was passed to attempt a call to a + subcommand then this is set to `None`. + command_failed: :class:`bool` + A boolean that indicates if the command failed to be parsed, checked, + or invoked. + """ + + def __init__(self, **attrs): + self.message = attrs.pop("message", None) + self.bot = attrs.pop("bot", None) + self.args = attrs.pop("args", []) + self.kwargs = attrs.pop("kwargs", {}) + self.prefix = attrs.pop("prefix") + self.command = attrs.pop("command", None) + self.view = attrs.pop("view", None) + self.invoked_with = attrs.pop("invoked_with", None) + self.invoked_subcommand = attrs.pop("invoked_subcommand", None) + self.subcommand_passed = attrs.pop("subcommand_passed", None) + self.command_failed = attrs.pop("command_failed", False) + self._state = self.message._state + + async def invoke(self, *args, **kwargs): + r"""|coro| + + Calls a command with the arguments given. + + This is useful if you want to just call the callback that a + :class:`.Command` holds internally. + + Note + ------ + You do not pass in the context as it is done for you. + + Warning + --------- + The first parameter passed **must** be the command being invoked. + + Parameters + ----------- + command: :class:`.Command` + A command or superclass of a command that is going to be called. + \*args + The arguments to to use. + \*\*kwargs + The keyword arguments to use. + """ + + try: + command = args[0] + except IndexError: + raise TypeError("Missing command to invoke.") from None + + arguments = [] + if command.instance is not None: + arguments.append(command.instance) + + arguments.append(self) + arguments.extend(args[1:]) + + ret = await command.callback(*arguments, **kwargs) + return ret + + async def reinvoke(self, *, call_hooks=False, restart=True): + """|coro| + + Calls the command again. + + This is similar to :meth:`~.Context.invoke` except that it bypasses + checks, cooldowns, and error handlers. + + .. note:: + + If you want to bypass :exc:`.UserInputError` derived exceptions, + it is recommended to use the regular :meth:`~.Context.invoke` + as it will work more naturally. After all, this will end up + using the old arguments the user has used and will thus just + fail again. + + Parameters + ------------ + call_hooks: bool + Whether to call the before and after invoke hooks. + restart: bool + Whether to start the call chain from the very beginning + or where we left off (i.e. the command that caused the error). + The default is to start where we left off. + """ + cmd = self.command + view = self.view + if cmd is None: + raise ValueError("This context is not valid.") + + # some state to revert to when we're done + index, previous = view.index, view.previous + invoked_with = self.invoked_with + invoked_subcommand = self.invoked_subcommand + subcommand_passed = self.subcommand_passed + + if restart: + to_call = cmd.root_parent or cmd + view.index = len(self.prefix) + view.previous = 0 + view.get_word() # advance to get the root command + else: + to_call = cmd + + try: + await to_call.reinvoke(self, call_hooks=call_hooks) + finally: + self.command = cmd + view.index = index + view.previous = previous + self.invoked_with = invoked_with + self.invoked_subcommand = invoked_subcommand + self.subcommand_passed = subcommand_passed + + @property + def valid(self): + """Checks if the invocation context is valid to be invoked with.""" + return self.prefix is not None and self.command is not None + + async def _get_channel(self): + return self.channel + + @property + def cog(self): + """Returns the cog associated with this context's command. None if it does not exist.""" + + if self.command is None: + return None + return self.command.instance + + @discord.utils.cached_property + def guild(self): + """Returns the guild associated with this context's command. None if not available.""" + return self.message.guild + + @discord.utils.cached_property + def channel(self): + """Returns the channel associated with this context's command. Shorthand for :attr:`Message.channel`.""" + return self.message.channel + + @discord.utils.cached_property + def author(self): + """Returns the author associated with this context's command. Shorthand for :attr:`Message.author`""" + return self.message.author + + @discord.utils.cached_property + def me(self): + """Similar to :attr:`Guild.me` except it may return the :class:`ClientUser` in private message contexts.""" + return self.guild.me if self.guild is not None else self.bot.user + + @property + def voice_client(self): + r"""Optional[:class:`VoiceClient`]: A shortcut to :attr:`Guild.voice_client`\, if applicable.""" + g = self.guild + return g.voice_client if g else None diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py new file mode 100644 index 000000000..a93123c1f --- /dev/null +++ b/discord/ext/commands/converter.py @@ -0,0 +1,560 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import re +import inspect + +import discord + +from .errors import BadArgument, NoPrivateMessage + +__all__ = [ + "Converter", + "MemberConverter", + "UserConverter", + "TextChannelConverter", + "InviteConverter", + "RoleConverter", + "GameConverter", + "ColourConverter", + "VoiceChannelConverter", + "EmojiConverter", + "PartialEmojiConverter", + "CategoryChannelConverter", + "IDConverter", + "clean_content", + "Greedy", +] + + +def _get_from_guilds(bot, getter, argument): + result = None + for guild in bot.guilds: + result = getattr(guild, getter)(argument) + if result: + return result + return result + + +class Converter: + """The base class of custom converters that require the :class:`.Context` + to be passed to be useful. + + This allows you to implement converters that function similar to the + special cased ``discord`` classes. + + Classes that derive from this should override the :meth:`~.Converter.convert` + method to do its conversion logic. This method must be a coroutine. + """ + + async def convert(self, ctx, argument): + """|coro| + + The method to override to do conversion logic. + + If an error is found while converting, it is recommended to + raise a :exc:`.CommandError` derived exception as it will + properly propagate to the error handlers. + + Parameters + ----------- + ctx: :class:`.Context` + The invocation context that the argument is being used in. + argument: str + The argument that is being converted. + """ + raise NotImplementedError("Derived classes need to implement this.") + + +class IDConverter(Converter): + def __init__(self): + self._id_regex = re.compile(r"([0-9]{15,21})$") + super().__init__() + + def _get_id_match(self, argument): + return self._id_regex.match(argument) + + +class MemberConverter(IDConverter): + """Converts to a :class:`Member`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name#discrim + 4. Lookup by name + 5. Lookup by nickname + """ + + async def convert(self, ctx, argument): + bot = ctx.bot + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]+)>$", argument) + guild = ctx.guild + result = None + if match is None: + # not a mention... + if guild: + result = guild.get_member_named(argument) + else: + result = _get_from_guilds(bot, "get_member_named", argument) + else: + user_id = int(match.group(1)) + if guild: + result = guild.get_member(user_id) + else: + result = _get_from_guilds(bot, "get_member", user_id) + + if result is None: + raise BadArgument('Member "{}" not found'.format(argument)) + + return result + + +class UserConverter(IDConverter): + """Converts to a :class:`User`. + + All lookups are via the global user cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name#discrim + 4. Lookup by name + """ + + async def convert(self, ctx, argument): + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]+)>$", argument) + result = None + state = ctx._state + + if match is not None: + user_id = int(match.group(1)) + result = ctx.bot.get_user(user_id) + else: + arg = argument + # check for discriminator if it exists + if len(arg) > 5 and arg[-5] == "#": + discrim = arg[-4:] + name = arg[:-5] + predicate = lambda u: u.name == name and u.discriminator == discrim + result = discord.utils.find(predicate, state._users.values()) + if result is not None: + return result + + predicate = lambda u: u.name == arg + result = discord.utils.find(predicate, state._users.values()) + + if result is None: + raise BadArgument('User "{}" not found'.format(argument)) + + return result + + +class TextChannelConverter(IDConverter): + """Converts to a :class:`TextChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + """ + + async def convert(self, ctx, argument): + bot = ctx.bot + + match = self._get_id_match(argument) or re.match(r"<#([0-9]+)>$", argument) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + result = discord.utils.get(guild.text_channels, name=argument) + else: + + def check(c): + return isinstance(c, discord.TextChannel) and c.name == argument + + result = discord.utils.find(check, bot.get_all_channels()) + else: + channel_id = int(match.group(1)) + if guild: + result = guild.get_channel(channel_id) + else: + result = _get_from_guilds(bot, "get_channel", channel_id) + + if not isinstance(result, discord.TextChannel): + raise BadArgument('Channel "{}" not found.'.format(argument)) + + return result + + +class VoiceChannelConverter(IDConverter): + """Converts to a :class:`VoiceChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + """ + + async def convert(self, ctx, argument): + bot = ctx.bot + match = self._get_id_match(argument) or re.match(r"<#([0-9]+)>$", argument) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + result = discord.utils.get(guild.voice_channels, name=argument) + else: + + def check(c): + return isinstance(c, discord.VoiceChannel) and c.name == argument + + result = discord.utils.find(check, bot.get_all_channels()) + else: + channel_id = int(match.group(1)) + if guild: + result = guild.get_channel(channel_id) + else: + result = _get_from_guilds(bot, "get_channel", channel_id) + + if not isinstance(result, discord.VoiceChannel): + raise BadArgument('Channel "{}" not found.'.format(argument)) + + return result + + +class CategoryChannelConverter(IDConverter): + """Converts to a :class:`CategoryChannel`. + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + """ + + async def convert(self, ctx, argument): + bot = ctx.bot + + match = self._get_id_match(argument) or re.match(r"<#([0-9]+)>$", argument) + result = None + guild = ctx.guild + + if match is None: + # not a mention + if guild: + result = discord.utils.get(guild.categories, name=argument) + else: + + def check(c): + return isinstance(c, discord.CategoryChannel) and c.name == argument + + result = discord.utils.find(check, bot.get_all_channels()) + else: + channel_id = int(match.group(1)) + if guild: + result = guild.get_channel(channel_id) + else: + result = _get_from_guilds(bot, "get_channel", channel_id) + + if not isinstance(result, discord.CategoryChannel): + raise BadArgument('Channel "{}" not found.'.format(argument)) + + return result + + +class ColourConverter(Converter): + """Converts to a :class:`Colour`. + + The following formats are accepted: + + - ``0x`` + - ``#`` + - ``0x#`` + - Any of the ``classmethod`` in :class:`Colour` + + - The ``_`` in the name can be optionally replaced with spaces. + """ + + async def convert(self, ctx, argument): + arg = argument.replace("0x", "").lower() + + if arg[0] == "#": + arg = arg[1:] + try: + value = int(arg, base=16) + return discord.Colour(value=value) + except ValueError: + method = getattr(discord.Colour, arg.replace(" ", "_"), None) + if method is None or not inspect.ismethod(method): + raise BadArgument('Colour "{}" is invalid.'.format(arg)) + return method() + + +class RoleConverter(IDConverter): + """Converts to a :class:`Role`. + + + All lookups are via the local guild. If in a DM context, then the lookup + is done by the global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by mention. + 3. Lookup by name + """ + + async def convert(self, ctx, argument): + guild = ctx.guild + if not guild: + raise NoPrivateMessage() + + match = self._get_id_match(argument) or re.match(r"<@&([0-9]+)>$", argument) + if match: + result = guild.get_role(int(match.group(1))) + else: + result = discord.utils.get(guild._roles.values(), name=argument) + + if result is None: + raise BadArgument('Role "{}" not found.'.format(argument)) + return result + + +class GameConverter(Converter): + """Converts to :class:`Game`.""" + + async def convert(self, ctx, argument): + return discord.Game(name=argument) + + +class InviteConverter(Converter): + """Converts to a :class:`Invite`. + + This is done via an HTTP request using :meth:`.Bot.get_invite`. + """ + + async def convert(self, ctx, argument): + try: + invite = await ctx.bot.get_invite(argument) + return invite + except Exception as exc: + raise BadArgument("Invite is invalid or expired") from exc + + +class EmojiConverter(IDConverter): + """Converts to a :class:`Emoji`. + + + All lookups are done for the local guild first, if available. If that lookup + fails, then it checks the client's global cache. + + The lookup strategy is as follows (in order): + + 1. Lookup by ID. + 2. Lookup by extracting ID from the emoji. + 3. Lookup by name + """ + + async def convert(self, ctx, argument): + match = self._get_id_match(argument) or re.match( + r"$", argument + ) + result = None + bot = ctx.bot + guild = ctx.guild + + if match is None: + # Try to get the emoji by name. Try local guild first. + if guild: + result = discord.utils.get(guild.emojis, name=argument) + + if result is None: + result = discord.utils.get(bot.emojis, name=argument) + else: + emoji_id = int(match.group(1)) + + # Try to look up emoji by id. + if guild: + result = discord.utils.get(guild.emojis, id=emoji_id) + + if result is None: + result = discord.utils.get(bot.emojis, id=emoji_id) + + if result is None: + raise BadArgument('Emoji "{}" not found.'.format(argument)) + + return result + + +class PartialEmojiConverter(Converter): + """Converts to a :class:`PartialEmoji`. + + + This is done by extracting the animated flag, name and ID from the emoji. + """ + + async def convert(self, ctx, argument): + match = re.match(r"<(a?):([a-zA-Z0-9\_]+):([0-9]+)>$", argument) + + if match: + emoji_animated = bool(match.group(1)) + emoji_name = match.group(2) + emoji_id = int(match.group(3)) + + return discord.PartialEmoji(animated=emoji_animated, name=emoji_name, id=emoji_id) + + raise BadArgument('Couldn\'t convert "{}" to PartialEmoji.'.format(argument)) + + +class clean_content(Converter): + """Converts the argument to mention scrubbed version of + said content. + + This behaves similarly to :attr:`.Message.clean_content`. + + Attributes + ------------ + fix_channel_mentions: :obj:`bool` + Whether to clean channel mentions. + use_nicknames: :obj:`bool` + Whether to use nicknames when transforming mentions. + escape_markdown: :obj:`bool` + Whether to also escape special markdown characters. + """ + + def __init__(self, *, fix_channel_mentions=False, use_nicknames=True, escape_markdown=False): + self.fix_channel_mentions = fix_channel_mentions + self.use_nicknames = use_nicknames + self.escape_markdown = escape_markdown + + async def convert(self, ctx, argument): + message = ctx.message + transformations = {} + + if self.fix_channel_mentions and ctx.guild: + + def resolve_channel(id, *, _get=ctx.guild.get_channel): + ch = _get(id) + return ("<#%s>" % id), ("#" + ch.name if ch else "#deleted-channel") + + transformations.update( + resolve_channel(channel) for channel in message.raw_channel_mentions + ) + + if self.use_nicknames and ctx.guild: + + def resolve_member(id, *, _get=ctx.guild.get_member): + m = _get(id) + return "@" + m.display_name if m else "@deleted-user" + + else: + + def resolve_member(id, *, _get=ctx.bot.get_user): + m = _get(id) + return "@" + m.name if m else "@deleted-user" + + transformations.update( + ("<@%s>" % member_id, resolve_member(member_id)) for member_id in message.raw_mentions + ) + + transformations.update( + ("<@!%s>" % member_id, resolve_member(member_id)) for member_id in message.raw_mentions + ) + + if ctx.guild: + + def resolve_role(_id, *, _find=ctx.guild.get_role): + r = _find(_id) + return "@" + r.name if r else "@deleted-role" + + transformations.update( + ("<@&%s>" % role_id, resolve_role(role_id)) + for role_id in message.raw_role_mentions + ) + + def repl(obj): + return transformations.get(obj.group(0), "") + + pattern = re.compile("|".join(transformations.keys())) + result = pattern.sub(repl, argument) + + if self.escape_markdown: + transformations = {re.escape(c): "\\" + c for c in ("*", "`", "_", "~", "\\")} + + def replace(obj): + return transformations.get(re.escape(obj.group(0)), "") + + pattern = re.compile("|".join(transformations.keys())) + result = pattern.sub(replace, result) + + # Completely ensure no mentions escape: + return re.sub(r"@(everyone|here|[!&]?[0-9]{17,21})", "@\u200b\\1", result) + + +class _Greedy: + __slots__ = ("converter",) + + def __init__(self, *, converter=None): + self.converter = converter + + def __getitem__(self, params): + if not isinstance(params, tuple): + params = (params,) + if len(params) != 1: + raise TypeError("Greedy[...] only takes a single argument") + converter = params[0] + + if not inspect.isclass(converter): + raise TypeError("Greedy[...] expects a type.") + + if converter is str or converter is type(None) or converter is _Greedy: + raise TypeError("Greedy[%s] is invalid." % converter.__name__) + + return self.__class__(converter=converter) + + +Greedy = _Greedy() diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py new file mode 100644 index 000000000..d3079029c --- /dev/null +++ b/discord/ext/commands/cooldowns.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import enum +import time + +__all__ = ["BucketType", "Cooldown", "CooldownMapping"] + + +class BucketType(enum.Enum): + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 + category = 5 + + +class Cooldown: + __slots__ = ("rate", "per", "type", "_window", "_tokens", "_last") + + def __init__(self, rate, per, type): + self.rate = int(rate) + self.per = float(per) + self.type = type + self._window = 0.0 + self._tokens = self.rate + self._last = 0.0 + + if not isinstance(self.type, BucketType): + raise TypeError("Cooldown type must be a BucketType") + + def get_tokens(self, current=None): + if not current: + current = time.time() + + tokens = self._tokens + + if current > self._window + self.per: + tokens = self.rate + return tokens + + def update_rate_limit(self): + current = time.time() + self._last = current + + self._tokens = self.get_tokens(current) + + # first token used means that we start a new rate limit window + if self._tokens == self.rate: + self._window = current + + # check if we are rate limited + if self._tokens == 0: + return self.per - (current - self._window) + + # we're not so decrement our tokens + self._tokens -= 1 + + # see if we got rate limited due to this token change, and if + # so update the window to point to our current time frame + if self._tokens == 0: + self._window = current + + def reset(self): + self._tokens = self.rate + self._last = 0.0 + + def copy(self): + return Cooldown(self.rate, self.per, self.type) + + def __repr__(self): + return "".format( + self + ) + + +class CooldownMapping: + def __init__(self, original): + self._cache = {} + self._cooldown = original + + @property + def valid(self): + return self._cooldown is not None + + @classmethod + def from_cooldown(cls, rate, per, type): + return cls(Cooldown(rate, per, type)) + + def _bucket_key(self, msg): + bucket_type = self._cooldown.type + if bucket_type is BucketType.user: + return msg.author.id + elif bucket_type is BucketType.guild: + return (msg.guild or msg.author).id + elif bucket_type is BucketType.channel: + return msg.channel.id + elif bucket_type is BucketType.member: + return ((msg.guild and msg.guild.id), msg.author.id) + elif bucket_type is BucketType.category: + return (msg.channel.category or msg.channel).id + + def _verify_cache_integrity(self): + # we want to delete all cache objects that haven't been used + # in a cooldown window. e.g. if we have a command that has a + # cooldown of 60s and it has not been used in 60s then that key should be deleted + current = time.time() + dead_keys = [k for k, v in self._cache.items() if current > v._last + v.per] + for k in dead_keys: + del self._cache[k] + + def get_bucket(self, message): + if self._cooldown.type is BucketType.default: + return self._cooldown + + self._verify_cache_integrity() + key = self._bucket_key(message) + if key not in self._cache: + bucket = self._cooldown.copy() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py new file mode 100644 index 000000000..d499923bc --- /dev/null +++ b/discord/ext/commands/core.py @@ -0,0 +1,1517 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import functools +import inspect +import typing + +import discord + +from .errors import * +from .cooldowns import Cooldown, BucketType, CooldownMapping +from .view import quoted_word +from . import converter as converters + +__all__ = [ + "Command", + "Group", + "GroupMixin", + "command", + "group", + "has_role", + "has_permissions", + "has_any_role", + "check", + "bot_has_role", + "bot_has_permissions", + "bot_has_any_role", + "cooldown", + "guild_only", + "is_owner", + "is_nsfw", +] + + +def wrap_callback(coro): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except CommandError: + raise + except asyncio.CancelledError: + return + except Exception as exc: + raise CommandInvokeError(exc) from exc + return ret + + return wrapped + + +def hooked_wrapped_callback(command, ctx, coro): + @functools.wraps(coro) + async def wrapped(*args, **kwargs): + try: + ret = await coro(*args, **kwargs) + except CommandError: + ctx.command_failed = True + raise + except asyncio.CancelledError: + ctx.command_failed = True + return + except Exception as exc: + ctx.command_failed = True + raise CommandInvokeError(exc) from exc + finally: + await command.call_after_hooks(ctx) + return ret + + return wrapped + + +def _convert_to_bool(argument): + lowered = argument.lower() + if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): + return True + elif lowered in ("no", "n", "false", "f", "0", "disable", "off"): + return False + else: + raise BadArgument(lowered + " is not a recognised boolean option") + + +class _CaseInsensitiveDict(dict): + def __contains__(self, k): + return super().__contains__(k.lower()) + + def __delitem__(self, k): + return super().__delitem__(k.lower()) + + def __getitem__(self, k): + return super().__getitem__(k.lower()) + + def get(self, k, default=None): + return super().get(k.lower(), default) + + def pop(self, k, default=None): + return super().pop(k.lower(), default) + + def __setitem__(self, k, v): + super().__setitem__(k.lower(), v) + + +class Command: + r"""A class that implements the protocol for a bot text command. + + These are not created manually, instead they are created via the + decorator or functional interface. + + Attributes + ----------- + name: :class:`str` + The name of the command. + callback: :ref:`coroutine ` + The coroutine that is executed when the command is called. + help: :class:`str` + The long help text for the command. + brief: :class:`str` + The short help text for the command. If this is not specified + then the first line of the long help text is used instead. + usage: :class:`str` + A replacement for arguments in the default help text. + aliases: :class:`list` + The list of aliases the command can be invoked under. + enabled: :class:`bool` + A boolean that indicates if the command is currently enabled. + If the command is invoked while it is disabled, then + :exc:`.DisabledCommand` is raised to the :func:`.on_command_error` + event. Defaults to ``True``. + parent: Optional[command] + The parent command that this command belongs to. ``None`` if there + isn't one. + checks + A list of predicates that verifies if the command could be executed + with the given :class:`.Context` as the sole parameter. If an exception + is necessary to be thrown to signal failure, then one derived from + :exc:`.CommandError` should be used. Note that if the checks fail then + :exc:`.CheckFailure` exception is raised to the :func:`.on_command_error` + event. + description: :class:`str` + The message prefixed into the default help command. + hidden: :class:`bool` + If ``True``\, the default help command does not show this in the + help output. + rest_is_raw: :class:`bool` + If ``False`` and a keyword-only argument is provided then the keyword + only argument is stripped and handled as if it was a regular argument + that handles :exc:`.MissingRequiredArgument` and default values in a + regular matter rather than passing the rest completely raw. If ``True`` + then the keyword-only argument will pass in the rest of the arguments + in a completely raw matter. Defaults to ``False``. + ignore_extra: :class:`bool` + If ``True``\, ignores extraneous strings passed to a command if all its + requirements are met (e.g. ``?foo a b c`` when only expecting ``a`` + and ``b``). Otherwise :func:`.on_command_error` and local error handlers + are called with :exc:`.TooManyArguments`. Defaults to ``True``. + """ + + def __init__(self, name, callback, **kwargs): + self.name = name + if not isinstance(name, str): + raise TypeError("Name of a command must be a string.") + + self.callback = callback + self.enabled = kwargs.get("enabled", True) + self.help = kwargs.get("help") + self.brief = kwargs.get("brief") + self.usage = kwargs.get("usage") + self.rest_is_raw = kwargs.get("rest_is_raw", False) + self.aliases = kwargs.get("aliases", []) + + if not isinstance(self.aliases, (list, tuple)): + raise TypeError("Aliases of a command must be a list of strings.") + + self.description = inspect.cleandoc(kwargs.get("description", "")) + self.hidden = kwargs.get("hidden", False) + + self.checks = kwargs.get("checks", []) + self.ignore_extra = kwargs.get("ignore_extra", True) + self.instance = None + self.parent = None + self._buckets = CooldownMapping(kwargs.get("cooldown")) + self._before_invoke = None + self._after_invoke = None + + @property + def callback(self): + return self._callback + + @callback.setter + def callback(self, function): + self._callback = function + self.module = function.__module__ + + signature = inspect.signature(function) + self.params = signature.parameters.copy() + + # PEP-563 allows postponing evaluation of annotations with a __future__ + # import. When postponed, Parameter.annotation will be a string and must + # be replaced with the real value for the converters to work later on + for key, value in self.params.items(): + if isinstance(value.annotation, str): + self.params[key] = value = value.replace( + annotation=eval(value.annotation, function.__globals__) + ) + + # fail early for when someone passes an unparameterized Greedy type + if value.annotation is converters.Greedy: + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") + + async def dispatch_error(self, ctx, error): + ctx.command_failed = True + cog = self.instance + try: + coro = self.on_error + except AttributeError: + pass + else: + injected = wrap_callback(coro) + if cog is not None: + await injected(cog, ctx, error) + else: + await injected(ctx, error) + + try: + local = getattr(cog, "_{0.__class__.__name__}__error".format(cog)) + except AttributeError: + pass + else: + wrapped = wrap_callback(local) + await wrapped(ctx, error) + finally: + ctx.bot.dispatch("command_error", ctx, error) + + def __get__(self, instance, owner): + if instance is not None: + self.instance = instance + return self + + async def _actual_conversion(self, ctx, converter, argument, param): + if converter is bool: + return _convert_to_bool(argument) + + try: + module = converter.__module__ + except AttributeError: + pass + else: + if module.startswith("discord.") and not module.endswith("converter"): + converter = getattr(converters, converter.__name__ + "Converter") + + try: + if inspect.isclass(converter): + if issubclass(converter, converters.Converter): + instance = converter() + ret = await instance.convert(ctx, argument) + return ret + else: + method = getattr(converter, "convert", None) + if method is not None and inspect.ismethod(method): + ret = await method(ctx, argument) + return ret + elif isinstance(converter, converters.Converter): + ret = await converter.convert(ctx, argument) + return ret + except CommandError: + raise + except Exception as exc: + raise ConversionError(converter, exc) from exc + + try: + return converter(argument) + except CommandError: + raise + except Exception as exc: + try: + name = converter.__name__ + except AttributeError: + name = converter.__class__.__name__ + + raise BadArgument( + 'Converting to "{}" failed for parameter "{}".'.format(name, param.name) + ) from exc + + async def do_conversion(self, ctx, converter, argument, param): + try: + origin = converter.__origin__ + except AttributeError: + pass + else: + if origin is typing.Union: + errors = [] + _NoneType = type(None) + for conv in converter.__args__: + # if we got to this part in the code, then the previous conversions have failed + # so we should just undo the view, return the default, and allow parsing to continue + # with the other parameters + if conv is _NoneType and param.kind != param.VAR_POSITIONAL: + ctx.view.undo() + return None if param.default is param.empty else param.default + + try: + value = await self._actual_conversion(ctx, conv, argument, param) + except CommandError as exc: + errors.append(exc) + else: + return value + + # if we're here, then we failed all the converters + raise BadUnionArgument(param, converter.__args__, errors) + + return await self._actual_conversion(ctx, converter, argument, param) + + def _get_converter(self, param): + converter = param.annotation + if converter is param.empty: + if param.default is not param.empty: + converter = str if param.default is None else type(param.default) + else: + converter = str + return converter + + async def transform(self, ctx, param): + required = param.default is param.empty + converter = self._get_converter(param) + consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + view = ctx.view + view.skip_ws() + + # The greedy converter is simple -- it keeps going until it fails in which case, + # it undos the view ready for the next parameter to use instead + if type(converter) is converters._Greedy: + if param.kind == param.POSITIONAL_OR_KEYWORD: + return await self._transform_greedy_pos(ctx, param, required, converter.converter) + elif param.kind == param.VAR_POSITIONAL: + return await self._transform_greedy_var_pos(ctx, param, converter.converter) + else: + # if we're here, then it's a KEYWORD_ONLY param type + # since this is mostly useless, we'll helpfully transform Greedy[X] + # into just X and do the parsing that way. + converter = converter.converter + + if view.eof: + if param.kind == param.VAR_POSITIONAL: + raise RuntimeError() # break the loop + if required: + raise MissingRequiredArgument(param) + return param.default + + previous = view.index + if consume_rest_is_special: + argument = view.read_rest().strip() + else: + argument = quoted_word(view) + view.previous = previous + + return await self.do_conversion(ctx, converter, argument, param) + + async def _transform_greedy_pos(self, ctx, param, required, converter): + view = ctx.view + result = [] + while not view.eof: + # for use with a manual undo + previous = view.index + + # parsing errors get propagated + view.skip_ws() + argument = quoted_word(view) + try: + value = await self.do_conversion(ctx, converter, argument, param) + except CommandError: + if not result: + if required: + raise + else: + view.index = previous + return param.default + view.index = previous + break + else: + result.append(value) + return result + + async def _transform_greedy_var_pos(self, ctx, param, converter): + view = ctx.view + previous = view.index + argument = quoted_word(view) + try: + value = await self.do_conversion(ctx, converter, argument, param) + except CommandError: + view.index = previous + raise RuntimeError() from None # break loop + else: + return value + + @property + def clean_params(self): + """Retrieves the parameter OrderedDict without the context or self parameters. + + Useful for inspecting signature. + """ + result = self.params.copy() + if self.instance is not None: + # first parameter is self + result.popitem(last=False) + + try: + # first/second parameter is context + result.popitem(last=False) + except Exception: + raise ValueError("Missing context parameter") from None + + return result + + @property + def full_parent_name(self): + """Retrieves the fully qualified parent command name. + + This the base command name required to execute it. For example, + in ``?one two three`` the parent name would be ``one two``. + """ + entries = [] + command = self + while command.parent is not None: + command = command.parent + entries.append(command.name) + + return " ".join(reversed(entries)) + + @property + def root_parent(self): + """Retrieves the root parent of this command. + + If the command has no parents then it returns ``None``. + + For example in commands ``?a b c test``, the root parent is + ``a``. + """ + entries = [] + command = self + while command.parent is not None: + command = command.parent + entries.append(command) + + if len(entries) == 0: + return None + + return entries[-1] + + @property + def qualified_name(self): + """Retrieves the fully qualified command name. + + This is the full parent name with the command name as well. + For example, in ``?one two three`` the qualified name would be + ``one two three``. + """ + + parent = self.full_parent_name + if parent: + return parent + " " + self.name + else: + return self.name + + def __str__(self): + return self.qualified_name + + async def _parse_arguments(self, ctx): + ctx.args = [ctx] if self.instance is None else [self.instance, ctx] + ctx.kwargs = {} + args = ctx.args + kwargs = ctx.kwargs + + view = ctx.view + iterator = iter(self.params.items()) + + if self.instance is not None: + # we have 'self' as the first parameter so just advance + # the iterator and resume parsing + try: + next(iterator) + except StopIteration: + fmt = 'Callback for {0.name} command is missing "self" parameter.' + raise discord.ClientException(fmt.format(self)) + + # next we have the 'ctx' as the next parameter + try: + next(iterator) + except StopIteration: + fmt = 'Callback for {0.name} command is missing "ctx" parameter.' + raise discord.ClientException(fmt.format(self)) + + for name, param in iterator: + if param.kind == param.POSITIONAL_OR_KEYWORD: + transformed = await self.transform(ctx, param) + args.append(transformed) + elif param.kind == param.KEYWORD_ONLY: + # kwarg only param denotes "consume rest" semantics + if self.rest_is_raw: + converter = self._get_converter(param) + argument = view.read_rest() + kwargs[name] = await self.do_conversion(ctx, converter, argument, param) + else: + kwargs[name] = await self.transform(ctx, param) + break + elif param.kind == param.VAR_POSITIONAL: + while not view.eof: + try: + transformed = await self.transform(ctx, param) + args.append(transformed) + except RuntimeError: + break + + if not self.ignore_extra: + if not view.eof: + raise TooManyArguments("Too many arguments passed to " + self.qualified_name) + + async def _verify_checks(self, ctx): + if not self.enabled: + raise DisabledCommand("{0.name} command is disabled".format(self)) + + if not await self.can_run(ctx): + raise CheckFailure( + "The check functions for command {0.qualified_name} failed.".format(self) + ) + + async def call_before_hooks(self, ctx): + # now that we're done preparing we can call the pre-command hooks + # first, call the command local hook: + cog = self.instance + if self._before_invoke is not None: + if cog is None: + await self._before_invoke(ctx) + else: + await self._before_invoke(cog, ctx) + + # call the cog local hook if applicable: + try: + hook = getattr(cog, "_{0.__class__.__name__}__before_invoke".format(cog)) + except AttributeError: + pass + else: + await hook(ctx) + + # call the bot global hook if necessary + hook = ctx.bot._before_invoke + if hook is not None: + await hook(ctx) + + async def call_after_hooks(self, ctx): + cog = self.instance + if self._after_invoke is not None: + if cog is None: + await self._after_invoke(ctx) + else: + await self._after_invoke(cog, ctx) + + try: + hook = getattr(cog, "_{0.__class__.__name__}__after_invoke".format(cog)) + except AttributeError: + pass + else: + await hook(ctx) + + hook = ctx.bot._after_invoke + if hook is not None: + await hook(ctx) + + async def prepare(self, ctx): + ctx.command = self + await self._verify_checks(ctx) + + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + retry_after = bucket.update_rate_limit() + if retry_after: + raise CommandOnCooldown(bucket, retry_after) + + await self._parse_arguments(ctx) + await self.call_before_hooks(ctx) + + def is_on_cooldown(self, ctx): + """Checks whether the command is currently on cooldown. + + Parameters + ----------- + ctx: :class:`.Context.` + The invocation context to use when checking the commands cooldown status. + + Returns + -------- + bool + A boolean indicating if the command is on cooldown. + """ + if not self._buckets.valid: + return False + + bucket = self._buckets.get_bucket(ctx.message) + return bucket.get_tokens() == 0 + + def reset_cooldown(self, ctx): + """Resets the cooldown on this command. + + Parameters + ----------- + ctx: :class:`.Context` + The invocation context to reset the cooldown under. + """ + if self._buckets.valid: + bucket = self._buckets.get_bucket(ctx.message) + bucket.reset() + + async def invoke(self, ctx): + await self.prepare(ctx) + + # terminate the invoked_subcommand chain. + # since we're in a regular command (and not a group) then + # the invoked subcommand is None. + ctx.invoked_subcommand = None + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) + + async def reinvoke(self, ctx, *, call_hooks=False): + ctx.command = self + await self._parse_arguments(ctx) + + if call_hooks: + await self.call_before_hooks(ctx) + + ctx.invoked_subcommand = None + try: + await self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + await self.call_after_hooks(ctx) + + def error(self, coro): + """A decorator that registers a coroutine as a local error handler. + + A local error handler is an :func:`.on_command_error` event limited to + a single command. However, the :func:`.on_command_error` is still + invoked afterwards as the catch-all. + + Parameters + ----------- + coro : :ref:`coroutine ` + The coroutine to register as the local error handler. + + Raises + ------- + discord.ClientException + The coroutine is not actually a coroutine. + """ + + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException("The error handler must be a coroutine.") + + self.on_error = coro + return coro + + def before_invoke(self, coro): + """A decorator that registers a coroutine as a pre-invoke hook. + + A pre-invoke hook is called directly before the command is + called. This makes it a useful function to set up database + connections or any type of set up required. + + This pre-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.before_invoke` for more info. + + Parameters + ----------- + coro + The coroutine to register as the pre-invoke hook. + + Raises + ------- + :exc:`.ClientException` + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException("The pre-invoke hook must be a coroutine.") + + self._before_invoke = coro + return coro + + def after_invoke(self, coro): + """A decorator that registers a coroutine as a post-invoke hook. + + A post-invoke hook is called directly after the command is + called. This makes it a useful function to clean-up database + connections or any type of clean up required. + + This post-invoke hook takes a sole parameter, a :class:`.Context`. + + See :meth:`.Bot.after_invoke` for more info. + + Parameters + ----------- + coro + The coroutine to register as the post-invoke hook. + + Raises + ------- + :exc:`.ClientException` + The coroutine is not actually a coroutine. + """ + if not asyncio.iscoroutinefunction(coro): + raise discord.ClientException("The post-invoke hook must be a coroutine.") + + self._after_invoke = coro + return coro + + @property + def cog_name(self): + """The name of the cog this command belongs to. None otherwise.""" + return type(self.instance).__name__ if self.instance is not None else None + + @property + def short_doc(self): + """Gets the "short" documentation of a command. + + By default, this is the :attr:`brief` attribute. + If that lookup leads to an empty string then the first line of the + :attr:`help` attribute is used instead. + """ + if self.brief: + return self.brief + if self.help: + return self.help.split("\n", 1)[0] + return "" + + @property + def signature(self): + """Returns a POSIX-like signature useful for help command output.""" + result = [] + parent = self.full_parent_name + if len(self.aliases) > 0: + aliases = "|".join(self.aliases) + fmt = "[%s|%s]" % (self.name, aliases) + if parent: + fmt = parent + " " + fmt + result.append(fmt) + else: + name = self.name if not parent else parent + " " + self.name + result.append(name) + + if self.usage: + result.append(self.usage) + return " ".join(result) + + params = self.clean_params + if not params: + return " ".join(result) + + for name, param in params.items(): + if param.default is not param.empty: + # We don't want None or '' to trigger the [name=value] case and instead it should + # do [name] since [name=None] or [name=] are not exactly useful for the user. + should_print = ( + param.default if isinstance(param.default, str) else param.default is not None + ) + if should_print: + result.append("[%s=%s]" % (name, param.default)) + else: + result.append("[%s]" % name) + elif param.kind == param.VAR_POSITIONAL: + result.append("[%s...]" % name) + else: + result.append("<%s>" % name) + + return " ".join(result) + + async def can_run(self, ctx): + """|coro| + + Checks if the command can be executed by checking all the predicates + inside the :attr:`.checks` attribute. + + Parameters + ----------- + ctx: :class:`.Context` + The ctx of the command currently being invoked. + + Raises + ------- + :class:`CommandError` + Any command error that was raised during a check call will be propagated + by this function. + + Returns + -------- + bool + A boolean indicating if the command can be invoked. + """ + + original = ctx.command + ctx.command = self + + try: + if not await ctx.bot.can_run(ctx): + raise CheckFailure( + "The global check functions for command {0.qualified_name} failed.".format( + self + ) + ) + + cog = self.instance + if cog is not None: + try: + local_check = getattr(cog, "_{0.__class__.__name__}__local_check".format(cog)) + except AttributeError: + pass + else: + ret = await discord.utils.maybe_coroutine(local_check, ctx) + if not ret: + return False + + predicates = self.checks + if not predicates: + # since we have no checks, then we just return True. + return True + + return await discord.utils.async_all(predicate(ctx) for predicate in predicates) + finally: + ctx.command = original + + +class GroupMixin: + """A mixin that implements common functionality for classes that behave + similar to :class:`.Group` and are allowed to register commands. + + Attributes + ----------- + all_commands: :class:`dict` + A mapping of command name to :class:`.Command` or superclass + objects. + case_insensitive: :class:`bool` + Whether the commands should be case insensitive. Defaults to ``False``. + """ + + def __init__(self, **kwargs): + case_insensitive = kwargs.get("case_insensitive", False) + self.all_commands = _CaseInsensitiveDict() if case_insensitive else {} + self.case_insensitive = case_insensitive + super().__init__(**kwargs) + + @property + def commands(self): + """Set[:class:`.Command`]: A unique set of commands without aliases that are registered.""" + return set(self.all_commands.values()) + + def recursively_remove_all_commands(self): + for command in self.all_commands.copy().values(): + if isinstance(command, GroupMixin): + command.recursively_remove_all_commands() + self.remove_command(command.name) + + def add_command(self, command): + """Adds a :class:`.Command` or its superclasses into the internal list + of commands. + + This is usually not called, instead the :meth:`~.GroupMixin.command` or + :meth:`~.GroupMixin.group` shortcut decorators are used instead. + + Parameters + ----------- + command + The command to add. + + Raises + ------- + :exc:`.ClientException` + If the command is already registered. + TypeError + If the command passed is not a subclass of :class:`.Command`. + """ + + if not isinstance(command, Command): + raise TypeError("The command passed must be a subclass of Command") + + if isinstance(self, Command): + command.parent = self + + if command.name in self.all_commands: + raise discord.ClientException( + "Command {0.name} is already registered.".format(command) + ) + + self.all_commands[command.name] = command + for alias in command.aliases: + if alias in self.all_commands: + raise discord.ClientException( + "The alias {} is already an existing command or alias.".format(alias) + ) + self.all_commands[alias] = command + + def remove_command(self, name): + """Remove a :class:`.Command` or subclasses from the internal list + of commands. + + This could also be used as a way to remove aliases. + + Parameters + ----------- + name: str + The name of the command to remove. + + Returns + -------- + :class:`.Command` or subclass + The command that was removed. If the name is not valid then + `None` is returned instead. + """ + command = self.all_commands.pop(name, None) + + # does not exist + if command is None: + return None + + if name in command.aliases: + # we're removing an alias so we don't want to remove the rest + return command + + # we're not removing the alias so let's delete the rest of them. + for alias in command.aliases: + self.all_commands.pop(alias, None) + return command + + def walk_commands(self): + """An iterator that recursively walks through all commands and subcommands.""" + for command in tuple(self.all_commands.values()): + yield command + if isinstance(command, GroupMixin): + yield from command.walk_commands() + + def get_command(self, name): + """Get a :class:`.Command` or subclasses from the internal list + of commands. + + This could also be used as a way to get aliases. + + The name could be fully qualified (e.g. ``'foo bar'``) will get + the subcommand ``bar`` of the group command ``foo``. If a + subcommand is not found then ``None`` is returned just as usual. + + Parameters + ----------- + name: str + The name of the command to get. + + Returns + -------- + Command or subclass + The command that was requested. If not found, returns ``None``. + """ + + names = name.split() + obj = self.all_commands.get(names[0]) + if not isinstance(obj, GroupMixin): + return obj + + for name in names[1:]: + try: + obj = obj.all_commands[name] + except (AttributeError, KeyError): + return None + + return obj + + def command(self, *args, **kwargs): + """A shortcut decorator that invokes :func:`.command` and adds it to + the internal command list via :meth:`~.GroupMixin.add_command`. + """ + + def decorator(func): + result = command(*args, **kwargs)(func) + self.add_command(result) + return result + + return decorator + + def group(self, *args, **kwargs): + """A shortcut decorator that invokes :func:`.group` and adds it to + the internal command list via :meth:`~.GroupMixin.add_command`. + """ + + def decorator(func): + result = group(*args, **kwargs)(func) + self.add_command(result) + return result + + return decorator + + +class Group(GroupMixin, Command): + """A class that implements a grouping protocol for commands to be + executed as subcommands. + + This class is a subclass of :class:`.Command` and thus all options + valid in :class:`.Command` are valid in here as well. + + Attributes + ----------- + invoke_without_command: :class:`bool` + Indicates if the group callback should begin parsing and + invocation only if no subcommand was found. Useful for + making it an error handling function to tell the user that + no subcommand was found or to have different functionality + in case no subcommand was found. If this is ``False``, then + the group callback will always be invoked first. This means + that the checks and the parsing dictated by its parameters + will be executed. Defaults to ``False``. + case_insensitive: :class:`bool` + Indicates if the group's commands should be case insensitive. + Defaults to ``False``. + """ + + def __init__(self, **attrs): + self.invoke_without_command = attrs.pop("invoke_without_command", False) + super().__init__(**attrs) + + async def invoke(self, ctx): + early_invoke = not self.invoke_without_command + if early_invoke: + await self.prepare(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.all_commands.get(trigger, None) + + if early_invoke: + injected = hooked_wrapped_callback(self, ctx, self.callback) + await injected(*ctx.args, **ctx.kwargs) + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + await ctx.invoked_subcommand.invoke(ctx) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + await super().invoke(ctx) + + async def reinvoke(self, ctx, *, call_hooks=False): + early_invoke = not self.invoke_without_command + if early_invoke: + ctx.command = self + await self._parse_arguments(ctx) + + if call_hooks: + await self.call_before_hooks(ctx) + + view = ctx.view + previous = view.index + view.skip_ws() + trigger = view.get_word() + + if trigger: + ctx.subcommand_passed = trigger + ctx.invoked_subcommand = self.all_commands.get(trigger, None) + + if early_invoke: + try: + await self.callback(*ctx.args, **ctx.kwargs) + except: + ctx.command_failed = True + raise + finally: + if call_hooks: + await self.call_after_hooks(ctx) + + if trigger and ctx.invoked_subcommand: + ctx.invoked_with = trigger + await ctx.invoked_subcommand.reinvoke(ctx, call_hooks=call_hooks) + elif not early_invoke: + # undo the trigger parsing + view.index = previous + view.previous = previous + await super().reinvoke(ctx, call_hooks=call_hooks) + + +# Decorators + + +def command(name=None, cls=None, **attrs): + """A decorator that transforms a function into a :class:`.Command` + or if called with :func:`.group`, :class:`.Group`. + + By default the ``help`` attribute is received automatically from the + docstring of the function and is cleaned up with the use of + ``inspect.cleandoc``. If the docstring is ``bytes``, then it is decoded + into :class:`str` using utf-8 encoding. + + All checks added using the :func:`.check` & co. decorators are added into + the function. There is no way to supply your own checks through this + decorator. + + Parameters + ----------- + name: str + The name to create the command with. By default this uses the + function name unchanged. + cls + The class to construct with. By default this is :class:`.Command`. + You usually do not change this. + attrs + Keyword arguments to pass into the construction of the class denoted + by ``cls``. + + Raises + ------- + TypeError + If the function is not a coroutine or is already a command. + """ + if cls is None: + cls = Command + + def decorator(func): + if isinstance(func, Command): + raise TypeError("Callback is already a command.") + if not asyncio.iscoroutinefunction(func): + raise TypeError("Callback must be a coroutine.") + + try: + checks = func.__commands_checks__ + checks.reverse() + del func.__commands_checks__ + except AttributeError: + checks = [] + + try: + cooldown = func.__commands_cooldown__ + del func.__commands_cooldown__ + except AttributeError: + cooldown = None + + help_doc = attrs.get("help") + if help_doc is not None: + help_doc = inspect.cleandoc(help_doc) + else: + help_doc = inspect.getdoc(func) + if isinstance(help_doc, bytes): + help_doc = help_doc.decode("utf-8") + + attrs["help"] = help_doc + fname = name or func.__name__ + return cls(name=fname, callback=func, checks=checks, cooldown=cooldown, **attrs) + + return decorator + + +def group(name=None, **attrs): + """A decorator that transforms a function into a :class:`.Group`. + + This is similar to the :func:`.command` decorator but creates a + :class:`.Group` instead of a :class:`.Command`. + """ + return command(name=name, cls=Group, **attrs) + + +def check(predicate): + r"""A decorator that adds a check to the :class:`.Command` or its + subclasses. These checks could be accessed via :attr:`.Command.checks`. + + These checks should be predicates that take in a single parameter taking + a :class:`.Context`. If the check returns a ``False``\-like value then + during invocation a :exc:`.CheckFailure` exception is raised and sent to + the :func:`.on_command_error` event. + + If an exception should be thrown in the predicate then it should be a + subclass of :exc:`.CommandError`. Any exception not subclassed from it + will be propagated while those subclassed will be sent to + :func:`.on_command_error`. + + .. note:: + + These functions can either be regular functions or coroutines. + + Parameters + ----------- + predicate + The predicate to check if the command should be invoked. + + Examples + --------- + + Creating a basic check to see if the command invoker is you. + + .. code-block:: python3 + + def check_if_it_is_me(ctx): + return ctx.message.author.id == 85309593344815104 + + @bot.command() + @commands.check(check_if_it_is_me) + async def only_for_me(ctx): + await ctx.send('I know you!') + + Transforming common checks into its own decorator: + + .. code-block:: python3 + + def is_me(): + def predicate(ctx): + return ctx.message.author.id == 85309593344815104 + return commands.check(predicate) + + @bot.command() + @is_me() + async def only_me(ctx): + await ctx.send('Only you!') + + """ + + def decorator(func): + if isinstance(func, Command): + func.checks.append(predicate) + else: + if not hasattr(func, "__commands_checks__"): + func.__commands_checks__ = [] + + func.__commands_checks__.append(predicate) + + return func + + return decorator + + +def has_role(item): + """A :func:`.check` that is added that checks if the member invoking the + command has the role specified via the name or ID specified. + + If a string is specified, you must give the exact name of the role, including + caps and spelling. + + If an integer is specified, you must give the exact snowflake ID of the role. + + If the message is invoked in a private message context then the check will + return ``False``. + + Parameters + ----------- + item: Union[int, str] + The name or ID of the role to check. + """ + + def predicate(ctx): + if not isinstance(ctx.channel, discord.abc.GuildChannel): + return False + + if isinstance(item, int): + role = discord.utils.get(ctx.author.roles, id=item) + else: + role = discord.utils.get(ctx.author.roles, name=item) + return role is not None + + return check(predicate) + + +def has_any_role(*items): + r"""A :func:`.check` that is added that checks if the member invoking the + command has **any** of the roles specified. This means that if they have + one out of the three roles specified, then this check will return `True`. + + Similar to :func:`.has_role`\, the names or IDs passed in must be exact. + + Parameters + ----------- + items + An argument list of names or IDs to check that the member has roles wise. + + Example + -------- + + .. code-block:: python3 + + @bot.command() + @commands.has_any_role('Library Devs', 'Moderators', 492212595072434186) + async def cool(ctx): + await ctx.send('You are cool indeed') + """ + + def predicate(ctx): + if not isinstance(ctx.channel, discord.abc.GuildChannel): + return False + + getter = functools.partial(discord.utils.get, ctx.author.roles) + return any( + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None + for item in items + ) + + return check(predicate) + + +def has_permissions(**perms): + """A :func:`.check` that is added that checks if the member has any of + the permissions necessary. + + The permissions passed in must be exactly like the properties shown under + :class:`.discord.Permissions`. + + This check raises a special exception, :exc:`.MissingPermissions` + that is derived from :exc:`.CheckFailure`. + + Parameters + ------------ + perms + An argument list of permissions to check for. + + Example + --------- + + .. code-block:: python3 + + @bot.command() + @commands.has_permissions(manage_messages=True) + async def test(ctx): + await ctx.send('You can manage messages.') + + """ + + def predicate(ctx): + ch = ctx.channel + permissions = ch.permissions_for(ctx.author) + + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm, None) != value + ] + + if not missing: + return True + + raise MissingPermissions(missing) + + return check(predicate) + + +def bot_has_role(item): + """Similar to :func:`.has_role` except checks if the bot itself has the + role. + """ + + def predicate(ctx): + ch = ctx.channel + if not isinstance(ch, discord.abc.GuildChannel): + return False + me = ch.guild.me + if isinstance(item, int): + role = discord.utils.get(me.roles, id=item) + else: + role = discord.utils.get(me.roles, name=item) + return role is not None + + return check(predicate) + + +def bot_has_any_role(*items): + """Similar to :func:`.has_any_role` except checks if the bot itself has + any of the roles listed. + """ + + def predicate(ctx): + ch = ctx.channel + if not isinstance(ch, discord.abc.GuildChannel): + return False + me = ch.guild.me + getter = functools.partial(discord.utils.get, me.roles) + return any( + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None + for item in items + ) + + return check(predicate) + + +def bot_has_permissions(**perms): + """Similar to :func:`.has_permissions` except checks if the bot itself has + the permissions listed. + + This check raises a special exception, :exc:`.BotMissingPermissions` + that is derived from :exc:`.CheckFailure`. + """ + + def predicate(ctx): + guild = ctx.guild + me = guild.me if guild is not None else ctx.bot.user + permissions = ctx.channel.permissions_for(me) + + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm, None) != value + ] + + if not missing: + return True + + raise BotMissingPermissions(missing) + + return check(predicate) + + +def guild_only(): + """A :func:`.check` that indicates this command must only be used in a + guild context only. Basically, no private messages are allowed when + using the command. + + This check raises a special exception, :exc:`.NoPrivateMessage` + that is derived from :exc:`.CheckFailure`. + """ + + def predicate(ctx): + if ctx.guild is None: + raise NoPrivateMessage("This command cannot be used in private messages.") + return True + + return check(predicate) + + +def is_owner(): + """A :func:`.check` that checks if the person invoking this command is the + owner of the bot. + + This is powered by :meth:`.Bot.is_owner`. + + This check raises a special exception, :exc:`.NotOwner` that is derived + from :exc:`.CheckFailure`. + """ + + async def predicate(ctx): + if not await ctx.bot.is_owner(ctx.author): + raise NotOwner("You do not own this bot.") + return True + + return check(predicate) + + +def is_nsfw(): + """A :func:`.check` that checks if the channel is a NSFW channel.""" + + def pred(ctx): + return isinstance(ctx.channel, discord.TextChannel) and ctx.channel.is_nsfw() + + return check(pred) + + +def cooldown(rate, per, type=BucketType.default): + """A decorator that adds a cooldown to a :class:`.Command` + or its subclasses. + + A cooldown allows a command to only be used a specific amount + of times in a specific time frame. These cooldowns can be based + either on a per-guild, per-channel, per-user, or global basis. + Denoted by the third argument of ``type`` which must be of enum + type ``BucketType`` which could be either: + + - ``BucketType.default`` for a global basis. + - ``BucketType.user`` for a per-user basis. + - ``BucketType.guild`` for a per-guild basis. + - ``BucketType.channel`` for a per-channel basis. + - ``BucketType.member`` for a per-member basis. + - ``BucketType.category`` for a per-category basis. + + If a cooldown is triggered, then :exc:`.CommandOnCooldown` is triggered in + :func:`.on_command_error` and the local error handler. + + A command can only have a single cooldown. + + Parameters + ------------ + rate: int + The number of times a command can be used before triggering a cooldown. + per: float + The amount of seconds to wait for a cooldown when it's been triggered. + type: ``BucketType`` + The type of cooldown to have. + """ + + def decorator(func): + if isinstance(func, Command): + func._buckets = CooldownMapping(Cooldown(rate, per, type)) + else: + func.__commands_cooldown__ = Cooldown(rate, per, type) + return func + + return decorator diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py new file mode 100644 index 000000000..744bcb01e --- /dev/null +++ b/discord/ext/commands/errors.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from discord.errors import DiscordException + + +__all__ = [ + "CommandError", + "MissingRequiredArgument", + "BadArgument", + "NoPrivateMessage", + "CheckFailure", + "CommandNotFound", + "DisabledCommand", + "CommandInvokeError", + "TooManyArguments", + "UserInputError", + "CommandOnCooldown", + "NotOwner", + "MissingPermissions", + "BotMissingPermissions", + "ConversionError", + "BadUnionArgument", +] + + +class CommandError(DiscordException): + r"""The base exception type for all command related errors. + + This inherits from :exc:`discord.DiscordException`. + + This exception and exceptions derived from it are handled + in a special way as they are caught and passed into a special event + from :class:`.Bot`\, :func:`on_command_error`. + """ + + def __init__(self, message=None, *args): + if message is not None: + # clean-up @everyone and @here mentions + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") + super().__init__(m, *args) + else: + super().__init__(*args) + + +class ConversionError(CommandError): + """Exception raised when a Converter class raises non-CommandError. + + This inherits from :exc:`.CommandError`. + + Attributes + ---------- + converter: :class:`discord.ext.commands.Converter` + The converter that failed. + original + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + """ + + def __init__(self, converter, original): + self.converter = converter + self.original = original + + +class UserInputError(CommandError): + """The base exception type for errors that involve errors + regarding user input. + + This inherits from :exc:`.CommandError`. + """ + + pass + + +class CommandNotFound(CommandError): + """Exception raised when a command is attempted to be invoked + but no command under that name is found. + + This is not raised for invalid subcommands, rather just the + initial main command that is attempted to be invoked. + """ + + pass + + +class MissingRequiredArgument(UserInputError): + """Exception raised when parsing a command and a parameter + that is required is not encountered. + + Attributes + ----------- + param: :class:`inspect.Parameter` + The argument that is missing. + """ + + def __init__(self, param): + self.param = param + super().__init__("{0.name} is a required argument that is missing.".format(param)) + + +class TooManyArguments(UserInputError): + """Exception raised when the command was passed too many arguments and its + :attr:`.Command.ignore_extra` attribute was not set to ``True``. + """ + + pass + + +class BadArgument(UserInputError): + """Exception raised when a parsing or conversion failure is encountered + on an argument to pass into a command. + """ + + pass + + +class CheckFailure(CommandError): + """Exception raised when the predicates in :attr:`.Command.checks` have failed.""" + + pass + + +class NoPrivateMessage(CheckFailure): + """Exception raised when an operation does not work in private message + contexts. + """ + + pass + + +class NotOwner(CheckFailure): + """Exception raised when the message author is not the owner of the bot.""" + + pass + + +class DisabledCommand(CommandError): + """Exception raised when the command being invoked is disabled.""" + + pass + + +class CommandInvokeError(CommandError): + """Exception raised when the command being invoked raised an exception. + + Attributes + ----------- + original + The original exception that was raised. You can also get this via + the ``__cause__`` attribute. + """ + + def __init__(self, e): + self.original = e + super().__init__("Command raised an exception: {0.__class__.__name__}: {0}".format(e)) + + +class CommandOnCooldown(CommandError): + """Exception raised when the command being invoked is on cooldown. + + Attributes + ----------- + cooldown: Cooldown + A class with attributes ``rate``, ``per``, and ``type`` similar to + the :func:`.cooldown` decorator. + retry_after: :class:`float` + The amount of seconds to wait before you can retry again. + """ + + def __init__(self, cooldown, retry_after): + self.cooldown = cooldown + self.retry_after = retry_after + super().__init__("You are on cooldown. Try again in {:.2f}s".format(retry_after)) + + +class MissingPermissions(CheckFailure): + """Exception raised when the command invoker lacks permissions to run + command. + + Attributes + ----------- + missing_perms: :class:`list` + The required permissions that are missing. + """ + + def __init__(self, missing_perms, *args): + self.missing_perms = missing_perms + + missing = [ + perm.replace("_", " ").replace("guild", "server").title() for perm in missing_perms + ] + + if len(missing) > 2: + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) + else: + fmt = " and ".join(missing) + message = "You are missing {} permission(s) to run command.".format(fmt) + super().__init__(message, *args) + + +class BotMissingPermissions(CheckFailure): + """Exception raised when the bot lacks permissions to run command. + + Attributes + ----------- + missing_perms: :class:`list` + The required permissions that are missing. + """ + + def __init__(self, missing_perms, *args): + self.missing_perms = missing_perms + + missing = [ + perm.replace("_", " ").replace("guild", "server").title() for perm in missing_perms + ] + + if len(missing) > 2: + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) + else: + fmt = " and ".join(missing) + message = "Bot requires {} permission(s) to run command.".format(fmt) + super().__init__(message, *args) + + +class BadUnionArgument(UserInputError): + """Exception raised when a :class:`typing.Union` converter fails for all + its associated types. + + Attributes + ----------- + param: :class:`inspect.Parameter` + The parameter that failed being converted. + converters: Tuple[Type, ...] + A tuple of converters attempted in conversion, in order of failure. + errors: List[:class:`CommandError`] + A list of errors that were caught from failing the conversion. + """ + + def __init__(self, param, converters, errors): + self.param = param + self.converters = converters + self.errors = errors + + def _get_name(x): + try: + return x.__name__ + except AttributeError: + return x.__class__.__name__ + + to_string = [_get_name(x) for x in converters] + if len(to_string) > 2: + fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1]) + else: + fmt = " or ".join(to_string) + + super().__init__('Could not convert "{0.name}" into {1}.'.format(param, fmt)) diff --git a/discord/ext/commands/formatter.py b/discord/ext/commands/formatter.py new file mode 100644 index 000000000..407c56448 --- /dev/null +++ b/discord/ext/commands/formatter.py @@ -0,0 +1,365 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import itertools +import inspect + +from .core import GroupMixin, Command +from .errors import CommandError + +# from discord.iterators import _FilteredAsyncIterator + +# help -> shows info of bot on top/bottom and lists subcommands +# help command -> shows detailed info of command +# help command -> same as above + +# + +# + +# + +# Cog: +# +# +# Other Cog: +# +# No Category: +# + +# Type help command for more info on a command. +# You can also type help category for more info on a category. + + +class Paginator: + """A class that aids in paginating code blocks for Discord messages. + + Attributes + ----------- + prefix: :class:`str` + The prefix inserted to every page. e.g. three backticks. + suffix: :class:`str` + The suffix appended at the end of every page. e.g. three backticks. + max_size: :class:`int` + The maximum amount of codepoints allowed in a page. + """ + + def __init__(self, prefix="```", suffix="```", max_size=2000): + self.prefix = prefix + self.suffix = suffix + self.max_size = max_size - len(suffix) + self._current_page = [prefix] + self._count = len(prefix) + 1 # prefix + newline + self._pages = [] + + def add_line(self, line="", *, empty=False): + """Adds a line to the current page. + + If the line exceeds the :attr:`max_size` then an exception + is raised. + + Parameters + ----------- + line: str + The line to add. + empty: bool + Indicates if another empty line should be added. + + Raises + ------ + RuntimeError + The line was too big for the current :attr:`max_size`. + """ + if len(line) > self.max_size - len(self.prefix) - 2: + raise RuntimeError( + "Line exceeds maximum page size %s" % (self.max_size - len(self.prefix) - 2) + ) + + if self._count + len(line) + 1 > self.max_size: + self.close_page() + + self._count += len(line) + 1 + self._current_page.append(line) + + if empty: + self._current_page.append("") + self._count += 1 + + def close_page(self): + """Prematurely terminate a page.""" + self._current_page.append(self.suffix) + self._pages.append("\n".join(self._current_page)) + self._current_page = [self.prefix] + self._count = len(self.prefix) + 1 # prefix + newline + + @property + def pages(self): + """Returns the rendered list of pages.""" + # we have more than just the prefix in our current page + if len(self._current_page) > 1: + self.close_page() + return self._pages + + def __repr__(self): + fmt = "" + return fmt.format(self) + + +class HelpFormatter: + """The default base implementation that handles formatting of the help + command. + + To override the behaviour of the formatter, :meth:`~.HelpFormatter.format` + should be overridden. A number of utility functions are provided for use + inside that method. + + Attributes + ----------- + show_hidden: :class:`bool` + Dictates if hidden commands should be shown in the output. + Defaults to ``False``. + show_check_failure: :class:`bool` + Dictates if commands that have their :attr:`.Command.checks` failed + shown. Defaults to ``False``. + width: :class:`int` + The maximum number of characters that fit in a line. + Defaults to 80. + """ + + def __init__(self, show_hidden=False, show_check_failure=False, width=80): + self.width = width + self.show_hidden = show_hidden + self.show_check_failure = show_check_failure + + def has_subcommands(self): + """:class:`bool`: Specifies if the command has subcommands.""" + return isinstance(self.command, GroupMixin) + + def is_bot(self): + """:class:`bool`: Specifies if the command being formatted is the bot itself.""" + return self.command is self.context.bot + + def is_cog(self): + """:class:`bool`: Specifies if the command being formatted is actually a cog.""" + return not self.is_bot() and not isinstance(self.command, Command) + + def shorten(self, text): + """Shortens text to fit into the :attr:`width`.""" + if len(text) > self.width: + return text[: self.width - 3] + "..." + return text + + @property + def max_name_size(self): + """:class:`int`: Returns the largest name length of a command or if it has subcommands + the largest subcommand name.""" + try: + commands = ( + self.command.all_commands if not self.is_cog() else self.context.bot.all_commands + ) + if commands: + return max( + map( + lambda c: len(c.name) if self.show_hidden or not c.hidden else 0, + commands.values(), + ) + ) + return 0 + except AttributeError: + return len(self.command.name) + + @property + def clean_prefix(self): + """The cleaned up invoke prefix. i.e. mentions are ``@name`` instead of ``<@id>``.""" + user = self.context.guild.me if self.context.guild else self.context.bot.user + # this breaks if the prefix mention is not the bot itself but I + # consider this to be an *incredibly* strange use case. I'd rather go + # for this common use case rather than waste performance for the + # odd one. + return self.context.prefix.replace(user.mention, "@" + user.display_name) + + def get_command_signature(self): + """Retrieves the signature portion of the help page.""" + prefix = self.clean_prefix + cmd = self.command + return prefix + cmd.signature + + def get_ending_note(self): + command_name = self.context.invoked_with + return ( + "Type {0}{1} command for more info on a command.\n" + "You can also type {0}{1} category for more info on a category.".format( + self.clean_prefix, command_name + ) + ) + + async def filter_command_list(self): + """Returns a filtered list of commands based on the two attributes + provided, :attr:`show_check_failure` and :attr:`show_hidden`. + Also filters based on if :meth:`~.HelpFormatter.is_cog` is valid. + + Returns + -------- + iterable + An iterable with the filter being applied. The resulting value is + a (key, value) :class:`tuple` of the command name and the command itself. + """ + + def sane_no_suspension_point_predicate(tup): + cmd = tup[1] + if self.is_cog(): + # filter commands that don't exist to this cog. + if cmd.instance is not self.command: + return False + + if cmd.hidden and not self.show_hidden: + return False + + return True + + async def predicate(tup): + if sane_no_suspension_point_predicate(tup) is False: + return False + + cmd = tup[1] + try: + return await cmd.can_run(self.context) + except CommandError: + return False + + iterator = ( + self.command.all_commands.items() + if not self.is_cog() + else self.context.bot.all_commands.items() + ) + if self.show_check_failure: + return filter(sane_no_suspension_point_predicate, iterator) + + # Gotta run every check and verify it + ret = [] + for elem in iterator: + valid = await predicate(elem) + if valid: + ret.append(elem) + + return ret + + def _add_subcommands_to_page(self, max_width, commands): + for name, command in commands: + if name in command.aliases: + # skip aliases + continue + + entry = " {0:<{width}} {1}".format(name, command.short_doc, width=max_width) + shortened = self.shorten(entry) + self._paginator.add_line(shortened) + + async def format_help_for(self, context, command_or_bot): + """Formats the help page and handles the actual heavy lifting of how + the help command looks like. To change the behaviour, override the + :meth:`~.HelpFormatter.format` method. + + Parameters + ----------- + context: :class:`.Context` + The context of the invoked help command. + command_or_bot: :class:`.Command` or :class:`.Bot` + The bot or command that we are getting the help of. + + Returns + -------- + list + A paginated output of the help command. + """ + self.context = context + self.command = command_or_bot + return await self.format() + + async def format(self): + """Handles the actual behaviour involved with formatting. + + To change the behaviour, this method should be overridden. + + Returns + -------- + list + A paginated output of the help command. + """ + self._paginator = Paginator() + + # we need a padding of ~80 or so + + description = ( + self.command.description if not self.is_cog() else inspect.getdoc(self.command) + ) + + if description: + # portion + self._paginator.add_line(description, empty=True) + + if isinstance(self.command, Command): + # + signature = self.get_command_signature() + self._paginator.add_line(signature, empty=True) + + # section + if self.command.help: + self._paginator.add_line(self.command.help, empty=True) + + # end it here if it's just a regular command + if not self.has_subcommands(): + self._paginator.close_page() + return self._paginator.pages + + max_width = self.max_name_size + + def category(tup): + cog = tup[1].cog_name + # we insert the zero width space there to give it approximate + # last place sorting position. + return cog + ":" if cog is not None else "\u200bNo Category:" + + filtered = await self.filter_command_list() + if self.is_bot(): + data = sorted(filtered, key=category) + for category, commands in itertools.groupby(data, key=category): + # there simply is no prettier way of doing this. + commands = sorted(commands) + if len(commands) > 0: + self._paginator.add_line(category) + + self._add_subcommands_to_page(max_width, commands) + else: + filtered = sorted(filtered) + if filtered: + self._paginator.add_line("Commands:") + self._add_subcommands_to_page(max_width, filtered) + + # add the ending note + self._paginator.add_line() + ending_note = self.get_ending_note() + self._paginator.add_line(ending_note) + return self._paginator.pages diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py new file mode 100644 index 000000000..c9a0107a1 --- /dev/null +++ b/discord/ext/commands/view.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .errors import BadArgument + + +class StringView: + def __init__(self, buffer): + self.index = 0 + self.buffer = buffer + self.end = len(buffer) + self.previous = 0 + + @property + def current(self): + return None if self.eof else self.buffer[self.index] + + @property + def eof(self): + return self.index >= self.end + + def undo(self): + self.index = self.previous + + def skip_ws(self): + pos = 0 + while not self.eof: + try: + current = self.buffer[self.index + pos] + if not current.isspace(): + break + pos += 1 + except IndexError: + break + + self.previous = self.index + self.index += pos + return self.previous != self.index + + def skip_string(self, string): + strlen = len(string) + if self.buffer[self.index : self.index + strlen] == string: + self.previous = self.index + self.index += strlen + return True + return False + + def read_rest(self): + result = self.buffer[self.index :] + self.previous = self.index + self.index = self.end + return result + + def read(self, n): + result = self.buffer[self.index : self.index + n] + self.previous = self.index + self.index += n + return result + + def get(self): + try: + result = self.buffer[self.index + 1] + except IndexError: + result = None + + self.previous = self.index + self.index += 1 + return result + + def get_word(self): + pos = 0 + while not self.eof: + try: + current = self.buffer[self.index + pos] + if current.isspace(): + break + pos += 1 + except IndexError: + break + self.previous = self.index + result = self.buffer[self.index : self.index + pos] + self.index += pos + return result + + def __repr__(self): + return "".format( + self + ) + + +# Parser + +# map from opening quotes to closing quotes +_quotes = { + '"': '"', + "‘": "’", + "‚": "‛", + "“": "”", + "„": "‟", + "⹂": "⹂", + "「": "」", + "『": "』", + "〝": "〞", + "﹁": "﹂", + "﹃": "﹄", + """: """, + "「": "」", + "«": "»", + "‹": "›", + "《": "》", + "〈": "〉", +} +_all_quotes = set(_quotes.keys()) | set(_quotes.values()) + + +def quoted_word(view): + current = view.current + + if current is None: + return None + + close_quote = _quotes.get(current) + is_quoted = bool(close_quote) + if is_quoted: + result = [] + _escaped_quotes = (current, close_quote) + else: + result = [current] + _escaped_quotes = _all_quotes + + while not view.eof: + current = view.get() + if not current: + if is_quoted: + # unexpected EOF + raise BadArgument("Expected closing {}.".format(close_quote)) + return "".join(result) + + # currently we accept strings in the format of "hello world" + # to embed a quote inside the string you must escape it: "a \"world\"" + if current == "\\": + next_char = view.get() + if not next_char: + # string ends with \ and no character after it + if is_quoted: + # if we're quoted then we're expecting a closing quote + raise BadArgument("Expected closing {}.".format(close_quote)) + # if we aren't then we just let it through + return "".join(result) + + if next_char in _escaped_quotes: + # escaped quote + result.append(next_char) + else: + # different escape character, ignore it + view.undo() + result.append(current) + continue + + if not is_quoted and current in _all_quotes: + # we aren't quoted + raise BadArgument("Unexpected quote mark in non-quoted string") + + # closing quote + if is_quoted and current == close_quote: + next_char = view.get() + valid_eof = not next_char or next_char.isspace() + if not valid_eof: + raise BadArgument("Expected space after closing quotation") + + # we're quoted so it's okay + return "".join(result) + + if current.isspace() and not is_quoted: + # end of word found + return "".join(result) + + result.append(current) diff --git a/discord/file.py b/discord/file.py new file mode 100644 index 000000000..7392e5ff4 --- /dev/null +++ b/discord/file.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import os.path + + +class File: + """A parameter object used for :meth:`abc.Messageable.send` + for sending file objects. + + Attributes + ----------- + fp: Union[:class:`str`, BinaryIO] + A file-like object opened in binary mode and read mode + or a filename representing a file in the hard drive to + open. + + .. note:: + + If the file-like object passed is opened via ``open`` then the + modes 'rb' should be used. + + To pass binary data, consider usage of ``io.BytesIO``. + + filename: Optional[:class:`str`] + The filename to display when uploading to Discord. + If this is not given then it defaults to ``fp.name`` or if ``fp`` is + a string then the ``filename`` will default to the string given. + spoiler: :class:`bool` + Whether the attachment is a spoiler. + """ + + __slots__ = ("fp", "filename", "_true_fp") + + def __init__(self, fp, filename=None, *, spoiler=False): + self.fp = fp + self._true_fp = None + + if filename is None: + if isinstance(fp, str): + _, self.filename = os.path.split(fp) + else: + self.filename = getattr(fp, "name", None) + else: + self.filename = filename + + if spoiler and not self.filename.startswith("SPOILER_"): + self.filename = "SPOILER_" + self.filename + + def open_file(self): + fp = self.fp + if isinstance(fp, str): + self._true_fp = fp = open(fp, "rb") + return fp + + def close(self): + if self._true_fp: + self._true_fp.close() diff --git a/discord/gateway.py b/discord/gateway.py new file mode 100644 index 000000000..bc4083711 --- /dev/null +++ b/discord/gateway.py @@ -0,0 +1,701 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +from collections import namedtuple +import json +import logging +import struct +import sys +import time +import threading +import zlib + +import websockets + +from . import utils +from .activity import _ActivityTag +from .errors import ConnectionClosed, InvalidArgument + +log = logging.getLogger(__name__) + +__all__ = [ + "DiscordWebSocket", + "KeepAliveHandler", + "VoiceKeepAliveHandler", + "DiscordVoiceWebSocket", + "ResumeWebSocket", +] + + +class ResumeWebSocket(Exception): + """Signals to initialise via RESUME opcode instead of IDENTIFY.""" + + def __init__(self, shard_id): + self.shard_id = shard_id + + +EventListener = namedtuple("EventListener", "predicate event result future") + + +class KeepAliveHandler(threading.Thread): + def __init__(self, *args, **kwargs): + ws = kwargs.pop("ws", None) + interval = kwargs.pop("interval", None) + shard_id = kwargs.pop("shard_id", None) + threading.Thread.__init__(self, *args, **kwargs) + self.ws = ws + self.interval = interval + self.daemon = True + self.shard_id = shard_id + self.msg = "Keeping websocket alive with sequence %s." + self._stop_ev = threading.Event() + self._last_ack = time.perf_counter() + self._last_send = time.perf_counter() + self.latency = float("inf") + self.heartbeat_timeout = ws._max_heartbeat_timeout + + def run(self): + while not self._stop_ev.wait(self.interval): + if self._last_ack + self.heartbeat_timeout < time.perf_counter(): + log.warning( + "Shard ID %s has stopped responding to the gateway. Closing and restarting.", + self.shard_id, + ) + coro = self.ws.close(4000) + f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) + + try: + f.result() + except Exception: + pass + finally: + self.stop() + return + + data = self.get_payload() + log.debug(self.msg, data["d"]) + coro = self.ws.send_as_json(data) + f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) + try: + # block until sending is complete + f.result() + except Exception: + self.stop() + else: + self._last_send = time.perf_counter() + + def get_payload(self): + return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence} + + def stop(self): + self._stop_ev.set() + + def ack(self): + ack_time = time.perf_counter() + self._last_ack = ack_time + self.latency = ack_time - self._last_send + + +class VoiceKeepAliveHandler(KeepAliveHandler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.msg = "Keeping voice websocket alive with timestamp %s." + + def get_payload(self): + return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)} + + +class DiscordWebSocket(websockets.client.WebSocketClientProtocol): + """Implements a WebSocket for Discord's gateway v6. + + This is created through :func:`create_main_websocket`. Library + users should never create this manually. + + Attributes + ----------- + DISPATCH + Receive only. Denotes an event to be sent to Discord, such as READY. + HEARTBEAT + When received tells Discord to keep the connection alive. + When sent asks if your connection is currently alive. + IDENTIFY + Send only. Starts a new session. + PRESENCE + Send only. Updates your presence. + VOICE_STATE + Send only. Starts a new connection to a voice guild. + VOICE_PING + Send only. Checks ping time to a voice guild, do not use. + RESUME + Send only. Resumes an existing connection. + RECONNECT + Receive only. Tells the client to reconnect to a new gateway. + REQUEST_MEMBERS + Send only. Asks for the full member list of a guild. + INVALIDATE_SESSION + Receive only. Tells the client to optionally invalidate the session + and IDENTIFY again. + HELLO + Receive only. Tells the client the heartbeat interval. + HEARTBEAT_ACK + Receive only. Confirms receiving of a heartbeat. Not having it implies + a connection issue. + GUILD_SYNC + Send only. Requests a guild sync. + gateway + The gateway we are currently connected to. + token + The authentication token for discord. + """ + + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 + INVALIDATE_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = None + # an empty dispatcher to prevent crashes + self._dispatch = lambda *args: None + # generic event listeners + self._dispatch_listeners = [] + # the keep alive + self._keep_alive = None + + # ws related stuff + self.session_id = None + self.sequence = None + self._zlib = zlib.decompressobj() + self._buffer = bytearray() + + @classmethod + async def from_client( + cls, client, *, shard_id=None, session=None, sequence=None, resume=False + ): + """Creates a main websocket for Discord from a :class:`Client`. + + This is for internal use only. + """ + gateway = await client.http.get_gateway() + ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) + + # dynamically add attributes needed + ws.token = client.http.token + ws._connection = client._connection + ws._dispatch = client.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = client._connection.shard_count + ws.session_id = session + ws.sequence = sequence + ws._max_heartbeat_timeout = client._connection.heartbeat_timeout + + client._connection._update_references(ws) + + log.info("Created websocket connected to %s", gateway) + + # poll event for OP Hello + await ws.poll_event() + + if not resume: + await ws.identify() + return ws + + await ws.resume() + try: + await ws.ensure_open() + except websockets.exceptions.ConnectionClosed: + # ws got closed so let's just do a regular IDENTIFY connect. + log.info( + "RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.", + shard_id, + ) + return await cls.from_client(client, shard_id=shard_id) + else: + return ws + + def wait_for(self, event, predicate, result=None): + """Waits for a DISPATCH'd event that meets the predicate. + + Parameters + ----------- + event : str + The event name in all upper case to wait for. + predicate + A function that takes a data parameter to check for event + properties. The data parameter is the 'd' key in the JSON message. + result + A function that takes the same data parameter and executes to send + the result to the future. If None, returns the data. + + Returns + -------- + asyncio.Future + A future to wait for. + """ + + future = self.loop.create_future() + entry = EventListener(event=event, predicate=predicate, result=result, future=future) + self._dispatch_listeners.append(entry) + return future + + async def identify(self): + """Sends the IDENTIFY packet.""" + payload = { + "op": self.IDENTIFY, + "d": { + "token": self.token, + "properties": { + "$os": sys.platform, + "$browser": "discord.py", + "$device": "discord.py", + "$referrer": "", + "$referring_domain": "", + }, + "compress": True, + "large_threshold": 250, + "v": 3, + }, + } + + if not self._connection.is_bot: + payload["d"]["synced_guilds"] = [] + + if self.shard_id is not None and self.shard_count is not None: + payload["d"]["shard"] = [self.shard_id, self.shard_count] + + state = self._connection + if state._activity is not None or state._status is not None: + payload["d"]["presence"] = { + "status": state._status, + "game": state._activity, + "since": 0, + "afk": False, + } + + await self.send_as_json(payload) + log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) + + async def resume(self): + """Sends the RESUME packet.""" + payload = { + "op": self.RESUME, + "d": {"seq": self.sequence, "session_id": self.session_id, "token": self.token}, + } + + await self.send_as_json(payload) + log.info("Shard ID %s has sent the RESUME payload.", self.shard_id) + + async def received_message(self, msg): + self._dispatch("socket_raw_receive", msg) + + if type(msg) is bytes: + self._buffer.extend(msg) + + if len(msg) >= 4: + if msg[-4:] == b"\x00\x00\xff\xff": + msg = self._zlib.decompress(self._buffer) + msg = msg.decode("utf-8") + self._buffer = bytearray() + else: + return + else: + return + + msg = json.loads(msg) + + log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg) + self._dispatch("socket_response", msg) + + op = msg.get("op") + data = msg.get("d") + seq = msg.get("s") + if seq is not None: + self.sequence = seq + + if op != self.DISPATCH: + if op == self.RECONNECT: + # "reconnect" can only be handled by the Client + # so we terminate our connection and raise an + # internal exception signalling to reconnect. + log.info("Received RECONNECT opcode.") + await self.close() + raise ResumeWebSocket(self.shard_id) + + if op == self.HEARTBEAT_ACK: + self._keep_alive.ack() + return + + if op == self.HEARTBEAT: + beat = self._keep_alive.get_payload() + await self.send_as_json(beat) + return + + if op == self.HELLO: + interval = data["heartbeat_interval"] / 1000.0 + self._keep_alive = KeepAliveHandler( + ws=self, interval=interval, shard_id=self.shard_id + ) + # send a heartbeat immediately + await self.send_as_json(self._keep_alive.get_payload()) + self._keep_alive.start() + return + + if op == self.INVALIDATE_SESSION: + if data is True: + await asyncio.sleep(5.0, loop=self.loop) + await self.close() + raise ResumeWebSocket(self.shard_id) + + self.sequence = None + self.session_id = None + log.info("Shard ID %s session has been invalidated.", self.shard_id) + await self.identify() + return + + log.warning("Unknown OP code %s.", op) + return + + event = msg.get("t") + + if event == "READY": + self._trace = trace = data.get("_trace", []) + self.sequence = msg["s"] + self.session_id = data["session_id"] + log.info( + "Shard ID %s has connected to Gateway: %s (Session ID: %s).", + self.shard_id, + ", ".join(trace), + self.session_id, + ) + + elif event == "RESUMED": + self._trace = trace = data.get("_trace", []) + log.info( + "Shard ID %s has successfully RESUMED session %s under trace %s.", + self.shard_id, + self.session_id, + ", ".join(trace), + ) + + parser = "parse_" + event.lower() + + try: + func = getattr(self._connection, parser) + except AttributeError: + log.warning("Unknown event %s.", event) + else: + func(data) + + # remove the dispatched listeners + removed = [] + for index, entry in enumerate(self._dispatch_listeners): + if entry.event != event: + continue + + future = entry.future + if future.cancelled(): + removed.append(index) + continue + + try: + valid = entry.predicate(data) + except Exception as exc: + future.set_exception(exc) + removed.append(index) + else: + if valid: + ret = data if entry.result is None else entry.result(data) + future.set_result(ret) + removed.append(index) + + for index in reversed(removed): + del self._dispatch_listeners[index] + + @property + def latency(self): + """:obj:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" + heartbeat = self._keep_alive + return float("inf") if heartbeat is None else heartbeat.latency + + def _can_handle_close(self, code): + return code not in (1000, 4004, 4010, 4011) + + async def poll_event(self): + """Polls for a DISPATCH event and handles the general gateway loop. + + Raises + ------ + ConnectionClosed + The websocket connection was terminated for unhandled reasons. + """ + try: + msg = await self.recv() + await self.received_message(msg) + except websockets.exceptions.ConnectionClosed as exc: + if self._can_handle_close(exc.code): + log.info( + "Websocket closed with %s (%s), attempting a reconnect.", exc.code, exc.reason + ) + raise ResumeWebSocket(self.shard_id) from exc + else: + log.info("Websocket closed with %s (%s), cannot reconnect.", exc.code, exc.reason) + raise ConnectionClosed(exc, shard_id=self.shard_id) from exc + + async def send(self, data): + self._dispatch("socket_raw_send", data) + await super().send(data) + + async def send_as_json(self, data): + try: + await super().send(utils.to_json(data)) + except websockets.exceptions.ConnectionClosed as exc: + if not self._can_handle_close(exc.code): + raise ConnectionClosed(exc, shard_id=self.shard_id) from exc + + async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): + if activity is not None: + if not isinstance(activity, _ActivityTag): + raise InvalidArgument("activity must be one of Game, Streaming, or Activity.") + activity = activity.to_dict() + + if status == "idle": + since = int(time.time() * 1000) + + payload = { + "op": self.PRESENCE, + "d": {"game": activity, "afk": afk, "since": since, "status": status}, + } + + sent = utils.to_json(payload) + log.debug('Sending "%s" to change status', sent) + await self.send(sent) + + async def request_sync(self, guild_ids): + payload = {"op": self.GUILD_SYNC, "d": list(guild_ids)} + await self.send_as_json(payload) + + async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + payload = { + "op": self.VOICE_STATE, + "d": { + "guild_id": guild_id, + "channel_id": channel_id, + "self_mute": self_mute, + "self_deaf": self_deaf, + }, + } + + log.debug("Updating our voice state to %s.", payload) + await self.send_as_json(payload) + + async def close(self, code=1000, reason=""): + if self._keep_alive: + self._keep_alive.stop() + + await super().close(code, reason) + + async def close_connection(self, *args, **kwargs): + if self._keep_alive: + self._keep_alive.stop() + + await super().close_connection(*args, **kwargs) + + +class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): + """Implements the websocket protocol for handling voice connections. + + Attributes + ----------- + IDENTIFY + Send only. Starts a new voice session. + SELECT_PROTOCOL + Send only. Tells discord what encryption mode and how to connect for voice. + READY + Receive only. Tells the websocket that the initial connection has completed. + HEARTBEAT + Send only. Keeps your websocket connection alive. + SESSION_DESCRIPTION + Receive only. Gives you the secret key required for voice. + SPEAKING + Send only. Notifies the client if you are currently speaking. + HEARTBEAT_ACK + Receive only. Tells you your heartbeat has been acknowledged. + RESUME + Sent only. Tells the client to resume its session. + HELLO + Receive only. Tells you that your websocket connection was acknowledged. + INVALIDATE_SESSION + Sent only. Tells you that your RESUME request has failed and to re-IDENTIFY. + """ + + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 + SESSION_DESCRIPTION = 4 + SPEAKING = 5 + HEARTBEAT_ACK = 6 + RESUME = 7 + HELLO = 8 + INVALIDATE_SESSION = 9 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_size = None + self._keep_alive = None + + async def send_as_json(self, data): + log.debug("Sending voice websocket frame: %s.", data) + await self.send(utils.to_json(data)) + + async def resume(self): + state = self._connection + payload = { + "op": self.RESUME, + "d": { + "token": state.token, + "server_id": str(state.server_id), + "session_id": state.session_id, + }, + } + await self.send_as_json(payload) + + async def identify(self): + state = self._connection + payload = { + "op": self.IDENTIFY, + "d": { + "server_id": str(state.server_id), + "user_id": str(state.user.id), + "session_id": state.session_id, + "token": state.token, + }, + } + await self.send_as_json(payload) + + @classmethod + async def from_client(cls, client, *, resume=False): + """Creates a voice websocket for the :class:`VoiceClient`.""" + gateway = "wss://" + client.endpoint + "/?v=3" + ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) + ws.gateway = gateway + ws._connection = client + ws._max_heartbeat_timeout = 60.0 + + if resume: + await ws.resume() + else: + await ws.identify() + + return ws + + async def select_protocol(self, ip, port): + payload = { + "op": self.SELECT_PROTOCOL, + "d": { + "protocol": "udp", + "data": {"address": ip, "port": port, "mode": "xsalsa20_poly1305"}, + }, + } + + await self.send_as_json(payload) + + async def speak(self, is_speaking=True): + payload = {"op": self.SPEAKING, "d": {"speaking": is_speaking, "delay": 0}} + + await self.send_as_json(payload) + + async def received_message(self, msg): + log.debug("Voice websocket frame received: %s", msg) + op = msg["op"] + data = msg.get("d") + + if op == self.READY: + interval = data["heartbeat_interval"] / 1000.0 + self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval) + self._keep_alive.start() + await self.initial_connection(data) + elif op == self.HEARTBEAT_ACK: + self._keep_alive.ack() + elif op == self.INVALIDATE_SESSION: + log.info("Voice RESUME failed.") + await self.identify() + elif op == self.SESSION_DESCRIPTION: + await self.load_secret_key(data) + + async def initial_connection(self, data): + state = self._connection + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + + packet = bytearray(70) + struct.pack_into(">I", packet, 0, state.ssrc) + state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) + recv = await self.loop.sock_recv(state.socket, 70) + log.debug("received packet in initial_connection: %s", recv) + + # the ip is ascii starting at the 4th byte and ending at the first null + ip_start = 4 + ip_end = recv.index(0, ip_start) + state.ip = recv[ip_start:ip_end].decode("ascii") + + # the port is a little endian unsigned short in the last two bytes + # yes, this is different endianness from everything else + state.port = struct.unpack_from("".format(self) + + def _update_voice_state(self, data, channel_id): + user_id = int(data["user_id"]) + channel = self.get_channel(channel_id) + try: + # check if we should remove the voice state from cache + if channel is None: + after = self._voice_states.pop(user_id) + else: + after = self._voice_states[user_id] + + before = copy.copy(after) + after._update(data, channel) + except KeyError: + # if we're here then we're getting added into the cache + after = VoiceState(data=data, channel=channel) + before = VoiceState(data=data, channel=None) + self._voice_states[user_id] = after + + member = self.get_member(user_id) + return member, before, after + + def _add_role(self, role): + # roles get added to the bottom (position 1, pos 0 is @everyone) + # so since self.roles has the @everyone role, we can't increment + # its position because it's stuck at position 0. Luckily x += False + # is equivalent to adding 0. So we cast the position to a bool and + # increment it. + for r in self._roles.values(): + r.position += not r.is_default() + + self._roles[role.id] = role + + def _remove_role(self, role_id): + # this raises KeyError if it fails.. + role = self._roles.pop(role_id) + + # since it didn't, we can change the positions now + # basically the same as above except we only decrement + # the position if we're above the role we deleted. + for r in self._roles.values(): + r.position -= r.position > role.position + + return role + + def _from_data(self, guild): + # according to Stan, this is always available even if the guild is unavailable + # I don't have this guarantee when someone updates the guild. + member_count = guild.get("member_count", None) + if member_count: + self._member_count = member_count + + self.name = guild.get("name") + self.region = try_enum(VoiceRegion, guild.get("region")) + self.verification_level = try_enum(VerificationLevel, guild.get("verification_level")) + self.default_notifications = try_enum( + NotificationLevel, guild.get("default_message_notifications") + ) + self.explicit_content_filter = try_enum( + ContentFilter, guild.get("explicit_content_filter", 0) + ) + self.afk_timeout = guild.get("afk_timeout") + self.icon = guild.get("icon") + self.unavailable = guild.get("unavailable", False) + self.id = int(guild["id"]) + self._roles = {} + state = self._state # speed up attribute access + for r in guild.get("roles", []): + role = Role(guild=self, data=r, state=state) + self._roles[role.id] = role + + self.mfa_level = guild.get("mfa_level") + self.emojis = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", []))) + self.features = guild.get("features", []) + self.splash = guild.get("splash") + self._system_channel_id = utils._get_as_snowflake(guild, "system_channel_id") + + for mdata in guild.get("members", []): + member = Member(data=mdata, guild=self, state=state) + self._add_member(member) + + self._sync(guild) + self._large = None if member_count is None else self._member_count >= 250 + + self.owner_id = utils._get_as_snowflake(guild, "owner_id") + self.afk_channel = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) + + for obj in guild.get("voice_states", []): + self._update_voice_state(obj, int(obj["channel_id"])) + + def _sync(self, data): + try: + self._large = data["large"] + except KeyError: + pass + + empty_tuple = tuple() + for presence in data.get("presences", []): + user_id = int(presence["user"]["id"]) + member = self.get_member(user_id) + if member is not None: + member._presence_update(presence, empty_tuple) + + if "channels" in data: + channels = data["channels"] + for c in channels: + if c["type"] == ChannelType.text.value: + self._add_channel(TextChannel(guild=self, data=c, state=self._state)) + elif c["type"] == ChannelType.voice.value: + self._add_channel(VoiceChannel(guild=self, data=c, state=self._state)) + elif c["type"] == ChannelType.category.value: + self._add_channel(CategoryChannel(guild=self, data=c, state=self._state)) + + @property + def channels(self): + """List[:class:`abc.GuildChannel`]: A list of channels that belongs to this guild.""" + return list(self._channels.values()) + + @property + def large(self): + """:class:`bool`: Indicates if the guild is a 'large' guild. + + A large guild is defined as having more than ``large_threshold`` count + members, which for this library is set to the maximum of 250. + """ + if self._large is None: + try: + return self._member_count >= 250 + except AttributeError: + return len(self._members) >= 250 + return self._large + + @property + def voice_channels(self): + """List[:class:`VoiceChannel`]: A list of voice channels that belongs to this guild. + + This is sorted by the position and are in UI order from top to bottom. + """ + r = [ch for ch in self._channels.values() if isinstance(ch, VoiceChannel)] + r.sort(key=lambda c: (c.position, c.id)) + return r + + @property + def me(self): + """Similar to :attr:`Client.user` except an instance of :class:`Member`. + This is essentially used to get the member version of yourself. + """ + self_id = self._state.user.id + return self.get_member(self_id) + + @property + def voice_client(self): + """Returns the :class:`VoiceClient` associated with this guild, if any.""" + return self._state._get_voice_client(self.id) + + @property + def text_channels(self): + """List[:class:`TextChannel`]: A list of text channels that belongs to this guild. + + This is sorted by the position and are in UI order from top to bottom. + """ + r = [ch for ch in self._channels.values() if isinstance(ch, TextChannel)] + r.sort(key=lambda c: (c.position, c.id)) + return r + + @property + def categories(self): + """List[:class:`CategoryChannel`]: A list of categories that belongs to this guild. + + This is sorted by the position and are in UI order from top to bottom. + """ + r = [ch for ch in self._channels.values() if isinstance(ch, CategoryChannel)] + r.sort(key=lambda c: (c.position, c.id)) + return r + + def by_category(self): + """Returns every :class:`CategoryChannel` and their associated channels. + + These channels and categories are sorted in the official Discord UI order. + + If the channels do not have a category, then the first element of the tuple is + ``None``. + + Returns + -------- + List[Tuple[Optional[:class:`CategoryChannel`], List[:class:`abc.GuildChannel`]]]: + The categories and their associated channels. + """ + grouped = defaultdict(list) + for channel in self._channels.values(): + if isinstance(channel, CategoryChannel): + continue + + grouped[channel.category_id].append(channel) + + def key(t): + k, v = t + return ((k.position, k.id) if k else (-1, -1), v) + + _get = self._channels.get + as_list = [(_get(k), v) for k, v in grouped.items()] + as_list.sort(key=key) + for _, channels in as_list: + channels.sort(key=lambda c: (not isinstance(c, TextChannel), c.position, c.id)) + return as_list + + def get_channel(self, channel_id): + """Returns a :class:`abc.GuildChannel` with the given ID. If not found, returns None.""" + return self._channels.get(channel_id) + + @property + def system_channel(self): + """Optional[:class:`TextChannel`]: Returns the guild's channel used for system messages. + + Currently this is only for new member joins. If no channel is set, then this returns ``None``. + """ + channel_id = self._system_channel_id + return channel_id and self._channels.get(channel_id) + + @property + def members(self): + """List[:class:`Member`]: A list of members that belong to this guild.""" + return list(self._members.values()) + + def get_member(self, user_id): + """Returns a :class:`Member` with the given ID. If not found, returns None.""" + return self._members.get(user_id) + + @property + def roles(self): + """Returns a :class:`list` of the guild's roles in hierarchy order. + + The first element of this list will be the lowest role in the + hierarchy. + """ + return sorted(self._roles.values()) + + def get_role(self, role_id): + """Returns a :class:`Role` with the given ID. If not found, returns None.""" + return self._roles.get(role_id) + + @utils.cached_slot_property("_default_role") + def default_role(self): + """Gets the @everyone role that all members have by default.""" + return utils.find(lambda r: r.is_default(), self._roles.values()) + + @property + def owner(self): + """:class:`Member`: The member that owns the guild.""" + return self.get_member(self.owner_id) + + @property + def icon_url(self): + """Returns the URL version of the guild's icon. Returns an empty string if it has no icon.""" + return self.icon_url_as() + + def icon_url_as(self, *, format="webp", size=1024): + """Returns a friendly URL version of the guild's icon. Returns an empty string if it has no icon. + + The format must be one of 'webp', 'jpeg', 'jpg', or 'png'. The + size must be a power of 2 between 16 and 2048. + + Parameters + ----------- + format: str + The format to attempt to convert the icon to. + size: int + The size of the image to display. + + Returns + -------- + str + The resulting CDN URL. + + Raises + ------ + InvalidArgument + Bad image format passed to ``format`` or invalid ``size``. + """ + if not valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 2048") + if format not in VALID_ICON_FORMATS: + raise InvalidArgument("format must be one of {}".format(VALID_ICON_FORMATS)) + + if self.icon is None: + return "" + + return "https://cdn.discordapp.com/icons/{0.id}/{0.icon}.{1}?size={2}".format( + self, format, size + ) + + @property + def splash_url(self): + """Returns the URL version of the guild's invite splash. Returns an empty string if it has no splash.""" + return self.splash_url_as() + + def splash_url_as(self, *, format="webp", size=2048): + """Returns a friendly URL version of the guild's invite splash. Returns an empty string if it has no splash. + + The format must be one of 'webp', 'jpeg', 'jpg', or 'png'. The + size must be a power of 2 between 16 and 2048. + + Parameters + ----------- + format: str + The format to attempt to convert the splash to. + size: int + The size of the image to display. + + Returns + -------- + str + The resulting CDN URL. + + Raises + ------ + InvalidArgument + Bad image format passed to ``format`` or invalid ``size``. + """ + if not valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 2048") + if format not in VALID_ICON_FORMATS: + raise InvalidArgument("format must be one of {}".format(VALID_ICON_FORMATS)) + + if self.splash is None: + return "" + + return "https://cdn.discordapp.com/splashes/{0.id}/{0.splash}.{1}?size={2}".format( + self, format, size + ) + + @property + def member_count(self): + """Returns the true member count regardless of it being loaded fully or not.""" + return self._member_count + + @property + def chunked(self): + """Returns a boolean indicating if the guild is "chunked". + + A chunked guild means that :attr:`member_count` is equal to the + number of members stored in the internal :attr:`members` cache. + + If this value returns ``False``, then you should request for + offline members. + """ + count = getattr(self, "_member_count", None) + if count is None: + return False + return count == len(self._members) + + @property + def shard_id(self): + """Returns the shard ID for this guild if applicable.""" + count = self._state.shard_count + if count is None: + return None + return (self.id >> 22) % count + + @property + def created_at(self): + """Returns the guild's creation time in UTC.""" + return utils.snowflake_time(self.id) + + def get_member_named(self, name): + """Returns the first member found that matches the name provided. + + The name can have an optional discriminator argument, e.g. "Jake#0001" + or "Jake" will both do the lookup. However the former will give a more + precise result. Note that the discriminator must have all 4 digits + for this to work. + + If a nickname is passed, then it is looked up via the nickname. Note + however, that a nickname + discriminator combo will not lookup the nickname + but rather the username + discriminator combo due to nickname + discriminator + not being unique. + + If no member is found, ``None`` is returned. + + Parameters + ----------- + name: str + The name of the member to lookup with an optional discriminator. + + Returns + -------- + :class:`Member` + The member in this guild with the associated name. If not found + then ``None`` is returned. + """ + + result = None + members = self.members + if len(name) > 5 and name[-5] == "#": + # The 5 length is checking to see if #0000 is in the string, + # as a#0000 has a length of 6, the minimum for a potential + # discriminator lookup. + potential_discriminator = name[-4:] + + # do the actual lookup and return if found + # if it isn't found then we'll do a full name lookup below. + result = utils.get(members, name=name[:-5], discriminator=potential_discriminator) + if result is not None: + return result + + def pred(m): + return m.nick == name or m.name == name + + return utils.find(pred, members) + + def _create_channel(self, name, overwrites, channel_type, category=None, reason=None): + if overwrites is None: + overwrites = {} + elif not isinstance(overwrites, dict): + raise InvalidArgument("overwrites parameter expects a dict.") + + perms = [] + for target, perm in overwrites.items(): + if not isinstance(perm, PermissionOverwrite): + raise InvalidArgument( + "Expected PermissionOverwrite received {0.__name__}".format(type(perm)) + ) + + allow, deny = perm.pair() + payload = {"allow": allow.value, "deny": deny.value, "id": target.id} + + if isinstance(target, Role): + payload["type"] = "role" + else: + payload["type"] = "member" + + perms.append(payload) + + parent_id = category.id if category else None + return self._state.http.create_channel( + self.id, + name, + channel_type.value, + parent_id=parent_id, + permission_overwrites=perms, + reason=reason, + ) + + async def create_text_channel(self, name, *, overwrites=None, category=None, reason=None): + """|coro| + + Creates a :class:`TextChannel` for the guild. + + Note that you need the :attr:`~Permissions.manage_channels` permission + to create the channel. + + The ``overwrites`` parameter can be used to create a 'secret' + channel upon creation. This parameter expects a :class:`dict` of + overwrites with the target (either a :class:`Member` or a :class:`Role`) + as the key and a :class:`PermissionOverwrite` as the value. + + Examples + ---------- + + Creating a basic channel: + + .. code-block:: python3 + + channel = await guild.create_text_channel('cool-channel') + + Creating a "secret" channel: + + .. code-block:: python3 + + overwrites = { + guild.default_role: discord.PermissionOverwrite(read_messages=False), + guild.me: discord.PermissionOverwrite(read_messages=True) + } + + channel = await guild.create_text_channel('secret', overwrites=overwrites) + + Parameters + ----------- + name: str + The channel's name. + overwrites + A :class:`dict` of target (either a role or a member) to + :class:`PermissionOverwrite` to apply upon creation of a channel. + Useful for creating secret channels. + category: Optional[:class:`CategoryChannel`] + The category to place the newly created channel under. + The permissions will be automatically synced to category if no + overwrites are provided. + reason: Optional[str] + The reason for creating this channel. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have the proper permissions to create this channel. + HTTPException + Creating the channel failed. + InvalidArgument + The permission overwrite information is not in proper form. + + Returns + ------- + :class:`TextChannel` + The channel that was just created. + """ + data = await self._create_channel( + name, overwrites, ChannelType.text, category, reason=reason + ) + channel = TextChannel(state=self._state, guild=self, data=data) + + # temporarily add to the cache + self._channels[channel.id] = channel + return channel + + async def create_voice_channel(self, name, *, overwrites=None, category=None, reason=None): + """|coro| + + Same as :meth:`create_text_channel` except makes a :class:`VoiceChannel` instead. + """ + data = await self._create_channel( + name, overwrites, ChannelType.voice, category, reason=reason + ) + channel = VoiceChannel(state=self._state, guild=self, data=data) + + # temporarily add to the cache + self._channels[channel.id] = channel + return channel + + async def create_category(self, name, *, overwrites=None, reason=None): + """|coro| + + Same as :meth:`create_text_channel` except makes a :class:`CategoryChannel` instead. + + .. note:: + + The ``category`` parameter is not supported in this function since categories + cannot have categories. + """ + data = await self._create_channel(name, overwrites, ChannelType.category, reason=reason) + channel = CategoryChannel(state=self._state, guild=self, data=data) + + # temporarily add to the cache + self._channels[channel.id] = channel + return channel + + create_category_channel = create_category + + async def leave(self): + """|coro| + + Leaves the guild. + + Note + -------- + You cannot leave the guild that you own, you must delete it instead + via :meth:`delete`. + + Raises + -------- + HTTPException + Leaving the guild failed. + """ + await self._state.http.leave_guild(self.id) + + async def delete(self): + """|coro| + + Deletes the guild. You must be the guild owner to delete the + guild. + + Raises + -------- + HTTPException + Deleting the guild failed. + Forbidden + You do not have permissions to delete the guild. + """ + + await self._state.http.delete_guild(self.id) + + async def edit(self, *, reason=None, **fields): + """|coro| + + Edits the guild. + + You must have the :attr:`~Permissions.manage_guild` permission + to edit the guild. + + Parameters + ---------- + name: str + The new name of the guild. + icon: bytes + A :term:`py:bytes-like object` representing the icon. Only PNG/JPEG supported. + Could be ``None`` to denote removal of the icon. + splash: bytes + A :term:`py:bytes-like object` representing the invite splash. + Only PNG/JPEG supported. Could be ``None`` to denote removing the + splash. Only available for partnered guilds with ``INVITE_SPLASH`` + feature. + region: :class:`VoiceRegion` + The new region for the guild's voice communication. + afk_channel: Optional[:class:`VoiceChannel`] + The new channel that is the AFK channel. Could be ``None`` for no AFK channel. + afk_timeout: int + The number of seconds until someone is moved to the AFK channel. + owner: :class:`Member` + The new owner of the guild to transfer ownership to. Note that you must + be owner of the guild to do this. + verification_level: :class:`VerificationLevel` + The new verification level for the guild. + default_notifications: :class:`NotificationLevel` + The new default notification level for the guild. + explicit_content_filter: :class:`ContentFilter` + The new explicit content filter for the guild. + vanity_code: str + The new vanity code for the guild. + system_channel: Optional[:class:`TextChannel`] + The new channel that is used for the system channel. Could be ``None`` for no system channel. + reason: Optional[str] + The reason for editing this guild. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to edit the guild. + HTTPException + Editing the guild failed. + InvalidArgument + The image format passed in to ``icon`` is invalid. It must be + PNG or JPG. This is also raised if you are not the owner of the + guild and request an ownership transfer. + """ + + http = self._state.http + try: + icon_bytes = fields["icon"] + except KeyError: + icon = self.icon + else: + if icon_bytes is not None: + icon = utils._bytes_to_base64_data(icon_bytes) + else: + icon = None + + try: + vanity_code = fields["vanity_code"] + except KeyError: + pass + else: + await http.change_vanity_code(self.id, vanity_code, reason=reason) + + try: + splash_bytes = fields["splash"] + except KeyError: + splash = self.splash + else: + if splash_bytes is not None: + splash = utils._bytes_to_base64_data(splash_bytes) + else: + splash = None + + fields["icon"] = icon + fields["splash"] = splash + + try: + default_message_notifications = int(fields.pop("default_notifications")) + except (TypeError, KeyError): + pass + else: + fields["default_message_notifications"] = default_message_notifications + + try: + afk_channel = fields.pop("afk_channel") + except KeyError: + pass + else: + if afk_channel is None: + fields["afk_channel_id"] = afk_channel + else: + fields["afk_channel_id"] = afk_channel.id + + try: + system_channel = fields.pop("system_channel") + except KeyError: + pass + else: + if system_channel is None: + fields["system_channel_id"] = system_channel + else: + fields["system_channel_id"] = system_channel.id + + if "owner" in fields: + if self.owner != self.me: + raise InvalidArgument("To transfer ownership you must be the owner of the guild.") + + fields["owner_id"] = fields["owner"].id + + if "region" in fields: + fields["region"] = str(fields["region"]) + + level = fields.get("verification_level", self.verification_level) + if not isinstance(level, VerificationLevel): + raise InvalidArgument("verification_level field must be of type VerificationLevel") + + fields["verification_level"] = level.value + + explicit_content_filter = fields.get( + "explicit_content_filter", self.explicit_content_filter + ) + if not isinstance(explicit_content_filter, ContentFilter): + raise InvalidArgument("explicit_content_filter field must be of type ContentFilter") + + fields["explicit_content_filter"] = explicit_content_filter.value + await http.edit_guild(self.id, reason=reason, **fields) + + async def get_ban(self, user): + """|coro| + + Retrieves the :class:`BanEntry` for a user, which is a namedtuple + with a ``user`` and ``reason`` field. See :meth:`bans` for more + information. + + You must have the :attr:`~Permissions.ban_members` permission + to get this information. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to get ban information from. + + Raises + ------ + Forbidden + You do not have proper permissions to get the information. + NotFound + This user is not banned. + HTTPException + An error occurred while fetching the information. + + Returns + ------- + BanEntry + The BanEntry object for the specified user. + """ + data = await self._state.http.get_ban(user.id, self.id) + return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"]) + + async def bans(self): + """|coro| + + Retrieves all the users that are banned from the guild. + + This coroutine returns a :class:`list` of BanEntry objects, which is a + namedtuple with a ``user`` field to denote the :class:`User` + that got banned along with a ``reason`` field specifying + why the user was banned that could be set to ``None``. + + You must have the :attr:`~Permissions.ban_members` permission + to get this information. + + Raises + ------- + Forbidden + You do not have proper permissions to get the information. + HTTPException + An error occurred while fetching the information. + + Returns + -------- + List[BanEntry] + A list of BanEntry objects. + """ + + data = await self._state.http.get_bans(self.id) + return [ + BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) + for e in data + ] + + async def prune_members(self, *, days, reason=None): + """|coro| + + Prunes the guild from its inactive members. + + The inactive members are denoted if they have not logged on in + ``days`` number of days and they have no roles. + + You must have the :attr:`~Permissions.kick_members` permission + to use this. + + To check how many members you would prune without actually pruning, + see the :meth:`estimate_pruned_members` function. + + Parameters + ----------- + days: int + The number of days before counting as inactive. + reason: Optional[str] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to prune members. + HTTPException + An error occurred while pruning members. + InvalidArgument + An integer was not passed for ``days``. + + Returns + --------- + int + The number of members pruned. + """ + + if not isinstance(days, int): + raise InvalidArgument( + "Expected int for ``days``, received {0.__class__.__name__} instead.".format(days) + ) + + data = await self._state.http.prune_members(self.id, days, reason=reason) + return data["pruned"] + + async def webhooks(self): + """|coro| + + Gets the list of webhooks from this guild. + + Requires :attr:`~.Permissions.manage_webhooks` permissions. + + Raises + ------- + Forbidden + You don't have permissions to get the webhooks. + + Returns + -------- + List[:class:`Webhook`] + The webhooks for this guild. + """ + + data = await self._state.http.guild_webhooks(self.id) + return [Webhook.from_state(d, state=self._state) for d in data] + + async def estimate_pruned_members(self, *, days): + """|coro| + + Similar to :meth:`prune_members` except instead of actually + pruning members, it returns how many members it would prune + from the guild had it been called. + + Parameters + ----------- + days: int + The number of days before counting as inactive. + + Raises + ------- + Forbidden + You do not have permissions to prune members. + HTTPException + An error occurred while fetching the prune members estimate. + InvalidArgument + An integer was not passed for ``days``. + + Returns + --------- + int + The number of members estimated to be pruned. + """ + + if not isinstance(days, int): + raise InvalidArgument( + "Expected int for ``days``, received {0.__class__.__name__} instead.".format(days) + ) + + data = await self._state.http.estimate_pruned_members(self.id, days) + return data["pruned"] + + async def invites(self): + """|coro| + + Returns a list of all active instant invites from the guild. + + You must have the :attr:`~Permissions.manage_guild` permission to get + this information. + + Raises + ------- + Forbidden + You do not have proper permissions to get the information. + HTTPException + An error occurred while fetching the information. + + Returns + ------- + List[:class:`Invite`] + The list of invites that are currently active. + """ + + data = await self._state.http.invites_from(self.id) + result = [] + for invite in data: + channel = self.get_channel(int(invite["channel"]["id"])) + invite["channel"] = channel + invite["guild"] = self + result.append(Invite(state=self._state, data=invite)) + + return result + + async def create_custom_emoji(self, *, name, image, roles=None, reason=None): + r"""|coro| + + Creates a custom :class:`Emoji` for the guild. + + There is currently a limit of 50 static and animated emojis respectively per guild, + unless the guild has the ``MORE_EMOJI`` feature which extends the limit to 200. + + You must have the :attr:`~Permissions.manage_emojis` permission to + do this. + + Parameters + ----------- + name: str + The emoji name. Must be at least 2 characters. + image: bytes + The :term:`py:bytes-like object` representing the image data to use. + Only JPG, PNG and GIF images are supported. + roles: Optional[list[:class:`Role`]] + A :class:`list` of :class:`Role`\s that can use this emoji. Leave empty to make it available to everyone. + reason: Optional[str] + The reason for creating this emoji. Shows up on the audit log. + + Returns + -------- + :class:`Emoji` + The created emoji. + + Raises + ------- + Forbidden + You are not allowed to create emojis. + HTTPException + An error occurred creating an emoji. + """ + + img = utils._bytes_to_base64_data(image) + if roles: + roles = [role.id for role in roles] + data = await self._state.http.create_custom_emoji( + self.id, name, img, roles=roles, reason=reason + ) + return self._state.store_emoji(self, data) + + async def create_role(self, *, reason=None, **fields): + """|coro| + + Creates a :class:`Role` for the guild. + + All fields are optional. + + You must have the :attr:`~Permissions.manage_roles` permission to + do this. + + Parameters + ----------- + name: str + The role name. Defaults to 'new role'. + permissions: :class:`Permissions` + The permissions to have. Defaults to no permissions. + colour: :class:`Colour` + The colour for the role. Defaults to :meth:`Colour.default`. + This is aliased to ``color`` as well. + hoist: bool + Indicates if the role should be shown separately in the member list. + Defaults to False. + mentionable: bool + Indicates if the role should be mentionable by others. + Defaults to False. + reason: Optional[str] + The reason for creating this role. Shows up on the audit log. + + Returns + -------- + :class:`Role` + The newly created role. + + Raises + ------- + Forbidden + You do not have permissions to change the role. + HTTPException + Editing the role failed. + InvalidArgument + An invalid keyword argument was given. + """ + + try: + perms = fields.pop("permissions") + except KeyError: + fields["permissions"] = 0 + else: + fields["permissions"] = perms.value + + try: + colour = fields.pop("colour") + except KeyError: + colour = fields.get("color", Colour.default()) + finally: + fields["color"] = colour.value + + valid_keys = ("name", "permissions", "color", "hoist", "mentionable") + for key in fields: + if key not in valid_keys: + raise InvalidArgument("%r is not a valid field." % key) + + data = await self._state.http.create_role(self.id, reason=reason, **fields) + role = Role(guild=self, data=data, state=self._state) + + # TODO: add to cache + return role + + async def kick(self, user, *, reason=None): + """|coro| + + Kicks a user from the guild. + + The user must meet the :class:`abc.Snowflake` abc. + + You must have the :attr:`~Permissions.kick_members` permission to + do this. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to kick from their guild. + reason: Optional[str] + The reason the user got kicked. + + Raises + ------- + Forbidden + You do not have the proper permissions to kick. + HTTPException + Kicking failed. + """ + await self._state.http.kick(user.id, self.id, reason=reason) + + async def ban(self, user, *, reason=None, delete_message_days=1): + """|coro| + + Bans a user from the guild. + + The user must meet the :class:`abc.Snowflake` abc. + + You must have the :attr:`~Permissions.ban_members` permission to + do this. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to ban from their guild. + delete_message_days: int + The number of days worth of messages to delete from the user + in the guild. The minimum is 0 and the maximum is 7. + reason: Optional[str] + The reason the user got banned. + + Raises + ------- + Forbidden + You do not have the proper permissions to ban. + HTTPException + Banning failed. + """ + await self._state.http.ban(user.id, self.id, delete_message_days, reason=reason) + + async def unban(self, user, *, reason=None): + """|coro| + + Unbans a user from the guild. + + The user must meet the :class:`abc.Snowflake` abc. + + You must have the :attr:`~Permissions.ban_members` permission to + do this. + + Parameters + ----------- + user: :class:`abc.Snowflake` + The user to unban. + reason: Optional[str] + The reason for doing this action. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have the proper permissions to unban. + HTTPException + Unbanning failed. + """ + await self._state.http.unban(user.id, self.id, reason=reason) + + async def vanity_invite(self): + """|coro| + + Returns the guild's special vanity invite. + + The guild must be partnered, i.e. have 'VANITY_URL' in + :attr:`~Guild.features`. + + You must have the :attr:`~Permissions.manage_guild` permission to use + this as well. + + Returns + -------- + :class:`Invite` + The special vanity invite. + + Raises + ------- + Forbidden + You do not have the proper permissions to get this. + HTTPException + Retrieving the vanity invite failed. + """ + + # we start with { code: abc } + payload = await self._state.http.get_vanity_code(self.id) + + # get the vanity URL channel since default channels aren't + # reliable or a thing anymore + data = await self._state.http.get_invite(payload["code"]) + + payload["guild"] = self + payload["channel"] = self.get_channel(int(data["channel"]["id"])) + payload["revoked"] = False + payload["temporary"] = False + payload["max_uses"] = 0 + payload["max_age"] = 0 + return Invite(state=self._state, data=payload) + + def ack(self): + """|coro| + + Marks every message in this guild as read. + + The user must not be a bot user. + + Raises + ------- + HTTPException + Acking failed. + ClientException + You must not be a bot user. + """ + + state = self._state + if state.is_bot: + raise ClientException("Must not be a bot account to ack messages.") + return state.http.ack_guild(self.id) + + def audit_logs( + self, *, limit=100, before=None, after=None, reverse=None, user=None, action=None + ): + """Return an :class:`AsyncIterator` that enables receiving the guild's audit logs. + + You must have the :attr:`~Permissions.view_audit_log` permission to use this. + + Parameters + ----------- + limit: Optional[int] + The number of entries to retrieve. If ``None`` retrieve all entries. + before: Union[:class:`abc.Snowflake`, datetime] + Retrieve entries before this date or entry. + If a date is provided it must be a timezone-naive datetime representing UTC time. + after: Union[:class:`abc.Snowflake`, datetime] + Retrieve entries after this date or entry. + If a date is provided it must be a timezone-naive datetime representing UTC time. + reverse: bool + If set to true, return entries in oldest->newest order. If unspecified, + this defaults to ``False`` for most cases. However if passing in a + ``after`` parameter then this is set to ``True``. This avoids getting entries + out of order in the ``after`` case. + user: :class:`abc.Snowflake` + The moderator to filter entries from. + action: :class:`AuditLogAction` + The action to filter with. + + Yields + -------- + :class:`AuditLogEntry` + The audit log entry. + + Raises + ------- + Forbidden + You are not allowed to fetch audit logs + HTTPException + An error occurred while fetching the audit logs. + + Examples + ---------- + + Getting the first 100 entries: :: + + async for entry in guild.audit_logs(limit=100): + print('{0.user} did {0.action} to {0.target}'.format(entry)) + + Getting entries for a specific action: :: + + async for entry in guild.audit_logs(action=discord.AuditLogAction.ban): + print('{0.user} banned {0.target}'.format(entry)) + + Getting entries made by a specific user: :: + + entries = await guild.audit_logs(limit=None, user=guild.me).flatten() + await channel.send('I made {} moderation actions.'.format(len(entries))) + """ + if user: + user = user.id + + if action: + action = action.value + + return AuditLogIterator( + self, + before=before, + after=after, + limit=limit, + reverse=reverse, + user_id=user, + action_type=action, + ) diff --git a/discord/http.py b/discord/http.py new file mode 100644 index 000000000..08bb720bf --- /dev/null +++ b/discord/http.py @@ -0,0 +1,911 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import json +import logging +import sys +from urllib.parse import quote as _uriquote +import weakref + +import aiohttp + +from .errors import HTTPException, Forbidden, NotFound, LoginFailure, GatewayNotFound +from . import __version__, utils + +log = logging.getLogger(__name__) + + +async def json_or_text(response): + text = await response.text(encoding="utf-8") + if response.headers["content-type"] == "application/json": + return json.loads(text) + return text + + +class Route: + BASE = "https://discordapp.com/api/v7" + + def __init__(self, method, path, **parameters): + self.path = path + self.method = method + url = self.BASE + self.path + if parameters: + self.url = url.format( + **{k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()} + ) + else: + self.url = url + + # major parameters: + self.channel_id = parameters.get("channel_id") + self.guild_id = parameters.get("guild_id") + + @property + def bucket(self): + # the bucket is just method + path w/ major parameters + return "{0.method}:{0.channel_id}:{0.guild_id}:{0.path}".format(self) + + +class MaybeUnlock: + def __init__(self, lock): + self.lock = lock + self._unlock = True + + def __enter__(self): + return self + + def defer(self): + self._unlock = False + + def __exit__(self, type, value, traceback): + if self._unlock: + self.lock.release() + + +class HTTPClient: + """Represents an HTTP client sending HTTP requests to the Discord API.""" + + SUCCESS_LOG = "{method} {url} has received {text}" + REQUEST_LOG = "{method} {url} with {json} has returned {status}" + + def __init__(self, connector=None, *, proxy=None, proxy_auth=None, loop=None): + self.loop = asyncio.get_event_loop() if loop is None else loop + self.connector = connector + self._session = aiohttp.ClientSession(connector=connector, loop=self.loop) + self._locks = weakref.WeakValueDictionary() + self._global_over = asyncio.Event(loop=self.loop) + self._global_over.set() + self.token = None + self.bot_token = False + self.proxy = proxy + self.proxy_auth = proxy_auth + + user_agent = "DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" + self.user_agent = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + + def recreate(self): + if self._session.closed: + self._session = aiohttp.ClientSession(connector=self.connector, loop=self.loop) + + async def request(self, route, *, header_bypass_delay=None, **kwargs): + bucket = route.bucket + method = route.method + url = route.url + + lock = self._locks.get(bucket) + if lock is None: + lock = asyncio.Lock(loop=self.loop) + if bucket is not None: + self._locks[bucket] = lock + + # header creation + headers = {"User-Agent": self.user_agent} + + if self.token is not None: + headers["Authorization"] = "Bot " + self.token if self.bot_token else self.token + # some checking if it's a JSON request + if "json" in kwargs: + headers["Content-Type"] = "application/json" + kwargs["data"] = utils.to_json(kwargs.pop("json")) + + try: + reason = kwargs.pop("reason") + except KeyError: + pass + else: + if reason: + headers["X-Audit-Log-Reason"] = _uriquote(reason, safe="/ ") + + kwargs["headers"] = headers + + # Proxy support + if self.proxy is not None: + kwargs["proxy"] = self.proxy + if self.proxy_auth is not None: + kwargs["proxy_auth"] = self.proxy_auth + + if not self._global_over.is_set(): + # wait until the global lock is complete + await self._global_over.wait() + + await lock + with MaybeUnlock(lock) as maybe_lock: + for tries in range(5): + async with self._session.request(method, url, **kwargs) as r: + log.debug( + "%s %s with %s has returned %s", method, url, kwargs.get("data"), r.status + ) + + # even errors have text involved in them so this is safe to call + data = await json_or_text(r) + + # check if we have rate limit header information + remaining = r.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and r.status != 429: + # we've depleted our current bucket + if header_bypass_delay is None: + delta = utils._parse_ratelimit_header(r) + else: + delta = header_bypass_delay + + log.debug( + "A rate limit bucket has been exhausted (bucket: %s, retry: %s).", + bucket, + delta, + ) + maybe_lock.defer() + self.loop.call_later(delta, lock.release) + + # the request was successful so just return the text/json + if 300 > r.status >= 200: + log.debug("%s %s has received %s", method, url, data) + return data + + # we are being rate limited + if r.status == 429: + fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' + + # sleep a bit + retry_after = data["retry_after"] / 1000.0 + log.warning(fmt, retry_after, bucket) + + # check if it's a global rate limit + is_global = data.get("global", False) + if is_global: + log.warning( + "Global rate limit has been hit. Retrying in %.2f seconds.", + retry_after, + ) + self._global_over.clear() + + await asyncio.sleep(retry_after, loop=self.loop) + log.debug("Done sleeping for the rate limit. Retrying...") + + # release the global lock now that the + # global rate limit has passed + if is_global: + self._global_over.set() + log.debug("Global rate limit is now over.") + + continue + + # we've received a 500 or 502, unconditional retry + if r.status in {500, 502}: + await asyncio.sleep(1 + tries * 2, loop=self.loop) + continue + + # the usual error cases + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: + raise HTTPException(r, data) + + # We've run out of retries, raise. + raise HTTPException(r, data) + + async def get_attachment(self, url): + async with self._session.get(url) as resp: + if resp.status == 200: + return await resp.read() + elif resp.status == 404: + raise NotFound(resp, "attachment not found") + elif resp.status == 403: + raise Forbidden(resp, "cannot retrieve attachment") + else: + raise HTTPException(resp, "failed to get attachment") + + # state management + + async def close(self): + await self._session.close() + + def _token(self, token, *, bot=True): + self.token = token + self.bot_token = bot + self._ack_token = None + + # login management + + async def static_login(self, token, *, bot): + old_token, old_bot = self.token, self.bot_token + self._token(token, bot=bot) + + try: + data = await self.request(Route("GET", "/users/@me")) + except HTTPException as exc: + self._token(old_token, bot=old_bot) + if exc.response.status == 401: + raise LoginFailure("Improper token has been passed.") from exc + raise + + return data + + def logout(self): + return self.request(Route("POST", "/auth/logout")) + + # Group functionality + + def start_group(self, user_id, recipients): + payload = {"recipients": recipients} + + return self.request( + Route("POST", "/users/{user_id}/channels", user_id=user_id), json=payload + ) + + def leave_group(self, channel_id): + return self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id)) + + def add_group_recipient(self, channel_id, user_id): + r = Route( + "PUT", + "/channels/{channel_id}/recipients/{user_id}", + channel_id=channel_id, + user_id=user_id, + ) + return self.request(r) + + def remove_group_recipient(self, channel_id, user_id): + r = Route( + "DELETE", + "/channels/{channel_id}/recipients/{user_id}", + channel_id=channel_id, + user_id=user_id, + ) + return self.request(r) + + def edit_group(self, channel_id, **options): + valid_keys = ("name", "icon") + payload = {k: v for k, v in options.items() if k in valid_keys} + + return self.request( + Route("PATCH", "/channels/{channel_id}", channel_id=channel_id), json=payload + ) + + def convert_group(self, channel_id): + return self.request(Route("POST", "/channels/{channel_id}/convert", channel_id=channel_id)) + + # Message management + + def start_private_message(self, user_id): + payload = {"recipient_id": user_id} + + return self.request(Route("POST", "/users/@me/channels"), json=payload) + + def send_message(self, channel_id, content, *, tts=False, embed=None, nonce=None): + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) + payload = {} + + if content: + payload["content"] = content + + if tts: + payload["tts"] = True + + if embed: + payload["embed"] = embed + + if nonce: + payload["nonce"] = nonce + + return self.request(r, json=payload) + + def send_typing(self, channel_id): + return self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id)) + + def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None): + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) + form = aiohttp.FormData() + + payload = {"tts": tts} + if content: + payload["content"] = content + if embed: + payload["embed"] = embed + if nonce: + payload["nonce"] = nonce + + form.add_field("payload_json", utils.to_json(payload)) + if len(files) == 1: + fp = files[0] + form.add_field("file", fp[0], filename=fp[1], content_type="application/octet-stream") + else: + for index, (buffer, filename) in enumerate(files): + form.add_field( + "file%s" % index, + buffer, + filename=filename, + content_type="application/octet-stream", + ) + + return self.request(r, data=form) + + async def ack_message(self, channel_id, message_id): + r = Route( + "POST", + "/channels/{channel_id}/messages/{message_id}/ack", + channel_id=channel_id, + message_id=message_id, + ) + data = await self.request(r, json={"token": self._ack_token}) + self._ack_token = data["token"] + + def ack_guild(self, guild_id): + return self.request(Route("POST", "/guilds/{guild_id}/ack", guild_id=guild_id)) + + def delete_message(self, channel_id, message_id, *, reason=None): + r = Route( + "DELETE", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) + return self.request(r, reason=reason) + + def delete_messages(self, channel_id, message_ids, *, reason=None): + r = Route("POST", "/channels/{channel_id}/messages/bulk_delete", channel_id=channel_id) + payload = {"messages": message_ids} + + return self.request(r, json=payload, reason=reason) + + def edit_message(self, message_id, channel_id, **fields): + r = Route( + "PATCH", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) + return self.request(r, json=fields) + + def add_reaction(self, message_id, channel_id, emoji): + r = Route( + "PUT", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", + channel_id=channel_id, + message_id=message_id, + emoji=emoji, + ) + return self.request(r, header_bypass_delay=0.25) + + def remove_reaction(self, message_id, channel_id, emoji, member_id): + r = Route( + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}", + channel_id=channel_id, + message_id=message_id, + member_id=member_id, + emoji=emoji, + ) + return self.request(r, header_bypass_delay=0.25) + + def remove_own_reaction(self, message_id, channel_id, emoji): + r = Route( + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", + channel_id=channel_id, + message_id=message_id, + emoji=emoji, + ) + return self.request(r, header_bypass_delay=0.25) + + def get_reaction_users(self, message_id, channel_id, emoji, limit, after=None): + r = Route( + "GET", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", + channel_id=channel_id, + message_id=message_id, + emoji=emoji, + ) + + params = {"limit": limit} + if after: + params["after"] = after + return self.request(r, params=params) + + def clear_reactions(self, message_id, channel_id): + r = Route( + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions", + channel_id=channel_id, + message_id=message_id, + ) + + return self.request(r) + + def get_message(self, channel_id, message_id): + r = Route( + "GET", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) + return self.request(r) + + def logs_from(self, channel_id, limit, before=None, after=None, around=None): + params = {"limit": limit} + + if before: + params["before"] = before + if after: + params["after"] = after + if around: + params["around"] = around + + return self.request( + Route("GET", "/channels/{channel_id}/messages", channel_id=channel_id), params=params + ) + + def pin_message(self, channel_id, message_id): + return self.request( + Route( + "PUT", + "/channels/{channel_id}/pins/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) + ) + + def unpin_message(self, channel_id, message_id): + return self.request( + Route( + "DELETE", + "/channels/{channel_id}/pins/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) + ) + + def pins_from(self, channel_id): + return self.request(Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id)) + + # Member management + + def kick(self, user_id, guild_id, reason=None): + r = Route( + "DELETE", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id + ) + if reason: + # thanks aiohttp + r.url = "{0.url}?reason={1}".format(r, _uriquote(reason)) + + return self.request(r) + + def ban(self, user_id, guild_id, delete_message_days=1, reason=None): + r = Route("PUT", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id) + params = {"delete-message-days": delete_message_days} + + if reason: + # thanks aiohttp + r.url = "{0.url}?reason={1}".format(r, _uriquote(reason)) + + return self.request(r, params=params) + + def unban(self, user_id, guild_id, *, reason=None): + r = Route( + "DELETE", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id + ) + return self.request(r, reason=reason) + + def guild_voice_state(self, user_id, guild_id, *, mute=None, deafen=None, reason=None): + r = Route( + "PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id + ) + payload = {} + if mute is not None: + payload["mute"] = mute + + if deafen is not None: + payload["deaf"] = deafen + + return self.request(r, json=payload, reason=reason) + + def edit_profile(self, password, username, avatar, **fields): + payload = {"password": password, "username": username, "avatar": avatar} + + if "email" in fields: + payload["email"] = fields["email"] + + if "new_password" in fields: + payload["new_password"] = fields["new_password"] + + return self.request(Route("PATCH", "/users/@me"), json=payload) + + def change_my_nickname(self, guild_id, nickname, *, reason=None): + r = Route("PATCH", "/guilds/{guild_id}/members/@me/nick", guild_id=guild_id) + payload = {"nick": nickname} + return self.request(r, json=payload, reason=reason) + + def change_nickname(self, guild_id, user_id, nickname, *, reason=None): + r = Route( + "PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id + ) + payload = {"nick": nickname} + return self.request(r, json=payload, reason=reason) + + def edit_member(self, guild_id, user_id, *, reason=None, **fields): + r = Route( + "PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id + ) + return self.request(r, json=fields, reason=reason) + + # Channel management + + def edit_channel(self, channel_id, *, reason=None, **options): + r = Route("PATCH", "/channels/{channel_id}", channel_id=channel_id) + valid_keys = ( + "name", + "parent_id", + "topic", + "bitrate", + "nsfw", + "user_limit", + "position", + "permission_overwrites", + "rate_limit_per_user", + ) + payload = {k: v for k, v in options.items() if k in valid_keys} + + return self.request(r, reason=reason, json=payload) + + def bulk_channel_update(self, guild_id, data, *, reason=None): + r = Route("PATCH", "/guilds/{guild_id}/channels", guild_id=guild_id) + return self.request(r, json=data, reason=reason) + + def create_channel( + self, + guild_id, + name, + channel_type, + parent_id=None, + permission_overwrites=None, + *, + reason=None + ): + payload = {"name": name, "type": channel_type} + + if permission_overwrites is not None: + payload["permission_overwrites"] = permission_overwrites + + if parent_id is not None: + payload["parent_id"] = parent_id + + return self.request( + Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), + json=payload, + reason=reason, + ) + + def delete_channel(self, channel_id, *, reason=None): + return self.request( + Route("DELETE", "/channels/{channel_id}", channel_id=channel_id), reason=reason + ) + + # Webhook management + + def create_webhook(self, channel_id, *, name, avatar=None): + payload = {"name": name} + if avatar is not None: + payload["avatar"] = avatar + + return self.request( + Route("POST", "/channels/{channel_id}/webhooks", channel_id=channel_id), json=payload + ) + + def channel_webhooks(self, channel_id): + return self.request(Route("GET", "/channels/{channel_id}/webhooks", channel_id=channel_id)) + + def guild_webhooks(self, guild_id): + return self.request(Route("GET", "/guilds/{guild_id}/webhooks", guild_id=guild_id)) + + def get_webhook(self, webhook_id): + return self.request(Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)) + + # Guild management + + def leave_guild(self, guild_id): + return self.request(Route("DELETE", "/users/@me/guilds/{guild_id}", guild_id=guild_id)) + + def delete_guild(self, guild_id): + return self.request(Route("DELETE", "/guilds/{guild_id}", guild_id=guild_id)) + + def create_guild(self, name, region, icon): + payload = {"name": name, "icon": icon, "region": region} + + return self.request(Route("POST", "/guilds"), json=payload) + + def edit_guild(self, guild_id, *, reason=None, **fields): + valid_keys = ( + "name", + "region", + "icon", + "afk_timeout", + "owner_id", + "afk_channel_id", + "splash", + "verification_level", + "system_channel_id", + "default_message_notifications", + "explicit_content_filter", + ) + + payload = {k: v for k, v in fields.items() if k in valid_keys} + + return self.request( + Route("PATCH", "/guilds/{guild_id}", guild_id=guild_id), json=payload, reason=reason + ) + + def get_bans(self, guild_id): + return self.request(Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id)) + + def get_ban(self, user_id, guild_id): + return self.request( + Route("GET", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id) + ) + + def get_vanity_code(self, guild_id): + return self.request(Route("GET", "/guilds/{guild_id}/vanity-url", guild_id=guild_id)) + + def change_vanity_code(self, guild_id, code, *, reason=None): + payload = {"code": code} + return self.request( + Route("PATCH", "/guilds/{guild_id}/vanity-url", guild_id=guild_id), + json=payload, + reason=reason, + ) + + def prune_members(self, guild_id, days, *, reason=None): + params = {"days": days} + return self.request( + Route("POST", "/guilds/{guild_id}/prune", guild_id=guild_id), + params=params, + reason=reason, + ) + + def estimate_pruned_members(self, guild_id, days): + params = {"days": days} + return self.request( + Route("GET", "/guilds/{guild_id}/prune", guild_id=guild_id), params=params + ) + + def create_custom_emoji(self, guild_id, name, image, *, roles=None, reason=None): + payload = {"name": name, "image": image, "roles": roles or []} + + r = Route("POST", "/guilds/{guild_id}/emojis", guild_id=guild_id) + return self.request(r, json=payload, reason=reason) + + def delete_custom_emoji(self, guild_id, emoji_id, *, reason=None): + r = Route( + "DELETE", "/guilds/{guild_id}/emojis/{emoji_id}", guild_id=guild_id, emoji_id=emoji_id + ) + return self.request(r, reason=reason) + + def edit_custom_emoji(self, guild_id, emoji_id, *, name, roles=None, reason=None): + payload = {"name": name, "roles": roles or []} + r = Route( + "PATCH", "/guilds/{guild_id}/emojis/{emoji_id}", guild_id=guild_id, emoji_id=emoji_id + ) + return self.request(r, json=payload, reason=reason) + + def get_audit_logs( + self, guild_id, limit=100, before=None, after=None, user_id=None, action_type=None + ): + params = {"limit": limit} + if before: + params["before"] = before + if after: + params["after"] = after + if user_id: + params["user_id"] = user_id + if action_type: + params["action_type"] = action_type + + r = Route("GET", "/guilds/{guild_id}/audit-logs", guild_id=guild_id) + return self.request(r, params=params) + + # Invite management + + def create_invite(self, channel_id, *, reason=None, **options): + r = Route("POST", "/channels/{channel_id}/invites", channel_id=channel_id) + payload = { + "max_age": options.get("max_age", 0), + "max_uses": options.get("max_uses", 0), + "temporary": options.get("temporary", False), + "unique": options.get("unique", True), + } + + return self.request(r, reason=reason, json=payload) + + def get_invite(self, invite_id): + return self.request(Route("GET", "/invite/{invite_id}", invite_id=invite_id)) + + def invites_from(self, guild_id): + return self.request(Route("GET", "/guilds/{guild_id}/invites", guild_id=guild_id)) + + def invites_from_channel(self, channel_id): + return self.request(Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id)) + + def delete_invite(self, invite_id, *, reason=None): + return self.request( + Route("DELETE", "/invite/{invite_id}", invite_id=invite_id), reason=reason + ) + + # Role management + + def edit_role(self, guild_id, role_id, *, reason=None, **fields): + r = Route( + "PATCH", "/guilds/{guild_id}/roles/{role_id}", guild_id=guild_id, role_id=role_id + ) + valid_keys = ("name", "permissions", "color", "hoist", "mentionable") + payload = {k: v for k, v in fields.items() if k in valid_keys} + return self.request(r, json=payload, reason=reason) + + def delete_role(self, guild_id, role_id, *, reason=None): + r = Route( + "DELETE", "/guilds/{guild_id}/roles/{role_id}", guild_id=guild_id, role_id=role_id + ) + return self.request(r, reason=reason) + + def replace_roles(self, user_id, guild_id, role_ids, *, reason=None): + return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason) + + def create_role(self, guild_id, *, reason=None, **fields): + r = Route("POST", "/guilds/{guild_id}/roles", guild_id=guild_id) + return self.request(r, json=fields, reason=reason) + + def move_role_position(self, guild_id, positions, *, reason=None): + r = Route("PATCH", "/guilds/{guild_id}/roles", guild_id=guild_id) + return self.request(r, json=positions, reason=reason) + + def add_role(self, guild_id, user_id, role_id, *, reason=None): + r = Route( + "PUT", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", + guild_id=guild_id, + user_id=user_id, + role_id=role_id, + ) + return self.request(r, reason=reason) + + def remove_role(self, guild_id, user_id, role_id, *, reason=None): + r = Route( + "DELETE", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", + guild_id=guild_id, + user_id=user_id, + role_id=role_id, + ) + return self.request(r, reason=reason) + + def edit_channel_permissions(self, channel_id, target, allow, deny, type, *, reason=None): + payload = {"id": target, "allow": allow, "deny": deny, "type": type} + r = Route( + "PUT", + "/channels/{channel_id}/permissions/{target}", + channel_id=channel_id, + target=target, + ) + return self.request(r, json=payload, reason=reason) + + def delete_channel_permissions(self, channel_id, target, *, reason=None): + r = Route( + "DELETE", + "/channels/{channel_id}/permissions/{target}", + channel_id=channel_id, + target=target, + ) + return self.request(r, reason=reason) + + # Voice management + + def move_member(self, user_id, guild_id, channel_id, *, reason=None): + return self.edit_member( + guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason + ) + + # Relationship related + + def remove_relationship(self, user_id): + r = Route("DELETE", "/users/@me/relationships/{user_id}", user_id=user_id) + return self.request(r) + + def add_relationship(self, user_id, type=None): + r = Route("PUT", "/users/@me/relationships/{user_id}", user_id=user_id) + payload = {} + if type is not None: + payload["type"] = type + + return self.request(r, json=payload) + + def send_friend_request(self, username, discriminator): + r = Route("POST", "/users/@me/relationships") + payload = {"username": username, "discriminator": int(discriminator)} + return self.request(r, json=payload) + + # Misc + + def application_info(self): + return self.request(Route("GET", "/oauth2/applications/@me")) + + async def get_gateway(self, *, encoding="json", v=6, zlib=True): + try: + data = await self.request(Route("GET", "/gateway")) + except HTTPException as exc: + raise GatewayNotFound() from exc + if zlib: + value = "{0}?encoding={1}&v={2}&compress=zlib-stream" + else: + value = "{0}?encoding={1}&v={2}" + return value.format(data["url"], encoding, v) + + async def get_bot_gateway(self, *, encoding="json", v=6, zlib=True): + try: + data = await self.request(Route("GET", "/gateway/bot")) + except HTTPException as exc: + raise GatewayNotFound() from exc + + if zlib: + value = "{0}?encoding={1}&v={2}&compress=zlib-stream" + else: + value = "{0}?encoding={1}&v={2}" + return data["shards"], value.format(data["url"], encoding, v) + + def get_user_info(self, user_id): + return self.request(Route("GET", "/users/{user_id}", user_id=user_id)) + + def get_user_profile(self, user_id): + return self.request(Route("GET", "/users/{user_id}/profile", user_id=user_id)) + + def get_mutual_friends(self, user_id): + return self.request(Route("GET", "/users/{user_id}/relationships", user_id=user_id)) + + def change_hypesquad_house(self, house_id): + payload = {"house_id": house_id} + return self.request(Route("POST", "/hypesquad/online"), json=payload) + + def leave_hypesquad_house(self): + return self.request(Route("DELETE", "/hypesquad/online")) diff --git a/discord/invite.py b/discord/invite.py new file mode 100644 index 000000000..7248f7003 --- /dev/null +++ b/discord/invite.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .utils import parse_time +from .mixins import Hashable +from .object import Object + + +class Invite(Hashable): + """Represents a Discord :class:`Guild` or :class:`abc.GuildChannel` invite. + + Depending on the way this object was created, some of the attributes can + have a value of ``None``. + + .. container:: operations + + .. describe:: x == y + + Checks if two invites are equal. + + .. describe:: x != y + + Checks if two invites are not equal. + + .. describe:: hash(x) + + Returns the invite hash. + + .. describe:: str(x) + + Returns the invite URL. + + Attributes + ----------- + max_age: :class:`int` + How long the before the invite expires in seconds. A value of 0 indicates that it doesn't expire. + code: :class:`str` + The URL fragment used for the invite. + guild: :class:`Guild` + The guild the invite is for. + revoked: :class:`bool` + Indicates if the invite has been revoked. + created_at: `datetime.datetime` + A datetime object denoting the time the invite was created. + temporary: :class:`bool` + Indicates that the invite grants temporary membership. + If True, members who joined via this invite will be kicked upon disconnect. + uses: :class:`int` + How many times the invite has been used. + max_uses: :class:`int` + How many times the invite can be used. + inviter: :class:`User` + The user who created the invite. + channel: :class:`abc.GuildChannel` + The channel the invite is for. + """ + + __slots__ = ( + "max_age", + "code", + "guild", + "revoked", + "created_at", + "uses", + "temporary", + "max_uses", + "inviter", + "channel", + "_state", + ) + + def __init__(self, *, state, data): + self._state = state + self.max_age = data.get("max_age") + self.code = data.get("code") + self.guild = data.get("guild") + self.revoked = data.get("revoked") + self.created_at = parse_time(data.get("created_at")) + self.temporary = data.get("temporary") + self.uses = data.get("uses") + self.max_uses = data.get("max_uses") + + inviter_data = data.get("inviter") + self.inviter = None if inviter_data is None else self._state.store_user(inviter_data) + self.channel = data.get("channel") + + @classmethod + def from_incomplete(cls, *, state, data): + guild_id = int(data["guild"]["id"]) + channel_id = int(data["channel"]["id"]) + guild = state._get_guild(guild_id) + if guild is not None: + channel = guild.get_channel(channel_id) + else: + guild = Object(id=guild_id) + channel = Object(id=channel_id) + guild.name = data["guild"]["name"] + + guild.splash = data["guild"]["splash"] + guild.splash_url = "" + if guild.splash: + guild.splash_url = "https://cdn.discordapp.com/splashes/{0.id}/{0.splash}.jpg?size=2048".format( + guild + ) + + channel.name = data["channel"]["name"] + + data["guild"] = guild + data["channel"] = channel + return cls(state=state, data=data) + + def __str__(self): + return self.url + + def __repr__(self): + return "".format(self) + + def __hash__(self): + return hash(self.code) + + @property + def id(self): + """Returns the proper code portion of the invite.""" + return self.code + + @property + def url(self): + """A property that retrieves the invite URL.""" + return "http://discord.gg/" + self.code + + async def delete(self, *, reason=None): + """|coro| + + Revokes the instant invite. + + You must have the :attr:`~Permissions.manage_channels` permission to do this. + + Parameters + ----------- + reason: Optional[str] + The reason for deleting this invite. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to revoke invites. + NotFound + The invite is invalid or expired. + HTTPException + Revoking the invite failed. + """ + + await self._state.http.delete_invite(self.code, reason=reason) diff --git a/discord/iterators.py b/discord/iterators.py new file mode 100644 index 000000000..acb86ab70 --- /dev/null +++ b/discord/iterators.py @@ -0,0 +1,489 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import datetime + +from .errors import NoMoreItems +from .utils import time_snowflake, maybe_coroutine +from .object import Object +from .audit_logs import AuditLogEntry + + +class _AsyncIterator: + __slots__ = () + + def get(self, **attrs): + def predicate(elem): + for attr, val in attrs.items(): + nested = attr.split("__") + obj = elem + for attribute in nested: + obj = getattr(obj, attribute) + + if obj != val: + return False + return True + + return self.find(predicate) + + async def find(self, predicate): + while True: + try: + elem = await self.next() + except NoMoreItems: + return None + + ret = await maybe_coroutine(predicate, elem) + if ret: + return elem + + def map(self, func): + return _MappedAsyncIterator(self, func) + + def filter(self, predicate): + return _FilteredAsyncIterator(self, predicate) + + async def flatten(self): + ret = [] + while True: + try: + item = await self.next() + except NoMoreItems: + return ret + else: + ret.append(item) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + msg = await self.next() + except NoMoreItems: + raise StopAsyncIteration() + else: + return msg + + +def _identity(x): + return x + + +class _MappedAsyncIterator(_AsyncIterator): + def __init__(self, iterator, func): + self.iterator = iterator + self.func = func + + async def next(self): + # this raises NoMoreItems and will propagate appropriately + item = await self.iterator.next() + return await maybe_coroutine(self.func, item) + + +class _FilteredAsyncIterator(_AsyncIterator): + def __init__(self, iterator, predicate): + self.iterator = iterator + + if predicate is None: + predicate = _identity + + self.predicate = predicate + + async def next(self): + getter = self.iterator.next + pred = self.predicate + while True: + # propagate NoMoreItems similar to _MappedAsyncIterator + item = await getter() + ret = await maybe_coroutine(pred, item) + if ret: + return item + + +class ReactionIterator(_AsyncIterator): + def __init__(self, message, emoji, limit=100, after=None): + self.message = message + self.limit = limit + self.after = after + state = message._state + self.getter = state.http.get_reaction_users + self.state = state + self.emoji = emoji + self.guild = message.guild + self.channel_id = message.channel.id + self.users = asyncio.Queue(loop=state.loop) + + async def next(self): + if self.users.empty(): + await self.fill_users() + + try: + return self.users.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + async def fill_users(self): + # this is a hack because >circular imports< + from .user import User + + if self.limit > 0: + retrieve = self.limit if self.limit <= 100 else 100 + + after = self.after.id if self.after else None + data = await self.getter( + self.message.id, self.channel_id, self.emoji, retrieve, after=after + ) + + if data: + self.limit -= retrieve + self.after = Object(id=int(data[0]["id"])) + + if self.guild is None: + for element in reversed(data): + await self.users.put(User(state=self.state, data=element)) + else: + for element in reversed(data): + member_id = int(element["id"]) + member = self.guild.get_member(member_id) + if member is not None: + await self.users.put(member) + else: + await self.users.put(User(state=self.state, data=element)) + + +class HistoryIterator(_AsyncIterator): + """Iterator for receiving a channel's message history. + + The messages endpoint has two behaviours we care about here: + If `before` is specified, the messages endpoint returns the `limit` + newest messages before `before`, sorted with newest first. For filling over + 100 messages, update the `before` parameter to the oldest message received. + Messages will be returned in order by time. + If `after` is specified, it returns the `limit` oldest messages after + `after`, sorted with newest first. For filling over 100 messages, update the + `after` parameter to the newest message received. If messages are not + reversed, they will be out of order (99-0, 199-100, so on) + + A note that if both before and after are specified, before is ignored by the + messages endpoint. + + Parameters + ----------- + messageable: :class:`abc.Messageable` + Messageable class to retrieve message history fro. + limit : int + Maximum number of messages to retrieve + before : :class:`Message` or id-like + Message before which all messages must be. + after : :class:`Message` or id-like + Message after which all messages must be. + around : :class:`Message` or id-like + Message around which all messages must be. Limit max 101. Note that if + limit is an even number, this will return at most limit+1 messages. + reverse: bool + If set to true, return messages in oldest->newest order. Recommended + when using with "after" queries with limit over 100, otherwise messages + will be out of order. + """ + + def __init__(self, messageable, limit, before=None, after=None, around=None, reverse=None): + + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + if isinstance(around, datetime.datetime): + around = Object(id=time_snowflake(around)) + + self.messageable = messageable + self.limit = limit + self.before = before + self.after = after + self.around = around + + if reverse is None: + self.reverse = after is not None + else: + self.reverse = reverse + + self._filter = None # message dict -> bool + + self.state = self.messageable._state + self.logs_from = self.state.http.logs_from + self.messages = asyncio.Queue(loop=self.state.loop) + + if self.around: + if self.limit is None: + raise ValueError("history does not support around with limit=None") + if self.limit > 101: + raise ValueError("history max limit 101 when specifying around parameter") + elif self.limit == 101: + self.limit = 100 # Thanks discord + elif self.limit == 1: + raise ValueError("Use get_message.") + + self._retrieve_messages = self._retrieve_messages_around_strategy + if self.before and self.after: + self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id + elif self.before: + self._filter = lambda m: int(m["id"]) < self.before.id + elif self.after: + self._filter = lambda m: self.after.id < int(m["id"]) + elif self.before and self.after: + if self.reverse: + self._retrieve_messages = self._retrieve_messages_after_strategy + self._filter = lambda m: int(m["id"]) < self.before.id + else: + self._retrieve_messages = self._retrieve_messages_before_strategy + self._filter = lambda m: int(m["id"]) > self.after.id + elif self.after: + self._retrieve_messages = self._retrieve_messages_after_strategy + else: + self._retrieve_messages = self._retrieve_messages_before_strategy + + async def next(self): + if self.messages.empty(): + await self.fill_messages() + + try: + return self.messages.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + def _get_retrieve(self): + l = self.limit + if l is None: + r = 100 + elif l <= 100: + r = l + else: + r = 100 + + self.retrieve = r + return r > 0 + + async def flatten(self): + # this is similar to fill_messages except it uses a list instead + # of a queue to place the messages in. + result = [] + channel = await self.messageable._get_channel() + self.channel = channel + while self._get_retrieve(): + data = await self._retrieve_messages(self.retrieve) + if len(data) < 100: + self.limit = 0 # terminate the infinite loop + + if self.reverse: + data = reversed(data) + if self._filter: + data = filter(self._filter, data) + + for element in data: + result.append(self.state.create_message(channel=channel, data=element)) + return result + + async def fill_messages(self): + if not hasattr(self, "channel"): + # do the required set up + channel = await self.messageable._get_channel() + self.channel = channel + + if self._get_retrieve(): + data = await self._retrieve_messages(self.retrieve) + if self.limit is None and len(data) < 100: + self.limit = 0 # terminate the infinite loop + + if self.reverse: + data = reversed(data) + if self._filter: + data = filter(self._filter, data) + + channel = self.channel + for element in data: + await self.messages.put(self.state.create_message(channel=channel, data=element)) + + async def _retrieve_messages(self, retrieve): + """Retrieve messages and update next parameters.""" + pass + + async def _retrieve_messages_before_strategy(self, retrieve): + """Retrieve messages using before parameter.""" + before = self.before.id if self.before else None + data = await self.logs_from(self.channel.id, retrieve, before=before) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.before = Object(id=int(data[-1]["id"])) + return data + + async def _retrieve_messages_after_strategy(self, retrieve): + """Retrieve messages using after parameter.""" + after = self.after.id if self.after else None + data = await self.logs_from(self.channel.id, retrieve, after=after) + if len(data): + if self.limit is not None: + self.limit -= retrieve + self.after = Object(id=int(data[0]["id"])) + return data + + async def _retrieve_messages_around_strategy(self, retrieve): + """Retrieve messages using around parameter.""" + if self.around: + around = self.around.id if self.around else None + data = await self.logs_from(self.channel.id, retrieve, around=around) + self.around = None + return data + return [] + + +class AuditLogIterator(_AsyncIterator): + def __init__( + self, + guild, + limit=None, + before=None, + after=None, + reverse=None, + user_id=None, + action_type=None, + ): + if isinstance(before, datetime.datetime): + before = Object(id=time_snowflake(before, high=False)) + if isinstance(after, datetime.datetime): + after = Object(id=time_snowflake(after, high=True)) + + self.guild = guild + self.loop = guild._state.loop + self.request = guild._state.http.get_audit_logs + self.limit = limit + self.before = before + self.user_id = user_id + self.action_type = action_type + self.after = after + self._users = {} + self._state = guild._state + + if reverse is None: + self.reverse = after is not None + else: + self.reverse = reverse + + self._filter = None # entry dict -> bool + + self.entries = asyncio.Queue(loop=self.loop) + + if self.before and self.after: + if self.reverse: + self._strategy = self._after_strategy + self._filter = lambda m: int(m["id"]) < self.before.id + else: + self._strategy = self._before_strategy + self._filter = lambda m: int(m["id"]) > self.after.id + elif self.after: + self._strategy = self._after_strategy + else: + self._strategy = self._before_strategy + + async def _before_strategy(self, retrieve): + before = self.before.id if self.before else None + data = await self.request( + self.guild.id, + limit=retrieve, + user_id=self.user_id, + action_type=self.action_type, + before=before, + ) + + entries = data.get("audit_log_entries", []) + if len(data) and entries: + if self.limit is not None: + self.limit -= retrieve + self.before = Object(id=int(entries[-1]["id"])) + return data.get("users", []), entries + + async def _after_strategy(self, retrieve): + after = self.after.id if self.after else None + data = await self.request( + self.guild.id, + limit=retrieve, + user_id=self.user_id, + action_type=self.action_type, + after=after, + ) + entries = data.get("audit_log_entries", []) + if len(data) and entries: + if self.limit is not None: + self.limit -= retrieve + self.after = Object(id=int(entries[0]["id"])) + return data.get("users", []), entries + + async def next(self): + if self.entries.empty(): + await self._fill() + + try: + return self.entries.get_nowait() + except asyncio.QueueEmpty: + raise NoMoreItems() + + def _get_retrieve(self): + l = self.limit + if l is None: + r = 100 + elif l <= 100: + r = l + else: + r = 100 + + self.retrieve = r + return r > 0 + + async def _fill(self): + from .user import User + + if self._get_retrieve(): + users, data = await self._strategy(self.retrieve) + if self.limit is None and len(data) < 100: + self.limit = 0 # terminate the infinite loop + + if self.reverse: + data = reversed(data) + if self._filter: + data = filter(self._filter, data) + + for user in users: + u = User(data=user, state=self._state) + self._users[u.id] = u + + for element in data: + # TODO: remove this if statement later + if element["action_type"] is None: + continue + + await self.entries.put( + AuditLogEntry(data=element, users=self._users, guild=self.guild) + ) diff --git a/discord/member.py b/discord/member.py new file mode 100644 index 000000000..b1f6cce3c --- /dev/null +++ b/discord/member.py @@ -0,0 +1,621 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import itertools + +import discord.abc + +from . import utils +from .user import BaseUser, User +from .activity import create_activity +from .permissions import Permissions +from .enums import Status, try_enum +from .colour import Colour +from .object import Object + + +class VoiceState: + """Represents a Discord user's voice state. + + Attributes + ------------ + deaf: :class:`bool` + Indicates if the user is currently deafened by the guild. + mute: :class:`bool` + Indicates if the user is currently muted by the guild. + self_mute: :class:`bool` + Indicates if the user is currently muted by their own accord. + self_deaf: :class:`bool` + Indicates if the user is currently deafened by their own accord. + afk: :class:`bool` + Indicates if the user is currently in the AFK channel in the guild. + channel: :class:`VoiceChannel` + The voice channel that the user is currently connected to. None if the user + is not currently in a voice channel. + """ + + __slots__ = ("session_id", "deaf", "mute", "self_mute", "self_deaf", "afk", "channel") + + def __init__(self, *, data, channel=None): + self.session_id = data.get("session_id") + self._update(data, channel) + + def _update(self, data, channel): + self.self_mute = data.get("self_mute", False) + self.self_deaf = data.get("self_deaf", False) + self.afk = data.get("suppress", False) + self.mute = data.get("mute", False) + self.deaf = data.get("deaf", False) + self.channel = channel + + def __repr__(self): + return "".format( + self + ) + + +def flatten_user(cls): + for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): + # ignore private/special methods + if attr.startswith("_"): + continue + + # don't override what we already have + if attr in cls.__dict__: + continue + + # if it's a slotted attribute or a property, redirect it + # slotted members are implemented as member_descriptors in Type.__dict__ + if not hasattr(value, "__annotations__"): + + def getter(self, x=attr): + return getattr(self._user, x) + + setattr(cls, attr, property(getter, doc="Equivalent to :attr:`User.%s`" % attr)) + else: + # probably a member function by now + def generate_function(x): + def general(self, *args, **kwargs): + return getattr(self._user, x)(*args, **kwargs) + + general.__name__ = x + return general + + func = generate_function(attr) + func.__doc__ = value.__doc__ + setattr(cls, attr, func) + + return cls + + +_BaseUser = discord.abc.User + + +@flatten_user +class Member(discord.abc.Messageable, _BaseUser): + """Represents a Discord member to a :class:`Guild`. + + This implements a lot of the functionality of :class:`User`. + + .. container:: operations + + .. describe:: x == y + + Checks if two members are equal. + Note that this works with :class:`User` instances too. + + .. describe:: x != y + + Checks if two members are not equal. + Note that this works with :class:`User` instances too. + + .. describe:: hash(x) + + Returns the member's hash. + + .. describe:: str(x) + + Returns the member's name with the discriminator. + + Attributes + ---------- + joined_at: `datetime.datetime` + A datetime object that specifies the date and time in UTC that the member joined the guild for + the first time. + activities: Tuple[Union[:class:`Game`, :class:`Streaming`, :class:`Spotify`, :class:`Activity`]] + The activities that the user is currently doing. + guild: :class:`Guild` + The guild that the member belongs to. + nick: Optional[:class:`str`] + The guild specific nickname of the user. + """ + + __slots__ = ( + "_roles", + "joined_at", + "_client_status", + "activities", + "guild", + "nick", + "_user", + "_state", + ) + + def __init__(self, *, data, guild, state): + self._state = state + self._user = state.store_user(data["user"]) + self.guild = guild + self.joined_at = utils.parse_time(data.get("joined_at")) + self._update_roles(data) + self._client_status = {None: Status.offline} + self.activities = tuple(map(create_activity, data.get("activities", []))) + self.nick = data.get("nick", None) + + def __str__(self): + return str(self._user) + + def __repr__(self): + return ( + "".format(self, self._user) + ) + + def __eq__(self, other): + return isinstance(other, _BaseUser) and other.id == self.id + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._user) + + @classmethod + def _copy(cls, member): + self = cls.__new__(cls) # to bypass __init__ + + self._roles = utils.SnowflakeList(member._roles, is_sorted=True) + self.joined_at = member.joined_at + self._client_status = member._client_status.copy() + self.guild = member.guild + self.nick = member.nick + self.activities = member.activities + self._state = member._state + self._user = User._copy(member._user) + return self + + async def _get_channel(self): + ch = await self.create_dm() + return ch + + def _update_roles(self, data): + self._roles = utils.SnowflakeList(map(int, data["roles"])) + + def _update(self, data, user=None): + if user: + self._user.name = user["username"] + self._user.discriminator = user["discriminator"] + self._user.avatar = user["avatar"] + self._user.bot = user.get("bot", False) + + # the nickname change is optional, + # if it isn't in the payload then it didn't change + try: + self.nick = data["nick"] + except KeyError: + pass + + self._update_roles(data) + + def _presence_update(self, data, user): + self.activities = tuple(map(create_activity, data.get("activities", []))) + self._client_status = {key: value for key, value in data.get("client_status", {}).items()} + self._client_status[None] = data["status"] + + if len(user) > 1: + u = self._user + u.name = user.get("username", u.name) + u.avatar = user.get("avatar", u.avatar) + u.discriminator = user.get("discriminator", u.discriminator) + + @property + def status(self): + """:class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead.""" + return try_enum(Status, self._client_status[None]) + + @status.setter + def status(self, value): + # internal use only + self._client_status[None] = str(value) + + @property + def mobile_status(self): + """:class:`Status`: The member's status on a mobile device, if applicable.""" + return try_enum(Status, self._client_status.get("mobile", "offline")) + + @property + def desktop_status(self): + """:class:`Status`: The member's status on the desktop client, if applicable.""" + return try_enum(Status, self._client_status.get("desktop", "offline")) + + @property + def web_status(self): + """:class:`Status`: The member's status on the web client, if applicable.""" + return try_enum(Status, self._client_status.get("web", "offline")) + + def is_on_mobile(self): + """:class:`bool`: A helper function that determines if a member is active on a mobile device.""" + return "mobile" in self._client_status + + @property + def colour(self): + """A property that returns a :class:`Colour` denoting the rendered colour + for the member. If the default colour is the one rendered then an instance + of :meth:`Colour.default` is returned. + + There is an alias for this under ``color``. + """ + + roles = self.roles[1:] # remove @everyone + + # highest order of the colour is the one that gets rendered. + # if the highest is the default colour then the next one with a colour + # is chosen instead + for role in reversed(roles): + if role.colour.value: + return role.colour + return Colour.default() + + color = colour + + @property + def roles(self): + """A :class:`list` of :class:`Role` that the member belongs to. Note + that the first element of this list is always the default '@everyone' + role. + + These roles are sorted by their position in the role hierarchy. + """ + result = [] + g = self.guild + for role_id in self._roles: + role = g.get_role(role_id) + if role: + result.append(role) + result.append(g.default_role) + result.sort() + return result + + @property + def mention(self): + """Returns a string that mentions the member.""" + if self.nick: + return "<@!%s>" % self.id + return "<@%s>" % self.id + + @property + def display_name(self): + """Returns the user's display name. + + For regular users this is just their username, but + if they have a guild specific nickname then that + is returned instead. + """ + return self.nick if self.nick is not None else self.name + + @property + def activity(self): + """Returns a class Union[:class:`Game`, :class:`Streaming`, :class:`Spotify`, :class:`Activity`] for the primary + activity the user is currently doing. Could be None if no activity is being done. + + .. note:: + + A user may have multiple activities, these can be accessed under :attr:`activities`. + """ + if self.activities: + return self.activities[0] + + def mentioned_in(self, message): + """Checks if the member is mentioned in the specified message. + + Parameters + ----------- + message: :class:`Message` + The message to check if you're mentioned in. + """ + if self._user.mentioned_in(message): + return True + + for role in message.role_mentions: + has_role = utils.get(self.roles, id=role.id) is not None + if has_role: + return True + + return False + + def permissions_in(self, channel): + """An alias for :meth:`abc.GuildChannel.permissions_for`. + + Basically equivalent to: + + .. code-block:: python3 + + channel.permissions_for(self) + + Parameters + ----------- + channel + The channel to check your permissions for. + """ + return channel.permissions_for(self) + + @property + def top_role(self): + """Returns the member's highest role. + + This is useful for figuring where a member stands in the role + hierarchy chain. + """ + return self.roles[-1] + + @property + def guild_permissions(self): + """Returns the member's guild permissions. + + This only takes into consideration the guild permissions + and not most of the implied permissions or any of the + channel permission overwrites. For 100% accurate permission + calculation, please use either :meth:`permissions_in` or + :meth:`abc.GuildChannel.permissions_for`. + + This does take into consideration guild ownership and the + administrator implication. + """ + + if self.guild.owner == self: + return Permissions.all() + + base = Permissions.none() + for r in self.roles: + base.value |= r.permissions.value + + if base.administrator: + return Permissions.all() + + return base + + @property + def voice(self): + """Optional[:class:`VoiceState`]: Returns the member's current voice state.""" + return self.guild._voice_state_for(self._user.id) + + async def ban(self, **kwargs): + """|coro| + + Bans this member. Equivalent to :meth:`Guild.ban` + """ + await self.guild.ban(self, **kwargs) + + async def unban(self, *, reason=None): + """|coro| + + Unbans this member. Equivalent to :meth:`Guild.unban` + """ + await self.guild.unban(self, reason=reason) + + async def kick(self, *, reason=None): + """|coro| + + Kicks this member. Equivalent to :meth:`Guild.kick` + """ + await self.guild.kick(self, reason=reason) + + async def edit(self, *, reason=None, **fields): + """|coro| + + Edits the member's data. + + Depending on the parameter passed, this requires different permissions listed below: + + +---------------+--------------------------------------+ + | Parameter | Permission | + +---------------+--------------------------------------+ + | nick | :attr:`Permissions.manage_nicknames` | + +---------------+--------------------------------------+ + | mute | :attr:`Permissions.mute_members` | + +---------------+--------------------------------------+ + | deafen | :attr:`Permissions.deafen_members` | + +---------------+--------------------------------------+ + | roles | :attr:`Permissions.manage_roles` | + +---------------+--------------------------------------+ + | voice_channel | :attr:`Permissions.move_members` | + +---------------+--------------------------------------+ + + All parameters are optional. + + Parameters + ----------- + nick: str + The member's new nickname. Use ``None`` to remove the nickname. + mute: bool + Indicates if the member should be guild muted or un-muted. + deafen: bool + Indicates if the member should be guild deafened or un-deafened. + roles: List[:class:`Roles`] + The member's new list of roles. This *replaces* the roles. + voice_channel: :class:`VoiceChannel` + The voice channel to move the member to. + reason: Optional[str] + The reason for editing this member. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have the proper permissions to the action requested. + HTTPException + The operation failed. + """ + http = self._state.http + guild_id = self.guild.id + payload = {} + + try: + nick = fields["nick"] + except KeyError: + # nick not present so... + pass + else: + nick = nick if nick else "" + if self._state.self_id == self.id: + await http.change_my_nickname(guild_id, nick, reason=reason) + else: + payload["nick"] = nick + + deafen = fields.get("deafen") + if deafen is not None: + payload["deaf"] = deafen + + mute = fields.get("mute") + if mute is not None: + payload["mute"] = mute + + try: + vc = fields["voice_channel"] + except KeyError: + pass + else: + payload["channel_id"] = vc.id + + try: + roles = fields["roles"] + except KeyError: + pass + else: + payload["roles"] = tuple(r.id for r in roles) + + await http.edit_member(guild_id, self.id, reason=reason, **payload) + + # TODO: wait for WS event for modify-in-place behaviour + + async def move_to(self, channel, *, reason=None): + """|coro| + + Moves a member to a new voice channel (they must be connected first). + + You must have the :attr:`~Permissions.move_members` permission to + use this. + + This raises the same exceptions as :meth:`edit`. + + Parameters + ----------- + channel: :class:`VoiceChannel` + The new voice channel to move the member to. + reason: Optional[str] + The reason for doing this action. Shows up on the audit log. + """ + await self.edit(voice_channel=channel, reason=reason) + + async def add_roles(self, *roles, reason=None, atomic=True): + r"""|coro| + + Gives the member a number of :class:`Role`\s. + + You must have the :attr:`~Permissions.manage_roles` permission to + use this. + + Parameters + ----------- + \*roles + An argument list of :class:`abc.Snowflake` representing a :class:`Role` + to give to the member. + reason: Optional[str] + The reason for adding these roles. Shows up on the audit log. + atomic: bool + Whether to atomically add roles. This will ensure that multiple + operations will always be applied regardless of the current + state of the cache. + + Raises + ------- + Forbidden + You do not have permissions to add these roles. + HTTPException + Adding roles failed. + """ + + if not atomic: + new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s) + await self.edit(roles=new_roles, reason=reason) + else: + req = self._state.http.add_role + guild_id = self.guild.id + user_id = self.id + for role in roles: + await req(guild_id, user_id, role.id, reason=reason) + + async def remove_roles(self, *roles, reason=None, atomic=True): + r"""|coro| + + Removes :class:`Role`\s from this member. + + You must have the :attr:`~Permissions.manage_roles` permission to + use this. + + Parameters + ----------- + \*roles + An argument list of :class:`abc.Snowflake` representing a :class:`Role` + to remove from the member. + reason: Optional[str] + The reason for removing these roles. Shows up on the audit log. + atomic: bool + Whether to atomically remove roles. This will ensure that multiple + operations will always be applied regardless of the current + state of the cache. + + Raises + ------- + Forbidden + You do not have permissions to remove these roles. + HTTPException + Removing the roles failed. + """ + + if not atomic: + new_roles = [Object(id=r.id) for r in self.roles[1:]] # remove @everyone + for role in roles: + try: + new_roles.remove(Object(id=role.id)) + except ValueError: + pass + + await self.edit(roles=new_roles, reason=reason) + else: + req = self._state.http.remove_role + guild_id = self.guild.id + user_id = self.id + for role in roles: + await req(guild_id, user_id, role.id, reason=reason) diff --git a/discord/message.py b/discord/message.py new file mode 100644 index 000000000..1c3e085f7 --- /dev/null +++ b/discord/message.py @@ -0,0 +1,799 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import re + +from . import utils +from .reaction import Reaction +from .emoji import Emoji, PartialEmoji +from .calls import CallMessage +from .enums import MessageType, try_enum +from .errors import InvalidArgument, ClientException, HTTPException +from .embeds import Embed + + +class Attachment: + """Represents an attachment from Discord. + + Attributes + ------------ + id: :class:`int` + The attachment ID. + size: :class:`int` + The attachment size in bytes. + height: Optional[:class:`int`] + The attachment's height, in pixels. Only applicable to images. + width: Optional[:class:`int`] + The attachment's width, in pixels. Only applicable to images. + filename: :class:`str` + The attachment's filename. + url: :class:`str` + The attachment URL. If the message this attachment was attached + to is deleted, then this will 404. + proxy_url: :class:`str` + The proxy URL. This is a cached version of the :attr:`~Attachment.url` in the + case of images. When the message is deleted, this URL might be valid for a few + minutes or not valid at all. + """ + + __slots__ = ("id", "size", "height", "width", "filename", "url", "proxy_url", "_http") + + def __init__(self, *, data, state): + self.id = int(data["id"]) + self.size = data["size"] + self.height = data.get("height") + self.width = data.get("width") + self.filename = data["filename"] + self.url = data.get("url") + self.proxy_url = data.get("proxy_url") + self._http = state.http + + def is_spoiler(self): + """:class:`bool`: Whether this attachment contains a spoiler.""" + return self.filename.startswith("SPOILER_") + + async def save(self, fp, *, seek_begin=True): + """|coro| + + Saves this attachment into a file-like object. + + Parameters + ----------- + fp: Union[BinaryIO, str] + The file-like object to save this attachment to or the filename + to use. If a filename is passed then a file is created with that + filename and used instead. + seek_begin: bool + Whether to seek to the beginning of the file after saving is + successfully done. + + Raises + -------- + HTTPException + Saving the attachment failed. + NotFound + The attachment was deleted. + + Returns + -------- + int + The number of bytes written. + """ + + data = await self._http.get_attachment(self.url) + if isinstance(fp, str): + with open(fp, "wb") as f: + return f.write(data) + else: + written = fp.write(data) + if seek_begin: + fp.seek(0) + return written + + +class Message: + r"""Represents a message from Discord. + + There should be no need to create one of these manually. + + Attributes + ----------- + tts: :class:`bool` + Specifies if the message was done with text-to-speech. + type: :class:`MessageType` + The type of message. In most cases this should not be checked, but it is helpful + in cases where it might be a system message for :attr:`system_content`. + author + A :class:`Member` that sent the message. If :attr:`channel` is a + private channel or the user has the left the guild, then it is a :class:`User` instead. + content: :class:`str` + The actual contents of the message. + nonce + The value used by the discord guild and the client to verify that the message is successfully sent. + This is typically non-important. + embeds: List[:class:`Embed`] + A list of embeds the message has. + channel + The :class:`TextChannel` that the message was sent from. + Could be a :class:`DMChannel` or :class:`GroupChannel` if it's a private message. + call: Optional[:class:`CallMessage`] + The call that the message refers to. This is only applicable to messages of type + :attr:`MessageType.call`. + mention_everyone: :class:`bool` + Specifies if the message mentions everyone. + + .. note:: + + This does not check if the ``@everyone`` or the ``@here`` text is in the message itself. + Rather this boolean indicates if either the ``@everyone`` or the ``@here`` text is in the message + **and** it did end up mentioning. + + mentions: :class:`list` + A list of :class:`Member` that were mentioned. If the message is in a private message + then the list will be of :class:`User` instead. For messages that are not of type + :attr:`MessageType.default`\, this array can be used to aid in system messages. + For more information, see :attr:`system_content`. + + .. warning:: + + The order of the mentions list is not in any particular order so you should + not rely on it. This is a discord limitation, not one with the library. + + channel_mentions: :class:`list` + A list of :class:`abc.GuildChannel` that were mentioned. If the message is in a private message + then the list is always empty. + role_mentions: :class:`list` + A list of :class:`Role` that were mentioned. If the message is in a private message + then the list is always empty. + id: :class:`int` + The message ID. + webhook_id: Optional[:class:`int`] + If this message was sent by a webhook, then this is the webhook ID's that sent this + message. + attachments: List[:class:`Attachment`] + A list of attachments given to a message. + pinned: :class:`bool` + Specifies if the message is currently pinned. + reactions : List[:class:`Reaction`] + Reactions to a message. Reactions can be either custom emoji or standard unicode emoji. + activity: Optional[:class:`dict`] + The activity associated with this message. Sent with Rich-Presence related messages that for + example, request joining, spectating, or listening to or with another member. + + It is a dictionary with the following optional keys: + + - ``type``: An integer denoting the type of message activity being requested. + - ``party_id``: The party ID associated with the party. + application: Optional[:class:`dict`] + The rich presence enabled application associated with this message. + + It is a dictionary with the following keys: + + - ``id``: A string representing the application's ID. + - ``name``: A string representing the application's name. + - ``description``: A string representing the application's description. + - ``icon``: A string representing the icon ID of the application. + - ``cover_image``: A string representing the embed's image asset ID. + """ + + __slots__ = ( + "_edited_timestamp", + "tts", + "content", + "channel", + "webhook_id", + "mention_everyone", + "embeds", + "id", + "mentions", + "author", + "_cs_channel_mentions", + "_cs_raw_mentions", + "attachments", + "_cs_clean_content", + "_cs_raw_channel_mentions", + "nonce", + "pinned", + "role_mentions", + "_cs_raw_role_mentions", + "type", + "call", + "_cs_system_content", + "_cs_guild", + "_state", + "reactions", + "application", + "activity", + ) + + def __init__(self, *, state, channel, data): + self._state = state + self.id = int(data["id"]) + self.webhook_id = utils._get_as_snowflake(data, "webhook_id") + self.reactions = [Reaction(message=self, data=d) for d in data.get("reactions", [])] + self.application = data.get("application") + self.activity = data.get("activity") + self._update(channel, data) + + def __repr__(self): + return "".format(self) + + def _try_patch(self, data, key, transform=None): + try: + value = data[key] + except KeyError: + pass + else: + if transform is None: + setattr(self, key, value) + else: + setattr(self, key, transform(value)) + + def _add_reaction(self, data, emoji, user_id): + reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) + is_me = data["me"] = user_id == self._state.self_id + + if reaction is None: + reaction = Reaction(message=self, data=data, emoji=emoji) + self.reactions.append(reaction) + else: + reaction.count += 1 + if is_me: + reaction.me = is_me + + return reaction + + def _remove_reaction(self, data, emoji, user_id): + reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) + + if reaction is None: + # already removed? + raise ValueError("Emoji already removed?") + + # if reaction isn't in the list, we crash. This means discord + # sent bad data, or we stored improperly + reaction.count -= 1 + + if user_id == self._state.self_id: + reaction.me = False + if reaction.count == 0: + # this raises ValueError if something went wrong as well. + self.reactions.remove(reaction) + + return reaction + + def _update(self, channel, data): + self.channel = channel + self._edited_timestamp = utils.parse_time(data.get("edited_timestamp")) + self._try_patch(data, "pinned") + self._try_patch(data, "application") + self._try_patch(data, "activity") + self._try_patch(data, "mention_everyone") + self._try_patch(data, "tts") + self._try_patch(data, "type", lambda x: try_enum(MessageType, x)) + self._try_patch(data, "content") + self._try_patch( + data, "attachments", lambda x: [Attachment(data=a, state=self._state) for a in x] + ) + self._try_patch(data, "embeds", lambda x: list(map(Embed.from_data, x))) + self._try_patch(data, "nonce") + + for handler in ("author", "mentions", "mention_roles", "call"): + try: + getattr(self, "_handle_%s" % handler)(data[handler]) + except KeyError: + continue + + # clear the cached properties + cached = filter(lambda attr: attr.startswith("_cs_"), self.__slots__) + for attr in cached: + try: + delattr(self, attr) + except AttributeError: + pass + + def _handle_author(self, author): + self.author = self._state.store_user(author) + if self.guild is not None: + found = self.guild.get_member(self.author.id) + if found is not None: + self.author = found + + def _handle_mentions(self, mentions): + self.mentions = [] + if self.guild is None: + self.mentions = [self._state.store_user(m) for m in mentions] + return + + for mention in filter(None, mentions): + id_search = int(mention["id"]) + member = self.guild.get_member(id_search) + if member is not None: + self.mentions.append(member) + + def _handle_mention_roles(self, role_mentions): + self.role_mentions = [] + if self.guild is not None: + for role_id in map(int, role_mentions): + role = self.guild.get_role(role_id) + if role is not None: + self.role_mentions.append(role) + + def _handle_call(self, call): + if call is None or self.type is not MessageType.call: + self.call = None + return + + # we get the participant source from the mentions array or + # the author + + participants = [] + for uid in map(int, call.get("participants", [])): + if uid == self.author.id: + participants.append(self.author) + else: + user = utils.find(lambda u: u.id == uid, self.mentions) + if user is not None: + participants.append(user) + + call["participants"] = participants + self.call = CallMessage(message=self, **call) + + @utils.cached_slot_property("_cs_guild") + def guild(self): + """Optional[:class:`Guild`]: The guild that the message belongs to, if applicable.""" + return getattr(self.channel, "guild", None) + + @utils.cached_slot_property("_cs_raw_mentions") + def raw_mentions(self): + """A property that returns an array of user IDs matched with + the syntax of <@user_id> in the message content. + + This allows you to receive the user IDs of mentioned users + even in a private message context. + """ + return [int(x) for x in re.findall(r"<@!?([0-9]+)>", self.content)] + + @utils.cached_slot_property("_cs_raw_channel_mentions") + def raw_channel_mentions(self): + """A property that returns an array of channel IDs matched with + the syntax of <#channel_id> in the message content. + """ + return [int(x) for x in re.findall(r"<#([0-9]+)>", self.content)] + + @utils.cached_slot_property("_cs_raw_role_mentions") + def raw_role_mentions(self): + """A property that returns an array of role IDs matched with + the syntax of <@&role_id> in the message content. + """ + return [int(x) for x in re.findall(r"<@&([0-9]+)>", self.content)] + + @utils.cached_slot_property("_cs_channel_mentions") + def channel_mentions(self): + if self.guild is None: + return [] + it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) + return utils._unique(it) + + @utils.cached_slot_property("_cs_clean_content") + def clean_content(self): + """A property that returns the content in a "cleaned up" + manner. This basically means that mentions are transformed + into the way the client shows it. e.g. ``<#id>`` will transform + into ``#name``. + + This will also transform @everyone and @here mentions into + non-mentions. + """ + + transformations = { + re.escape("<#%s>" % channel.id): "#" + channel.name + for channel in self.channel_mentions + } + + mention_transforms = { + re.escape("<@%s>" % member.id): "@" + member.display_name for member in self.mentions + } + + # add the <@!user_id> cases as well.. + second_mention_transforms = { + re.escape("<@!%s>" % member.id): "@" + member.display_name for member in self.mentions + } + + transformations.update(mention_transforms) + transformations.update(second_mention_transforms) + + if self.guild is not None: + role_transforms = { + re.escape("<@&%s>" % role.id): "@" + role.name for role in self.role_mentions + } + transformations.update(role_transforms) + + def repl(obj): + return transformations.get(re.escape(obj.group(0)), "") + + pattern = re.compile("|".join(transformations.keys())) + result = pattern.sub(repl, self.content) + + transformations = {"@everyone": "@\u200beveryone", "@here": "@\u200bhere"} + + def repl2(obj): + return transformations.get(obj.group(0), "") + + pattern = re.compile("|".join(transformations.keys())) + return pattern.sub(repl2, result) + + @property + def created_at(self): + """datetime.datetime: The message's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def edited_at(self): + """Optional[datetime.datetime]: A naive UTC datetime object containing the edited time of the message.""" + return self._edited_timestamp + + @property + def jump_url(self): + """:class:`str`: Returns a URL that allows the client to jump to this message.""" + guild_id = getattr(self.guild, "id", "@me") + return "https://discordapp.com/channels/{0}/{1.channel.id}/{1.id}".format(guild_id, self) + + @utils.cached_slot_property("_cs_system_content") + def system_content(self): + r"""A property that returns the content that is rendered + regardless of the :attr:`Message.type`. + + In the case of :attr:`MessageType.default`\, this just returns the + regular :attr:`Message.content`. Otherwise this returns an English + message denoting the contents of the system message. + """ + + if self.type is MessageType.default: + return self.content + + if self.type is MessageType.pins_add: + return "{0.name} pinned a message to this channel.".format(self.author) + + if self.type is MessageType.recipient_add: + return "{0.name} added {1.name} to the group.".format(self.author, self.mentions[0]) + + if self.type is MessageType.recipient_remove: + return "{0.name} removed {1.name} from the group.".format( + self.author, self.mentions[0] + ) + + if self.type is MessageType.channel_name_change: + return "{0.author.name} changed the channel name: {0.content}".format(self) + + if self.type is MessageType.channel_icon_change: + return "{0.author.name} changed the channel icon.".format(self) + + if self.type is MessageType.new_member: + formats = [ + "{0} just joined the server - glhf!", + "{0} just joined. Everyone, look busy!", + "{0} just joined. Can I get a heal?", + "{0} joined your party.", + "{0} joined. You must construct additional pylons.", + "Ermagherd. {0} is here.", + "Welcome, {0}. Stay awhile and listen.", + "Welcome, {0}. We were expecting you ( ͡° ͜ʖ ͡°)", + "Welcome, {0}. We hope you brought pizza.", + "Welcome {0}. Leave your weapons by the door.", + "A wild {0} appeared.", + "Swoooosh. {0} just landed.", + "Brace yourselves. {0} just joined the server.", + "{0} just joined. Hide your bananas.", + "{0} just arrived. Seems OP - please nerf.", + "{0} just slid into the server.", + "A {0} has spawned in the server.", + "Big {0} showed up!", + "Where’s {0}? In the server!", + "{0} hopped into the server. Kangaroo!!", + "{0} just showed up. Hold my beer.", + "Challenger approaching - {0} has appeared!", + "It's a bird! It's a plane! Nevermind, it's just {0}.", + "It's {0}! Praise the sun! [T]/", + "Never gonna give {0} up. Never gonna let {0} down.", + "Ha! {0} has joined! You activated my trap card!", + "Cheers, love! {0}'s here!", + "Hey! Listen! {0} has joined!", + "We've been expecting you {0}", + "It's dangerous to go alone, take {0}!", + "{0} has joined the server! It's super effective!", + "Cheers, love! {0} is here!", + "{0} is here, as the prophecy foretold.", + "{0} has arrived. Party's over.", + "Ready player {0}", + "{0} is here to kick butt and chew bubblegum. And {0} is all out of gum.", + "Hello. Is it {0} you're looking for?", + "{0} has joined. Stay a while and listen!", + "Roses are red, violets are blue, {0} joined this server with you", + ] + + index = int(self.created_at.timestamp()) % len(formats) + return formats[index].format(self.author.name) + + if self.type is MessageType.call: + # we're at the call message type now, which is a bit more complicated. + # we can make the assumption that Message.channel is a PrivateChannel + # with the type ChannelType.group or ChannelType.private + call_ended = self.call.ended_timestamp is not None + + if self.channel.me in self.call.participants: + return "{0.author.name} started a call.".format(self) + elif call_ended: + return "You missed a call from {0.author.name}".format(self) + else: + return "{0.author.name} started a call \N{EM DASH} Join the call.".format(self) + + async def delete(self): + """|coro| + + Deletes the message. + + Your own messages could be deleted without any proper permissions. However to + delete other people's messages, you need the :attr:`~Permissions.manage_messages` + permission. + + Raises + ------ + Forbidden + You do not have proper permissions to delete the message. + HTTPException + Deleting the message failed. + """ + await self._state.http.delete_message(self.channel.id, self.id) + + async def edit(self, **fields): + """|coro| + + Edits the message. + + The content must be able to be transformed into a string via ``str(content)``. + + Parameters + ----------- + content: Optional[str] + The new content to replace the message with. + Could be ``None`` to remove the content. + embed: Optional[:class:`Embed`] + The new embed to replace the original with. + Could be ``None`` to remove the embed. + delete_after: Optional[float] + If provided, the number of seconds to wait in the background + before deleting the message we just edited. If the deletion fails, + then it is silently ignored. + + Raises + ------- + HTTPException + Editing the message failed. + """ + + try: + content = fields["content"] + except KeyError: + pass + else: + if content is not None: + fields["content"] = str(content) + + try: + embed = fields["embed"] + except KeyError: + pass + else: + if embed is not None: + fields["embed"] = embed.to_dict() + + data = await self._state.http.edit_message(self.id, self.channel.id, **fields) + self._update(channel=self.channel, data=data) + + try: + delete_after = fields["delete_after"] + except KeyError: + pass + else: + if delete_after is not None: + + async def delete(): + await asyncio.sleep(delete_after, loop=self._state.loop) + try: + await self._state.http.delete_message(self.channel.id, self.id) + except HTTPException: + pass + + asyncio.ensure_future(delete(), loop=self._state.loop) + + async def pin(self): + """|coro| + + Pins the message. + + You must have the :attr:`~Permissions.manage_messages` permission to do + this in a non-private channel context. + + Raises + ------- + Forbidden + You do not have permissions to pin the message. + NotFound + The message or channel was not found or deleted. + HTTPException + Pinning the message failed, probably due to the channel + having more than 50 pinned messages. + """ + + await self._state.http.pin_message(self.channel.id, self.id) + self.pinned = True + + async def unpin(self): + """|coro| + + Unpins the message. + + You must have the :attr:`~Permissions.manage_messages` permission to do + this in a non-private channel context. + + Raises + ------- + Forbidden + You do not have permissions to unpin the message. + NotFound + The message or channel was not found or deleted. + HTTPException + Unpinning the message failed. + """ + + await self._state.http.unpin_message(self.channel.id, self.id) + self.pinned = False + + async def add_reaction(self, emoji): + """|coro| + + Add a reaction to the message. + + The emoji may be a unicode emoji or a custom guild :class:`Emoji`. + + You must have the :attr:`~Permissions.read_message_history` permission + to use this. If nobody else has reacted to the message using this + emoji, the :attr:`~Permissions.add_reactions` permission is required. + + Parameters + ------------ + emoji: Union[:class:`Emoji`, :class:`Reaction`, :class:`PartialEmoji`, str] + The emoji to react with. + + Raises + -------- + HTTPException + Adding the reaction failed. + Forbidden + You do not have the proper permissions to react to the message. + NotFound + The emoji you specified was not found. + InvalidArgument + The emoji parameter is invalid. + """ + + emoji = self._emoji_reaction(emoji) + await self._state.http.add_reaction(self.id, self.channel.id, emoji) + + async def remove_reaction(self, emoji, member): + """|coro| + + Remove a reaction by the member from the message. + + The emoji may be a unicode emoji or a custom guild :class:`Emoji`. + + If the reaction is not your own (i.e. ``member`` parameter is not you) then + the :attr:`~Permissions.manage_messages` permission is needed. + + The ``member`` parameter must represent a member and meet + the :class:`abc.Snowflake` abc. + + Parameters + ------------ + emoji: Union[:class:`Emoji`, :class:`Reaction`, :class:`PartialEmoji`, str] + The emoji to remove. + member: :class:`abc.Snowflake` + The member for which to remove the reaction. + + Raises + -------- + HTTPException + Removing the reaction failed. + Forbidden + You do not have the proper permissions to remove the reaction. + NotFound + The member or emoji you specified was not found. + InvalidArgument + The emoji parameter is invalid. + """ + + emoji = self._emoji_reaction(emoji) + + if member.id == self._state.self_id: + await self._state.http.remove_own_reaction(self.id, self.channel.id, emoji) + else: + await self._state.http.remove_reaction(self.id, self.channel.id, emoji, member.id) + + @staticmethod + def _emoji_reaction(emoji): + if isinstance(emoji, Reaction): + emoji = emoji.emoji + + if isinstance(emoji, Emoji): + return "%s:%s" % (emoji.name, emoji.id) + if isinstance(emoji, PartialEmoji): + return emoji._as_reaction() + if isinstance(emoji, str): + return emoji # this is okay + + raise InvalidArgument( + "emoji argument must be str, Emoji, or Reaction not {.__class__.__name__}.".format( + emoji + ) + ) + + async def clear_reactions(self): + """|coro| + + Removes all the reactions from the message. + + You need the :attr:`~Permissions.manage_messages` permission to use this. + + Raises + -------- + HTTPException + Removing the reactions failed. + Forbidden + You do not have the proper permissions to remove all the reactions. + """ + await self._state.http.clear_reactions(self.id, self.channel.id) + + def ack(self): + """|coro| + + Marks this message as read. + + The user must not be a bot user. + + Raises + ------- + HTTPException + Acking failed. + ClientException + You must not be a bot user. + """ + + state = self._state + if state.is_bot: + raise ClientException("Must not be a bot account to ack messages.") + return state.http.ack_message(self.channel.id, self.id) diff --git a/discord/mixins.py b/discord/mixins.py new file mode 100644 index 000000000..8d3b72c94 --- /dev/null +++ b/discord/mixins.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +class EqualityComparable: + __slots__ = () + + def __eq__(self, other): + return isinstance(other, self.__class__) and other.id == self.id + + def __ne__(self, other): + if isinstance(other, self.__class__): + return other.id != self.id + return True + + +class Hashable(EqualityComparable): + __slots__ = () + + def __hash__(self): + return self.id >> 22 diff --git a/discord/object.py b/discord/object.py new file mode 100644 index 000000000..e07b12699 --- /dev/null +++ b/discord/object.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from . import utils +from .mixins import Hashable + + +class Object(Hashable): + """Represents a generic Discord object. + + The purpose of this class is to allow you to create 'miniature' + versions of data classes if you want to pass in just an ID. Most functions + that take in a specific data class with an ID can also take in this class + as a substitute instead. Note that even though this is the case, not all + objects (if any) actually inherit from this class. + + There are also some cases where some websocket events are received + in :issue:`strange order <21>` and when such events happened you would + receive this class rather than the actual data class. These cases are + extremely rare. + + .. container:: operations + + .. describe:: x == y + + Checks if two objects are equal. + + .. describe:: x != y + + Checks if two objects are not equal. + + .. describe:: hash(x) + + Returns the object's hash. + + Attributes + ----------- + id : :class:`str` + The ID of the object. + """ + + def __init__(self, id): + self.id = id + + @property + def created_at(self): + """Returns the snowflake's creation time in UTC.""" + return utils.snowflake_time(self.id) diff --git a/discord/opus.py b/discord/opus.py new file mode 100644 index 000000000..08a86a897 --- /dev/null +++ b/discord/opus.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import array +import ctypes +import ctypes.util +import logging +import os.path +import sys + +from .errors import DiscordException + +log = logging.getLogger(__name__) +c_int_ptr = ctypes.POINTER(ctypes.c_int) +c_int16_ptr = ctypes.POINTER(ctypes.c_int16) +c_float_ptr = ctypes.POINTER(ctypes.c_float) + + +class EncoderStruct(ctypes.Structure): + pass + + +EncoderStructPtr = ctypes.POINTER(EncoderStruct) + + +def _err_lt(result, func, args): + if result < 0: + log.info("error has happened in %s", func.__name__) + raise OpusError(result) + return result + + +def _err_ne(result, func, args): + ret = args[-1]._obj + if ret.value != 0: + log.info("error has happened in %s", func.__name__) + raise OpusError(ret.value) + return result + + +# A list of exported functions. +# The first argument is obviously the name. +# The second one are the types of arguments it takes. +# The third is the result type. +# The fourth is the error handler. +exported_functions = [ + ("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None), + ("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ( + "opus_encoder_create", + [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], + EncoderStructPtr, + _err_ne, + ), + ( + "opus_encode", + [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ("opus_encoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_encoder_destroy", [EncoderStructPtr], None, None), +] + + +def libopus_loader(name): + # create the library... + lib = ctypes.cdll.LoadLibrary(name) + + # register the functions... + for item in exported_functions: + func = getattr(lib, item[0]) + + try: + if item[1]: + func.argtypes = item[1] + + func.restype = item[2] + except KeyError: + pass + + try: + if item[3]: + func.errcheck = item[3] + except KeyError: + log.exception("Error assigning check function to %s", func) + + return lib + + +try: + if sys.platform == "win32": + _basedir = os.path.dirname(os.path.abspath(__file__)) + _bitness = "x64" if sys.maxsize > 2 ** 32 else "x86" + _filename = os.path.join(_basedir, "bin", "libopus-0.{}.dll".format(_bitness)) + _lib = libopus_loader(_filename) + else: + _lib = libopus_loader(ctypes.util.find_library("opus")) +except Exception: + _lib = None + + +def load_opus(name): + """Loads the libopus shared library for use with voice. + + If this function is not called then the library uses the function + `ctypes.util.find_library`__ and then loads that one + if available. + + .. _find library: https://docs.python.org/3.5/library/ctypes.html#finding-shared-libraries + __ `find library`_ + + Not loading a library leads to voice not working. + + This function propagates the exceptions thrown. + + Warning + -------- + The bitness of the library must match the bitness of your python + interpreter. If the library is 64-bit then your python interpreter + must be 64-bit as well. Usually if there's a mismatch in bitness then + the load will throw an exception. + + Note + ---- + On Windows, the .dll extension is not necessary. However, on Linux + the full extension is required to load the library, e.g. ``libopus.so.1``. + On Linux however, `find library`_ will usually find the library automatically + without you having to call this. + + Parameters + ---------- + name: str + The filename of the shared library. + """ + global _lib + _lib = libopus_loader(name) + + +def is_loaded(): + """Function to check if opus lib is successfully loaded either + via the ``ctypes.util.find_library`` call of :func:`load_opus`. + + This must return ``True`` for voice to work. + + Returns + ------- + bool + Indicates if the opus library has been loaded. + """ + global _lib + return _lib is not None + + +class OpusError(DiscordException): + """An exception that is thrown for libopus related errors. + + Attributes + ---------- + code : :class:`int` + The error code returned. + """ + + def __init__(self, code): + self.code = code + msg = _lib.opus_strerror(self.code).decode("utf-8") + log.info('"%s" has happened', msg) + super().__init__(msg) + + +class OpusNotLoaded(DiscordException): + """An exception that is thrown for when libopus is not loaded.""" + + pass + + +# Some constants... +OK = 0 +APPLICATION_AUDIO = 2049 +APPLICATION_VOIP = 2048 +APPLICATION_LOWDELAY = 2051 +CTL_SET_BITRATE = 4002 +CTL_SET_BANDWIDTH = 4008 +CTL_SET_FEC = 4012 +CTL_SET_PLP = 4014 +CTL_SET_SIGNAL = 4024 + +band_ctl = {"narrow": 1101, "medium": 1102, "wide": 1103, "superwide": 1104, "full": 1105} + +signal_ctl = {"auto": -1000, "voice": 3001, "music": 3002} + + +class Encoder: + SAMPLING_RATE = 48000 + CHANNELS = 2 + FRAME_LENGTH = 20 + SAMPLE_SIZE = 4 # (bit_rate / 8) * CHANNELS (bit_rate == 16) + SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH) + + FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE + + def __init__(self, application=APPLICATION_AUDIO): + self.application = application + + if not is_loaded(): + raise OpusNotLoaded() + + self._state = self._create_state() + self.set_bitrate(128) + self.set_fec(True) + self.set_expected_packet_loss_percent(0.15) + self.set_bandwidth("full") + self.set_signal_type("auto") + + def __del__(self): + if hasattr(self, "_state"): + _lib.opus_encoder_destroy(self._state) + self._state = None + + def _create_state(self): + ret = ctypes.c_int() + return _lib.opus_encoder_create( + self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret) + ) + + def set_bitrate(self, kbps): + kbps = min(128, max(16, int(kbps))) + + _lib.opus_encoder_ctl(self._state, CTL_SET_BITRATE, kbps * 1024) + return kbps + + def set_bandwidth(self, req): + if req not in band_ctl: + raise KeyError( + "%r is not a valid bandwidth setting. Try one of: %s" % (req, ",".join(band_ctl)) + ) + + k = band_ctl[req] + _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) + + def set_signal_type(self, req): + if req not in signal_ctl: + raise KeyError( + "%r is not a valid signal setting. Try one of: %s" % (req, ",".join(signal_ctl)) + ) + + k = signal_ctl[req] + _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) + + def set_fec(self, enabled=True): + _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) + + def set_expected_packet_loss_percent(self, percentage): + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) + + def encode(self, pcm, frame_size): + max_data_bytes = len(pcm) + pcm = ctypes.cast(pcm, c_int16_ptr) + data = (ctypes.c_char * max_data_bytes)() + + ret = _lib.opus_encode(self._state, pcm, frame_size, data, max_data_bytes) + + return array.array("b", data[:ret]).tobytes() diff --git a/discord/permissions.py b/discord/permissions.py new file mode 100644 index 000000000..4fc17b27b --- /dev/null +++ b/discord/permissions.py @@ -0,0 +1,636 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +class Permissions: + """Wraps up the Discord permission value. + + The properties provided are two way. You can set and retrieve individual + bits using the properties as if they were regular bools. This allows + you to edit permissions. + + .. container:: operations + + .. describe:: x == y + + Checks if two permissions are equal. + .. describe:: x != y + + Checks if two permissions are not equal. + .. describe:: x <= y + + Checks if a permission is a subset of another permission. + .. describe:: x >= y + + Checks if a permission is a superset of another permission. + .. describe:: x < y + + Checks if a permission is a strict subset of another permission. + .. describe:: x > y + + Checks if a permission is a strict superset of another permission. + .. describe:: hash(x) + + Return the permission's hash. + .. describe:: iter(x) + + Returns an iterator of ``(perm, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + Attributes + ----------- + value + The raw value. This value is a bit array field of a 53-bit integer + representing the currently available permissions. You should query + permissions via the properties rather than using this raw value. + """ + + __slots__ = ("value",) + + def __init__(self, permissions=0): + if not isinstance(permissions, int): + raise TypeError( + "Expected int parameter, received %s instead." % permissions.__class__.__name__ + ) + + self.value = permissions + + def __eq__(self, other): + return isinstance(other, Permissions) and self.value == other.value + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.value) + + def __repr__(self): + return "" % self.value + + def _perm_iterator(self): + for attr in dir(self): + # check if it's a property, because if so it's a permission + is_property = isinstance(getattr(self.__class__, attr), property) + if is_property: + yield (attr, getattr(self, attr)) + + def __iter__(self): + return self._perm_iterator() + + def is_subset(self, other): + """Returns True if self has the same or fewer permissions as other.""" + if isinstance(other, Permissions): + return (self.value & other.value) == self.value + else: + raise TypeError( + "cannot compare {} with {}".format( + self.__class__.__name__, other.__class__.__name__ + ) + ) + + def is_superset(self, other): + """Returns True if self has the same or more permissions as other.""" + if isinstance(other, Permissions): + return (self.value | other.value) == self.value + else: + raise TypeError( + "cannot compare {} with {}".format( + self.__class__.__name__, other.__class__.__name__ + ) + ) + + def is_strict_subset(self, other): + """Returns True if the permissions on other are a strict subset of those on self.""" + return self.is_subset(other) and self != other + + def is_strict_superset(self, other): + """Returns True if the permissions on other are a strict superset of those on self.""" + return self.is_superset(other) and self != other + + __le__ = is_subset + __ge__ = is_superset + __lt__ = is_strict_subset + __gt__ = is_strict_superset + + @classmethod + def none(cls): + """A factory method that creates a :class:`Permissions` with all + permissions set to False.""" + return cls(0) + + @classmethod + def all(cls): + """A factory method that creates a :class:`Permissions` with all + permissions set to True.""" + return cls(0b01111111111101111111110111111111) + + @classmethod + def all_channel(cls): + """A :class:`Permissions` with all channel-specific permissions set to + True and the guild-specific ones set to False. The guild-specific + permissions are currently: + + - manage_guild + - kick_members + - ban_members + - administrator + - change_nicknames + - manage_nicknames + """ + return cls(0b00110011111101111111110001010001) + + @classmethod + def general(cls): + """A factory method that creates a :class:`Permissions` with all + "General" permissions from the official Discord UI set to True.""" + return cls(0b01111100000000000000000010111111) + + @classmethod + def text(cls): + """A factory method that creates a :class:`Permissions` with all + "Text" permissions from the official Discord UI set to True.""" + return cls(0b00000000000001111111110001000000) + + @classmethod + def voice(cls): + """A factory method that creates a :class:`Permissions` with all + "Voice" permissions from the official Discord UI set to True.""" + return cls(0b00000011111100000000000100000000) + + def update(self, **kwargs): + r"""Bulk updates this permission object. + + Allows you to set multiple attributes by using keyword + arguments. The names must be equivalent to the properties + listed. Extraneous key/value pairs will be silently ignored. + + Parameters + ------------ + \*\*kwargs + A list of key/value pairs to bulk update permissions with. + """ + for key, value in kwargs.items(): + try: + is_property = isinstance(getattr(self.__class__, key), property) + except AttributeError: + continue + + if is_property: + setattr(self, key, value) + + def _bit(self, index): + return bool((self.value >> index) & 1) + + def _set(self, index, value): + if value is True: + self.value |= 1 << index + elif value is False: + self.value &= ~(1 << index) + else: + raise TypeError("Value to set for Permissions must be a bool.") + + def handle_overwrite(self, allow, deny): + # Basically this is what's happening here. + # We have an original bit array, e.g. 1010 + # Then we have another bit array that is 'denied', e.g. 1111 + # And then we have the last one which is 'allowed', e.g. 0101 + # We want original OP denied to end up resulting in + # whatever is in denied to be set to 0. + # So 1010 OP 1111 -> 0000 + # Then we take this value and look at the allowed values. + # And whatever is allowed is set to 1. + # So 0000 OP2 0101 -> 0101 + # The OP is base & ~denied. + # The OP2 is base | allowed. + self.value = (self.value & ~deny) | allow + + @property + def create_instant_invite(self): + """Returns True if the user can create instant invites.""" + return self._bit(0) + + @create_instant_invite.setter + def create_instant_invite(self, value): + self._set(0, value) + + @property + def kick_members(self): + """Returns True if the user can kick users from the guild.""" + return self._bit(1) + + @kick_members.setter + def kick_members(self, value): + self._set(1, value) + + @property + def ban_members(self): + """Returns True if a user can ban users from the guild.""" + return self._bit(2) + + @ban_members.setter + def ban_members(self, value): + self._set(2, value) + + @property + def administrator(self): + """Returns True if a user is an administrator. This role overrides all other permissions. + + This also bypasses all channel-specific overrides. + """ + return self._bit(3) + + @administrator.setter + def administrator(self, value): + self._set(3, value) + + @property + def manage_channels(self): + """Returns True if a user can edit, delete, or create channels in the guild. + + This also corresponds to the "Manage Channel" channel-specific override.""" + return self._bit(4) + + @manage_channels.setter + def manage_channels(self, value): + self._set(4, value) + + @property + def manage_guild(self): + """Returns True if a user can edit guild properties.""" + return self._bit(5) + + @manage_guild.setter + def manage_guild(self, value): + self._set(5, value) + + @property + def add_reactions(self): + """Returns True if a user can add reactions to messages.""" + return self._bit(6) + + @add_reactions.setter + def add_reactions(self, value): + self._set(6, value) + + @property + def view_audit_log(self): + """Returns True if a user can view the guild's audit log.""" + return self._bit(7) + + @view_audit_log.setter + def view_audit_log(self, value): + self._set(7, value) + + @property + def priority_speaker(self): + """Returns True if a user can be more easily heard while talking.""" + return self._bit(8) + + @priority_speaker.setter + def priority_speaker(self, value): + self._set(8, value) + + # 1 unused + + @property + def read_messages(self): + """Returns True if a user can read messages from all or specific text channels.""" + return self._bit(10) + + @read_messages.setter + def read_messages(self, value): + self._set(10, value) + + @property + def send_messages(self): + """Returns True if a user can send messages from all or specific text channels.""" + return self._bit(11) + + @send_messages.setter + def send_messages(self, value): + self._set(11, value) + + @property + def send_tts_messages(self): + """Returns True if a user can send TTS messages from all or specific text channels.""" + return self._bit(12) + + @send_tts_messages.setter + def send_tts_messages(self, value): + self._set(12, value) + + @property + def manage_messages(self): + """Returns True if a user can delete or pin messages in a text channel. Note that there are currently no ways to edit other people's messages.""" + return self._bit(13) + + @manage_messages.setter + def manage_messages(self, value): + self._set(13, value) + + @property + def embed_links(self): + """Returns True if a user's messages will automatically be embedded by Discord.""" + return self._bit(14) + + @embed_links.setter + def embed_links(self, value): + self._set(14, value) + + @property + def attach_files(self): + """Returns True if a user can send files in their messages.""" + return self._bit(15) + + @attach_files.setter + def attach_files(self, value): + self._set(15, value) + + @property + def read_message_history(self): + """Returns True if a user can read a text channel's previous messages.""" + return self._bit(16) + + @read_message_history.setter + def read_message_history(self, value): + self._set(16, value) + + @property + def mention_everyone(self): + """Returns True if a user's @everyone or @here will mention everyone in the text channel.""" + return self._bit(17) + + @mention_everyone.setter + def mention_everyone(self, value): + self._set(17, value) + + @property + def external_emojis(self): + """Returns True if a user can use emojis from other guilds.""" + return self._bit(18) + + @external_emojis.setter + def external_emojis(self, value): + self._set(18, value) + + # 1 unused + + @property + def connect(self): + """Returns True if a user can connect to a voice channel.""" + return self._bit(20) + + @connect.setter + def connect(self, value): + self._set(20, value) + + @property + def speak(self): + """Returns True if a user can speak in a voice channel.""" + return self._bit(21) + + @speak.setter + def speak(self, value): + self._set(21, value) + + @property + def mute_members(self): + """Returns True if a user can mute other users.""" + return self._bit(22) + + @mute_members.setter + def mute_members(self, value): + self._set(22, value) + + @property + def deafen_members(self): + """Returns True if a user can deafen other users.""" + return self._bit(23) + + @deafen_members.setter + def deafen_members(self, value): + self._set(23, value) + + @property + def move_members(self): + """Returns True if a user can move users between other voice channels.""" + return self._bit(24) + + @move_members.setter + def move_members(self, value): + self._set(24, value) + + @property + def use_voice_activation(self): + """Returns True if a user can use voice activation in voice channels.""" + return self._bit(25) + + @use_voice_activation.setter + def use_voice_activation(self, value): + self._set(25, value) + + @property + def change_nickname(self): + """Returns True if a user can change their nickname in the guild.""" + return self._bit(26) + + @change_nickname.setter + def change_nickname(self, value): + self._set(26, value) + + @property + def manage_nicknames(self): + """Returns True if a user can change other user's nickname in the guild.""" + return self._bit(27) + + @manage_nicknames.setter + def manage_nicknames(self, value): + self._set(27, value) + + @property + def manage_roles(self): + """Returns True if a user can create or edit roles less than their role's position. + + This also corresponds to the "Manage Permissions" channel-specific override. + """ + return self._bit(28) + + @manage_roles.setter + def manage_roles(self, value): + self._set(28, value) + + @property + def manage_webhooks(self): + """Returns True if a user can create, edit, or delete webhooks.""" + return self._bit(29) + + @manage_webhooks.setter + def manage_webhooks(self, value): + self._set(29, value) + + @property + def manage_emojis(self): + """Returns True if a user can create, edit, or delete emojis.""" + return self._bit(30) + + @manage_emojis.setter + def manage_emojis(self, value): + self._set(30, value) + + # 1 unused + + # after these 32 bits, there's 21 more unused ones technically + + +def augment_from_permissions(cls): + cls.VALID_NAMES = { + name for name in dir(Permissions) if isinstance(getattr(Permissions, name), property) + } + + # make descriptors for all the valid names + for name in cls.VALID_NAMES: + # god bless Python + def getter(self, x=name): + return self._values.get(x) + + def setter(self, value, x=name): + self._set(x, value) + + prop = property(getter, setter) + setattr(cls, name, prop) + + return cls + + +@augment_from_permissions +class PermissionOverwrite: + r"""A type that is used to represent a channel specific permission. + + Unlike a regular :class:`Permissions`\, the default value of a + permission is equivalent to ``None`` and not ``False``. Setting + a value to ``False`` is **explicitly** denying that permission, + while setting a value to ``True`` is **explicitly** allowing + that permission. + + The values supported by this are the same as :class:`Permissions` + with the added possibility of it being set to ``None``. + + Supported operations: + + +-----------+------------------------------------------+ + | Operation | Description | + +===========+==========================================+ + | iter(x) | Returns an iterator of (perm, value) | + | | pairs. This allows this class to be used | + | | as an iterable in e.g. set/list/dict | + | | constructions. | + +-----------+------------------------------------------+ + + Parameters + ----------- + \*\*kwargs + Set the value of permissions by their name. + """ + + __slots__ = ("_values",) + + def __init__(self, **kwargs): + self._values = {} + + for key, value in kwargs.items(): + if key not in self.VALID_NAMES: + raise ValueError("no permission called {0}.".format(key)) + + setattr(self, key, value) + + def _set(self, key, value): + if value not in (True, None, False): + raise TypeError( + "Expected bool or NoneType, received {0.__class__.__name__}".format(value) + ) + + self._values[key] = value + + def pair(self): + """Returns the (allow, deny) pair from this overwrite. + + The value of these pairs is :class:`Permissions`. + """ + + allow = Permissions.none() + deny = Permissions.none() + + for key, value in self._values.items(): + if value is True: + setattr(allow, key, True) + elif value is False: + setattr(deny, key, True) + + return allow, deny + + @classmethod + def from_pair(cls, allow, deny): + """Creates an overwrite from an allow/deny pair of :class:`Permissions`.""" + ret = cls() + for key, value in allow: + if value is True: + setattr(ret, key, True) + + for key, value in deny: + if value is True: + setattr(ret, key, False) + + return ret + + def is_empty(self): + """Checks if the permission overwrite is currently empty. + + An empty permission overwrite is one that has no overwrites set + to True or False. + """ + return all(x is None for x in self._values.values()) + + def update(self, **kwargs): + r"""Bulk updates this permission overwrite object. + + Allows you to set multiple attributes by using keyword + arguments. The names must be equivalent to the properties + listed. Extraneous key/value pairs will be silently ignored. + + Parameters + ------------ + \*\*kwargs + A list of key/value pairs to bulk update with. + """ + for key, value in kwargs.items(): + if key not in self.VALID_NAMES: + continue + + setattr(self, key, value) + + def __iter__(self): + for key in self.VALID_NAMES: + yield key, self._values.get(key) diff --git a/discord/player.py b/discord/player.py new file mode 100644 index 000000000..17040dc2d --- /dev/null +++ b/discord/player.py @@ -0,0 +1,356 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import threading +import subprocess +import audioop +import logging +import shlex +import time + +from .errors import ClientException +from .opus import Encoder as OpusEncoder + +log = logging.getLogger(__name__) + +__all__ = ["AudioSource", "PCMAudio", "FFmpegPCMAudio", "PCMVolumeTransformer"] + + +class AudioSource: + """Represents an audio stream. + + The audio stream can be Opus encoded or not, however if the audio stream + is not Opus encoded then the audio format must be 16-bit 48KHz stereo PCM. + + .. warning:: + + The audio source reads are done in a separate thread. + """ + + def read(self): + """Reads 20ms worth of audio. + + Subclasses must implement this. + + If the audio is complete, then returning an empty + :term:`py:bytes-like object` to signal this is the way to do so. + + If :meth:`is_opus` method returns ``True``, then it must return + 20ms worth of Opus encoded audio. Otherwise, it must be 20ms + worth of 16-bit 48KHz stereo PCM, which is about 3,840 bytes + per frame (20ms worth of audio). + + Returns + -------- + bytes + A bytes like object that represents the PCM or Opus data. + """ + raise NotImplementedError + + def is_opus(self): + """Checks if the audio source is already encoded in Opus. + + Defaults to ``False``. + """ + return False + + def cleanup(self): + """Called when clean-up is needed to be done. + + Useful for clearing buffer data or processes after + it is done playing audio. + """ + pass + + def __del__(self): + self.cleanup() + + +class PCMAudio(AudioSource): + """Represents raw 16-bit 48KHz stereo PCM audio source. + + Attributes + ----------- + stream: file-like object + A file-like object that reads byte data representing raw PCM. + """ + + def __init__(self, stream): + self.stream = stream + + def read(self): + ret = self.stream.read(OpusEncoder.FRAME_SIZE) + if len(ret) != OpusEncoder.FRAME_SIZE: + return b"" + return ret + + +class FFmpegPCMAudio(AudioSource): + """An audio source from FFmpeg (or AVConv). + + This launches a sub-process to a specific input file given. + + .. warning:: + + You must have the ffmpeg or avconv executable in your path environment + variable in order for this to work. + + Parameters + ------------ + source: Union[str, BinaryIO] + The input that ffmpeg will take and convert to PCM bytes. + If ``pipe`` is True then this is a file-like object that is + passed to the stdin of ffmpeg. + executable: str + The executable name (and path) to use. Defaults to ``ffmpeg``. + pipe: bool + If true, denotes that ``source`` parameter will be passed + to the stdin of ffmpeg. Defaults to ``False``. + stderr: Optional[BinaryIO] + A file-like object to pass to the Popen constructor. + Could also be an instance of ``subprocess.PIPE``. + options: Optional[str] + Extra command line arguments to pass to ffmpeg after the ``-i`` flag. + before_options: Optional[str] + Extra command line arguments to pass to ffmpeg before the ``-i`` flag. + + Raises + -------- + ClientException + The subprocess failed to be created. + """ + + def __init__( + self, + source, + *, + executable="ffmpeg", + pipe=False, + stderr=None, + before_options=None, + options=None + ): + stdin = None if not pipe else source + + args = [executable] + + if isinstance(before_options, str): + args.extend(shlex.split(before_options)) + + args.append("-i") + args.append("-" if pipe else source) + args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning")) + + if isinstance(options, str): + args.extend(shlex.split(options)) + + args.append("pipe:1") + + self._process = None + try: + self._process = subprocess.Popen( + args, stdin=stdin, stdout=subprocess.PIPE, stderr=stderr + ) + self._stdout = self._process.stdout + except FileNotFoundError: + raise ClientException(executable + " was not found.") from None + except subprocess.SubprocessError as exc: + raise ClientException("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc + + def read(self): + ret = self._stdout.read(OpusEncoder.FRAME_SIZE) + if len(ret) != OpusEncoder.FRAME_SIZE: + return b"" + return ret + + def cleanup(self): + proc = self._process + if proc is None: + return + + log.info("Preparing to terminate ffmpeg process %s.", proc.pid) + proc.kill() + if proc.poll() is None: + log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) + proc.communicate() + log.info( + "ffmpeg process %s should have terminated with a return code of %s.", + proc.pid, + proc.returncode, + ) + else: + log.info( + "ffmpeg process %s successfully terminated with return code of %s.", + proc.pid, + proc.returncode, + ) + + self._process = None + + +class PCMVolumeTransformer(AudioSource): + """Transforms a previous :class:`AudioSource` to have volume controls. + + This does not work on audio sources that have :meth:`AudioSource.is_opus` + set to ``True``. + + Parameters + ------------ + original: :class:`AudioSource` + The original AudioSource to transform. + volume: float + The initial volume to set it to. + See :attr:`volume` for more info. + + Raises + ------- + TypeError + Not an audio source. + ClientException + The audio source is opus encoded. + """ + + def __init__(self, original, volume=1.0): + if not isinstance(original, AudioSource): + raise TypeError("expected AudioSource not {0.__class__.__name__}.".format(original)) + + if original.is_opus(): + raise ClientException("AudioSource must not be Opus encoded.") + + self.original = original + self.volume = volume + + @property + def volume(self): + """Retrieves or sets the volume as a floating point percentage (e.g. 1.0 for 100%).""" + return self._volume + + @volume.setter + def volume(self, value): + self._volume = max(value, 0.0) + + def cleanup(self): + self.original.cleanup() + + def read(self): + ret = self.original.read() + return audioop.mul(ret, 2, min(self._volume, 2.0)) + + +class AudioPlayer(threading.Thread): + DELAY = OpusEncoder.FRAME_LENGTH / 1000.0 + + def __init__(self, source, client, *, after=None): + threading.Thread.__init__(self) + self.daemon = True + self.source = source + self.client = client + self.after = after + + self._end = threading.Event() + self._resumed = threading.Event() + self._resumed.set() # we are not paused + self._current_error = None + self._connected = client._connected + self._lock = threading.Lock() + + if after is not None and not callable(after): + raise TypeError('Expected a callable for the "after" parameter.') + + def _do_run(self): + self.loops = 0 + self._start = time.time() + + # getattr lookup speed ups + play_audio = self.client.send_audio_packet + + while not self._end.is_set(): + # are we paused? + if not self._resumed.is_set(): + # wait until we aren't + self._resumed.wait() + continue + + # are we disconnected from voice? + if not self._connected.is_set(): + # wait until we are connected + self._connected.wait() + # reset our internal data + self.loops = 0 + self._start = time.time() + + self.loops += 1 + data = self.source.read() + + if not data: + self.stop() + break + + play_audio(data, encode=not self.source.is_opus()) + next_time = self._start + self.DELAY * self.loops + delay = max(0, self.DELAY + (next_time - time.time())) + time.sleep(delay) + + def run(self): + try: + self._do_run() + except Exception as exc: + self._current_error = exc + self.stop() + finally: + self.source.cleanup() + self._call_after() + + def _call_after(self): + if self.after is not None: + try: + self.after(self._current_error) + except Exception: + log.exception("Calling the after function failed.") + + def stop(self): + self._end.set() + self._resumed.set() + + def pause(self): + self._resumed.clear() + + def resume(self): + self.loops = 0 + self._start = time.time() + self._resumed.set() + + def is_playing(self): + return self._resumed.is_set() and not self._end.is_set() + + def is_paused(self): + return not self._end.is_set() and not self._resumed.is_set() + + def _set_source(self, source): + with self._lock: + self.pause() + self.source = source + self.resume() diff --git a/discord/raw_models.py b/discord/raw_models.py new file mode 100644 index 000000000..e80525c21 --- /dev/null +++ b/discord/raw_models.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2018 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + + +class RawMessageDeleteEvent: + """Represents the event payload for a :func:`on_raw_message_delete` event. + + Attributes + ------------ + channel_id: :class:`int` + The channel ID where the deletion took place. + guild_id: Optional[:class:`int`] + The guild ID where the deletion took place, if applicable. + message_id: :class:`int` + The message ID that got deleted. + """ + + __slots__ = ("message_id", "channel_id", "guild_id") + + def __init__(self, data): + self.message_id = int(data["id"]) + self.channel_id = int(data["channel_id"]) + + try: + self.guild_id = int(data["guild_id"]) + except KeyError: + self.guild_id = None + + +class RawBulkMessageDeleteEvent: + """Represents the event payload for a :func:`on_raw_bulk_message_delete` event. + + Attributes + ----------- + message_ids: Set[:class:`int`] + A :class:`set` of the message IDs that were deleted. + channel_id: :class:`int` + The channel ID where the message got deleted. + guild_id: Optional[:class:`int`] + The guild ID where the message got deleted, if applicable. + """ + + __slots__ = ("message_ids", "channel_id", "guild_id") + + def __init__(self, data): + self.message_ids = {int(x) for x in data.get("ids", [])} + self.channel_id = int(data["channel_id"]) + + try: + self.guild_id = int(data["guild_id"]) + except KeyError: + self.guild_id = None + + +class RawMessageUpdateEvent: + """Represents the payload for a :func:`on_raw_message_edit` event. + + Attributes + ----------- + message_id: :class:`int` + The message ID that got updated. + data: :class:`dict` + The raw data given by the + `gateway `_ + """ + + __slots__ = ("message_id", "data") + + def __init__(self, data): + self.message_id = int(data["id"]) + self.data = data + + +class RawReactionActionEvent: + """Represents the payload for a :func:`on_raw_reaction_add` or + :func:`on_raw_reaction_remove` event. + + Attributes + ----------- + message_id: :class:`int` + The message ID that got or lost a reaction. + user_id: :class:`int` + The user ID who added or removed the reaction. + channel_id: :class:`int` + The channel ID where the reaction got added or removed. + guild_id: Optional[:class:`int`] + The guild ID where the reaction got added or removed, if applicable. + emoji: :class:`PartialEmoji` + The custom or unicode emoji being used. + """ + + __slots__ = ("message_id", "user_id", "channel_id", "guild_id", "emoji") + + def __init__(self, data, emoji): + self.message_id = int(data["message_id"]) + self.channel_id = int(data["channel_id"]) + self.user_id = int(data["user_id"]) + self.emoji = emoji + + try: + self.guild_id = int(data["guild_id"]) + except KeyError: + self.guild_id = None + + +class RawReactionClearEvent: + """Represents the payload for a :func:`on_raw_reaction_clear` event. + + Attributes + ----------- + message_id: :class:`int` + The message ID that got its reactions cleared. + channel_id: :class:`int` + The channel ID where the reactions got cleared. + guild_id: Optional[:class:`int`] + The guild ID where the reactions got cleared. + """ + + __slots__ = ("message_id", "channel_id", "guild_id") + + def __init__(self, data): + self.message_id = int(data["message_id"]) + self.channel_id = int(data["channel_id"]) + + try: + self.guild_id = int(data["guild_id"]) + except KeyError: + self.guild_id = None diff --git a/discord/reaction.py b/discord/reaction.py new file mode 100644 index 000000000..60dc4e047 --- /dev/null +++ b/discord/reaction.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .iterators import ReactionIterator + + +class Reaction: + """Represents a reaction to a message. + + Depending on the way this object was created, some of the attributes can + have a value of ``None``. + + .. container:: operations + + .. describe:: x == y + + Checks if two reactions are equal. This works by checking if the emoji + is the same. So two messages with the same reaction will be considered + "equal". + + .. describe:: x != y + + Checks if two reactions are not equal. + + .. describe:: hash(x) + + Returns the reaction's hash. + + .. describe:: str(x) + + Returns the string form of the reaction's emoji. + + Attributes + ----------- + emoji: :class:`Emoji` or :class:`str` + The reaction emoji. May be a custom emoji, or a unicode emoji. + count: :class:`int` + Number of times this reaction was made + me: :class:`bool` + If the user sent this reaction. + message: :class:`Message` + Message this reaction is for. + """ + + __slots__ = ("message", "count", "emoji", "me") + + def __init__(self, *, message, data, emoji=None): + self.message = message + self.emoji = emoji or message._state.get_reaction_emoji(data["emoji"]) + self.count = data.get("count", 1) + self.me = data.get("me") + + @property + def custom_emoji(self): + """:class:`bool`: If this is a custom emoji.""" + return not isinstance(self.emoji, str) + + def __eq__(self, other): + return isinstance(other, self.__class__) and other.emoji == self.emoji + + def __ne__(self, other): + if isinstance(other, self.__class__): + return other.emoji != self.emoji + return True + + def __hash__(self): + return hash(self.emoji) + + def __str__(self): + return str(self.emoji) + + def __repr__(self): + return "".format(self) + + def users(self, limit=None, after=None): + """Returns an :class:`AsyncIterator` representing the users that have reacted to the message. + + The ``after`` parameter must represent a member + and meet the :class:`abc.Snowflake` abc. + + Parameters + ------------ + limit: int + The maximum number of results to return. + If not provided, returns all the users who + reacted to the message. + after: :class:`abc.Snowflake` + For pagination, reactions are sorted by member. + + Raises + -------- + HTTPException + Getting the users for the reaction failed. + + Examples + --------- + + Usage :: + + # I do not actually recommend doing this. + async for user in reaction.users(): + await channel.send('{0} has reacted with {1.emoji}!'.format(user, reaction)) + + Flattening into a list: :: + + users = await reaction.users().flatten() + # users is now a list... + winner = random.choice(users) + await channel.send('{} has won the raffle.'.format(winner)) + + Yields + -------- + Union[:class:`User`, :class:`Member`] + The member (if retrievable) or the user that has reacted + to this message. The case where it can be a :class:`Member` is + in a guild message context. Sometimes it can be a :class:`User` + if the member has left the guild. + """ + + if self.custom_emoji: + emoji = "{0.name}:{0.id}".format(self.emoji) + else: + emoji = self.emoji + + if limit is None: + limit = self.count + + return ReactionIterator(self.message, emoji, limit, after) diff --git a/discord/relationship.py b/discord/relationship.py new file mode 100644 index 000000000..59ee8cb63 --- /dev/null +++ b/discord/relationship.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .enums import RelationshipType, try_enum + + +class Relationship: + """Represents a relationship in Discord. + + A relationship is like a friendship, a person who is blocked, etc. + Only non-bot accounts can have relationships. + + Attributes + ----------- + user: :class:`User` + The user you have the relationship with. + type: :class:`RelationshipType` + The type of relationship you have. + """ + + __slots__ = ("type", "user", "_state") + + def __init__(self, *, state, data): + self._state = state + self.type = try_enum(RelationshipType, data["type"]) + self.user = state.store_user(data["user"]) + + def __repr__(self): + return "".format(self) + + async def delete(self): + """|coro| + + Deletes the relationship. + + Raises + ------ + HTTPException + Deleting the relationship failed. + """ + + await self._state.http.remove_relationship(self.user.id) + + async def accept(self): + """|coro| + + Accepts the relationship request. e.g. accepting a + friend request. + + Raises + ------- + HTTPException + Accepting the relationship failed. + """ + + await self._state.http.add_relationship(self.user.id) diff --git a/discord/role.py b/discord/role.py new file mode 100644 index 000000000..e0bd6fa96 --- /dev/null +++ b/discord/role.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from .permissions import Permissions +from .errors import InvalidArgument +from .colour import Colour +from .mixins import Hashable +from .utils import snowflake_time + + +class Role(Hashable): + """Represents a Discord role in a :class:`Guild`. + + .. container:: operations + + .. describe:: x == y + + Checks if two roles are equal. + + .. describe:: x != y + + Checks if two roles are not equal. + + .. describe:: x > y + + Checks if a role is higher than another in the hierarchy. + + .. describe:: x < y + + Checks if a role is lower than another in the hierarchy. + + .. describe:: x >= y + + Checks if a role is higher or equal to another in the hierarchy. + + .. describe:: x <= y + + Checks if a role is lower or equal to another in the hierarchy. + + .. describe:: hash(x) + + Return the role's hash. + + .. describe:: str(x) + + Returns the role's name. + + Attributes + ---------- + id: :class:`int` + The ID for the role. + name: :class:`str` + The name of the role. + permissions: :class:`Permissions` + Represents the role's permissions. + guild: :class:`Guild` + The guild the role belongs to. + colour: :class:`Colour` + Represents the role colour. An alias exists under ``color``. + hoist: :class:`bool` + Indicates if the role will be displayed separately from other members. + position: :class:`int` + The position of the role. This number is usually positive. The bottom + role has a position of 0. + managed: :class:`bool` + Indicates if the role is managed by the guild through some form of + integrations such as Twitch. + mentionable: :class:`bool` + Indicates if the role can be mentioned by users. + """ + + __slots__ = ( + "id", + "name", + "permissions", + "color", + "colour", + "position", + "managed", + "mentionable", + "hoist", + "guild", + "_state", + ) + + def __init__(self, *, guild, state, data): + self.guild = guild + self._state = state + self.id = int(data["id"]) + self._update(data) + + def __str__(self): + return self.name + + def __repr__(self): + return "".format(self) + + def __lt__(self, other): + if not isinstance(other, Role) or not isinstance(self, Role): + return NotImplemented + + if self.guild != other.guild: + raise RuntimeError("cannot compare roles from two different guilds.") + + # the @everyone role is always the lowest role in hierarchy + guild_id = self.guild.id + if self.id == guild_id: + # everyone_role < everyone_role -> False + return other.id != guild_id + + if self.position < other.position: + return True + + if self.position == other.position: + return int(self.id) > int(other.id) + + return False + + def __le__(self, other): + r = Role.__lt__(other, self) + if r is NotImplemented: + return NotImplemented + return not r + + def __gt__(self, other): + return Role.__lt__(other, self) + + def __ge__(self, other): + r = Role.__lt__(self, other) + if r is NotImplemented: + return NotImplemented + return not r + + def _update(self, data): + self.name = data["name"] + self.permissions = Permissions(data.get("permissions", 0)) + self.position = data.get("position", 0) + self.colour = Colour(data.get("color", 0)) + self.hoist = data.get("hoist", False) + self.managed = data.get("managed", False) + self.mentionable = data.get("mentionable", False) + self.color = self.colour + + def is_default(self): + """Checks if the role is the default role.""" + return self.guild.id == self.id + + @property + def created_at(self): + """Returns the role's creation time in UTC.""" + return snowflake_time(self.id) + + @property + def mention(self): + """Returns a string that allows you to mention a role.""" + return "<@&%s>" % self.id + + @property + def members(self): + """Returns a :class:`list` of :class:`Member` with this role.""" + all_members = self.guild.members + if self.is_default(): + return all_members + + role_id = self.id + return [member for member in all_members if member._roles.has(role_id)] + + async def _move(self, position, reason): + if position <= 0: + raise InvalidArgument("Cannot move role to position 0 or below") + + if self.is_default(): + raise InvalidArgument("Cannot move default role") + + if self.position == position: + return # Save discord the extra request. + + http = self._state.http + + change_range = range(min(self.position, position), max(self.position, position) + 1) + roles = [ + r.id for r in self.guild.roles[1:] if r.position in change_range and r.id != self.id + ] + + if self.position > position: + roles.insert(0, self.id) + else: + roles.append(self.id) + + payload = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] + await http.move_role_position(self.guild.id, payload, reason=reason) + + async def edit(self, *, reason=None, **fields): + """|coro| + + Edits the role. + + You must have the :attr:`~Permissions.manage_roles` permission to + use this. + + All fields are optional. + + Parameters + ----------- + name: str + The new role name to change to. + permissions: :class:`Permissions` + The new permissions to change to. + colour: :class:`Colour` + The new colour to change to. (aliased to color as well) + hoist: bool + Indicates if the role should be shown separately in the member list. + mentionable: bool + Indicates if the role should be mentionable by others. + position: int + The new role's position. This must be below your top role's + position or it will fail. + reason: Optional[str] + The reason for editing this role. Shows up on the audit log. + + Raises + ------- + Forbidden + You do not have permissions to change the role. + HTTPException + Editing the role failed. + InvalidArgument + An invalid position was given or the default + role was asked to be moved. + """ + + position = fields.get("position") + if position is not None: + await self._move(position, reason=reason) + self.position = position + + try: + colour = fields["colour"] + except KeyError: + colour = fields.get("color", self.colour) + + payload = { + "name": fields.get("name", self.name), + "permissions": fields.get("permissions", self.permissions).value, + "color": colour.value, + "hoist": fields.get("hoist", self.hoist), + "mentionable": fields.get("mentionable", self.mentionable), + } + + data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) + self._update(data) + + async def delete(self, *, reason=None): + """|coro| + + Deletes the role. + + You must have the :attr:`~Permissions.manage_roles` permission to + use this. + + Parameters + ----------- + reason: Optional[str] + The reason for deleting this role. Shows up on the audit log. + + Raises + -------- + Forbidden + You do not have permissions to delete the role. + HTTPException + Deleting the role failed. + """ + + await self._state.http.delete_role(self.guild.id, self.id, reason=reason) diff --git a/discord/shard.py b/discord/shard.py new file mode 100644 index 000000000..6f028def9 --- /dev/null +++ b/discord/shard.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import itertools +import logging + +import websockets + +from .state import AutoShardedConnectionState +from .client import Client +from .gateway import * +from .errors import ClientException, InvalidArgument +from . import utils +from .enums import Status + +log = logging.getLogger(__name__) + + +class Shard: + def __init__(self, ws, client): + self.ws = ws + self._client = client + self.loop = self._client.loop + self._current = self.loop.create_future() + self._current.set_result(None) # we just need an already done future + self._pending = asyncio.Event(loop=self.loop) + self._pending_task = None + + @property + def id(self): + return self.ws.shard_id + + def is_pending(self): + return not self._pending.is_set() + + def complete_pending_reads(self): + self._pending.set() + + async def _pending_reads(self): + try: + while self.is_pending(): + await self.poll() + except asyncio.CancelledError: + pass + + def launch_pending_reads(self): + self._pending_task = asyncio.ensure_future(self._pending_reads(), loop=self.loop) + + def wait(self): + return self._pending_task + + async def poll(self): + try: + await self.ws.poll_event() + except ResumeWebSocket: + log.info("Got a request to RESUME the websocket at Shard ID %s.", self.id) + coro = DiscordWebSocket.from_client( + self._client, + resume=True, + shard_id=self.id, + session=self.ws.session_id, + sequence=self.ws.sequence, + ) + self.ws = await asyncio.wait_for(coro, timeout=180.0, loop=self.loop) + + def get_future(self): + if self._current.done(): + self._current = asyncio.ensure_future(self.poll(), loop=self.loop) + + return self._current + + +class AutoShardedClient(Client): + """A client similar to :class:`Client` except it handles the complications + of sharding for the user into a more manageable and transparent single + process bot. + + When using this client, you will be able to use it as-if it was a regular + :class:`Client` with a single shard when implementation wise internally it + is split up into multiple shards. This allows you to not have to deal with + IPC or other complicated infrastructure. + + It is recommended to use this client only if you have surpassed at least + 1000 guilds. + + If no :attr:`shard_count` is provided, then the library will use the + Bot Gateway endpoint call to figure out how many shards to use. + + If a ``shard_ids`` parameter is given, then those shard IDs will be used + to launch the internal shards. Note that :attr:`shard_count` must be provided + if this is used. By default, when omitted, the client will launch shards from + 0 to ``shard_count - 1``. + + Attributes + ------------ + shard_ids: Optional[List[:class:`int`]] + An optional list of shard_ids to launch the shards with. + """ + + def __init__(self, *args, loop=None, **kwargs): + kwargs.pop("shard_id", None) + self.shard_ids = kwargs.pop("shard_ids", None) + super().__init__(*args, loop=loop, **kwargs) + + if self.shard_ids is not None: + if self.shard_count is None: + raise ClientException( + "When passing manual shard_ids, you must provide a shard_count." + ) + elif not isinstance(self.shard_ids, (list, tuple)): + raise ClientException("shard_ids parameter must be a list or a tuple.") + + self._connection = AutoShardedConnectionState( + dispatch=self.dispatch, + chunker=self._chunker, + handlers=self._handlers, + syncer=self._syncer, + http=self.http, + loop=self.loop, + **kwargs + ) + + # instead of a single websocket, we have multiple + # the key is the shard_id + self.shards = {} + + def _get_websocket(guild_id): + i = (guild_id >> 22) % self.shard_count + return self.shards[i].ws + + self._connection._get_websocket = _get_websocket + + async def _chunker(self, guild, *, shard_id=None): + try: + guild_id = guild.id + shard_id = shard_id or guild.shard_id + except AttributeError: + guild_id = [s.id for s in guild] + + payload = {"op": 8, "d": {"guild_id": guild_id, "query": "", "limit": 0}} + + ws = self.shards[shard_id].ws + await ws.send_as_json(payload) + + @property + def latency(self): + """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. + + This operates similarly to :meth:`.Client.latency` except it uses the average + latency of every shard's latency. To get a list of shard latency, check the + :attr:`latencies` property. Returns ``nan`` if there are no shards ready. + """ + if not self.shards: + return float("nan") + return sum(latency for _, latency in self.latencies) / len(self.shards) + + @property + def latencies(self): + """List[Tuple[:class:`int`, :class:`float`]]: A list of latencies between a HEARTBEAT and a HEARTBEAT_ACK in seconds. + + This returns a list of tuples with elements ``(shard_id, latency)``. + """ + return [(shard_id, shard.ws.latency) for shard_id, shard in self.shards.items()] + + async def request_offline_members(self, *guilds): + r"""|coro| + + Requests previously offline members from the guild to be filled up + into the :attr:`Guild.members` cache. This function is usually not + called. It should only be used if you have the ``fetch_offline_members`` + parameter set to ``False``. + + When the client logs on and connects to the websocket, Discord does + not provide the library with offline members if the number of members + in the guild is larger than 250. You can check if a guild is large + if :attr:`Guild.large` is ``True``. + + Parameters + ----------- + \*guilds + An argument list of guilds to request offline members for. + + Raises + ------- + InvalidArgument + If any guild is unavailable or not large in the collection. + """ + if any(not g.large or g.unavailable for g in guilds): + raise InvalidArgument("An unavailable or non-large guild was passed.") + + _guilds = sorted(guilds, key=lambda g: g.shard_id) + for shard_id, sub_guilds in itertools.groupby(_guilds, key=lambda g: g.shard_id): + sub_guilds = list(sub_guilds) + await self._connection.request_offline_members(sub_guilds, shard_id=shard_id) + + async def launch_shard(self, gateway, shard_id): + try: + coro = websockets.connect( + gateway, loop=self.loop, klass=DiscordWebSocket, compression=None + ) + ws = await asyncio.wait_for(coro, loop=self.loop, timeout=180.0) + except Exception: + log.info("Failed to connect for shard_id: %s. Retrying...", shard_id) + await asyncio.sleep(5.0, loop=self.loop) + return await self.launch_shard(gateway, shard_id) + + ws.token = self.http.token + ws._connection = self._connection + ws._dispatch = self.dispatch + ws.gateway = gateway + ws.shard_id = shard_id + ws.shard_count = self.shard_count + ws._max_heartbeat_timeout = self._connection.heartbeat_timeout + + try: + # OP HELLO + await asyncio.wait_for(ws.poll_event(), loop=self.loop, timeout=180.0) + await asyncio.wait_for(ws.identify(), loop=self.loop, timeout=180.0) + except asyncio.TimeoutError: + log.info("Timed out when connecting for shard_id: %s. Retrying...", shard_id) + await asyncio.sleep(5.0, loop=self.loop) + return await self.launch_shard(gateway, shard_id) + + # keep reading the shard while others connect + self.shards[shard_id] = ret = Shard(ws, self) + ret.launch_pending_reads() + await asyncio.sleep(5.0, loop=self.loop) + + async def launch_shards(self): + if self.shard_count is None: + self.shard_count, gateway = await self.http.get_bot_gateway() + else: + gateway = await self.http.get_gateway() + + self._connection.shard_count = self.shard_count + + shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) + + for shard_id in shard_ids: + await self.launch_shard(gateway, shard_id) + + shards_to_wait_for = [] + for shard in self.shards.values(): + shard.complete_pending_reads() + shards_to_wait_for.append(shard.wait()) + + # wait for all pending tasks to finish + await utils.sane_wait_for(shards_to_wait_for, timeout=300.0, loop=self.loop) + + async def _connect(self): + await self.launch_shards() + + while True: + pollers = [shard.get_future() for shard in self.shards.values()] + done, _ = await asyncio.wait( + pollers, loop=self.loop, return_when=asyncio.FIRST_COMPLETED + ) + for f in done: + # we wanna re-raise to the main Client.connect handler if applicable + f.result() + + async def close(self): + """|coro| + + Closes the connection to discord. + """ + if self.is_closed(): + return + + self._closed.set() + + for vc in self.voice_clients: + try: + await vc.disconnect() + except Exception: + pass + + to_close = [shard.ws.close() for shard in self.shards.values()] + if to_close: + await asyncio.wait(to_close, loop=self.loop) + + await self.http.close() + + async def change_presence(self, *, activity=None, status=None, afk=False, shard_id=None): + """|coro| + + Changes the client's presence. + + The activity parameter is a :class:`Activity` object (not a string) that represents + the activity being done currently. This could also be the slimmed down versions, + :class:`Game` and :class:`Streaming`. + + Example: :: + + game = discord.Game("with the API") + await client.change_presence(status=discord.Status.idle, activity=game) + + Parameters + ---------- + activity: Optional[Union[:class:`Game`, :class:`Streaming`, :class:`Activity`]] + The activity being done. ``None`` if no currently active activity is done. + status: Optional[:class:`Status`] + Indicates what status to change to. If None, then + :attr:`Status.online` is used. + afk: bool + Indicates if you are going AFK. This allows the discord + client to know how to handle push notifications better + for you in case you are actually idle and not lying. + shard_id: Optional[int] + The shard_id to change the presence to. If not specified + or ``None``, then it will change the presence of every + shard the bot can see. + + Raises + ------ + InvalidArgument + If the ``activity`` parameter is not of proper type. + """ + + if status is None: + status = "online" + status_enum = Status.online + elif status is Status.offline: + status = "invisible" + status_enum = Status.offline + else: + status_enum = status + status = str(status) + + if shard_id is None: + for shard in self.shards.values(): + await shard.ws.change_presence(activity=activity, status=status, afk=afk) + + guilds = self._connection.guilds + else: + shard = self.shards[shard_id] + await shard.ws.change_presence(activity=activity, status=status, afk=afk) + guilds = [g for g in self._connection.guilds if g.shard_id == shard_id] + + for guild in guilds: + me = guild.me + if me is None: + continue + + me.activities = (activity,) + me.status = status_enum diff --git a/discord/state.py b/discord/state.py new file mode 100644 index 000000000..a6462de65 --- /dev/null +++ b/discord/state.py @@ -0,0 +1,1048 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +from collections import deque, namedtuple, OrderedDict +import copy +import datetime +import enum +import itertools +import logging +import math +import weakref + +from .guild import Guild +from .activity import _ActivityTag +from .user import User, ClientUser +from .emoji import Emoji, PartialEmoji +from .message import Message +from .relationship import Relationship +from .channel import * +from .raw_models import * +from .member import Member +from .role import Role +from .enums import ChannelType, try_enum, Status +from . import utils +from .embeds import Embed + + +class ListenerType(enum.Enum): + chunk = 0 + + +Listener = namedtuple("Listener", ("type", "future", "predicate")) +log = logging.getLogger(__name__) +ReadyState = namedtuple("ReadyState", ("launch", "guilds")) + + +class ConnectionState: + def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options): + self.loop = loop + self.http = http + self.max_messages = max(options.get("max_messages", 5000), 100) + self.dispatch = dispatch + self.chunker = chunker + self.syncer = syncer + self.is_bot = None + self.handlers = handlers + self.shard_count = None + self._ready_task = None + self._fetch_offline = options.get("fetch_offline_members", True) + self.heartbeat_timeout = options.get("heartbeat_timeout", 60.0) + self._listeners = [] + + activity = options.get("activity", None) + if activity: + if not isinstance(activity, _ActivityTag): + raise TypeError("activity parameter must be one of Game, Streaming, or Activity.") + + activity = activity.to_dict() + + status = options.get("status", None) + if status: + if status is Status.offline: + status = "invisible" + else: + status = str(status) + + self._activity = activity + self._status = status + + self.clear() + + def clear(self): + self.user = None + self._users = weakref.WeakValueDictionary() + self._emojis = {} + self._calls = {} + self._guilds = {} + self._voice_clients = {} + + # LRU of max size 128 + self._private_channels = OrderedDict() + # extra dict to look up private channels by user id + self._private_channels_by_user = {} + self._messages = deque(maxlen=self.max_messages) + + def process_listeners(self, listener_type, argument, result): + removed = [] + for i, listener in enumerate(self._listeners): + if listener.type != listener_type: + continue + + future = listener.future + if future.cancelled(): + removed.append(i) + continue + + try: + passed = listener.predicate(argument) + except Exception as exc: + future.set_exception(exc) + removed.append(i) + else: + if passed: + future.set_result(result) + removed.append(i) + if listener.type == ListenerType.chunk: + break + + for index in reversed(removed): + del self._listeners[index] + + def call_handlers(self, key, *args, **kwargs): + try: + func = self.handlers[key] + except KeyError: + pass + else: + func(*args, **kwargs) + + @property + def self_id(self): + u = self.user + return u.id if u else None + + @property + def voice_clients(self): + return list(self._voice_clients.values()) + + def _get_voice_client(self, guild_id): + return self._voice_clients.get(guild_id) + + def _add_voice_client(self, guild_id, voice): + self._voice_clients[guild_id] = voice + + def _remove_voice_client(self, guild_id): + self._voice_clients.pop(guild_id, None) + + def _update_references(self, ws): + for vc in self.voice_clients: + vc.main_ws = ws + + def store_user(self, data): + # this way is 300% faster than `dict.setdefault`. + user_id = int(data["id"]) + try: + return self._users[user_id] + except KeyError: + user = User(state=self, data=data) + if user.discriminator != "0000": + self._users[user_id] = user + return user + + def get_user(self, id): + return self._users.get(id) + + def store_emoji(self, guild, data): + emoji_id = int(data["id"]) + self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) + return emoji + + @property + def guilds(self): + return list(self._guilds.values()) + + def _get_guild(self, guild_id): + return self._guilds.get(guild_id) + + def _add_guild(self, guild): + self._guilds[guild.id] = guild + + def _remove_guild(self, guild): + self._guilds.pop(guild.id, None) + + for emoji in guild.emojis: + self._emojis.pop(emoji.id, None) + + del guild + + @property + def emojis(self): + return list(self._emojis.values()) + + def get_emoji(self, emoji_id): + return self._emojis.get(emoji_id) + + @property + def private_channels(self): + return list(self._private_channels.values()) + + def _get_private_channel(self, channel_id): + try: + value = self._private_channels[channel_id] + except KeyError: + return None + else: + self._private_channels.move_to_end(channel_id) + return value + + def _get_private_channel_by_user(self, user_id): + return self._private_channels_by_user.get(user_id) + + def _add_private_channel(self, channel): + channel_id = channel.id + self._private_channels[channel_id] = channel + + if self.is_bot and len(self._private_channels) > 128: + _, to_remove = self._private_channels.popitem(last=False) + if isinstance(to_remove, DMChannel): + self._private_channels_by_user.pop(to_remove.recipient.id, None) + + if isinstance(channel, DMChannel): + self._private_channels_by_user[channel.recipient.id] = channel + + def add_dm_channel(self, data): + channel = DMChannel(me=self.user, state=self, data=data) + self._add_private_channel(channel) + return channel + + def _remove_private_channel(self, channel): + self._private_channels.pop(channel.id, None) + if isinstance(channel, DMChannel): + self._private_channels_by_user.pop(channel.recipient.id, None) + + def _get_message(self, msg_id): + return utils.find(lambda m: m.id == msg_id, self._messages) + + def _add_guild_from_data(self, guild): + guild = Guild(data=guild, state=self) + self._add_guild(guild) + return guild + + def chunks_needed(self, guild): + for _ in range(math.ceil(guild._member_count / 1000)): + yield self.receive_chunk(guild.id) + + def _get_guild_channel(self, data): + try: + guild = self._get_guild(int(data["guild_id"])) + except KeyError: + channel = self.get_channel(int(data["channel_id"])) + guild = None + else: + channel = guild and guild.get_channel(int(data["channel_id"])) + + return channel, guild + + async def request_offline_members(self, guilds): + # get all the chunks + chunks = [] + for guild in guilds: + chunks.extend(self.chunks_needed(guild)) + + # we only want to request ~75 guilds per chunk request. + splits = [guilds[i : i + 75] for i in range(0, len(guilds), 75)] + for split in splits: + await self.chunker(split) + + # wait for the chunks + if chunks: + try: + await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0, loop=self.loop) + except asyncio.TimeoutError: + log.info("Somehow timed out waiting for chunks.") + + async def _delay_ready(self): + try: + launch = self._ready_state.launch + + # only real bots wait for GUILD_CREATE streaming + if self.is_bot: + while not launch.is_set(): + # this snippet of code is basically waiting 2 seconds + # until the last GUILD_CREATE was sent + launch.set() + await asyncio.sleep(2, loop=self.loop) + + guilds = next(zip(*self._ready_state.guilds), []) + if self._fetch_offline: + await self.request_offline_members(guilds) + + for guild, unavailable in self._ready_state.guilds: + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + # remove the state + try: + del self._ready_state + except AttributeError: + pass # already been deleted somehow + + # call GUILD_SYNC after we're done chunking + if not self.is_bot: + log.info("Requesting GUILD_SYNC for %s guilds", len(self.guilds)) + await self.syncer([s.id for s in self.guilds]) + except asyncio.CancelledError: + pass + else: + # dispatch the event + self.call_handlers("ready") + self.dispatch("ready") + finally: + self._ready_task = None + + def parse_ready(self, data): + if self._ready_task is not None: + self._ready_task.cancel() + + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + self.clear() + self.user = ClientUser(state=self, data=data["user"]) + + guilds = self._ready_state.guilds + for guild_data in data["guilds"]: + guild = self._add_guild_from_data(guild_data) + if (not self.is_bot and not guild.unavailable) or guild.large: + guilds.append((guild, guild.unavailable)) + + for relationship in data.get("relationships", []): + try: + r_id = int(relationship["id"]) + except KeyError: + continue + else: + self.user._relationships[r_id] = Relationship(state=self, data=relationship) + + for pm in data.get("private_channels", []): + factory, _ = _channel_factory(pm["type"]) + self._add_private_channel(factory(me=self.user, data=pm, state=self)) + + self.dispatch("connect") + self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop) + + def parse_resumed(self, data): + self.dispatch("resumed") + + def parse_message_create(self, data): + channel, _ = self._get_guild_channel(data) + message = Message(channel=channel, data=data, state=self) + self.dispatch("message", message) + self._messages.append(message) + + def parse_message_delete(self, data): + raw = RawMessageDeleteEvent(data) + self.dispatch("raw_message_delete", raw) + + found = self._get_message(raw.message_id) + if found is not None: + self.dispatch("message_delete", found) + self._messages.remove(found) + + def parse_message_delete_bulk(self, data): + raw = RawBulkMessageDeleteEvent(data) + self.dispatch("raw_bulk_message_delete", raw) + + to_be_deleted = [message for message in self._messages if message.id in raw.message_ids] + for msg in to_be_deleted: + self.dispatch("message_delete", msg) + self._messages.remove(msg) + + def parse_message_update(self, data): + raw = RawMessageUpdateEvent(data) + self.dispatch("raw_message_edit", raw) + message = self._get_message(raw.message_id) + if message is not None: + older_message = copy.copy(message) + if "call" in data: + # call state message edit + message._handle_call(data["call"]) + elif "content" not in data: + # embed only edit + message.embeds = [Embed.from_data(d) for d in data["embeds"]] + else: + message._update(channel=message.channel, data=data) + + self.dispatch("message_edit", older_message, message) + + def parse_message_reaction_add(self, data): + emoji_data = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji_data, "id") + emoji = PartialEmoji(animated=emoji_data["animated"], id=emoji_id, name=emoji_data["name"]) + raw = RawReactionActionEvent(data, emoji) + self.dispatch("raw_reaction_add", raw) + + # rich interface here + message = self._get_message(raw.message_id) + if message is not None: + emoji = self._upgrade_partial_emoji(emoji) + reaction = message._add_reaction(data, emoji, raw.user_id) + user = self._get_reaction_user(message.channel, raw.user_id) + if user: + self.dispatch("reaction_add", reaction, user) + + def parse_message_reaction_remove_all(self, data): + raw = RawReactionClearEvent(data) + self.dispatch("raw_reaction_clear", raw) + + message = self._get_message(raw.message_id) + if message is not None: + old_reactions = message.reactions.copy() + message.reactions.clear() + self.dispatch("reaction_clear", message, old_reactions) + + def parse_message_reaction_remove(self, data): + emoji_data = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji_data, "id") + emoji = PartialEmoji(animated=emoji_data["animated"], id=emoji_id, name=emoji_data["name"]) + raw = RawReactionActionEvent(data, emoji) + self.dispatch("raw_reaction_remove", raw) + + message = self._get_message(raw.message_id) + if message is not None: + emoji = self._upgrade_partial_emoji(emoji) + try: + reaction = message._remove_reaction(data, emoji, raw.user_id) + except (AttributeError, ValueError): # eventual consistency lol + pass + else: + user = self._get_reaction_user(message.channel, raw.user_id) + if user: + self.dispatch("reaction_remove", reaction, user) + + def parse_presence_update(self, data): + guild_id = utils._get_as_snowflake(data, "guild_id") + guild = self._get_guild(guild_id) + if guild is None: + log.warning( + "PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.", guild_id + ) + return + + user = data["user"] + member_id = int(user["id"]) + member = guild.get_member(member_id) + if member is None: + if "username" not in user: + # sometimes we receive 'incomplete' member data post-removal. + # skip these useless cases. + return + + member = Member(guild=guild, data=data, state=self) + guild._add_member(member) + + old_member = Member._copy(member) + member._presence_update(data=data, user=user) + self.dispatch("member_update", old_member, member) + + def parse_user_update(self, data): + self.user = ClientUser(state=self, data=data) + + def parse_channel_delete(self, data): + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = int(data["id"]) + if guild is not None: + channel = guild.get_channel(channel_id) + if channel is not None: + guild._remove_channel(channel) + self.dispatch("guild_channel_delete", channel) + else: + # the reason we're doing this is so it's also removed from the + # private channel by user cache as well + channel = self._get_private_channel(channel_id) + if channel is not None: + self._remove_private_channel(channel) + self.dispatch("private_channel_delete", channel) + + def parse_channel_update(self, data): + channel_type = try_enum(ChannelType, data.get("type")) + channel_id = int(data["id"]) + if channel_type is ChannelType.group: + channel = self._get_private_channel(channel_id) + old_channel = copy.copy(channel) + channel._update_group(data) + self.dispatch("private_channel_update", old_channel, channel) + return + + guild_id = utils._get_as_snowflake(data, "guild_id") + guild = self._get_guild(guild_id) + if guild is not None: + channel = guild.get_channel(channel_id) + if channel is not None: + old_channel = copy.copy(channel) + channel._update(guild, data) + self.dispatch("guild_channel_update", old_channel, channel) + else: + log.warning( + "CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.", channel_id + ) + else: + log.warning( + "CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.", guild_id + ) + + def parse_channel_create(self, data): + factory, ch_type = _channel_factory(data["type"]) + if factory is None: + log.warning( + "CHANNEL_CREATE referencing an unknown channel type %s. Discarding.", data["type"] + ) + return + + channel = None + + if ch_type in (ChannelType.group, ChannelType.private): + channel_id = int(data["id"]) + if self._get_private_channel(channel_id) is None: + channel = factory(me=self.user, data=data, state=self) + self._add_private_channel(channel) + self.dispatch("private_channel_create", channel) + else: + guild_id = utils._get_as_snowflake(data, "guild_id") + guild = self._get_guild(guild_id) + if guild is not None: + channel = factory(guild=guild, state=self, data=data) + guild._add_channel(channel) + self.dispatch("guild_channel_create", channel) + else: + log.warning( + "CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.", guild_id + ) + return + + def parse_channel_pins_update(self, data): + channel_id = int(data["channel_id"]) + channel = self.get_channel(channel_id) + if channel is None: + log.warning( + "CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) + return + + last_pin = ( + utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None + ) + + try: + # I have not imported discord.abc in this file + # the isinstance check is also 2x slower than just checking this attribute + # so we're just gonna check it since it's easier and faster and lazier + channel.guild + except AttributeError: + self.dispatch("private_channel_pins_update", channel, last_pin) + else: + self.dispatch("guild_channel_pins_update", channel, last_pin) + + def parse_channel_recipient_add(self, data): + channel = self._get_private_channel(int(data["channel_id"])) + user = self.store_user(data["user"]) + channel.recipients.append(user) + self.dispatch("group_join", channel, user) + + def parse_channel_recipient_remove(self, data): + channel = self._get_private_channel(int(data["channel_id"])) + user = self.store_user(data["user"]) + try: + channel.recipients.remove(user) + except ValueError: + pass + else: + self.dispatch("group_remove", channel, user) + + def parse_guild_member_add(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is None: + log.warning( + "GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + member = Member(guild=guild, data=data, state=self) + guild._add_member(member) + guild._member_count += 1 + self.dispatch("member_join", member) + + def parse_guild_member_remove(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is not None: + user_id = int(data["user"]["id"]) + member = guild.get_member(user_id) + if member is not None: + guild._remove_member(member) + guild._member_count -= 1 + self.dispatch("member_remove", member) + else: + log.warning( + "GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + + def parse_guild_member_update(self, data): + guild = self._get_guild(int(data["guild_id"])) + user = data["user"] + user_id = int(user["id"]) + if guild is None: + log.warning( + "GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + member = guild.get_member(user_id) + if member is not None: + old_member = copy.copy(member) + member._update(data, user) + self.dispatch("member_update", old_member, member) + else: + log.warning( + "GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.", user_id + ) + + def parse_guild_emojis_update(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is None: + log.warning( + "GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + before_emojis = guild.emojis + for emoji in before_emojis: + self._emojis.pop(emoji.id, None) + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data["emojis"])) + self.dispatch("guild_emojis_update", guild, before_emojis, guild.emojis) + + def _get_create_guild(self, data): + if data.get("unavailable") is False: + # GUILD_CREATE with unavailable in the response + # usually means that the guild has become available + # and is therefore in the cache + guild = self._get_guild(int(data["id"])) + if guild is not None: + guild.unavailable = False + guild._from_data(data) + return guild + + return self._add_guild_from_data(data) + + async def _chunk_and_dispatch(self, guild, unavailable): + chunks = list(self.chunks_needed(guild)) + await self.chunker(guild) + if chunks: + try: + await utils.sane_wait_for(chunks, timeout=len(chunks), loop=self.loop) + except asyncio.TimeoutError: + log.info("Somehow timed out waiting for chunks.") + + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + def parse_guild_create(self, data): + unavailable = data.get("unavailable") + if unavailable is True: + # joined a guild with unavailable == True so.. + return + + guild = self._get_create_guild(data) + + # check if it requires chunking + if guild.large: + if unavailable is False: + # check if we're waiting for 'useful' READY + # and if we are, we don't want to dispatch any + # event such as guild_join or guild_available + # because we're still in the 'READY' phase. Or + # so we say. + try: + state = self._ready_state + state.launch.clear() + state.guilds.append((guild, unavailable)) + except AttributeError: + # the _ready_state attribute is only there during + # processing of useful READY. + pass + else: + return + + # since we're not waiting for 'useful' READY we'll just + # do the chunk request here if wanted + if self._fetch_offline: + asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop) + return + + # Dispatch available if newly available + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + def parse_guild_sync(self, data): + guild = self._get_guild(int(data["id"])) + guild._sync(data) + + def parse_guild_update(self, data): + guild = self._get_guild(int(data["id"])) + if guild is not None: + old_guild = copy.copy(guild) + guild._from_data(data) + self.dispatch("guild_update", old_guild, guild) + else: + log.warning( + "GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.", data["id"] + ) + + def parse_guild_delete(self, data): + guild = self._get_guild(int(data["id"])) + if guild is None: + log.warning( + "GUILD_DELETE referencing an unknown guild ID: %s. Discarding.", data["id"] + ) + return + + if data.get("unavailable", False) and guild is not None: + # GUILD_DELETE with unavailable being True means that the + # guild that was available is now currently unavailable + guild.unavailable = True + self.dispatch("guild_unavailable", guild) + return + + # do a cleanup of the messages cache + self._messages = deque( + (msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages + ) + + self._remove_guild(guild) + self.dispatch("guild_remove", guild) + + def parse_guild_ban_add(self, data): + # we make the assumption that GUILD_BAN_ADD is done + # before GUILD_MEMBER_REMOVE is called + # hence we don't remove it from cache or do anything + # strange with it, the main purpose of this event + # is mainly to dispatch to another event worth listening to for logging + guild = self._get_guild(int(data["guild_id"])) + if guild is not None: + try: + user = User(data=data["user"], state=self) + except KeyError: + pass + else: + member = guild.get_member(user.id) or user + self.dispatch("member_ban", guild, member) + + def parse_guild_ban_remove(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is not None: + if "user" in data: + user = self.store_user(data["user"]) + self.dispatch("member_unban", guild, user) + + def parse_guild_role_create(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is None: + log.warning( + "GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + return + + role_data = data["role"] + role = Role(guild=guild, data=role_data, state=self) + guild._add_role(role) + self.dispatch("guild_role_create", role) + + def parse_guild_role_delete(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is not None: + role_id = int(data["role_id"]) + try: + role = guild._remove_role(role_id) + except KeyError: + return + else: + self.dispatch("guild_role_delete", role) + else: + log.warning( + "GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + + def parse_guild_role_update(self, data): + guild = self._get_guild(int(data["guild_id"])) + if guild is not None: + role_data = data["role"] + role_id = int(role_data["id"]) + role = guild.get_role(role_id) + if role is not None: + old_role = copy.copy(role) + role._update(role_data) + self.dispatch("guild_role_update", old_role, role) + else: + log.warning( + "GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) + + def parse_guild_members_chunk(self, data): + guild_id = int(data["guild_id"]) + guild = self._get_guild(guild_id) + members = data.get("members", []) + for member in members: + m = Member(guild=guild, data=member, state=self) + existing = guild.get_member(m.id) + if existing is None or existing.joined_at is None: + guild._add_member(m) + + log.info("Processed a chunk for %s members in guild ID %s.", len(members), guild_id) + self.process_listeners(ListenerType.chunk, guild, len(members)) + + def parse_webhooks_update(self, data): + channel = self.get_channel(int(data["channel_id"])) + if channel: + self.dispatch("webhooks_update", channel) + + def parse_voice_state_update(self, data): + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = utils._get_as_snowflake(data, "channel_id") + if guild is not None: + if int(data["user_id"]) == self.user.id: + voice = self._get_voice_client(guild.id) + if voice is not None: + ch = guild.get_channel(channel_id) + if ch is not None: + voice.channel = ch + + member, before, after = guild._update_voice_state(data, channel_id) + if member is not None: + self.dispatch("voice_state_update", member, before, after) + else: + log.warning( + "VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.", + data["user_id"], + ) + else: + # in here we're either at private or group calls + call = self._calls.get(channel_id) + if call is not None: + call._update_voice_state(data) + + def parse_voice_server_update(self, data): + try: + key_id = int(data["guild_id"]) + except KeyError: + key_id = int(data["channel_id"]) + + vc = self._get_voice_client(key_id) + if vc is not None: + asyncio.ensure_future(vc._create_socket(key_id, data)) + + def parse_typing_start(self, data): + channel, guild = self._get_guild_channel(data) + if channel is not None: + member = None + user_id = utils._get_as_snowflake(data, "user_id") + if isinstance(channel, DMChannel): + member = channel.recipient + elif isinstance(channel, TextChannel) and guild is not None: + member = guild.get_member(user_id) + elif isinstance(channel, GroupChannel): + member = utils.find(lambda x: x.id == user_id, channel.recipients) + + if member is not None: + timestamp = datetime.datetime.utcfromtimestamp(data.get("timestamp")) + self.dispatch("typing", channel, member, timestamp) + + def parse_relationship_add(self, data): + key = int(data["id"]) + old = self.user.get_relationship(key) + new = Relationship(state=self, data=data) + self.user._relationships[key] = new + if old is not None: + self.dispatch("relationship_update", old, new) + else: + self.dispatch("relationship_add", new) + + def parse_relationship_remove(self, data): + key = int(data["id"]) + try: + old = self.user._relationships.pop(key) + except KeyError: + pass + else: + self.dispatch("relationship_remove", old) + + def _get_reaction_user(self, channel, user_id): + if isinstance(channel, TextChannel): + return channel.guild.get_member(user_id) + return self.get_user(user_id) + + def get_reaction_emoji(self, data): + emoji_id = utils._get_as_snowflake(data, "id") + + if not emoji_id: + return data["name"] + + try: + return self._emojis[emoji_id] + except KeyError: + return PartialEmoji(animated=data["animated"], id=emoji_id, name=data["name"]) + + def _upgrade_partial_emoji(self, emoji): + emoji_id = emoji.id + if not emoji_id: + return emoji.name + try: + return self._emojis[emoji_id] + except KeyError: + return emoji + + def get_channel(self, id): + if id is None: + return None + + pm = self._get_private_channel(id) + if pm is not None: + return pm + + for guild in self.guilds: + channel = guild.get_channel(id) + if channel is not None: + return channel + + def create_message(self, *, channel, data): + return Message(state=self, channel=channel, data=data) + + def receive_chunk(self, guild_id): + future = self.loop.create_future() + listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) + self._listeners.append(listener) + return future + + +class AutoShardedConnectionState(ConnectionState): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._ready_task = None + + async def request_offline_members(self, guilds, *, shard_id): + # get all the chunks + chunks = [] + for guild in guilds: + chunks.extend(self.chunks_needed(guild)) + + # we only want to request ~75 guilds per chunk request. + splits = [guilds[i : i + 75] for i in range(0, len(guilds), 75)] + for split in splits: + await self.chunker(split, shard_id=shard_id) + + # wait for the chunks + if chunks: + try: + await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0, loop=self.loop) + except asyncio.TimeoutError: + log.info("Somehow timed out waiting for chunks.") + + async def _delay_ready(self): + launch = self._ready_state.launch + while not launch.is_set(): + # this snippet of code is basically waiting 2 seconds + # until the last GUILD_CREATE was sent + launch.set() + await asyncio.sleep(2.0 * self.shard_count, loop=self.loop) + + if self._fetch_offline: + guilds = sorted(self._ready_state.guilds, key=lambda g: g[0].shard_id) + + for shard_id, sub_guilds_info in itertools.groupby( + guilds, key=lambda g: g[0].shard_id + ): + sub_guilds, sub_available = zip(*sub_guilds_info) + await self.request_offline_members(sub_guilds, shard_id=shard_id) + + for guild, unavailable in zip(sub_guilds, sub_available): + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + self.dispatch("shard_ready", shard_id) + else: + for guild, unavailable in self._ready_state.guilds: + if unavailable is False: + self.dispatch("guild_available", guild) + else: + self.dispatch("guild_join", guild) + + # remove the state + try: + del self._ready_state + except AttributeError: + pass # already been deleted somehow + + # regular users cannot shard so we won't worry about it here. + + # clear the current task + self._ready_task = None + + # dispatch the event + self.call_handlers("ready") + self.dispatch("ready") + + def parse_ready(self, data): + if not hasattr(self, "_ready_state"): + self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + + self.user = ClientUser(state=self, data=data["user"]) + + guilds = self._ready_state.guilds + for guild_data in data["guilds"]: + guild = self._add_guild_from_data(guild_data) + if guild.large: + guilds.append((guild, guild.unavailable)) + + for pm in data.get("private_channels", []): + factory, _ = _channel_factory(pm["type"]) + self._add_private_channel(factory(me=self.user, data=pm, state=self)) + + self.dispatch("connect") + if self._ready_task is None: + self._ready_task = asyncio.ensure_future(self._delay_ready(), loop=self.loop) diff --git a/discord/user.py b/discord/user.py new file mode 100644 index 000000000..c9c440941 --- /dev/null +++ b/discord/user.py @@ -0,0 +1,699 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from collections import namedtuple + +import discord.abc +from .utils import snowflake_time, _bytes_to_base64_data, parse_time, valid_icon_size +from .enums import DefaultAvatar, RelationshipType, UserFlags, HypeSquadHouse +from .errors import ClientException, InvalidArgument +from .colour import Colour + +VALID_STATIC_FORMATS = {"jpeg", "jpg", "webp", "png"} +VALID_AVATAR_FORMATS = VALID_STATIC_FORMATS | {"gif"} + + +class Profile(namedtuple("Profile", "flags user mutual_guilds connected_accounts premium_since")): + __slots__ = () + + @property + def nitro(self): + return self.premium_since is not None + + premium = nitro + + def _has_flag(self, o): + v = o.value + return (self.flags & v) == v + + @property + def staff(self): + return self._has_flag(UserFlags.staff) + + @property + def partner(self): + return self._has_flag(UserFlags.partner) + + @property + def bug_hunter(self): + return self._has_flag(UserFlags.bug_hunter) + + @property + def early_supporter(self): + return self._has_flag(UserFlags.early_supporter) + + @property + def hypesquad(self): + return self._has_flag(UserFlags.hypesquad) + + @property + def hypesquad_houses(self): + flags = ( + UserFlags.hypesquad_bravery, + UserFlags.hypesquad_brilliance, + UserFlags.hypesquad_balance, + ) + return [house for house, flag in zip(HypeSquadHouse, flags) if self._has_flag(flag)] + + +_BaseUser = discord.abc.User + + +class BaseUser(_BaseUser): + __slots__ = ("name", "id", "discriminator", "avatar", "bot", "_state") + + def __init__(self, *, state, data): + self._state = state + self.name = data["username"] + self.id = int(data["id"]) + self.discriminator = data["discriminator"] + self.avatar = data["avatar"] + self.bot = data.get("bot", False) + + def __str__(self): + return "{0.name}#{0.discriminator}".format(self) + + def __eq__(self, other): + return isinstance(other, _BaseUser) and other.id == self.id + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return self.id >> 22 + + @classmethod + def _copy(cls, user): + self = cls.__new__(cls) # bypass __init__ + + self.name = user.name + self.id = user.id + self.discriminator = user.discriminator + self.avatar = user.avatar + self.bot = user.bot + self._state = user._state + + return self + + @property + def avatar_url(self): + """Returns a friendly URL version of the avatar the user has. + + If the user does not have a traditional avatar, their default + avatar URL is returned instead. + + This is equivalent to calling :meth:`avatar_url_as` with + the default parameters (i.e. webp/gif detection and a size of 1024). + """ + return self.avatar_url_as(format=None, size=1024) + + def is_avatar_animated(self): + """:class:`bool`: Returns True if the user has an animated avatar.""" + return bool(self.avatar and self.avatar.startswith("a_")) + + def avatar_url_as(self, *, format=None, static_format="webp", size=1024): + """Returns a friendly URL version of the avatar the user has. + + If the user does not have a traditional avatar, their default + avatar URL is returned instead. + + The format must be one of 'webp', 'jpeg', 'jpg', 'png' or 'gif', and + 'gif' is only valid for animated avatars. The size must be a power of 2 + between 16 and 1024. + + Parameters + ----------- + format: Optional[str] + The format to attempt to convert the avatar to. + If the format is ``None``, then it is automatically + detected into either 'gif' or static_format depending on the + avatar being animated or not. + static_format: 'str' + Format to attempt to convert only non-animated avatars to. + Defaults to 'webp' + size: int + The size of the image to display. + + Returns + -------- + str + The resulting CDN URL. + + Raises + ------ + InvalidArgument + Bad image format passed to ``format`` or ``static_format``, or + invalid ``size``. + """ + if not valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 1024") + if format is not None and format not in VALID_AVATAR_FORMATS: + raise InvalidArgument("format must be None or one of {}".format(VALID_AVATAR_FORMATS)) + if format == "gif" and not self.is_avatar_animated(): + raise InvalidArgument("non animated avatars do not support gif format") + if static_format not in VALID_STATIC_FORMATS: + raise InvalidArgument("static_format must be one of {}".format(VALID_STATIC_FORMATS)) + + if self.avatar is None: + return self.default_avatar_url + + if format is None: + if self.is_avatar_animated(): + format = "gif" + else: + format = static_format + + return "https://cdn.discordapp.com/avatars/{0.id}/{0.avatar}.{1}?size={2}".format( + self, format, size + ) + + @property + def default_avatar(self): + """Returns the default avatar for a given user. This is calculated by the user's discriminator""" + return DefaultAvatar(int(self.discriminator) % len(DefaultAvatar)) + + @property + def default_avatar_url(self): + """Returns a URL for a user's default avatar.""" + return "https://cdn.discordapp.com/embed/avatars/{}.png".format(self.default_avatar.value) + + @property + def colour(self): + """A property that returns a :class:`Colour` denoting the rendered colour + for the user. This always returns :meth:`Colour.default`. + + There is an alias for this under ``color``. + """ + return Colour.default() + + color = colour + + @property + def mention(self): + """Returns a string that allows you to mention the given user.""" + return "<@{0.id}>".format(self) + + def permissions_in(self, channel): + """An alias for :meth:`abc.GuildChannel.permissions_for`. + + Basically equivalent to: + + .. code-block:: python3 + + channel.permissions_for(self) + + Parameters + ----------- + channel + The channel to check your permissions for. + """ + return channel.permissions_for(self) + + @property + def created_at(self): + """Returns the user's creation time in UTC. + + This is when the user's discord account was created.""" + return snowflake_time(self.id) + + @property + def display_name(self): + """Returns the user's display name. + + For regular users this is just their username, but + if they have a guild specific nickname then that + is returned instead. + """ + return self.name + + def mentioned_in(self, message): + """Checks if the user is mentioned in the specified message. + + Parameters + ----------- + message : :class:`Message` + The message to check if you're mentioned in. + """ + + if message.mention_everyone: + return True + + for user in message.mentions: + if user.id == self.id: + return True + + return False + + +class ClientUser(BaseUser): + """Represents your Discord user. + + .. container:: operations + + .. describe:: x == y + + Checks if two users are equal. + + .. describe:: x != y + + Checks if two users are not equal. + + .. describe:: hash(x) + + Return the user's hash. + + .. describe:: str(x) + + Returns the user's name with discriminator. + + Attributes + ----------- + name: :class:`str` + The user's username. + id: :class:`int` + The user's unique ID. + discriminator: :class:`str` + The user's discriminator. This is given when the username has conflicts. + avatar: Optional[:class:`str`] + The avatar hash the user has. Could be None. + bot: :class:`bool` + Specifies if the user is a bot account. + verified: :class:`bool` + Specifies if the user is a verified account. + email: Optional[:class:`str`] + The email the user used when registering. + mfa_enabled: :class:`bool` + Specifies if the user has MFA turned on and working. + premium: :class:`bool` + Specifies if the user is a premium user (e.g. has Discord Nitro). + """ + + __slots__ = ("email", "verified", "mfa_enabled", "premium", "_relationships") + + def __init__(self, *, state, data): + super().__init__(state=state, data=data) + self.verified = data.get("verified", False) + self.email = data.get("email") + self.mfa_enabled = data.get("mfa_enabled", False) + self.premium = data.get("premium", False) + self._relationships = {} + + def __repr__(self): + return ( + "".format(self) + ) + + def get_relationship(self, user_id): + """Retrieves the :class:`Relationship` if applicable. + + Parameters + ----------- + user_id: int + The user ID to check if we have a relationship with them. + + Returns + -------- + Optional[:class:`Relationship`] + The relationship if available or ``None`` + """ + return self._relationships.get(user_id) + + @property + def relationships(self): + """Returns a :class:`list` of :class:`Relationship` that the user has.""" + return list(self._relationships.values()) + + @property + def friends(self): + r"""Returns a :class:`list` of :class:`User`\s that the user is friends with.""" + return [r.user for r in self._relationships.values() if r.type is RelationshipType.friend] + + @property + def blocked(self): + r"""Returns a :class:`list` of :class:`User`\s that the user has blocked.""" + return [r.user for r in self._relationships.values() if r.type is RelationshipType.blocked] + + async def edit(self, **fields): + """|coro| + + Edits the current profile of the client. + + If a bot account is used then a password field is optional, + otherwise it is required. + + Note + ----- + To upload an avatar, a :term:`py:bytes-like object` must be passed in that + represents the image being uploaded. If this is done through a file + then the file must be opened via ``open('some_filename', 'rb')`` and + the :term:`py:bytes-like object` is given through the use of ``fp.read()``. + + The only image formats supported for uploading is JPEG and PNG. + + Parameters + ----------- + password : str + The current password for the client's account. + Only applicable to user accounts. + new_password: str + The new password you wish to change to. + Only applicable to user accounts. + email: str + The new email you wish to change to. + Only applicable to user accounts. + house: Optional[:class:`HypeSquadHouse`] + The hypesquad house you wish to change to. + Could be ``None`` to leave the current house. + Only applicable to user accounts. + username :str + The new username you wish to change to. + avatar: bytes + A :term:`py:bytes-like object` representing the image to upload. + Could be ``None`` to denote no avatar. + + Raises + ------ + HTTPException + Editing your profile failed. + InvalidArgument + Wrong image format passed for ``avatar``. + ClientException + Password is required for non-bot accounts. + House field was not a HypeSquadHouse. + """ + + try: + avatar_bytes = fields["avatar"] + except KeyError: + avatar = self.avatar + else: + if avatar_bytes is not None: + avatar = _bytes_to_base64_data(avatar_bytes) + else: + avatar = None + + not_bot_account = not self.bot + password = fields.get("password") + if not_bot_account and password is None: + raise ClientException("Password is required for non-bot accounts.") + + args = { + "password": password, + "username": fields.get("username", self.name), + "avatar": avatar, + } + + if not_bot_account: + args["email"] = fields.get("email", self.email) + + if "new_password" in fields: + args["new_password"] = fields["new_password"] + + http = self._state.http + + if "house" in fields: + house = fields["house"] + if house is None: + await http.leave_hypesquad_house() + elif not isinstance(house, HypeSquadHouse): + raise ClientException("`house` parameter was not a HypeSquadHouse") + else: + value = house.value + + await http.change_hypesquad_house(value) + + data = await http.edit_profile(**args) + if not_bot_account: + self.email = data["email"] + try: + http._token(data["token"], bot=False) + except KeyError: + pass + + # manually update data by calling __init__ explicitly. + self.__init__(state=self._state, data=data) + + async def create_group(self, *recipients): + r"""|coro| + + Creates a group direct message with the recipients + provided. These recipients must be have a relationship + of type :attr:`RelationshipType.friend`. + + Bot accounts cannot create a group. + + Parameters + ----------- + \*recipients + An argument :class:`list` of :class:`User` to have in + your group. + + Return + ------- + :class:`GroupChannel` + The new group channel. + + Raises + ------- + HTTPException + Failed to create the group direct message. + ClientException + Attempted to create a group with only one recipient. + This does not include yourself. + """ + + from .channel import GroupChannel + + if len(recipients) < 2: + raise ClientException("You must have two or more recipients to create a group.") + + users = [str(u.id) for u in recipients] + data = await self._state.http.start_group(self.id, users) + return GroupChannel(me=self, data=data, state=self._state) + + +class User(BaseUser, discord.abc.Messageable): + """Represents a Discord user. + + .. container:: operations + + .. describe:: x == y + + Checks if two users are equal. + + .. describe:: x != y + + Checks if two users are not equal. + + .. describe:: hash(x) + + Return the user's hash. + + .. describe:: str(x) + + Returns the user's name with discriminator. + + Attributes + ----------- + name: :class:`str` + The user's username. + id: :class:`int` + The user's unique ID. + discriminator: :class:`str` + The user's discriminator. This is given when the username has conflicts. + avatar: Optional[:class:`str`] + The avatar hash the user has. Could be None. + bot: :class:`bool` + Specifies if the user is a bot account. + """ + + __slots__ = ("__weakref__",) + + def __repr__(self): + return "".format( + self + ) + + async def _get_channel(self): + ch = await self.create_dm() + return ch + + @property + def dm_channel(self): + """Returns the :class:`DMChannel` associated with this user if it exists. + + If this returns ``None``, you can create a DM channel by calling the + :meth:`create_dm` coroutine function. + """ + return self._state._get_private_channel_by_user(self.id) + + async def create_dm(self): + """Creates a :class:`DMChannel` with this user. + + This should be rarely called, as this is done transparently for most + people. + """ + found = self.dm_channel + if found is not None: + return found + + state = self._state + data = await state.http.start_private_message(self.id) + return state.add_dm_channel(data) + + @property + def relationship(self): + """Returns the :class:`Relationship` with this user if applicable, ``None`` otherwise.""" + return self._state.user.get_relationship(self.id) + + async def mutual_friends(self): + """|coro| + + Gets all mutual friends of this user. This can only be used by non-bot accounts + + Returns + ------- + List[:class:`User`] + The users that are mutual friends. + + Raises + ------- + Forbidden + Not allowed to get mutual friends of this user. + HTTPException + Getting mutual friends failed. + """ + state = self._state + mutuals = await state.http.get_mutual_friends(self.id) + return [User(state=state, data=friend) for friend in mutuals] + + def is_friend(self): + """:class:`bool`: Checks if the user is your friend.""" + r = self.relationship + if r is None: + return False + return r.type is RelationshipType.friend + + def is_blocked(self): + """:class:`bool`: Checks if the user is blocked.""" + r = self.relationship + if r is None: + return False + return r.type is RelationshipType.blocked + + async def block(self): + """|coro| + + Blocks the user. + + Raises + ------- + Forbidden + Not allowed to block this user. + HTTPException + Blocking the user failed. + """ + + await self._state.http.add_relationship(self.id, type=RelationshipType.blocked.value) + + async def unblock(self): + """|coro| + + Unblocks the user. + + Raises + ------- + Forbidden + Not allowed to unblock this user. + HTTPException + Unblocking the user failed. + """ + await self._state.http.remove_relationship(self.id) + + async def remove_friend(self): + """|coro| + + Removes the user as a friend. + + Raises + ------- + Forbidden + Not allowed to remove this user as a friend. + HTTPException + Removing the user as a friend failed. + """ + await self._state.http.remove_relationship(self.id) + + async def send_friend_request(self): + """|coro| + + Sends the user a friend request. + + Raises + ------- + Forbidden + Not allowed to send a friend request to the user. + HTTPException + Sending the friend request failed. + """ + await self._state.http.send_friend_request( + username=self.name, discriminator=self.discriminator + ) + + async def profile(self): + """|coro| + + Gets the user's profile. This can only be used by non-bot accounts. + + Raises + ------- + Forbidden + Not allowed to fetch profiles. + HTTPException + Fetching the profile failed. + + Returns + -------- + :class:`Profile` + The profile of the user. + """ + + state = self._state + data = await state.http.get_user_profile(self.id) + + def transform(d): + return state._get_guild(int(d["id"])) + + since = data.get("premium_since") + mutual_guilds = list(filter(None, map(transform, data.get("mutual_guilds", [])))) + return Profile( + flags=data["user"].get("flags", 0), + premium_since=parse_time(since), + mutual_guilds=mutual_guilds, + user=self, + connected_accounts=data["connected_accounts"], + ) diff --git a/discord/utils.py b/discord/utils.py new file mode 100644 index 000000000..bf78b8bfa --- /dev/null +++ b/discord/utils.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import array +import asyncio +from base64 import b64encode +from bisect import bisect_left +import datetime +from email.utils import parsedate_to_datetime +import functools +from inspect import isawaitable as _isawaitable +import json +from re import split as re_split +import warnings + +from .errors import InvalidArgument + +DISCORD_EPOCH = 1420070400000 + + +class cached_property: + def __init__(self, function): + self.function = function + self.__doc__ = getattr(function, "__doc__") + + def __get__(self, instance, owner): + if instance is None: + return self + + value = self.function(instance) + setattr(instance, self.function.__name__, value) + + return value + + +class CachedSlotProperty: + def __init__(self, name, function): + self.name = name + self.function = function + self.__doc__ = getattr(function, "__doc__") + + def __get__(self, instance, owner): + if instance is None: + return self + + try: + return getattr(instance, self.name) + except AttributeError: + value = self.function(instance) + setattr(instance, self.name, value) + return value + + +def cached_slot_property(name): + def decorator(func): + return CachedSlotProperty(name, func) + + return decorator + + +def parse_time(timestamp): + if timestamp: + return datetime.datetime(*map(int, re_split(r"[^\d]", timestamp.replace("+00:00", "")))) + return None + + +def deprecated(instead=None): + def actual_decorator(func): + @functools.wraps(func) + def decorated(*args, **kwargs): + warnings.simplefilter("always", DeprecationWarning) # turn off filter + if instead: + fmt = "{0.__name__} is deprecated, use {1} instead." + else: + fmt = "{0.__name__} is deprecated." + + warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) # reset filter + return func(*args, **kwargs) + + return decorated + + return actual_decorator + + +def oauth_url(client_id, permissions=None, guild=None, redirect_uri=None): + """A helper function that returns the OAuth2 URL for inviting the bot + into guilds. + + Parameters + ----------- + client_id : str + The client ID for your bot. + permissions : :class:`Permissions` + The permissions you're requesting. If not given then you won't be requesting any + permissions. + guild : :class:`Guild` + The guild to pre-select in the authorization screen, if available. + redirect_uri : str + An optional valid redirect URI. + """ + url = "https://discordapp.com/oauth2/authorize?client_id={}&scope=bot".format(client_id) + if permissions is not None: + url = url + "&permissions=" + str(permissions.value) + if guild is not None: + url = url + "&guild_id=" + str(guild.id) + if redirect_uri is not None: + from urllib.parse import urlencode + + url = url + "&response_type=code&" + urlencode({"redirect_uri": redirect_uri}) + return url + + +def snowflake_time(id): + """Returns the creation date in UTC of a discord id.""" + return datetime.datetime.utcfromtimestamp(((id >> 22) + DISCORD_EPOCH) / 1000) + + +def time_snowflake(datetime_obj, high=False): + """Returns a numeric snowflake pretending to be created at the given date. + + When using as the lower end of a range, use time_snowflake(high=False) - 1 to be inclusive, high=True to be exclusive + When using as the higher end of a range, use time_snowflake(high=True) + 1 to be inclusive, high=False to be exclusive + + Parameters + ----------- + datetime_obj + A timezone-naive datetime object representing UTC time. + high + Whether or not to set the lower 22 bit to high or low. + """ + unix_seconds = (datetime_obj - type(datetime_obj)(1970, 1, 1)).total_seconds() + discord_millis = int(unix_seconds * 1000 - DISCORD_EPOCH) + + return (discord_millis << 22) + (2 ** 22 - 1 if high else 0) + + +def find(predicate, seq): + """A helper to return the first element found in the sequence + that meets the predicate. For example: :: + + member = find(lambda m: m.name == 'Mighty', channel.guild.members) + + would find the first :class:`Member` whose name is 'Mighty' and return it. + If an entry is not found, then ``None`` is returned. + + This is different from `filter`_ due to the fact it stops the moment it finds + a valid entry. + + + .. _filter: https://docs.python.org/3.6/library/functions.html#filter + + Parameters + ----------- + predicate + A function that returns a boolean-like result. + seq : iterable + The iterable to search through. + """ + + for element in seq: + if predicate(element): + return element + return None + + +def get(iterable, **attrs): + r"""A helper that returns the first element in the iterable that meets + all the traits passed in ``attrs``. This is an alternative for + :func:`discord.utils.find`. + + When multiple attributes are specified, they are checked using + logical AND, not logical OR. Meaning they have to meet every + attribute passed in and not one of them. + + To have a nested attribute search (i.e. search by ``x.y``) then + pass in ``x__y`` as the keyword argument. + + If nothing is found that matches the attributes passed, then + ``None`` is returned. + + Examples + --------- + + Basic usage: + + .. code-block:: python3 + + member = discord.utils.get(message.guild.members, name='Foo') + + Multiple attribute matching: + + .. code-block:: python3 + + channel = discord.utils.get(guild.voice_channels, name='Foo', bitrate=64000) + + Nested attribute matching: + + .. code-block:: python3 + + channel = discord.utils.get(client.get_all_channels(), guild__name='Cool', name='general') + + Parameters + ----------- + iterable + An iterable to search through. + \*\*attrs + Keyword arguments that denote attributes to search with. + """ + + def predicate(elem): + for attr, val in attrs.items(): + nested = attr.split("__") + obj = elem + for attribute in nested: + obj = getattr(obj, attribute) + + if obj != val: + return False + return True + + return find(predicate, iterable) + + +def _unique(iterable): + seen = set() + adder = seen.add + return [x for x in iterable if not (x in seen or adder(x))] + + +def _get_as_snowflake(data, key): + try: + value = data[key] + except KeyError: + return None + else: + return value and int(value) + + +def _get_mime_type_for_image(data): + if data.startswith(b"\x89\x50\x4E\x47\x0D\x0A\x1A\x0A"): + return "image/png" + elif data.startswith(b"\xFF\xD8") and data.rstrip(b"\0").endswith(b"\xFF\xD9"): + return "image/jpeg" + elif data.startswith(b"\x47\x49\x46\x38\x37\x61") or data.startswith( + b"\x47\x49\x46\x38\x39\x61" + ): + return "image/gif" + elif data.startswith(b"RIFF") and data[8:12] == b"WEBP": + return "image/webp" + else: + raise InvalidArgument("Unsupported image type given") + + +def _bytes_to_base64_data(data): + fmt = "data:{mime};base64,{data}" + mime = _get_mime_type_for_image(data) + b64 = b64encode(data).decode("ascii") + return fmt.format(mime=mime, data=b64) + + +def to_json(obj): + return json.dumps(obj, separators=(",", ":"), ensure_ascii=True) + + +def _parse_ratelimit_header(request): + now = parsedate_to_datetime(request.headers["Date"]) + reset = datetime.datetime.fromtimestamp( + int(request.headers["X-Ratelimit-Reset"]), datetime.timezone.utc + ) + return (reset - now).total_seconds() + + +async def maybe_coroutine(f, *args, **kwargs): + value = f(*args, **kwargs) + if _isawaitable(value): + return await value + else: + return value + + +async def async_all(gen, *, check=_isawaitable): + for elem in gen: + if check(elem): + elem = await elem + if not elem: + return False + return True + + +async def sane_wait_for(futures, *, timeout, loop): + _, pending = await asyncio.wait(futures, timeout=timeout, loop=loop) + + if len(pending) != 0: + raise asyncio.TimeoutError() + + +def valid_icon_size(size): + """Icons must be power of 2 within [16, 2048].""" + return not size & (size - 1) and size in range(16, 2049) + + +class SnowflakeList(array.array): + """Internal data storage class to efficiently store a list of snowflakes. + + This should have the following characteristics: + + - Low memory usage + - O(n) iteration (obviously) + - O(n log n) initial creation if data is unsorted + - O(log n) search and indexing + - O(n) insertion + """ + + __slots__ = () + + def __new__(cls, data, *, is_sorted=False): + return array.array.__new__(cls, "Q", data if is_sorted else sorted(data)) + + def add(self, element): + i = bisect_left(self, element) + self.insert(i, element) + + def get(self, element): + i = bisect_left(self, element) + return self[i] if i != len(self) and self[i] == element else None + + def has(self, element): + i = bisect_left(self, element) + return i != len(self) and self[i] == element diff --git a/discord/voice_client.py b/discord/voice_client.py new file mode 100644 index 000000000..33c8001dd --- /dev/null +++ b/discord/voice_client.py @@ -0,0 +1,438 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +"""Some documentation to refer to: + +- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. +- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. +- We pull the session_id from VOICE_STATE_UPDATE. +- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. +- Then we initiate the voice web socket (vWS) pointing to the endpoint. +- We send opcode 0 with the user_id, server_id, session_id and token using the vWS. +- The vWS sends back opcode 2 with an ssrc, port, modes(array) and hearbeat_interval. +- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. +- Then we send our IP and port via vWS with opcode 1. +- When that's all done, we receive opcode 4 from the vWS. +- Finally we can transmit data to endpoint:port. +""" + +import asyncio +import socket +import logging +import struct +import threading + +from . import opus +from .backoff import ExponentialBackoff +from .gateway import * +from .errors import ClientException, ConnectionClosed +from .player import AudioPlayer, AudioSource + +try: + import nacl.secret + + has_nacl = True +except ImportError: + has_nacl = False + + +log = logging.getLogger(__name__) + + +class VoiceClient: + """Represents a Discord voice connection. + + You do not create these, you typically get them from + e.g. :meth:`VoiceChannel.connect`. + + Warning + -------- + In order to play audio, you must have loaded the opus library + through :func:`opus.load_opus`. + + If you don't do this then the library will not be able to + transmit audio. + + Attributes + ----------- + session_id: :class:`str` + The voice connection session ID. + token: :class:`str` + The voice connection token. + endpoint: :class:`str` + The endpoint we are connecting to. + channel: :class:`abc.Connectable` + The voice channel connected to. + loop + The event loop that the voice client is running on. + """ + + def __init__(self, state, timeout, channel): + if not has_nacl: + raise RuntimeError("PyNaCl library needed in order to use voice") + + self.channel = channel + self.main_ws = None + self.timeout = timeout + self.ws = None + self.socket = None + self.loop = state.loop + self._state = state + # this will be used in the AudioPlayer thread + self._connected = threading.Event() + self._handshake_complete = asyncio.Event(loop=self.loop) + + self._connections = 0 + self.sequence = 0 + self.timestamp = 0 + self._runner = None + self._player = None + self.encoder = opus.Encoder() + + warn_nacl = not has_nacl + + @property + def guild(self): + """Optional[:class:`Guild`]: The guild we're connected to, if applicable.""" + return getattr(self.channel, "guild", None) + + @property + def user(self): + """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" + return self._state.user + + def checked_add(self, attr, value, limit): + val = getattr(self, attr) + if val + value > limit: + setattr(self, attr, 0) + else: + setattr(self, attr, val + value) + + # connection related + + async def start_handshake(self): + log.info("Starting voice handshake...") + + guild_id, channel_id = self.channel._get_voice_state_pair() + state = self._state + self.main_ws = ws = state._get_websocket(guild_id) + self._connections += 1 + + # request joining + await ws.voice_state(guild_id, channel_id) + + try: + await asyncio.wait_for( + self._handshake_complete.wait(), timeout=self.timeout, loop=self.loop + ) + except asyncio.TimeoutError: + await self.terminate_handshake(remove=True) + raise + + log.info( + "Voice handshake complete. Endpoint found %s (IP: %s)", self.endpoint, self.endpoint_ip + ) + + async def terminate_handshake(self, *, remove=False): + guild_id, channel_id = self.channel._get_voice_state_pair() + self._handshake_complete.clear() + await self.main_ws.voice_state(guild_id, None, self_mute=True) + + log.info( + "The voice handshake is being terminated for Channel ID %s (Guild ID %s)", + channel_id, + guild_id, + ) + if remove: + log.info( + "The voice client has been removed for Channel ID %s (Guild ID %s)", + channel_id, + guild_id, + ) + key_id, _ = self.channel._get_voice_client_key() + self._state._remove_voice_client(key_id) + + async def _create_socket(self, server_id, data): + self._connected.clear() + self.session_id = self.main_ws.session_id + self.server_id = server_id + self.token = data.get("token") + endpoint = data.get("endpoint") + + if endpoint is None or self.token is None: + log.warning( + "Awaiting endpoint... This requires waiting. " + "If timeout occurred considering raising the timeout and reconnecting." + ) + return + + self.endpoint = endpoint.replace(":80", "") + self.endpoint_ip = socket.gethostbyname(self.endpoint) + + if self.socket: + try: + self.socket.close() + except Exception: + pass + + self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + self.socket.setblocking(False) + + if self._handshake_complete.is_set(): + # terminate the websocket and handle the reconnect loop if necessary. + self._handshake_complete.clear() + await self.ws.close(4000) + return + + self._handshake_complete.set() + + async def connect(self, *, reconnect=True, _tries=0, do_handshake=True): + log.info("Connecting to voice...") + try: + del self.secret_key + except AttributeError: + pass + + if do_handshake: + await self.start_handshake() + + try: + self.ws = await DiscordVoiceWebSocket.from_client(self) + self._connected.clear() + while not hasattr(self, "secret_key"): + await self.ws.poll_event() + self._connected.set() + except (ConnectionClosed, asyncio.TimeoutError): + if reconnect and _tries < 5: + log.exception("Failed to connect to voice... Retrying...") + await asyncio.sleep(1 + _tries * 2.0, loop=self.loop) + await self.terminate_handshake() + await self.connect(reconnect=reconnect, _tries=_tries + 1) + else: + raise + + if self._runner is None: + self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) + + async def poll_voice_ws(self, reconnect): + backoff = ExponentialBackoff() + while True: + try: + await self.ws.poll_event() + except (ConnectionClosed, asyncio.TimeoutError) as exc: + if isinstance(exc, ConnectionClosed): + if exc.code == 1000: + await self.disconnect() + break + + if not reconnect: + await self.disconnect() + raise + + retry = backoff.delay() + log.exception("Disconnected from voice... Reconnecting in %.2fs.", retry) + self._connected.clear() + await asyncio.sleep(retry, loop=self.loop) + await self.terminate_handshake() + try: + await self.connect(reconnect=True) + except asyncio.TimeoutError: + # at this point we've retried 5 times... let's continue the loop. + log.warning("Could not connect to voice... Retrying...") + continue + + async def disconnect(self, *, force=False): + """|coro| + + Disconnects this voice client from voice. + """ + if not force and not self._connected.is_set(): + return + + self.stop() + self._connected.clear() + + try: + if self.ws: + await self.ws.close() + + await self.terminate_handshake(remove=True) + finally: + if self.socket: + self.socket.close() + + async def move_to(self, channel): + """|coro| + + Moves you to a different voice channel. + + Parameters + ----------- + channel: :class:`abc.Snowflake` + The channel to move to. Must be a voice channel. + """ + guild_id, _ = self.channel._get_voice_state_pair() + await self.main_ws.voice_state(guild_id, channel.id) + + def is_connected(self): + """:class:`bool`: Indicates if the voice client is connected to voice.""" + return self._connected.is_set() + + # audio related + + def _get_voice_packet(self, data): + header = bytearray(12) + nonce = bytearray(24) + box = nacl.secret.SecretBox(bytes(self.secret_key)) + + # Formulate header + header[0] = 0x80 + header[1] = 0x78 + struct.pack_into(">H", header, 2, self.sequence) + struct.pack_into(">I", header, 4, self.timestamp) + struct.pack_into(">I", header, 8, self.ssrc) + + # Copy header to nonce's first 12 bytes + nonce[:12] = header + + # Encrypt and return the data + return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + + def play(self, source, *, after=None): + """Plays an :class:`AudioSource`. + + The finalizer, ``after`` is called after the source has been exhausted + or an error occurred. + + If an error happens while the audio player is running, the exception is + caught and the audio player is then stopped. + + Parameters + ----------- + source: :class:`AudioSource` + The audio source we're reading from. + after + The finalizer that is called after the stream is exhausted. + All exceptions it throws are silently discarded. This function + must have a single parameter, ``error``, that denotes an + optional exception that was raised during playing. + + Raises + ------- + ClientException + Already playing audio or not connected. + TypeError + source is not a :class:`AudioSource` or after is not a callable. + """ + + if not self._connected: + raise ClientException("Not connected to voice.") + + if self.is_playing(): + raise ClientException("Already playing audio.") + + if not isinstance(source, AudioSource): + raise TypeError("source must an AudioSource not {0.__class__.__name__}".format(source)) + + self._player = AudioPlayer(source, self, after=after) + self._player.start() + + def is_playing(self): + """Indicates if we're currently playing audio.""" + return self._player is not None and self._player.is_playing() + + def is_paused(self): + """Indicates if we're playing audio, but if we're paused.""" + return self._player is not None and self._player.is_paused() + + def stop(self): + """Stops playing audio.""" + if self._player: + self._player.stop() + self._player = None + + def pause(self): + """Pauses the audio playing.""" + if self._player: + self._player.pause() + + def resume(self): + """Resumes the audio playing.""" + if self._player: + self._player.resume() + + @property + def source(self): + """Optional[:class:`AudioSource`]: The audio source being played, if playing. + + This property can also be used to change the audio source currently being played. + """ + return self._player.source if self._player else None + + @source.setter + def source(self, value): + if not isinstance(value, AudioSource): + raise TypeError("expected AudioSource not {0.__class__.__name__}.".format(value)) + + if self._player is None: + raise ValueError("Not playing anything.") + + self._player._set_source(value) + + def send_audio_packet(self, data, *, encode=True): + """Sends an audio packet composed of the data. + + You must be connected to play audio. + + Parameters + ---------- + data: bytes + The :term:`py:bytes-like object` denoting PCM or Opus voice data. + encode: bool + Indicates if ``data`` should be encoded into Opus. + + Raises + ------- + ClientException + You are not connected. + OpusError + Encoding the data failed. + """ + + self.checked_add("sequence", 1, 65535) + if encode: + encoded_data = self.encoder.encode(data, self.encoder.SAMPLES_PER_FRAME) + else: + encoded_data = data + packet = self._get_voice_packet(encoded_data) + try: + self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) + except BlockingIOError: + log.warning( + "A packet has been dropped (seq: %s, timestamp: %s)", self.sequence, self.timestamp + ) + + self.checked_add("timestamp", self.encoder.SAMPLES_PER_FRAME, 4294967295) diff --git a/discord/webhook.py b/discord/webhook.py new file mode 100644 index 000000000..237201a75 --- /dev/null +++ b/discord/webhook.py @@ -0,0 +1,703 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-2017 Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +import asyncio +import json +import time +import re + +import aiohttp + +from . import utils +from .errors import InvalidArgument, HTTPException, Forbidden, NotFound +from .user import BaseUser, User + +__all__ = ["WebhookAdapter", "AsyncWebhookAdapter", "RequestsWebhookAdapter", "Webhook"] + + +class WebhookAdapter: + """Base class for all webhook adapters. + + Attributes + ------------ + webhook: :class:`Webhook` + The webhook that owns this adapter. + """ + + BASE = "https://discordapp.com/api/v7" + + def _prepare(self, webhook): + self._webhook_id = webhook.id + self._webhook_token = webhook.token + self._request_url = "{0.BASE}/webhooks/{1}/{2}".format(self, webhook.id, webhook.token) + self.webhook = webhook + + def request(self, verb, url, payload=None, multipart=None): + """Actually does the request. + + Subclasses must implement this. + + Parameters + ----------- + verb: str + The HTTP verb to use for the request. + url: str + The URL to send the request to. This will have + the query parameters already added to it, if any. + multipart: Optional[dict] + A dict containing multipart form data to send with + the request. If a filename is being uploaded, then it will + be under a ``file`` key which will have a 3-element :class:`tuple` + denoting ``(filename, file, content_type)``. + payload: Optional[dict] + The JSON to send with the request, if any. + """ + raise NotImplementedError() + + def delete_webhook(self): + return self.request("DELETE", self._request_url) + + def edit_webhook(self, **payload): + return self.request("PATCH", self._request_url, payload=payload) + + def handle_execution_response(self, data, *, wait): + """Transforms the webhook execution response into something + more meaningful. + + This is mainly used to convert the data into a :class:`Message` + if necessary. + + Subclasses must implement this. + + Parameters + ------------ + data + The data that was returned from the request. + wait: bool + Whether the webhook execution was asked to wait or not. + """ + raise NotImplementedError() + + def store_user(self, data): + # mocks a ConnectionState for appropriate use for Message + return BaseUser(state=self, data=data) + + def execute_webhook(self, *, payload, wait=False, file=None, files=None): + if file is not None: + multipart = {"file": file, "payload_json": utils.to_json(payload)} + data = None + elif files is not None: + multipart = {"payload_json": utils.to_json(payload)} + for i, file in enumerate(files, start=1): + multipart["file%i" % i] = file + data = None + else: + data = payload + multipart = None + + url = "%s?wait=%d" % (self._request_url, wait) + maybe_coro = self.request("POST", url, multipart=multipart, payload=data) + return self.handle_execution_response(maybe_coro, wait=wait) + + +class AsyncWebhookAdapter(WebhookAdapter): + """A webhook adapter suited for use with aiohttp. + + .. note:: + + You are responsible for cleaning up the client session. + + Parameters + ----------- + session: aiohttp.ClientSession + The session to use to send requests. + """ + + def __init__(self, session): + self.session = session + self.loop = session.loop + + async def request(self, verb, url, payload=None, multipart=None): + headers = {} + data = None + if payload: + headers["Content-Type"] = "application/json" + data = utils.to_json(payload) + + if multipart: + data = aiohttp.FormData() + for key, value in multipart.items(): + if key.startswith("file"): + data.add_field(key, value[1], filename=value[0], content_type=value[2]) + else: + data.add_field(key, value) + + for tries in range(5): + async with self.session.request(verb, url, headers=headers, data=data) as r: + data = await r.text(encoding="utf-8") + if r.headers["Content-Type"] == "application/json": + data = json.loads(data) + + # check if we have rate limit header information + remaining = r.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and r.status != 429: + delta = utils._parse_ratelimit_header(r) + await asyncio.sleep(delta, loop=self.loop) + + if 300 > r.status >= 200: + return data + + # we are being rate limited + if r.status == 429: + retry_after = data["retry_after"] / 1000.0 + await asyncio.sleep(retry_after, loop=self.loop) + continue + + if r.status in (500, 502): + await asyncio.sleep(1 + tries * 2, loop=self.loop) + continue + + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: + raise HTTPException(r, data) + + async def handle_execution_response(self, response, *, wait): + data = await response + if not wait: + return data + + # transform into Message object + from .message import Message + + return Message(data=data, state=self, channel=self.webhook.channel) + + +class RequestsWebhookAdapter(WebhookAdapter): + """A webhook adapter suited for use with ``requests``. + + Only versions of requests higher than 2.13.0 are supported. + + Parameters + ----------- + session: Optional[`requests.Session `_] + The requests session to use for sending requests. If not given then + each request will create a new session. Note if a session is given, + the webhook adapter **will not** clean it up for you. You must close + the session yourself. + sleep: bool + Whether to sleep the thread when encountering a 429 or pre-emptive + rate limit or a 5xx status code. Defaults to ``True``. If set to + ``False`` then this will raise an :exc:`HTTPException` instead. + """ + + def __init__(self, session=None, *, sleep=True): + import requests + + self.session = session or requests + self.sleep = sleep + + def request(self, verb, url, payload=None, multipart=None): + headers = {} + data = None + if payload: + headers["Content-Type"] = "application/json" + data = utils.to_json(payload) + + if multipart is not None: + data = {"payload_json": multipart.pop("payload_json")} + + for tries in range(5): + r = self.session.request(verb, url, headers=headers, data=data, files=multipart) + r.encoding = "utf-8" + data = r.text + + # compatibility with aiohttp + r.status = r.status_code + + if r.headers["Content-Type"] == "application/json": + data = json.loads(data) + + # check if we have rate limit header information + remaining = r.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and r.status != 429 and self.sleep: + delta = utils._parse_ratelimit_header(r) + time.sleep(delta) + + if 300 > r.status >= 200: + return data + + # we are being rate limited + if r.status == 429: + if self.sleep: + retry_after = data["retry_after"] / 1000.0 + time.sleep(retry_after) + continue + else: + raise HTTPException(r, data) + + if self.sleep and r.status in (500, 502): + time.sleep(1 + tries * 2) + continue + + if r.status == 403: + raise Forbidden(r, data) + elif r.status == 404: + raise NotFound(r, data) + else: + raise HTTPException(r, data) + + def handle_execution_response(self, response, *, wait): + if not wait: + return response + + # transform into Message object + from .message import Message + + return Message(data=response, state=self, channel=self.webhook.channel) + + +class Webhook: + """Represents a Discord webhook. + + Webhooks are a form to send messages to channels in Discord without a + bot user or authentication. + + There are two main ways to use Webhooks. The first is through the ones + received by the library such as :meth:`.Guild.webhooks` and + :meth:`.TextChannel.webhooks`. The ones received by the library will + automatically have an adapter bound using the library's HTTP session. + Those webhooks will have :meth:`~.Webhook.send`, :meth:`~.Webhook.delete` and + :meth:`~.Webhook.edit` as coroutines. + + The second form involves creating a webhook object manually without having + it bound to a websocket connection using the :meth:`~.Webhook.from_url` or + :meth:`~.Webhook.partial` classmethods. This form allows finer grained control + over how requests are done, allowing you to mix async and sync code using either + ``aiohttp`` or ``requests``. + + For example, creating a webhook from a URL and using ``aiohttp``: + + .. code-block:: python3 + + from discord import Webhook, AsyncWebhookAdapter + import aiohttp + + async def foo(): + async with aiohttp.ClientSession() as session: + webhook = Webhook.from_url('url-here', adapter=AsyncWebhookAdapter(session)) + await webhook.send('Hello World', username='Foo') + + Or creating a webhook from an ID and token and using ``requests``: + + .. code-block:: python3 + + import requests + from discord import Webhook, RequestsWebhookAdapter + + webhook = Webhook.partial(123456, 'abcdefg', adapter=RequestsWebhookAdapter()) + webhook.send('Hello World', username='Foo') + + Attributes + ------------ + id: :class:`int` + The webhook's ID + token: :class:`str` + The authentication token of the webhook. + guild_id: Optional[:class:`int`] + The guild ID this webhook is for. + channel_id: Optional[:class:`int`] + The channel ID this webhook is for. + user: Optional[:class:`abc.User`] + The user this webhook was created by. If the webhook was + received without authentication then this will be ``None``. + name: Optional[:class:`str`] + The default name of the webhook. + avatar: Optional[:class:`str`] + The default avatar of the webhook. + """ + + __slots__ = ( + "id", + "guild_id", + "channel_id", + "user", + "name", + "avatar", + "token", + "_state", + "_adapter", + ) + + def __init__(self, data, *, adapter, state=None): + self.id = int(data["id"]) + self.channel_id = utils._get_as_snowflake(data, "channel_id") + self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.name = data.get("name") + self.avatar = data.get("avatar") + self.token = data["token"] + self._state = state + self._adapter = adapter + self._adapter._prepare(self) + + user = data.get("user") + if user is None: + self.user = None + elif state is None: + self.user = BaseUser(state=None, data=user) + else: + self.user = User(state=state, data=user) + + def __repr__(self): + return "" % self.id + + @property + def url(self): + """Returns the webhook's url.""" + return "https://discordapp.com/api/webhooks/{}/{}".format(self.id, self.token) + + @classmethod + def partial(cls, id, token, *, adapter): + """Creates a partial :class:`Webhook`. + + A partial webhook is just a webhook object with an ID and a token. + + Parameters + ----------- + id: int + The ID of the webhook. + token: str + The authentication token of the webhook. + adapter: :class:`WebhookAdapter` + The webhook adapter to use when sending requests. This is + typically :class:`AsyncWebhookAdapter` for ``aiohttp`` or + :class:`RequestsWebhookAdapter` for ``requests``. + """ + + if not isinstance(adapter, WebhookAdapter): + raise TypeError("adapter must be a subclass of WebhookAdapter") + + data = {"id": id, "token": token} + + return cls(data, adapter=adapter) + + @classmethod + def from_url(cls, url, *, adapter): + """Creates a partial :class:`Webhook` from a webhook URL. + + Parameters + ------------ + url: str + The URL of the webhook. + adapter: :class:`WebhookAdapter` + The webhook adapter to use when sending requests. This is + typically :class:`AsyncWebhookAdapter` for ``aiohttp`` or + :class:`RequestsWebhookAdapter` for ``requests``. + + Raises + ------- + InvalidArgument + The URL is invalid. + """ + + m = re.search( + r"discordapp.com/api/webhooks/(?P[0-9]{17,21})/(?P[A-Za-z0-9\.\-\_]{60,68})", + url, + ) + if m is None: + raise InvalidArgument("Invalid webhook URL given.") + return cls(m.groupdict(), adapter=adapter) + + @classmethod + def from_state(cls, data, state): + return cls(data, adapter=AsyncWebhookAdapter(session=state.http._session), state=state) + + @property + def guild(self): + """Optional[:class:`Guild`]: The guild this webhook belongs to. + + If this is a partial webhook, then this will always return ``None``. + """ + return self._state and self._state._get_guild(self.guild_id) + + @property + def channel(self): + """Optional[:class:`TextChannel`]: The text channel this webhook belongs to. + + If this is a partial webhook, then this will always return ``None``. + """ + guild = self.guild + return guild and guild.get_channel(self.channel_id) + + @property + def created_at(self): + """Returns the webhook's creation time in UTC.""" + return utils.snowflake_time(self.id) + + @property + def avatar_url(self): + """Returns a friendly URL version of the avatar the webhook has. + + If the webhook does not have a traditional avatar, their default + avatar URL is returned instead. + + This is equivalent to calling :meth:`avatar_url_as` with the + default parameters. + """ + return self.avatar_url_as() + + def avatar_url_as(self, *, format=None, size=1024): + """Returns a friendly URL version of the avatar the webhook has. + + If the webhook does not have a traditional avatar, their default + avatar URL is returned instead. + + The format must be one of 'jpeg', 'jpg', or 'png'. + The size must be a power of 2 between 16 and 1024. + + Parameters + ----------- + format: Optional[str] + The format to attempt to convert the avatar to. + If the format is ``None``, then it is equivalent to png. + size: int + The size of the image to display. + + Returns + -------- + str + The resulting CDN URL. + + Raises + ------ + InvalidArgument + Bad image format passed to ``format`` or invalid ``size``. + """ + if self.avatar is None: + # Default is always blurple apparently + return "https://cdn.discordapp.com/embed/avatars/0.png" + + if not utils.valid_icon_size(size): + raise InvalidArgument("size must be a power of 2 between 16 and 1024") + + format = format or "png" + + if format not in ("png", "jpg", "jpeg"): + raise InvalidArgument("format must be one of 'png', 'jpg', or 'jpeg'.") + + return "https://cdn.discordapp.com/avatars/{0.id}/{0.avatar}.{1}?size={2}".format( + self, format, size + ) + + def delete(self): + """|maybecoro| + + Deletes this Webhook. + + If the webhook is constructed with a :class:`RequestsWebhookAdapter` then this is + not a coroutine. + + Raises + ------- + HTTPException + Deleting the webhook failed. + NotFound + This webhook does not exist. + Forbidden + You do not have permissions to delete this webhook. + """ + return self._adapter.delete_webhook() + + def edit(self, **kwargs): + """|maybecoro| + + Edits this Webhook. + + If the webhook is constructed with a :class:`RequestsWebhookAdapter` then this is + not a coroutine. + + Parameters + ------------- + name: Optional[str] + The webhook's new default name. + avatar: Optional[bytes] + A :term:`py:bytes-like object` representing the webhook's new default avatar. + + Raises + ------- + HTTPException + Editing the webhook failed. + NotFound + This webhook does not exist. + Forbidden + You do not have permissions to edit this webhook. + """ + payload = {} + + try: + name = kwargs["name"] + except KeyError: + pass + else: + if name is not None: + payload["name"] = str(name) + else: + payload["name"] = None + + try: + avatar = kwargs["avatar"] + except KeyError: + pass + else: + if avatar is not None: + payload["avatar"] = utils._bytes_to_base64_data(avatar) + else: + payload["avatar"] = None + + return self._adapter.edit_webhook(**payload) + + def send( + self, + content=None, + *, + wait=False, + username=None, + avatar_url=None, + tts=False, + file=None, + files=None, + embed=None, + embeds=None + ): + """|maybecoro| + + Sends a message using the webhook. + + If the webhook is constructed with a :class:`RequestsWebhookAdapter` then this is + not a coroutine. + + The content must be a type that can convert to a string through ``str(content)``. + + To upload a single file, the ``file`` parameter should be used with a + single :class:`File` object. + + If the ``embed`` parameter is provided, it must be of type :class:`Embed` and + it must be a rich embed type. You cannot mix the ``embed`` parameter with the + ``embeds`` parameter, which must be a :class:`list` of :class:`Embed` objects to send. + + Parameters + ------------ + content + The content of the message to send. + wait: bool + Whether the server should wait before sending a response. This essentially + means that the return type of this function changes from ``None`` to + a :class:`Message` if set to ``True``. + username: str + The username to send with this message. If no username is provided + then the default username for the webhook is used. + avatar_url: str + The avatar URL to send with this message. If no avatar URL is provided + then the default avatar for the webhook is used. + tts: bool + Indicates if the message should be sent using text-to-speech. + file: :class:`File` + The file to upload. This cannot be mixed with ``files`` parameter. + files: List[:class:`File`] + A list of files to send with the content. This cannot be mixed with the + ``file`` parameter. + embed: :class:`Embed` + The rich embed for the content to send. This cannot be mixed with + ``embeds`` parameter. + embeds: List[:class:`Embed`] + A list of embeds to send with the content. Maximum of 10. This cannot + be mixed with the ``embed`` parameter. + + Raises + -------- + HTTPException + Sending the message failed. + NotFound + This webhook was not found. + Forbidden + The authorization token for the webhook is incorrect. + InvalidArgument + You specified both ``embed`` and ``embeds`` or the length of + ``embeds`` was invalid. + + Returns + --------- + Optional[:class:`Message`] + The message that was sent. + """ + + payload = {} + + if files is not None and file is not None: + raise InvalidArgument("Cannot mix file and files keyword arguments.") + if embeds is not None and embed is not None: + raise InvalidArgument("Cannot mix embed and embeds keyword arguments.") + + if embeds is not None: + if len(embeds) > 10: + raise InvalidArgument("embeds has a maximum of 10 elements.") + payload["embeds"] = [e.to_dict() for e in embeds] + + if embed is not None: + payload["embeds"] = [embed.to_dict()] + + if content is not None: + payload["content"] = str(content) + + payload["tts"] = tts + if avatar_url: + payload["avatar_url"] = avatar_url + if username: + payload["username"] = username + + if file is not None: + try: + to_pass = (file.filename, file.open_file(), "application/octet-stream") + return self._adapter.execute_webhook(wait=wait, file=to_pass, payload=payload) + finally: + file.close() + elif files is not None: + try: + to_pass = [ + (file.filename, file.open_file(), "application/octet-stream") for file in files + ] + return self._adapter.execute_webhook(wait=wait, files=to_pass, payload=payload) + finally: + for file in files: + file.close() + else: + return self._adapter.execute_webhook(wait=wait, payload=payload) + + def execute(self, *args, **kwargs): + """An alias for :meth:`~.Webhook.send`.""" + return self.send(*args, **kwargs) diff --git a/docs/Makefile b/docs/Makefile index 3798b8efd..49c3b9536 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -14,9 +14,6 @@ help: .PHONY: help Makefile -init: - cd .. && pipenv lock -r --dev > docs/requirements.txt && echo 'git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py-1.0' >> docs/requirements.txt - # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile diff --git a/docs/guide_cog_creation.rst b/docs/guide_cog_creation.rst index f70e4e414..539859b84 100644 --- a/docs/guide_cog_creation.rst +++ b/docs/guide_cog_creation.rst @@ -18,7 +18,7 @@ Getting started --------------- To start off, be sure that you have installed Python 3.6.2 or higher (3.6.6 or higher on Windows). -Open a terminal or command prompt and type :code:`pip install --process-dependency-links -U git+https://github.com/Cog-Creators/Red-DiscordBot@V3/develop#egg=redbot[test]` +Open a terminal or command prompt and type :code:`pip install -U git+https://github.com/Cog-Creators/Red-DiscordBot@V3/develop#egg=redbot[test]` (note that if you get an error with this, try again but put :code:`python -m` in front of the command This will install the latest version of V3. diff --git a/docs/install_linux_mac.rst b/docs/install_linux_mac.rst index 8d58ed267..8262cc828 100644 --- a/docs/install_linux_mac.rst +++ b/docs/install_linux_mac.rst @@ -148,19 +148,19 @@ To install without audio support: .. code-block:: none - pip3 install -U --process-dependency-links --no-cache-dir Red-DiscordBot + pip3 install -U Red-DiscordBot Or, to install with audio support: .. code-block:: none - pip3 install -U --process-dependency-links --no-cache-dir Red-DiscordBot[voice] + pip3 install -U Red-DiscordBot[voice] Or, install with audio and MongoDB support: .. code-block:: none - pip3 install -U --process-dependency-links --no-cache-dir Red-DiscordBot[voice,mongo] + pip3 install -U Red-DiscordBot[voice,mongo] .. note:: diff --git a/docs/install_windows.rst b/docs/install_windows.rst index db13f7360..82e4ce036 100644 --- a/docs/install_windows.rst +++ b/docs/install_windows.rst @@ -40,19 +40,19 @@ Installing Red .. code-block:: none - python -m pip install -U --process-dependency-links --no-cache-dir Red-DiscordBot + python -m pip install -U Red-DiscordBot * With audio: .. code-block:: none - python -m pip install -U --process-dependency-links --no-cache-dir Red-DiscordBot[voice] + python -m pip install -U Red-DiscordBot[voice] * With audio and MongoDB support: .. code-block:: none - python -m pip install -U --process-dependency-links --no-cache-dir Red-DiscordBot[voice,mongo] + python -m pip install -U Red-DiscordBot[voice,mongo] .. note:: diff --git a/make.bat b/make.bat index 812ad2ec7..492871499 100644 --- a/make.bat +++ b/make.bat @@ -1,6 +1,6 @@ @echo off -if "%1"=="" goto help +if [%1] == [] goto help REM This allows us to expand variables at execution setlocal ENABLEDELAYEDEXPANSION @@ -21,6 +21,17 @@ exit /B %ERRORLEVEL% black -l 99 -N --check !PYFILES! exit /B %ERRORLEVEL% +:update_vendor +if [%REF%] == [] ( + set REF2="rewrite" +) else ( + set REF2=%REF% +) +pip install --upgrade --no-deps -t . https://github.com/Rapptz/discord.py/archive/!REF2!.tar.gz#egg=discord.py +del /S /Q "discord.py*.egg-info" +for /F %%i in ('dir /S /B discord.py*.egg-info') do rmdir /S /Q %%i +goto reformat + :help echo Usage: echo make ^ @@ -28,3 +39,5 @@ echo. echo Commands: echo reformat Reformat all .py files being tracked by git. echo stylecheck Check which tracked .py files need reformatting. +echo update_vendor Update vendored discord.py library to %%REF%%, which defaults to +echo "rewrite" diff --git a/redbot/launcher.py b/redbot/launcher.py index 0904a94cb..b52d4e454 100644 --- a/redbot/launcher.py +++ b/redbot/launcher.py @@ -115,18 +115,7 @@ def update_red(dev=False, voice=False, mongo=False, docs=False, test=False): package = "Red-DiscordBot" if egg_l: package += "[{}]".format(", ".join(egg_l)) - arguments = [ - interpreter, - "-m", - "pip", - "install", - "-U", - "-I", - "--no-cache-dir", - "--force-reinstall", - "--process-dependency-links", - package, - ] + arguments = [interpreter, "-m", "pip", "install", "-U", package] if not is_venv(): arguments.append("--user") code = subprocess.call(arguments) diff --git a/setup.py b/setup.py index bdb05ec3a..17cd7c319 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ install_requires = [ "attrs==18.2.0", "chardet==3.0.4", "colorama==0.4.1", - "discord.py>=1.0.0a0", "distro==1.3.0; sys_platform == 'linux'", "fuzzywuzzy==0.17.0", "idna-ssl==1.1.0", @@ -70,11 +69,6 @@ if os.name == "nt": python_requires = ">=3.6.6,<3.8" -def get_dependency_links(): - with open("dependency_links.txt") as file: - return file.read().splitlines() - - def check_compiler_available(): m = ccompiler.new_compiler() @@ -102,15 +96,12 @@ if __name__ == "__main__": next(r for r in install_requires if r.lower().startswith("python-levenshtein")) ) - if "READTHEDOCS" in os.environ: - install_requires.remove( - next(r for r in install_requires if r.lower().startswith("discord.py")) - ) - setup( name="Red-DiscordBot", version=get_version(), - packages=find_packages(include=["redbot", "redbot.*"]), + packages=( + find_packages(include=("redbot", "redbot.*")) + ["discord", "discord.ext.commands"] + ), package_data={"": ["locales/*.po", "data/*", "data/**/*"]}, url="https://github.com/Cog-Creators/Red-DiscordBot", license="GPLv3", @@ -139,6 +130,5 @@ if __name__ == "__main__": }, python_requires=python_requires, install_requires=install_requires, - dependency_links=get_dependency_links(), extras_require=extras_require, ) diff --git a/tox.ini b/tox.ini index 2f34b0cc3..a453d4a07 100644 --- a/tox.ini +++ b/tox.ini @@ -16,8 +16,6 @@ description = Run unit tests with pytest whitelist_externals = pytest extras = voice, test, mongo -deps = - -r{toxinidir}/dependency_links.txt commands = python -m compileall ./redbot/cogs pytest From 6d5762d71148c93378f22e3f741cf94cec85ecfe Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Mon, 28 Jan 2019 14:38:43 +1100 Subject: [PATCH 7/9] Move Red-Lavalink to main requirements Signed-off-by: Toby Harradine --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 17cd7c319..a79dd6c4d 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ install_requires = [ "pyyaml==3.13", "raven==6.10.0", "raven-aiohttp==0.7.0", + "red-lavalink==0.2.0", "schema==0.6.8", "websockets==6.0", "yarl==1.3.0", @@ -59,7 +60,7 @@ extras_require = { "sphinxcontrib-websupport==1.1.0", "urllib3==1.24.1", ], - "voice": ["red-lavalink==0.2.0"], + "voice": [], "style": ["black==18.9b0", "click==7.0", "toml==0.10.0"], } From 5c1c6e1f0385ce89850dc079adb07e9954b6999e Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Mon, 28 Jan 2019 14:49:55 +1100 Subject: [PATCH 8/9] Remove version from help message Signed-off-by: Toby Harradine --- redbot/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redbot/__main__.py b/redbot/__main__.py index f990c6553..38ef2e5e0 100644 --- a/redbot/__main__.py +++ b/redbot/__main__.py @@ -111,7 +111,7 @@ def list_instances(): def main(): - description = "Red - Version {}".format(__version__) + description = "Red V3" cli_flags = parse_cli_flags(sys.argv[1:]) if cli_flags.list_instances: list_instances() From 91258fea780cc12adb930603866c6a05c8aae594 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Mon, 28 Jan 2019 14:52:14 +1100 Subject: [PATCH 9/9] Bump version to 3.0.0 Signed-off-by: Toby Harradine --- redbot/core/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redbot/core/__init__.py b/redbot/core/__init__.py index 049f4e4a5..9876bd032 100644 --- a/redbot/core/__init__.py +++ b/redbot/core/__init__.py @@ -148,5 +148,5 @@ class VersionInfo: ) -__version__ = "3.0.0rc3.post1" +__version__ = "3.0.0" version_info = VersionInfo.from_str(__version__)