diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..5d94153b1 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,12 @@ +language: python +python: + - "3.5.3" + - "3.6.1" +install: + - pip install -r requirements.txt +script: + - python -m compileall ./cogs + - python -m pytest +cache: pip +notifications: + email: false \ No newline at end of file diff --git a/cogs/__init__.py b/cogs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/__init__.py b/core/__init__.py index e69de29bb..606246c14 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -0,0 +1 @@ +from core.config import Config \ No newline at end of file diff --git a/core/bot.py b/core/bot.py index 0e9e66cf5..a3363b08c 100644 --- a/core/bot.py +++ b/core/bot.py @@ -1,6 +1,7 @@ from discord.ext import commands from collections import Counter -from core.settings import CoreDB + +from core import Config from enum import Enum import os @@ -8,17 +9,33 @@ import os class Red(commands.Bot): def __init__(self, cli_flags, **kwargs): self._shutdown_mode = ExitCodes.CRITICAL - self.db = CoreDB("core/data/settings.json", - relative_path=False) + self.db = Config.get_core_conf(force_registration=True) + + self.db.register_global( + token=None, + prefix=[], + packages=[], + coowners=[], + whitelist=[], + blacklist=[] + ) + + self.db.register_guild( + prefix=[], + whitelist=[], + blacklist=[], + admin_role=None, + mod_role=None + ) def prefix_manager(bot, message): if not cli_flags.prefix: - global_prefix = self.db.get_global("prefix", []) + global_prefix = self.db.prefix() else: global_prefix = cli_flags.prefix if message.guild is None: return global_prefix - server_prefix = self.db.get(message.guild, "prefix", []) + server_prefix = self.db.guild(message.guild).prefix() return server_prefix if server_prefix else global_prefix if "command_prefix" not in kwargs: @@ -30,7 +47,7 @@ class Red(commands.Bot): async def is_owner(self, user, allow_coowners=True): if allow_coowners: - if user.id in self.db.get_global("coowners", []): + if user.id in self.db.coowners(): return True return await super().is_owner(user) @@ -65,7 +82,7 @@ class Red(commands.Bot): for package in self.extensions: if package.startswith("cogs."): loaded.append(package) - await self.db.set_global("packages", loaded) + await self.db.set("packages", loaded) class ExitCodes(Enum): diff --git a/core/checks.py b/core/checks.py index 1852eaf49..b823b8fa0 100644 --- a/core/checks.py +++ b/core/checks.py @@ -1,3 +1,4 @@ +import discord from discord.ext import commands @@ -23,8 +24,12 @@ def mod_or_permissions(**perms): if ctx.guild is None: return has_perms_or_is_owner author = ctx.author - mod_role = ctx.bot.db.get_mod_role(ctx.guild) - admin_role = ctx.bot.db.get_admin_role(ctx.guild) + mod_role_id = ctx.bot.db.guild(ctx.guild).mod_role() + admin_role_id = ctx.bot.db.guild(ctx.guild).admin_role() + + mod_role = discord.utils.get(ctx.guild.roles, id=mod_role_id) + admin_role = discord.utils.get(ctx.guild.roles, id=admin_role_id) + is_staff = mod_role in author.roles or admin_role in author.roles is_guild_owner = author == ctx.guild.owner @@ -40,7 +45,7 @@ def admin_or_permissions(**perms): return has_perms_or_is_owner author = ctx.author is_guild_owner = author == ctx.guild.owner - admin_role = ctx.bot.db.get_admin_role(ctx.guild) + admin_role = ctx.bot.db.guild(ctx.guild).admin_role() return admin_role in author.roles or has_perms_or_is_owner or is_guild_owner diff --git a/core/cli.py b/core/cli.py index 81193985b..a985ea2b9 100644 --- a/core/cli.py +++ b/core/cli.py @@ -1,3 +1,4 @@ +import argparse import asyncio @@ -19,7 +20,7 @@ def interactive_config(red, token_set, prefix_set): print("That doesn't look like a valid token.") token = "" if token: - loop.run_until_complete(red.db.set_global("token", token)) + loop.run_until_complete(red.db.set("token", token)) if not prefix_set: prefix = "" @@ -36,6 +37,50 @@ def interactive_config(red, token_set, prefix_set): if not confirm("> "): prefix = "" if prefix: - loop.run_until_complete(red.db.set_global("prefix", [prefix])) + loop.run_until_complete(red.db.set("prefix", [prefix])) return token + + +def parse_cli_flags(): + parser = argparse.ArgumentParser(description="Red - Discord Bot") + parser.add_argument("--owner", help="ID of the owner. Only who hosts " + "Red should be owner, this has " + "security implications") + parser.add_argument("--prefix", "-p", action="append", + help="Global prefix. Can be multiple") + parser.add_argument("--no-prompt", + action="store_true", + help="Disables console inputs. Features requiring " + "console interaction could be disabled as a " + "result") + parser.add_argument("--no-cogs", + action="store_true", + help="Starts Red with no cogs loaded, only core") + parser.add_argument("--self-bot", + action='store_true', + help="Specifies if Red should log in as selfbot") + parser.add_argument("--not-bot", + action='store_true', + help="Specifies if the token used belongs to a bot " + "account.") + parser.add_argument("--dry-run", + action="store_true", + help="Makes Red quit with code 0 just before the " + "login. This is useful for testing the boot " + "process.") + parser.add_argument("--debug", + action="store_true", + help="Sets the loggers level as debug") + parser.add_argument("--dev", + action="store_true", + help="Enables developer mode") + + args = parser.parse_args() + + if args.prefix: + args.prefix = sorted(args.prefix, reverse=True) + else: + args.prefix = [] + + return args \ No newline at end of file diff --git a/core/config.py b/core/config.py new file mode 100644 index 000000000..a800e912c --- /dev/null +++ b/core/config.py @@ -0,0 +1,521 @@ +from pathlib import Path + +from core.drivers.red_json import JSON as JSONDriver +from core.drivers.red_mongo import Mongo +import logging + +from typing import Callable + +log = logging.getLogger("red.config") + +class BaseConfig: + def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False, + hash_uuid=True, collection="GLOBAL", collection_uuid=None, + defaults={}): + self.cog_name = cog_name + if hash_uuid: + self.uuid = str(hash(unique_identifier)) + else: + self.uuid = unique_identifier + self.driver_spawn = driver_spawn + self._driver = None + self.collection = collection + self.collection_uuid = collection_uuid + + self.force_registration = force_registration + + try: + self.driver.maybe_add_ident(self.uuid) + except AttributeError: + pass + + self.driver_getmap = { + "GLOBAL": self.driver.get_global, + "GUILD": self.driver.get_guild, + "CHANNEL": self.driver.get_channel, + "ROLE": self.driver.get_role, + "USER": self.driver.get_user + } + + self.driver_setmap = { + "GLOBAL": self.driver.set_global, + "GUILD": self.driver.set_guild, + "CHANNEL": self.driver.set_channel, + "ROLE": self.driver.set_role, + "USER": self.driver.set_user + } + + self.curr_key = None + + self.unsettable_keys = ("cog_name", "cog_identifier", "_id", + "guild_id", "channel_id", "role_id", + "user_id", "uuid") + self.invalid_keys = ( + "driver_spawn", + "_driver", "collection", + "collection_uuid", "force_registration" + ) + + self.defaults = defaults if defaults else { + "GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {}, + "MEMBER": {}, "USER": {}} + + @classmethod + def get_conf(cls, cog_instance: object, unique_identifier: int=0, + force_registration: bool=False): + """ + Gets a config object that cog's can use to safely store data. The + backend to this is totally modular and can easily switch between + JSON and a DB. However, when changed, all data will likely be lost + unless cogs write some converters for their data. + + Positional Arguments: + cog_instance - The cog `self` object, can be passed in from your + cog's __init__ method. + + Keyword Arguments: + unique_identifier - a random integer or string that is used to + differentiate your cog from any other named the same. This way we + can safely store data for multiple cogs that are named the same. + + YOU SHOULD USE THIS. + + force_registration - A flag which will cause the Config object to + throw exceptions if you try to get/set data keys that you have + not pre-registered. I highly recommend you ENABLE this as it + will help reduce dumb typo errors. + """ + + url = None # TODO: get mongo url + port = None # TODO: get mongo port + + def spawn_mongo_driver(): + return Mongo(url, port) + + # TODO: Determine which backend users want, default to JSON + + cog_name = cog_instance.__class__.__name__ + + driver_spawn = JSONDriver(cog_name) + + return cls(cog_name=cog_name, unique_identifier=unique_identifier, + driver_spawn=driver_spawn, force_registration=force_registration) + + @classmethod + def get_core_conf(cls, force_registration: bool=False): + core_data_path = Path.cwd() / 'core' / '.data' + driver_spawn = JSONDriver("Core", data_path_override=core_data_path) + return cls(cog_name="Core", driver_spawn=driver_spawn, + unique_identifier=0, + force_registration=force_registration) + + @property + def driver(self): + if self._driver is None: + try: + self._driver = self.driver_spawn() + except TypeError: + return self.driver_spawn + + return self._driver + + def __getattr__(self, key): + """This should be used to return config key data as determined by + `self.collection` and `self.collection_uuid`.""" + raise NotImplemented + + def __setattr__(self, key, value): + if 'defaults' in self.__dict__: # Necessary to let the cog load + restricted = list(self.defaults[self.collection].keys()) + \ + list(self.unsettable_keys) + if key in restricted: + raise ValueError("Not allowed to dynamically set attributes of" + " unsettable_keys: {}".format(restricted)) + else: + self.__dict__[key] = value + else: + self.__dict__[key] = value + + def clear(self): + """Clears all values in the current context ONLY.""" + raise NotImplemented + + def set(self, key, value): + """This should set config key with value `value` in the + corresponding collection as defined by `self.collection` and + `self.collection_uuid`.""" + raise NotImplemented + + def guild(self, guild): + """This should return a `BaseConfig` instance with the corresponding + `collection` and `collection_uuid`.""" + raise NotImplemented + + def channel(self, channel): + """This should return a `BaseConfig` instance with the corresponding + `collection` and `collection_uuid`.""" + raise NotImplemented + + def role(self, role): + """This should return a `BaseConfig` instance with the corresponding + `collection` and `collection_uuid`.""" + raise NotImplemented + + def member(self, member): + """This should return a `BaseConfig` instance with the corresponding + `collection` and `collection_uuid`.""" + raise NotImplemented + + def user(self, user): + """This should return a `BaseConfig` instance with the corresponding + `collection` and `collection_uuid`.""" + raise NotImplemented + + def register_global(self, **global_defaults): + """ + Registers a new dict of global defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param global_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in global_defaults.items(): + try: + self._register_global(k, v) + except KeyError: + log.exception("Bad default global key.") + + def _register_global(self, key, default=None): + """Registers a global config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["GLOBAL"][key] = default + + def register_guild(self, **guild_defaults): + """ + Registers a new dict of guild defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param guild_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in guild_defaults.items(): + try: + self._register_guild(k, v) + except KeyError: + log.exception("Bad default guild key.") + + def _register_guild(self, key, default=None): + """Registers a guild config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["GUILD"][key] = default + + def register_channel(self, **channel_defaults): + """ + Registers a new dict of channel defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param channel_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in channel_defaults.items(): + try: + self._register_channel(k, v) + except KeyError: + log.exception("Bad default channel key.") + + def _register_channel(self, key, default=None): + """Registers a channel config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["CHANNEL"][key] = default + + def register_role(self, **role_defaults): + """ + Registers a new dict of role defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param role_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in role_defaults.items(): + try: + self._register_role(k, v) + except KeyError: + log.exception("Bad default role key.") + + def _register_role(self, key, default=None): + """Registers a role config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["ROLE"][key] = default + + def register_member(self, **member_defaults): + """ + Registers a new dict of member defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param member_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in member_defaults.items(): + try: + self._register_member(k, v) + except KeyError: + log.exception("Bad default member key.") + + def _register_member(self, key, default=None): + """Registers a member config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["MEMBER"][key] = default + + def register_user(self, **user_defaults): + """ + Registers a new dict of user defaults. This function should + be called EVERY TIME the cog loads (aka just do it in + __init__)! + + :param user_defaults: Each key should be the key you want to + access data by and the value is the default value of that + key. + :return: + """ + for k, v in user_defaults.items(): + try: + self._register_user(k, v) + except KeyError: + log.exception("Bad default user key.") + + def _register_user(self, key, default=None): + """Registers a user config key `key`""" + if key in self.unsettable_keys: + raise KeyError("Attempt to use restricted key: '{}'".format(key)) + elif not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + self.defaults["USER"][key] = default + + +class Config(BaseConfig): + """ + Config object created by `Config.get_conf()` + + This configuration object is designed to make backend data + storage mechanisms pluggable. It also is designed to + help a cog developer make fewer mistakes (such as + typos) when dealing with cog data and to make those mistakes + apparent much faster in the design process. + + It also has the capability to safely store data between cogs + that share the same name. + + There are two main components to this config object. First, + you have the ability to get data on a level specific basis. + The seven levels available are: global, guild, channel, role, + member, user, and misc. + + The second main component is registering default values for + data in each of the levels. This functionality is OPTIONAL + and must be explicitly enabled when creating the Config object + using the kwarg `force_registration=True`. + + Basic Usage: + Creating a Config object: + Use the `Config.get_conf()` class method to create new + Config objects. + + See the `Config.get_conf()` documentation for more + information. + + Registering Default Values (optional): + You can register default values for data at all levels + EXCEPT misc. + + Simply pass in the key/value pairs as keyword arguments to + the respective function. + + e.g.: conf_obj.register_global(enabled=True) + conf_obj.register_guild(likes_red=True) + + Retrieving data by attributes: + Since I registered the "enabled" key in the previous example + at the global level I can now do: + + conf_obj.enabled() + + which will retrieve the current value of the "enabled" + key, making use of the default of "True". I can also do + the same for the guild key "likes_red": + + conf_obj.guild(guild_obj).likes_red() + + If I elected to not register default values, you can provide them + when you try to access the key: + + conf_obj.no_default(default=True) + + However if you do not provide a default and you do not register + defaults, accessing the attribute will return "None". + + Saving data: + This is accomplished by using the `set` function available at + every level. + + e.g.: conf_obj.set("enabled", False) + conf_obj.guild(guild_obj).set("likes_red", False) + + If `force_registration` was enabled when the config object + was created you will only be allowed to save keys that you + have registered. + + Misc data is special, use `conf.misc()` and `conf.set_misc(value)` + respectively. + """ + + def __getattr__(self, key) -> Callable: + """ + Until I've got a better way to do this I'm just gonna fake __call__ + + :param key: + :return: lambda function with kwarg + """ + return self._get_value_from_key(key) + + def _get_value_from_key(self, key) -> Callable: + try: + default = self.defaults[self.collection][key] + except KeyError as e: + if self.force_registration: + raise AttributeError("Key '{}' not registered!".format(key)) from e + default = None + + self.curr_key = key + + if self.collection != "MEMBER": + ret = lambda default=default: self.driver_getmap[self.collection]( + self.cog_name, self.uuid, self.collection_uuid, key, + default=default) + else: + mid, sid = self.collection_uuid + ret = lambda default=default: self.driver.get_member( + self.cog_name, self.uuid, mid, sid, key, + default=default) + return ret + + def get(self, key, default=None): + """ + Included as an alternative to registering defaults. + + :param key: + :param default: + :return: + """ + + try: + return getattr(self, key)(default=default) + except AttributeError: + return + + async def set(self, key, value): + # Notice to future developers: + # This code was commented to allow users to set keys without having to register them. + # That being said, if they try to get keys without registering them + # things will blow up. I do highly recommend enforcing the key registration. + + if key in self.unsettable_keys or key in self.invalid_keys: + raise KeyError("Restricted key name, please use another.") + + if self.force_registration and key not in self.defaults[self.collection]: + raise AttributeError("Key '{}' not registered!".format(key)) + + if not key.isidentifier(): + raise RuntimeError("Invalid key name, must be a valid python variable" + " name.") + + if self.collection == "GLOBAL": + await self.driver.set_global(self.cog_name, self.uuid, key, value) + elif self.collection == "MEMBER": + mid, sid = self.collection_uuid + await self.driver.set_member(self.cog_name, self.uuid, mid, sid, + key, value) + elif self.collection in self.driver_setmap: + func = self.driver_setmap[self.collection] + await func(self.cog_name, self.uuid, self.collection_uuid, key, value) + + async def clear(self): + await self.driver_setmap[self.collection]( + self.cog_name, self.uuid, self.collection_uuid, None, None, + clear=True) + + def guild(self, guild): + new = type(self)(self.cog_name, self.uuid, self.driver, + hash_uuid=False, defaults=self.defaults) + new.collection = "GUILD" + new.collection_uuid = guild.id + new._driver = None + return new + + def channel(self, channel): + new = type(self)(self.cog_name, self.uuid, self.driver, + hash_uuid=False, defaults=self.defaults) + new.collection = "CHANNEL" + new.collection_uuid = channel.id + new._driver = None + return new + + def role(self, role): + new = type(self)(self.cog_name, self.uuid, self.driver, + hash_uuid=False, defaults=self.defaults) + new.collection = "ROLE" + new.collection_uuid = role.id + new._driver = None + return new + + def member(self, member): + guild = member.guild + new = type(self)(self.cog_name, self.uuid, self.driver, + hash_uuid=False, defaults=self.defaults) + new.collection = "MEMBER" + new.collection_uuid = (member.id, guild.id) + new._driver = None + return new + + def user(self, user): + new = type(self)(self.cog_name, self.uuid, self.driver, + hash_uuid=False, defaults=self.defaults) + new.collection = "USER" + new.collection_uuid = user.id + new._driver = None + return new diff --git a/core/drivers/__init__.py b/core/drivers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/core/drivers/red_base.py b/core/drivers/red_base.py new file mode 100644 index 000000000..1b58cc1dc --- /dev/null +++ b/core/drivers/red_base.py @@ -0,0 +1,45 @@ +class BaseDriver: + def get_global(self, cog_name, ident, collection_id, key, *, default=None): + raise NotImplementedError() + + def get_guild(self, cog_name, ident, guild_id, key, *, default=None): + raise NotImplementedError() + + def get_channel(self, cog_name, ident, channel_id, key, *, default=None): + raise NotImplementedError() + + def get_role(self, cog_name, ident, role_id, key, *, default=None): + raise NotImplementedError() + + def get_member(self, cog_name, ident, user_id, guild_id, key, *, + default=None): + raise NotImplementedError() + + def get_user(self, cog_name, ident, user_id, key, *, default=None): + raise NotImplementedError() + + def get_misc(self, cog_name, ident, *, default=None): + raise NotImplementedError() + + async def set_global(self, cog_name, ident, key, value, clear=False): + raise NotImplementedError() + + async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False): + raise NotImplementedError() + + async def set_channel(self, cog_name, ident, channel_id, key, value, + clear=False): + raise NotImplementedError() + + async def set_role(self, cog_name, ident, role_id, key, value, clear=False): + raise NotImplementedError() + + async def set_member(self, cog_name, ident, user_id, guild_id, key, value, + clear=False): + raise NotImplementedError() + + async def set_user(self, cog_name, ident, user_id, key, value, clear=False): + raise NotImplementedError() + + async def set_misc(self, cog_name, ident, value, clear=False): + raise NotImplementedError() diff --git a/core/drivers/red_json.py b/core/drivers/red_json.py new file mode 100644 index 000000000..073d37a7c --- /dev/null +++ b/core/drivers/red_json.py @@ -0,0 +1,135 @@ +from core.json_io import JsonIO +import os +from .red_base import BaseDriver + +from pathlib import Path + + +class JSON(BaseDriver): + def __init__(self, cog_name, *args, data_path_override: Path=None, + file_name_override: str="settings.json", **kwargs): + self.cog_name = cog_name + self.file_name = file_name_override + if data_path_override: + self.data_path = data_path_override + else: + self.data_path = Path.cwd() / 'cogs' / '.data' / self.cog_name + + self.data_path.mkdir(parents=True, exist_ok=True) + + self.data_path = self.data_path / self.file_name + + self.jsonIO = JsonIO(self.data_path) + + try: + self.data = self.jsonIO._load_json() + except FileNotFoundError: + self.data = {} + + def maybe_add_ident(self, ident: str): + if ident in self.data: + return + + self.data[ident] = {} + for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"): + if k not in self.data[ident]: + self.data[ident][k] = {} + + self.jsonIO._save_json(self.data) + + def get_global(self, cog_name, ident, _, key, *, default=None): + return self.data[ident]["GLOBAL"].get(key, default) + + def get_guild(self, cog_name, ident, guild_id, key, *, default=None): + guilddata = self.data[ident]["GUILD"].get(str(guild_id), {}) + return guilddata.get(key, default) + + def get_channel(self, cog_name, ident, channel_id, key, *, default=None): + channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {}) + return channeldata.get(key, default) + + def get_role(self, cog_name, ident, role_id, key, *, default=None): + roledata = self.data[ident]["ROLE"].get(str(role_id), {}) + return roledata.get(key, default) + + def get_member(self, cog_name, ident, user_id, guild_id, key, *, + default=None): + userdata = self.data[ident]["MEMBER"].get(str(user_id), {}) + guilddata = userdata.get(str(guild_id), {}) + return guilddata.get(key, default) + + def get_user(self, cog_name, ident, user_id, key, *, default=None): + userdata = self.data[ident]["USER"].get(str(user_id), {}) + return userdata.get(key, default) + + async def set_global(self, cog_name, ident, key, value, clear=False): + if clear: + self.data[ident]["GLOBAL"] = {} + else: + self.data[ident]["GLOBAL"][key] = value + await self.jsonIO._threadsafe_save_json(self.data) + + async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False): + guild_id = str(guild_id) + if clear: + self.data[ident]["GUILD"][guild_id] = {} + else: + try: + self.data[ident]["GUILD"][guild_id][key] = value + except KeyError: + self.data[ident]["GUILD"][guild_id] = {} + self.data[ident]["GUILD"][guild_id][key] = value + await self.jsonIO._threadsafe_save_json(self.data) + + async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False): + channel_id = str(channel_id) + if clear: + self.data[ident]["CHANNEL"][channel_id] = {} + else: + try: + self.data[ident]["CHANNEL"][channel_id][key] = value + except KeyError: + self.data[ident]["CHANNEL"][channel_id] = {} + self.data[ident]["CHANNEL"][channel_id][key] = value + await self.jsonIO._threadsafe_save_json(self.data) + + async def set_role(self, cog_name, ident, role_id, key, value, clear=False): + role_id = str(role_id) + if clear: + self.data[ident]["ROLE"][role_id] = {} + else: + try: + self.data[ident]["ROLE"][role_id][key] = value + except KeyError: + self.data[ident]["ROLE"][role_id] = {} + self.data[ident]["ROLE"][role_id][key] = value + await self.jsonIO._threadsafe_save_json(self.data) + + async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False): + user_id = str(user_id) + guild_id = str(guild_id) + if clear: + self.data[ident]["MEMBER"][user_id] = {} + else: + try: + self.data[ident]["MEMBER"][user_id][guild_id][key] = value + except KeyError: + if user_id not in self.data[ident]["MEMBER"]: + self.data[ident]["MEMBER"][user_id] = {} + if guild_id not in self.data[ident]["MEMBER"][user_id]: + self.data[ident]["MEMBER"][user_id][guild_id] = {} + + self.data[ident]["MEMBER"][user_id][guild_id][key] = value + await self.jsonIO._threadsafe_save_json(self.data) + + async def set_user(self, cog_name, ident, user_id, key, value, clear=False): + user_id = str(user_id) + if clear: + self.data[ident]["USER"][user_id] = {} + else: + try: + self.data[ident]["USER"][user_id][key] = value + except KeyError: + self.data[ident]["USER"][user_id] = {} + self.data[ident]["USER"][user_id][key] = value + await self.jsonIO._threadsafe_save_json(self.data) diff --git a/core/drivers/red_mongo.py b/core/drivers/red_mongo.py new file mode 100644 index 000000000..4bf83ca2c --- /dev/null +++ b/core/drivers/red_mongo.py @@ -0,0 +1,211 @@ +import pymongo as m +from .red_base import BaseDriver + + +class RedMongoException(Exception): + """Base Red Mongo Exception class""" + pass + + +class MultipleMatches(RedMongoException): + """Raised when multiple documents match a single cog_name and + cog_identifier pair.""" + pass + + +class MissingCollection(RedMongoException): + """Raised when a collection is missing from the mongo db""" + pass + + +class Mongo(BaseDriver): + def __init__(self, host, port=27017, admin_user=None, admin_pass=None, + **kwargs): + self.conn = m.MongoClient(host=host, port=port, **kwargs) + + self.admin_user = admin_user + self.admin_pass = admin_pass + + self._db = self.conn.red + if self.admin_user is not None and self.admin_pass is not None: + self._db.authenticate(self.admin_user, self.admin_pass) + + self._global = self._db.GLOBAL + self._guild = self._db.GUILD + self._channel = self._db.CHANNEL + self._role = self._db.ROLE + self._member = self._db.MEMBER + self._user = self._db.USER + + def get_global(self, cog_name, cog_identifier, _, key, *, default=None): + doc = self._global.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the GLOBAL" + " level: ({}, {})".format(cog_name, + cog_identifier)) + elif doc.count() == 1: + return doc[0].get(key, default) + return default + + def get_guild(self, cog_name, cog_identifier, guild_id, key, *, + default=None): + doc = self._guild.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier, + "guild_id": guild_id}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the GUILD" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, guild_id)) + elif doc.count() == 1: + return doc[0].get(key, default) + return default + + def get_channel(self, cog_name, cog_identifier, channel_id, key, *, + default=None): + doc = self._channel.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier, + "channel_id": channel_id}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the CHANNEL" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, channel_id)) + elif doc.count() == 1: + return doc[0].get(key, default) + return default + + def get_role(self, cog_name, cog_identifier, role_id, key, *, + default=None): + doc = self._role.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier, + "role_id": role_id}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the ROLE" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, role_id)) + elif doc.count() == 1: + return doc[0].get(key, default) + return default + + def get_member(self, cog_name, cog_identifier, user_id, guild_id, key, *, + default=None): + doc = self._member.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier, + "user_id": user_id, "guild_id": guild_id}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the MEMBER" + " level: ({}, {}, mid {}, sid {})".format( + cog_name, cog_identifier, user_id, + guild_id)) + elif doc.count() == 1: + return doc[0].get(key, default) + return default + + def get_user(self, cog_name, cog_identifier, user_id, key, *, + default=None): + doc = self._user.find( + {"cog_name": cog_name, "cog_identifier": cog_identifier, + "user_id": user_id}, + projection=[key, ], batch_size=2) + if doc.count() == 2: + raise MultipleMatches("Too many matching documents at the USER" + " level: ({}, {}, mid {})".format( + cog_name, cog_identifier, user_id)) + elif doc.count() == 1: + return doc[0].get(key, default) + else: + return default + + def set_global(self, cog_name, cog_identifier, key, value, clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier} + data = {"$set": {key: value}} + if self._global.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the GLOBAL" + " level: ({}, {})".format(cog_name, + cog_identifier)) + else: + if clear: + self._global.delete_one(filter) + else: + self._global.update_one(filter, data, upsert=True) + + def set_guild(self, cog_name, cog_identifier, guild_id, key, value, + clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, + "guild_id": guild_id} + data = {"$set": {key: value}} + if self._guild.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the GUILD" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, guild_id)) + else: + if clear: + self._guild.delete_one(filter) + else: + self._guild.update_one(filter, data, upsert=True) + + def set_channel(self, cog_name, cog_identifier, channel_id, key, value, + clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, + "channel_id": channel_id} + data = {"$set": {key: value}} + if self._channel.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the CHANNEL" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, channel_id)) + else: + if clear: + self._channel.delete_one(filter) + else: + self._channel.update_one(filter, data, upsert=True) + + def set_role(self, cog_name, cog_identifier, role_id, key, value, + clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, + "role_id": role_id} + data = {"$set": {key: value}} + if self._role.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the ROLE" + " level: ({}, {}, {})".format( + cog_name, cog_identifier, role_id)) + else: + if clear: + self._role.delete_one(filter) + else: + self._role.update_one(filter, data, upsert=True) + + def set_member(self, cog_name, cog_identifier, user_id, guild_id, key, + value, clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, + "guild_id": guild_id, "user_id": user_id} + data = {"$set": {key: value}} + if self._member.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the MEMBER" + " level: ({}, {}, mid {}, sid {})".format( + cog_name, cog_identifier, user_id, + guild_id)) + else: + if clear: + self._member.delete_one(filter) + else: + self._member.update_one(filter, data, upsert=True) + + def set_user(self, cog_name, cog_identifier, user_id, key, value, + clear=False): + filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, + "user_id": user_id} + data = {"$set": {key: value}} + if self._user.count(filter) > 1: + raise MultipleMatches("Too many matching documents at the USER" + " level: ({}, {}, mid {})".format( + cog_name, cog_identifier, user_id)) + else: + if clear: + self._user.delete_one(filter) + else: + self._user.update_one(filter, data, upsert=True) diff --git a/core/events.py b/core/events.py index 4dbad77ed..309f99d02 100644 --- a/core/events.py +++ b/core/events.py @@ -32,7 +32,7 @@ def init_events(bot, cli_flags): if cli_flags.no_cogs is False: print("Loading packages...") failed = [] - packages = bot.db.get_global("packages", []) + packages = bot.db.packages() for package in packages: try: @@ -69,7 +69,7 @@ def init_events(bot, cli_flags): print("\nInvite URL: {}\n".format(invite_url)) @bot.event - async def on_command_error(error, ctx): + async def on_command_error(ctx, error): if isinstance(error, commands.MissingRequiredArgument): await bot.send_cmd_help(ctx) elif isinstance(error, commands.BadArgument): diff --git a/core/global_checks.py b/core/global_checks.py index 7df92f7e3..d2f405da6 100644 --- a/core/global_checks.py +++ b/core/global_checks.py @@ -1,3 +1,6 @@ +from discord.ext import commands + + def init_global_checks(bot): @bot.check @@ -5,20 +8,19 @@ def init_global_checks(bot): if await bot.is_owner(ctx.author): return True - if bot.db.get_global("whitelist", []): - return ctx.author.id in bot.db.get_global("whitelist", []) + if bot.db.whitelist(): + return ctx.author.id in bot.db.whitelist() - return ctx.author.id not in bot.db.get_global("blacklist", []) + return ctx.author.id not in bot.db.blacklist() @bot.check - async def local_perms(ctx): + async def local_perms(ctx: commands.Context): if await bot.is_owner(ctx.author): return True elif ctx.message.guild is None: return True - guild_perms = bot.db.get_all(ctx.guild, {}) - local_blacklist = guild_perms.get("blacklist", []) - local_whitelist = guild_perms.get("whitelist", []) + local_blacklist = bot.db.guild(ctx.guild).blacklist() + local_whitelist = bot.db.guild(ctx.guild).whitelist() if local_whitelist: return ctx.author.id in local_whitelist diff --git a/core/json_io.py b/core/json_io.py index 823c2acb6..060cf09c8 100644 --- a/core/json_io.py +++ b/core/json_io.py @@ -7,7 +7,7 @@ from uuid import uuid4 # This is basically our old DataIO, except that it's now threadsafe # and just a base for much more elaborate classes - +from pathlib import Path log = logging.getLogger("red") @@ -17,25 +17,33 @@ MINIFIED = {"sort_keys": True, "separators": (',', ':')} class JsonIO: """Basic functions for atomic saving / loading of json files""" - _lock = asyncio.Lock() + def __init__(self, path: Path=Path.cwd()): + """ + :param path: Full path to file. + """ + self._lock = asyncio.Lock() + self.path = path - def _save_json(self, path, data, settings=PRETTY): - log.debug("Saving file {}".format(path)) - filename, _ = os.path.splitext(path) + # noinspection PyUnresolvedReferences + def _save_json(self, data, settings=PRETTY): + log.debug("Saving file {}".format(self.path)) + filename = self.path.stem tmp_file = "{}-{}.tmp".format(filename, uuid4().fields[0]) - with open(tmp_file, encoding="utf-8", mode="w") as f: + tmp_path = self.path.parent / tmp_file + with tmp_path.open(encoding="utf-8", mode="w") as f: json.dump(data, f, **settings) - os.replace(tmp_file, path) + tmp_path.replace(self.path) - async def _threadsafe_save_json(self, path, data, settings=PRETTY): + async def _threadsafe_save_json(self, data, settings=PRETTY): loop = asyncio.get_event_loop() - func = functools.partial(self._save_json, path, data, settings) + func = functools.partial(self._save_json, data, settings) with await self._lock: await loop.run_in_executor(None, func) - def _load_json(self, path): - log.debug("Reading file {}".format(path)) - with open(path, encoding='utf-8', mode="r") as f: + # noinspection PyUnresolvedReferences + def _load_json(self): + log.debug("Reading file {}".format(self.path)) + with self.path.open(encoding='utf-8', mode="r") as f: data = json.load(f) return data diff --git a/core/settings.py b/core/settings.py deleted file mode 100644 index c3e95fc41..000000000 --- a/core/settings.py +++ /dev/null @@ -1,124 +0,0 @@ -from core.utils.helpers import JsonGuildDB -import discord -import argparse - - -class CoreDB(JsonGuildDB): - """ - The central DB used by Red to store a variety - of settings, both global and guild specific - """ - - def get_admin_role(self, guild): - """Returns the guild's admin role - - Returns None if not set or if the role - couldn't be retrieved""" - _id = self.get_all(guild, {}).get("admin_role", None) - return discord.utils.get(guild.roles, id=_id) - - def get_mod_role(self, guild): - """Returns the guild's mod role - - Returns None if not set or if the role - couldn't be retrieved""" - _id = self.get_all(guild, {}).get("mod_role", None) - return discord.utils.get(guild.roles, id=_id) - - async def set_admin_role(self, role): - """Sets the admin role for the guild""" - if not isinstance(role, discord.Role): - raise TypeError("A valid Discord role must be passed.") - await self.set(role.guild, "admin_role", role.id) - - async def set_mod_role(self, role): - """Sets the mod role for the guild""" - if not isinstance(role, discord.Role): - raise TypeError("A valid Discord role must be passed.") - await self.set(role.guild, "mod_role", role.id) - - def get_global_whitelist(self): - """Returns the global whitelist""" - return self.get_global("whitelist", []) - - def get_global_blacklist(self): - """Returns the global whitelist""" - return self.get_global("blacklist", []) - - async def set_global_whitelist(self, whitelist): - """Sets the global whitelist""" - if not isinstance(list, whitelist): - raise TypeError("A list of IDs must be passed.") - await self.set_global("whitelist", whitelist) - - async def set_global_blacklist(self, blacklist): - """Sets the global blacklist""" - if not isinstance(list, blacklist): - raise TypeError("A list of IDs must be passed.") - await self.set_global("blacklist", blacklist) - - def get_guild_whitelist(self, guild): - """Returns the guild's whitelist""" - return self.get(guild, "whitelist", []) - - def get_guild_blacklist(self, guild): - """Returns the guild's blacklist""" - return self.get(guild, "blacklist", []) - - async def set_guild_whitelist(self, guild, whitelist): - """Sets the guild's whitelist""" - if not isinstance(guild, discord.Guild) or not isinstance(whitelist, list): - raise TypeError("A valid Discord guild and a list of IDs " - "must be passed.") - await self.set(guild, "whitelist", whitelist) - - async def set_guild_blacklist(self, guild, blacklist): - """Sets the guild's blacklist""" - if not isinstance(guild, discord.Guild) or not isinstance(blacklist, list): - raise TypeError("A valid Discord guild and a list of IDs " - "must be passed.") - await self.set(guild, "blacklist", blacklist) - - -def parse_cli_flags(): - parser = argparse.ArgumentParser(description="Red - Discord Bot") - parser.add_argument("--owner", help="ID of the owner. Only who hosts " - "Red should be owner, this has " - "security implications") - parser.add_argument("--prefix", "-p", action="append", - help="Global prefix. Can be multiple") - parser.add_argument("--no-prompt", - action="store_true", - help="Disables console inputs. Features requiring " - "console interaction could be disabled as a " - "result") - parser.add_argument("--no-cogs", - action="store_true", - help="Starts Red with no cogs loaded, only core") - parser.add_argument("--self-bot", - action='store_true', - help="Specifies if Red should log in as selfbot") - parser.add_argument("--not-bot", - action='store_true', - help="Specifies if the token used belongs to a bot " - "account.") - parser.add_argument("--dry-run", - action="store_true", - help="Makes Red quit with code 0 just before the " - "login. This is useful for testing the boot " - "process.") - parser.add_argument("--debug", - action="store_true", - help="Sets the loggers level as debug") - parser.add_argument("--dev", - action="store_true", - help="Enables developer mode") - - args = parser.parse_args() - - if args.prefix: - args.prefix = sorted(args.prefix, reverse=True) - else: - args.prefix = [] - - return args diff --git a/core/utils/helpers.py b/core/utils/helpers.py deleted file mode 100644 index 21f71cb79..000000000 --- a/core/utils/helpers.py +++ /dev/null @@ -1,216 +0,0 @@ -import os -import discord -import asyncio -import functools -import inspect -from collections import defaultdict -from core.json_io import JsonIO - - -GLOBAL_KEY = '__global__' -SENTINEL = object() - - -class JsonDB(JsonIO): - """ - A DB-like helper class to streamline the saving of json files - - Parameters: - - file_path: str - The path of the json file you want to create / access - create_dirs: bool=True - If True, it will create any missing directory leading to - the file you want to create - relative_path: bool=True - The file_path you specified is relative to the path from which - you're instantiating this object from - i.e. If you're in a package's folder and your file_path is - 'data/settings.json', these files will be created inside - the package's folder and not Red's root folder - default_value: Optional=None - Same behaviour as a defaultdict - """ - _caller = "" - - def __init__(self, file_path, **kwargs): - local = kwargs.pop("relative_path", True) - if local and not self._caller: - self._caller = self._get_caller_path() - - create_dirs = kwargs.pop("create_dirs", True) - default_value = kwargs.pop("default_value", SENTINEL) - self.autosave = kwargs.pop("autosave", False) - self.path = os.path.join(self._caller, file_path) - - file_exists = os.path.isfile(self.path) - - if create_dirs and not file_exists: - path, _ = os.path.split(self.path) - if path: - try: - os.makedirs(path) - except FileExistsError: - pass - - if file_exists: - # Might be worth looking into threadsafe ways for very large files - self._data = self._load_json(self.path) - else: - self._data = {} - self._blocking_save() - - if default_value is not SENTINEL: - def _get_default(): - return default_value - self._data = defaultdict(_get_default, self._data) - - self._loop = asyncio.get_event_loop() - self._task = functools.partial(self._threadsafe_save_json, self._data) - - async def set(self, key, value): - """Sets a DB's entry""" - self._data[key] = value - await self.save() - - def get(self, key, default=None): - """Returns a DB's entry""" - return self._data.get(key, default) - - async def remove(self, key): - """Removes a DB's entry""" - del self._data[key] - await self.save() - - async def pop(self, key, default=None): - """Removes and returns a DB's entry""" - value = self._data.pop(key, default) - await self.save() - return value - - async def wipe(self): - """Wipes DB""" - self._data = {} - await self.save() - - def all(self): - """Returns all DB's data""" - return self._data - - def _blocking_save(self): - """Using this should be avoided. Let's stick to threadsafe saves""" - self._save_json(self.path, self._data) - - async def save(self): - """Threadsafe save to file""" - await self._threadsafe_save_json(self.path, self._data) - - def _get_caller_path(self): - frame = inspect.stack()[2] - module = inspect.getmodule(frame[0]) - abspath = os.path.abspath(module.__file__) - return os.path.dirname(abspath) - - def __contains__(self, key): - return key in self._data - - def __getitem__(self, key): - return self._data[key] - - def __len__(self): - return len(self._data) - - def __repr__(self): - return "<{} {}>".format(self.__class__.__name__, self._data) - - -class JsonGuildDB(JsonDB): - """ - A DB-like helper class to streamline the saving of json files - This is a variant of JsonDB that allows for guild specific data - Global data is still allowed with dedicated methods - - Same parameters as JsonDB - """ - - def __init__(self, *args, **kwargs): - local = kwargs.get("relative_path", True) - if local and not self._caller: - self._caller = self._get_caller_path() - - super().__init__(*args, **kwargs) - - async def set(self, guild, key, value): - """Sets a guild's entry""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only set guild data') - if str(guild.id) not in self._data: - self._data[str(guild.id)] = {} - self._data[str(guild.id)][key] = value - await self.save() - - def get(self, guild, key, default=None): - """Returns a guild's entry""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only get guild data') - if str(guild.id) not in self._data: - return default - return self._data[str(guild.id)].get(key, default) - - async def remove(self, guild, key): - """Removes a guild's entry""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only remove guild data') - if str(guild.id) not in self._data: - raise KeyError('Guild data is not present') - del self._data[str(guild.id)][key] - await self.save() - - async def pop(self, guild, key, default=None): - """Removes and returns a guild's entry""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only remove guild data') - value = self._data.get(str(guild.id), {}).pop(key, default) - await self.save() - return value - - def get_all(self, guild, default=None): - """Returns all entries of a guild""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only get guild data') - return self._data.get(str(guild.id), default) - - async def remove_all(self, guild): - """Removes all entries of a guild""" - if not isinstance(guild, discord.Guild): - raise TypeError('Can only remove guilds') - await super().remove(str(guild.id)) - - async def set_global(self, key, value): - """Sets a global value""" - if GLOBAL_KEY not in self._data: - self._data[GLOBAL_KEY] = {} - self._data[GLOBAL_KEY][key] = value - await self.save() - - def get_global(self, key, default=None): - """Gets a global value""" - if GLOBAL_KEY not in self._data: - self._data[GLOBAL_KEY] = {} - - return self._data[GLOBAL_KEY].get(key, default) - - async def remove_global(self, key): - """Removes a global value""" - if GLOBAL_KEY not in self._data: - self._data[GLOBAL_KEY] = {} - del self._data[GLOBAL_KEY][key] - await self.save() - - async def pop_global(self, key, default=None): - """Removes and returns a global value""" - if GLOBAL_KEY not in self._data: - self._data[GLOBAL_KEY] = {} - value = self._data[GLOBAL_KEY].pop(key, default) - await self.save() - return value diff --git a/main.py b/main.py index 71cf08122..2613d0328 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,7 @@ from core.bot import Red, ExitCodes from core.global_checks import init_global_checks from core.events import init_events -from core.settings import parse_cli_flags -from core.cli import interactive_config, confirm +from core.cli import interactive_config, confirm, parse_cli_flags from core.core_commands import Core from core.dev_commands import Dev import asyncio @@ -65,8 +64,8 @@ if __name__ == '__main__': if cli_flags.dev: red.add_cog(Dev()) - token = os.environ.get("RED_TOKEN", red.db.get_global("token", None)) - prefix = cli_flags.prefix or red.db.get_global("prefix", []) + token = os.environ.get("RED_TOKEN", red.db.token()) + prefix = cli_flags.prefix or red.db.prefix() if token is None or not prefix: if cli_flags.no_prompt is False: @@ -89,11 +88,11 @@ if __name__ == '__main__': "a user account, remember that the --not-bot flag " "must be used. For self-bot functionalities instead, " "--self-bot") - db_token = red.db.get_global("token") + db_token = red.db.token() if db_token and not cli_flags.no_prompt: print("\nDo you want to reset the token? (y/n)") if confirm("> "): - loop.run_until_complete(red.db.remove_global("token")) + loop.run_until_complete(red.db.set("token", "")) print("Token has been reset.") except KeyboardInterrupt: log.info("Keyboard interrupt detected. Quitting...") diff --git a/requirements.txt b/requirements.txt index 0b21c2f7e..08a4d6f99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,5 @@ git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py[voice] -youtube_dl \ No newline at end of file +youtube_dl +pytest +git+https://github.com/pytest-dev/pytest-asyncio +pymongo \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..20efc7a69 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,101 @@ +from collections import namedtuple +from pathlib import Path + +import pytest +import random + +from core.bot import Red +from core.drivers import red_json +from core import Config + + +@pytest.fixture(scope="module") +def json_driver(tmpdir_factory): + driver = red_json.JSON( + "PyTest", + data_path_override=Path(str(tmpdir_factory.getbasetemp())) + ) + return driver + + +@pytest.fixture(scope="module") +def config(json_driver): + return Config( + cog_name="PyTest", + unique_identifier=0, + driver_spawn=json_driver) + + +@pytest.fixture(scope="module") +def config_fr(json_driver): + """ + Mocked config object with force_register enabled. + """ + return Config( + cog_name="PyTest", + unique_identifier=0, + driver_spawn=json_driver, + force_registration=True + ) + + +#region Dpy Mocks +@pytest.fixture(scope="module") +def empty_guild(): + mock_guild = namedtuple("Guild", "id members") + return mock_guild(random.randint(1, 999999999), []) + + +@pytest.fixture(scope="module") +def empty_channel(): + mock_channel = namedtuple("Channel", "id") + return mock_channel(random.randint(1, 999999999)) + + +@pytest.fixture(scope="module") +def empty_role(): + mock_role = namedtuple("Role", "id") + return mock_role(random.randint(1, 999999999)) + + +@pytest.fixture(scope="module") +def empty_member(empty_guild): + mock_member = namedtuple("Member", "id guild") + return mock_member(random.randint(1, 999999999), empty_guild) + + +@pytest.fixture(scope="module") +def empty_user(): + mock_user = namedtuple("User", "id") + return mock_user(random.randint(1, 999999999)) + + +@pytest.fixture(scope="module") +def empty_message(): + mock_msg = namedtuple("Message", "content") + return mock_msg("No content.") + + +@pytest.fixture(scope="module") +def ctx(empty_member, empty_channel, red): + mock_ctx = namedtuple("Context", "author guild channel message bot") + return mock_ctx(empty_member, empty_member.guild, empty_channel, + empty_message, red) +#endregion + + +#region Red Mock +@pytest.fixture +def red(monkeypatch, config_fr, event_loop): + from core.cli import parse_cli_flags + cli_flags = parse_cli_flags() + + description = "Red v3 - Alpha" + + monkeypatch.setattr("core.config.Config.get_core_conf", + lambda *args, **kwargs: config_fr) + + red = Red(cli_flags, description=description, pm_help=None, + loop=event_loop) + return red +#endregion \ No newline at end of file diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/core/test_config.py b/tests/core/test_config.py new file mode 100644 index 000000000..ff1c24256 --- /dev/null +++ b/tests/core/test_config.py @@ -0,0 +1,147 @@ +import pytest + + +#region Register Tests +def test_config_register_global(config): + config.register_global(enabled=False) + assert config.defaults["GLOBAL"]["enabled"] is False + assert config.enabled() is False + + +def test_config_register_global_badvalues(config): + with pytest.raises(RuntimeError): + config.register_global(**{"invalid var name": True}) + + +def test_config_register_guild(config, empty_guild): + config.register_guild(enabled=False, some_list=[], some_dict={}) + assert config.defaults["GUILD"]["enabled"] is False + assert config.defaults["GUILD"]["some_list"] == [] + assert config.defaults["GUILD"]["some_dict"] == {} + + assert config.guild(empty_guild).enabled() is False + assert config.guild(empty_guild).some_list() == [] + assert config.guild(empty_guild).some_dict() == {} + + +def test_config_register_channel(config, empty_channel): + config.register_channel(enabled=False) + assert config.defaults["CHANNEL"]["enabled"] is False + assert config.channel(empty_channel).enabled() is False + + +def test_config_register_role(config, empty_role): + config.register_role(enabled=False) + assert config.defaults["ROLE"]["enabled"] is False + assert config.role(empty_role).enabled() is False + + +def test_config_register_member(config, empty_member): + config.register_member(some_number=-1) + assert config.defaults["MEMBER"]["some_number"] == -1 + assert config.member(empty_member).some_number() == -1 + + +def test_config_register_user(config, empty_user): + config.register_user(some_value=None) + assert config.defaults["USER"]["some_value"] is None + assert config.user(empty_user).some_value() is None + + +def test_config_force_register_global(config_fr): + with pytest.raises(AttributeError): + config_fr.enabled() + + config_fr.register_global(enabled=True) + assert config_fr.enabled() is True +#endregion + + +#region Default Value Overrides +def test_global_default_override(config): + assert config.enabled(True) is True + assert config.get("enabled") is None + assert config.get("enabled", default=True) is True + + +def test_global_default_nofr(config): + assert config.nofr() is None + assert config.nofr(True) is True + assert config.get("nofr") is None + assert config.get("nofr", default=True) is True + + +def test_guild_default_override(config, empty_guild): + assert config.guild(empty_guild).enabled(True) is True + assert config.guild(empty_guild).get("enabled") is None + assert config.guild(empty_guild).get("enabled", default=True) is True + + +def test_channel_default_override(config, empty_channel): + assert config.channel(empty_channel).enabled(True) is True + assert config.channel(empty_channel).get("enabled") is None + assert config.channel(empty_channel).get("enabled", default=True) is True + + +def test_role_default_override(config, empty_role): + assert config.role(empty_role).enabled(True) is True + assert config.role(empty_role).get("enabled") is None + assert config.role(empty_role).get("enabled", default=True) is True + + +def test_member_default_override(config, empty_member): + assert config.member(empty_member).enabled(True) is True + assert config.member(empty_member).get("enabled") is None + assert config.member(empty_member).get("enabled", default=True) is True + + +def test_user_default_override(config, empty_user): + assert config.user(empty_user).some_value(True) is True + assert config.user(empty_user).get("some_value") is None + assert config.user(empty_user).get("some_value", default=True) is True +#endregion + + +#region Setting Values +@pytest.mark.asyncio +async def test_set_global(config): + await config.set("enabled", True) + assert config.enabled() is True + + +@pytest.mark.asyncio +async def test_set_global_badkey(config): + with pytest.raises(RuntimeError): + await config.set("this is a bad key", True) + + +@pytest.mark.asyncio +async def test_set_global_invalidkey(config): + with pytest.raises(KeyError): + await config.set("uuid", True) + + +@pytest.mark.asyncio +async def test_set_guild(config, empty_guild): + await config.guild(empty_guild).set("enabled", True) + assert config.guild(empty_guild).enabled() is True + + curr_list = config.guild(empty_guild).some_list([1, 2, 3]) + assert curr_list == [1, 2, 3] + curr_list.append(4) + + await config.guild(empty_guild).set("some_list", curr_list) + assert config.guild(empty_guild).some_list() == curr_list + + +@pytest.mark.asyncio +async def test_set_channel(config, empty_channel): + await config.channel(empty_channel).set("enabled", True) + assert config.channel(empty_channel).enabled() is True + + +@pytest.mark.asyncio +async def test_set_channel_no_register(config, empty_channel): + await config.channel(empty_channel).set("no_register", True) + assert config.channel(empty_channel).no_register() is True +#endregion diff --git a/tests/core/test_installation.py b/tests/core/test_installation.py new file mode 100644 index 000000000..e2be71cf6 --- /dev/null +++ b/tests/core/test_installation.py @@ -0,0 +1,6 @@ +import pytest + + +@pytest.mark.asyncio +async def test_can_init_bot(red): + assert red is not None \ No newline at end of file