mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-05 18:58:53 -05:00
[Core] Replaced JsonDB with Config (#770)
This commit is contained in:
parent
a8745297dc
commit
3988fbbc09
12
.travis.yml
Normal file
12
.travis.yml
Normal file
@ -0,0 +1,12 @@
|
||||
language: python
|
||||
python:
|
||||
- "3.5.3"
|
||||
- "3.6.1"
|
||||
install:
|
||||
- pip install -r requirements.txt
|
||||
script:
|
||||
- python -m compileall ./cogs
|
||||
- python -m pytest
|
||||
cache: pip
|
||||
notifications:
|
||||
email: false
|
||||
0
cogs/__init__.py
Normal file
0
cogs/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from core.config import Config
|
||||
31
core/bot.py
31
core/bot.py
@ -1,6 +1,7 @@
|
||||
from discord.ext import commands
|
||||
from collections import Counter
|
||||
from core.settings import CoreDB
|
||||
|
||||
from core import Config
|
||||
from enum import Enum
|
||||
import os
|
||||
|
||||
@ -8,17 +9,33 @@ import os
|
||||
class Red(commands.Bot):
|
||||
def __init__(self, cli_flags, **kwargs):
|
||||
self._shutdown_mode = ExitCodes.CRITICAL
|
||||
self.db = CoreDB("core/data/settings.json",
|
||||
relative_path=False)
|
||||
self.db = Config.get_core_conf(force_registration=True)
|
||||
|
||||
self.db.register_global(
|
||||
token=None,
|
||||
prefix=[],
|
||||
packages=[],
|
||||
coowners=[],
|
||||
whitelist=[],
|
||||
blacklist=[]
|
||||
)
|
||||
|
||||
self.db.register_guild(
|
||||
prefix=[],
|
||||
whitelist=[],
|
||||
blacklist=[],
|
||||
admin_role=None,
|
||||
mod_role=None
|
||||
)
|
||||
|
||||
def prefix_manager(bot, message):
|
||||
if not cli_flags.prefix:
|
||||
global_prefix = self.db.get_global("prefix", [])
|
||||
global_prefix = self.db.prefix()
|
||||
else:
|
||||
global_prefix = cli_flags.prefix
|
||||
if message.guild is None:
|
||||
return global_prefix
|
||||
server_prefix = self.db.get(message.guild, "prefix", [])
|
||||
server_prefix = self.db.guild(message.guild).prefix()
|
||||
return server_prefix if server_prefix else global_prefix
|
||||
|
||||
if "command_prefix" not in kwargs:
|
||||
@ -30,7 +47,7 @@ class Red(commands.Bot):
|
||||
|
||||
async def is_owner(self, user, allow_coowners=True):
|
||||
if allow_coowners:
|
||||
if user.id in self.db.get_global("coowners", []):
|
||||
if user.id in self.db.coowners():
|
||||
return True
|
||||
return await super().is_owner(user)
|
||||
|
||||
@ -65,7 +82,7 @@ class Red(commands.Bot):
|
||||
for package in self.extensions:
|
||||
if package.startswith("cogs."):
|
||||
loaded.append(package)
|
||||
await self.db.set_global("packages", loaded)
|
||||
await self.db.set("packages", loaded)
|
||||
|
||||
|
||||
class ExitCodes(Enum):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
|
||||
@ -23,8 +24,12 @@ def mod_or_permissions(**perms):
|
||||
if ctx.guild is None:
|
||||
return has_perms_or_is_owner
|
||||
author = ctx.author
|
||||
mod_role = ctx.bot.db.get_mod_role(ctx.guild)
|
||||
admin_role = ctx.bot.db.get_admin_role(ctx.guild)
|
||||
mod_role_id = ctx.bot.db.guild(ctx.guild).mod_role()
|
||||
admin_role_id = ctx.bot.db.guild(ctx.guild).admin_role()
|
||||
|
||||
mod_role = discord.utils.get(ctx.guild.roles, id=mod_role_id)
|
||||
admin_role = discord.utils.get(ctx.guild.roles, id=admin_role_id)
|
||||
|
||||
is_staff = mod_role in author.roles or admin_role in author.roles
|
||||
is_guild_owner = author == ctx.guild.owner
|
||||
|
||||
@ -40,7 +45,7 @@ def admin_or_permissions(**perms):
|
||||
return has_perms_or_is_owner
|
||||
author = ctx.author
|
||||
is_guild_owner = author == ctx.guild.owner
|
||||
admin_role = ctx.bot.db.get_admin_role(ctx.guild)
|
||||
admin_role = ctx.bot.db.guild(ctx.guild).admin_role()
|
||||
|
||||
return admin_role in author.roles or has_perms_or_is_owner or is_guild_owner
|
||||
|
||||
|
||||
49
core/cli.py
49
core/cli.py
@ -1,3 +1,4 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
|
||||
@ -19,7 +20,7 @@ def interactive_config(red, token_set, prefix_set):
|
||||
print("That doesn't look like a valid token.")
|
||||
token = ""
|
||||
if token:
|
||||
loop.run_until_complete(red.db.set_global("token", token))
|
||||
loop.run_until_complete(red.db.set("token", token))
|
||||
|
||||
if not prefix_set:
|
||||
prefix = ""
|
||||
@ -36,6 +37,50 @@ def interactive_config(red, token_set, prefix_set):
|
||||
if not confirm("> "):
|
||||
prefix = ""
|
||||
if prefix:
|
||||
loop.run_until_complete(red.db.set_global("prefix", [prefix]))
|
||||
loop.run_until_complete(red.db.set("prefix", [prefix]))
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def parse_cli_flags():
|
||||
parser = argparse.ArgumentParser(description="Red - Discord Bot")
|
||||
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
|
||||
"Red should be owner, this has "
|
||||
"security implications")
|
||||
parser.add_argument("--prefix", "-p", action="append",
|
||||
help="Global prefix. Can be multiple")
|
||||
parser.add_argument("--no-prompt",
|
||||
action="store_true",
|
||||
help="Disables console inputs. Features requiring "
|
||||
"console interaction could be disabled as a "
|
||||
"result")
|
||||
parser.add_argument("--no-cogs",
|
||||
action="store_true",
|
||||
help="Starts Red with no cogs loaded, only core")
|
||||
parser.add_argument("--self-bot",
|
||||
action='store_true',
|
||||
help="Specifies if Red should log in as selfbot")
|
||||
parser.add_argument("--not-bot",
|
||||
action='store_true',
|
||||
help="Specifies if the token used belongs to a bot "
|
||||
"account.")
|
||||
parser.add_argument("--dry-run",
|
||||
action="store_true",
|
||||
help="Makes Red quit with code 0 just before the "
|
||||
"login. This is useful for testing the boot "
|
||||
"process.")
|
||||
parser.add_argument("--debug",
|
||||
action="store_true",
|
||||
help="Sets the loggers level as debug")
|
||||
parser.add_argument("--dev",
|
||||
action="store_true",
|
||||
help="Enables developer mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prefix:
|
||||
args.prefix = sorted(args.prefix, reverse=True)
|
||||
else:
|
||||
args.prefix = []
|
||||
|
||||
return args
|
||||
521
core/config.py
Normal file
521
core/config.py
Normal file
@ -0,0 +1,521 @@
|
||||
from pathlib import Path
|
||||
|
||||
from core.drivers.red_json import JSON as JSONDriver
|
||||
from core.drivers.red_mongo import Mongo
|
||||
import logging
|
||||
|
||||
from typing import Callable
|
||||
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
class BaseConfig:
|
||||
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
|
||||
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
|
||||
defaults={}):
|
||||
self.cog_name = cog_name
|
||||
if hash_uuid:
|
||||
self.uuid = str(hash(unique_identifier))
|
||||
else:
|
||||
self.uuid = unique_identifier
|
||||
self.driver_spawn = driver_spawn
|
||||
self._driver = None
|
||||
self.collection = collection
|
||||
self.collection_uuid = collection_uuid
|
||||
|
||||
self.force_registration = force_registration
|
||||
|
||||
try:
|
||||
self.driver.maybe_add_ident(self.uuid)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
self.driver_getmap = {
|
||||
"GLOBAL": self.driver.get_global,
|
||||
"GUILD": self.driver.get_guild,
|
||||
"CHANNEL": self.driver.get_channel,
|
||||
"ROLE": self.driver.get_role,
|
||||
"USER": self.driver.get_user
|
||||
}
|
||||
|
||||
self.driver_setmap = {
|
||||
"GLOBAL": self.driver.set_global,
|
||||
"GUILD": self.driver.set_guild,
|
||||
"CHANNEL": self.driver.set_channel,
|
||||
"ROLE": self.driver.set_role,
|
||||
"USER": self.driver.set_user
|
||||
}
|
||||
|
||||
self.curr_key = None
|
||||
|
||||
self.unsettable_keys = ("cog_name", "cog_identifier", "_id",
|
||||
"guild_id", "channel_id", "role_id",
|
||||
"user_id", "uuid")
|
||||
self.invalid_keys = (
|
||||
"driver_spawn",
|
||||
"_driver", "collection",
|
||||
"collection_uuid", "force_registration"
|
||||
)
|
||||
|
||||
self.defaults = defaults if defaults else {
|
||||
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {},
|
||||
"MEMBER": {}, "USER": {}}
|
||||
|
||||
@classmethod
|
||||
def get_conf(cls, cog_instance: object, unique_identifier: int=0,
|
||||
force_registration: bool=False):
|
||||
"""
|
||||
Gets a config object that cog's can use to safely store data. The
|
||||
backend to this is totally modular and can easily switch between
|
||||
JSON and a DB. However, when changed, all data will likely be lost
|
||||
unless cogs write some converters for their data.
|
||||
|
||||
Positional Arguments:
|
||||
cog_instance - The cog `self` object, can be passed in from your
|
||||
cog's __init__ method.
|
||||
|
||||
Keyword Arguments:
|
||||
unique_identifier - a random integer or string that is used to
|
||||
differentiate your cog from any other named the same. This way we
|
||||
can safely store data for multiple cogs that are named the same.
|
||||
|
||||
YOU SHOULD USE THIS.
|
||||
|
||||
force_registration - A flag which will cause the Config object to
|
||||
throw exceptions if you try to get/set data keys that you have
|
||||
not pre-registered. I highly recommend you ENABLE this as it
|
||||
will help reduce dumb typo errors.
|
||||
"""
|
||||
|
||||
url = None # TODO: get mongo url
|
||||
port = None # TODO: get mongo port
|
||||
|
||||
def spawn_mongo_driver():
|
||||
return Mongo(url, port)
|
||||
|
||||
# TODO: Determine which backend users want, default to JSON
|
||||
|
||||
cog_name = cog_instance.__class__.__name__
|
||||
|
||||
driver_spawn = JSONDriver(cog_name)
|
||||
|
||||
return cls(cog_name=cog_name, unique_identifier=unique_identifier,
|
||||
driver_spawn=driver_spawn, force_registration=force_registration)
|
||||
|
||||
@classmethod
|
||||
def get_core_conf(cls, force_registration: bool=False):
|
||||
core_data_path = Path.cwd() / 'core' / '.data'
|
||||
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
|
||||
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
||||
unique_identifier=0,
|
||||
force_registration=force_registration)
|
||||
|
||||
@property
|
||||
def driver(self):
|
||||
if self._driver is None:
|
||||
try:
|
||||
self._driver = self.driver_spawn()
|
||||
except TypeError:
|
||||
return self.driver_spawn
|
||||
|
||||
return self._driver
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""This should be used to return config key data as determined by
|
||||
`self.collection` and `self.collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if 'defaults' in self.__dict__: # Necessary to let the cog load
|
||||
restricted = list(self.defaults[self.collection].keys()) + \
|
||||
list(self.unsettable_keys)
|
||||
if key in restricted:
|
||||
raise ValueError("Not allowed to dynamically set attributes of"
|
||||
" unsettable_keys: {}".format(restricted))
|
||||
else:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self.__dict__[key] = value
|
||||
|
||||
def clear(self):
|
||||
"""Clears all values in the current context ONLY."""
|
||||
raise NotImplemented
|
||||
|
||||
def set(self, key, value):
|
||||
"""This should set config key with value `value` in the
|
||||
corresponding collection as defined by `self.collection` and
|
||||
`self.collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def guild(self, guild):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def channel(self, channel):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def role(self, role):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def member(self, member):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def user(self, user):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def register_global(self, **global_defaults):
|
||||
"""
|
||||
Registers a new dict of global defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param global_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in global_defaults.items():
|
||||
try:
|
||||
self._register_global(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default global key.")
|
||||
|
||||
def _register_global(self, key, default=None):
|
||||
"""Registers a global config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["GLOBAL"][key] = default
|
||||
|
||||
def register_guild(self, **guild_defaults):
|
||||
"""
|
||||
Registers a new dict of guild defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param guild_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in guild_defaults.items():
|
||||
try:
|
||||
self._register_guild(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default guild key.")
|
||||
|
||||
def _register_guild(self, key, default=None):
|
||||
"""Registers a guild config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["GUILD"][key] = default
|
||||
|
||||
def register_channel(self, **channel_defaults):
|
||||
"""
|
||||
Registers a new dict of channel defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param channel_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in channel_defaults.items():
|
||||
try:
|
||||
self._register_channel(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default channel key.")
|
||||
|
||||
def _register_channel(self, key, default=None):
|
||||
"""Registers a channel config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["CHANNEL"][key] = default
|
||||
|
||||
def register_role(self, **role_defaults):
|
||||
"""
|
||||
Registers a new dict of role defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param role_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in role_defaults.items():
|
||||
try:
|
||||
self._register_role(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default role key.")
|
||||
|
||||
def _register_role(self, key, default=None):
|
||||
"""Registers a role config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["ROLE"][key] = default
|
||||
|
||||
def register_member(self, **member_defaults):
|
||||
"""
|
||||
Registers a new dict of member defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param member_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in member_defaults.items():
|
||||
try:
|
||||
self._register_member(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default member key.")
|
||||
|
||||
def _register_member(self, key, default=None):
|
||||
"""Registers a member config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["MEMBER"][key] = default
|
||||
|
||||
def register_user(self, **user_defaults):
|
||||
"""
|
||||
Registers a new dict of user defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param user_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in user_defaults.items():
|
||||
try:
|
||||
self._register_user(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default user key.")
|
||||
|
||||
def _register_user(self, key, default=None):
|
||||
"""Registers a user config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["USER"][key] = default
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
"""
|
||||
Config object created by `Config.get_conf()`
|
||||
|
||||
This configuration object is designed to make backend data
|
||||
storage mechanisms pluggable. It also is designed to
|
||||
help a cog developer make fewer mistakes (such as
|
||||
typos) when dealing with cog data and to make those mistakes
|
||||
apparent much faster in the design process.
|
||||
|
||||
It also has the capability to safely store data between cogs
|
||||
that share the same name.
|
||||
|
||||
There are two main components to this config object. First,
|
||||
you have the ability to get data on a level specific basis.
|
||||
The seven levels available are: global, guild, channel, role,
|
||||
member, user, and misc.
|
||||
|
||||
The second main component is registering default values for
|
||||
data in each of the levels. This functionality is OPTIONAL
|
||||
and must be explicitly enabled when creating the Config object
|
||||
using the kwarg `force_registration=True`.
|
||||
|
||||
Basic Usage:
|
||||
Creating a Config object:
|
||||
Use the `Config.get_conf()` class method to create new
|
||||
Config objects.
|
||||
|
||||
See the `Config.get_conf()` documentation for more
|
||||
information.
|
||||
|
||||
Registering Default Values (optional):
|
||||
You can register default values for data at all levels
|
||||
EXCEPT misc.
|
||||
|
||||
Simply pass in the key/value pairs as keyword arguments to
|
||||
the respective function.
|
||||
|
||||
e.g.: conf_obj.register_global(enabled=True)
|
||||
conf_obj.register_guild(likes_red=True)
|
||||
|
||||
Retrieving data by attributes:
|
||||
Since I registered the "enabled" key in the previous example
|
||||
at the global level I can now do:
|
||||
|
||||
conf_obj.enabled()
|
||||
|
||||
which will retrieve the current value of the "enabled"
|
||||
key, making use of the default of "True". I can also do
|
||||
the same for the guild key "likes_red":
|
||||
|
||||
conf_obj.guild(guild_obj).likes_red()
|
||||
|
||||
If I elected to not register default values, you can provide them
|
||||
when you try to access the key:
|
||||
|
||||
conf_obj.no_default(default=True)
|
||||
|
||||
However if you do not provide a default and you do not register
|
||||
defaults, accessing the attribute will return "None".
|
||||
|
||||
Saving data:
|
||||
This is accomplished by using the `set` function available at
|
||||
every level.
|
||||
|
||||
e.g.: conf_obj.set("enabled", False)
|
||||
conf_obj.guild(guild_obj).set("likes_red", False)
|
||||
|
||||
If `force_registration` was enabled when the config object
|
||||
was created you will only be allowed to save keys that you
|
||||
have registered.
|
||||
|
||||
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
|
||||
respectively.
|
||||
"""
|
||||
|
||||
def __getattr__(self, key) -> Callable:
|
||||
"""
|
||||
Until I've got a better way to do this I'm just gonna fake __call__
|
||||
|
||||
:param key:
|
||||
:return: lambda function with kwarg
|
||||
"""
|
||||
return self._get_value_from_key(key)
|
||||
|
||||
def _get_value_from_key(self, key) -> Callable:
|
||||
try:
|
||||
default = self.defaults[self.collection][key]
|
||||
except KeyError as e:
|
||||
if self.force_registration:
|
||||
raise AttributeError("Key '{}' not registered!".format(key)) from e
|
||||
default = None
|
||||
|
||||
self.curr_key = key
|
||||
|
||||
if self.collection != "MEMBER":
|
||||
ret = lambda default=default: self.driver_getmap[self.collection](
|
||||
self.cog_name, self.uuid, self.collection_uuid, key,
|
||||
default=default)
|
||||
else:
|
||||
mid, sid = self.collection_uuid
|
||||
ret = lambda default=default: self.driver.get_member(
|
||||
self.cog_name, self.uuid, mid, sid, key,
|
||||
default=default)
|
||||
return ret
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""
|
||||
Included as an alternative to registering defaults.
|
||||
|
||||
:param key:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
|
||||
try:
|
||||
return getattr(self, key)(default=default)
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
async def set(self, key, value):
|
||||
# Notice to future developers:
|
||||
# This code was commented to allow users to set keys without having to register them.
|
||||
# That being said, if they try to get keys without registering them
|
||||
# things will blow up. I do highly recommend enforcing the key registration.
|
||||
|
||||
if key in self.unsettable_keys or key in self.invalid_keys:
|
||||
raise KeyError("Restricted key name, please use another.")
|
||||
|
||||
if self.force_registration and key not in self.defaults[self.collection]:
|
||||
raise AttributeError("Key '{}' not registered!".format(key))
|
||||
|
||||
if not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
|
||||
if self.collection == "GLOBAL":
|
||||
await self.driver.set_global(self.cog_name, self.uuid, key, value)
|
||||
elif self.collection == "MEMBER":
|
||||
mid, sid = self.collection_uuid
|
||||
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
|
||||
key, value)
|
||||
elif self.collection in self.driver_setmap:
|
||||
func = self.driver_setmap[self.collection]
|
||||
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
|
||||
|
||||
async def clear(self):
|
||||
await self.driver_setmap[self.collection](
|
||||
self.cog_name, self.uuid, self.collection_uuid, None, None,
|
||||
clear=True)
|
||||
|
||||
def guild(self, guild):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "GUILD"
|
||||
new.collection_uuid = guild.id
|
||||
new._driver = None
|
||||
return new
|
||||
|
||||
def channel(self, channel):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "CHANNEL"
|
||||
new.collection_uuid = channel.id
|
||||
new._driver = None
|
||||
return new
|
||||
|
||||
def role(self, role):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "ROLE"
|
||||
new.collection_uuid = role.id
|
||||
new._driver = None
|
||||
return new
|
||||
|
||||
def member(self, member):
|
||||
guild = member.guild
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "MEMBER"
|
||||
new.collection_uuid = (member.id, guild.id)
|
||||
new._driver = None
|
||||
return new
|
||||
|
||||
def user(self, user):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "USER"
|
||||
new.collection_uuid = user.id
|
||||
new._driver = None
|
||||
return new
|
||||
0
core/drivers/__init__.py
Normal file
0
core/drivers/__init__.py
Normal file
45
core/drivers/red_base.py
Normal file
45
core/drivers/red_base.py
Normal file
@ -0,0 +1,45 @@
|
||||
class BaseDriver:
|
||||
def get_global(self, cog_name, ident, collection_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
||||
default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_misc(self, cog_name, ident, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_channel(self, cog_name, ident, channel_id, key, value,
|
||||
clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
|
||||
clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_misc(self, cog_name, ident, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
135
core/drivers/red_json.py
Normal file
135
core/drivers/red_json.py
Normal file
@ -0,0 +1,135 @@
|
||||
from core.json_io import JsonIO
|
||||
import os
|
||||
from .red_base import BaseDriver
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class JSON(BaseDriver):
|
||||
def __init__(self, cog_name, *args, data_path_override: Path=None,
|
||||
file_name_override: str="settings.json", **kwargs):
|
||||
self.cog_name = cog_name
|
||||
self.file_name = file_name_override
|
||||
if data_path_override:
|
||||
self.data_path = data_path_override
|
||||
else:
|
||||
self.data_path = Path.cwd() / 'cogs' / '.data' / self.cog_name
|
||||
|
||||
self.data_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.data_path = self.data_path / self.file_name
|
||||
|
||||
self.jsonIO = JsonIO(self.data_path)
|
||||
|
||||
try:
|
||||
self.data = self.jsonIO._load_json()
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
|
||||
def maybe_add_ident(self, ident: str):
|
||||
if ident in self.data:
|
||||
return
|
||||
|
||||
self.data[ident] = {}
|
||||
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
|
||||
if k not in self.data[ident]:
|
||||
self.data[ident][k] = {}
|
||||
|
||||
self.jsonIO._save_json(self.data)
|
||||
|
||||
def get_global(self, cog_name, ident, _, key, *, default=None):
|
||||
return self.data[ident]["GLOBAL"].get(key, default)
|
||||
|
||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
||||
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {})
|
||||
return guilddata.get(key, default)
|
||||
|
||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
||||
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {})
|
||||
return channeldata.get(key, default)
|
||||
|
||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
||||
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
|
||||
return roledata.get(key, default)
|
||||
|
||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
||||
default=None):
|
||||
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
|
||||
guilddata = userdata.get(str(guild_id), {})
|
||||
return guilddata.get(key, default)
|
||||
|
||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
||||
userdata = self.data[ident]["USER"].get(str(user_id), {})
|
||||
return userdata.get(key, default)
|
||||
|
||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
||||
if clear:
|
||||
self.data[ident]["GLOBAL"] = {}
|
||||
else:
|
||||
self.data[ident]["GLOBAL"][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
||||
guild_id = str(guild_id)
|
||||
if clear:
|
||||
self.data[ident]["GUILD"][guild_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["GUILD"][guild_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["GUILD"][guild_id] = {}
|
||||
self.data[ident]["GUILD"][guild_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
|
||||
channel_id = str(channel_id)
|
||||
if clear:
|
||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
||||
role_id = str(role_id)
|
||||
if clear:
|
||||
self.data[ident]["ROLE"][role_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["ROLE"][role_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["ROLE"][role_id] = {}
|
||||
self.data[ident]["ROLE"][role_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
|
||||
user_id = str(user_id)
|
||||
guild_id = str(guild_id)
|
||||
if clear:
|
||||
self.data[ident]["MEMBER"][user_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
||||
except KeyError:
|
||||
if user_id not in self.data[ident]["MEMBER"]:
|
||||
self.data[ident]["MEMBER"][user_id] = {}
|
||||
if guild_id not in self.data[ident]["MEMBER"][user_id]:
|
||||
self.data[ident]["MEMBER"][user_id][guild_id] = {}
|
||||
|
||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
||||
user_id = str(user_id)
|
||||
if clear:
|
||||
self.data[ident]["USER"][user_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["USER"][user_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["USER"][user_id] = {}
|
||||
self.data[ident]["USER"][user_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
211
core/drivers/red_mongo.py
Normal file
211
core/drivers/red_mongo.py
Normal file
@ -0,0 +1,211 @@
|
||||
import pymongo as m
|
||||
from .red_base import BaseDriver
|
||||
|
||||
|
||||
class RedMongoException(Exception):
|
||||
"""Base Red Mongo Exception class"""
|
||||
pass
|
||||
|
||||
|
||||
class MultipleMatches(RedMongoException):
|
||||
"""Raised when multiple documents match a single cog_name and
|
||||
cog_identifier pair."""
|
||||
pass
|
||||
|
||||
|
||||
class MissingCollection(RedMongoException):
|
||||
"""Raised when a collection is missing from the mongo db"""
|
||||
pass
|
||||
|
||||
|
||||
class Mongo(BaseDriver):
|
||||
def __init__(self, host, port=27017, admin_user=None, admin_pass=None,
|
||||
**kwargs):
|
||||
self.conn = m.MongoClient(host=host, port=port, **kwargs)
|
||||
|
||||
self.admin_user = admin_user
|
||||
self.admin_pass = admin_pass
|
||||
|
||||
self._db = self.conn.red
|
||||
if self.admin_user is not None and self.admin_pass is not None:
|
||||
self._db.authenticate(self.admin_user, self.admin_pass)
|
||||
|
||||
self._global = self._db.GLOBAL
|
||||
self._guild = self._db.GUILD
|
||||
self._channel = self._db.CHANNEL
|
||||
self._role = self._db.ROLE
|
||||
self._member = self._db.MEMBER
|
||||
self._user = self._db.USER
|
||||
|
||||
def get_global(self, cog_name, cog_identifier, _, key, *, default=None):
|
||||
doc = self._global.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the GLOBAL"
|
||||
" level: ({}, {})".format(cog_name,
|
||||
cog_identifier))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
return default
|
||||
|
||||
def get_guild(self, cog_name, cog_identifier, guild_id, key, *,
|
||||
default=None):
|
||||
doc = self._guild.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"guild_id": guild_id},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the GUILD"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, guild_id))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
return default
|
||||
|
||||
def get_channel(self, cog_name, cog_identifier, channel_id, key, *,
|
||||
default=None):
|
||||
doc = self._channel.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"channel_id": channel_id},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the CHANNEL"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, channel_id))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
return default
|
||||
|
||||
def get_role(self, cog_name, cog_identifier, role_id, key, *,
|
||||
default=None):
|
||||
doc = self._role.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"role_id": role_id},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the ROLE"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, role_id))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
return default
|
||||
|
||||
def get_member(self, cog_name, cog_identifier, user_id, guild_id, key, *,
|
||||
default=None):
|
||||
doc = self._member.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"user_id": user_id, "guild_id": guild_id},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the MEMBER"
|
||||
" level: ({}, {}, mid {}, sid {})".format(
|
||||
cog_name, cog_identifier, user_id,
|
||||
guild_id))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
return default
|
||||
|
||||
def get_user(self, cog_name, cog_identifier, user_id, key, *,
|
||||
default=None):
|
||||
doc = self._user.find(
|
||||
{"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"user_id": user_id},
|
||||
projection=[key, ], batch_size=2)
|
||||
if doc.count() == 2:
|
||||
raise MultipleMatches("Too many matching documents at the USER"
|
||||
" level: ({}, {}, mid {})".format(
|
||||
cog_name, cog_identifier, user_id))
|
||||
elif doc.count() == 1:
|
||||
return doc[0].get(key, default)
|
||||
else:
|
||||
return default
|
||||
|
||||
def set_global(self, cog_name, cog_identifier, key, value, clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier}
|
||||
data = {"$set": {key: value}}
|
||||
if self._global.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the GLOBAL"
|
||||
" level: ({}, {})".format(cog_name,
|
||||
cog_identifier))
|
||||
else:
|
||||
if clear:
|
||||
self._global.delete_one(filter)
|
||||
else:
|
||||
self._global.update_one(filter, data, upsert=True)
|
||||
|
||||
def set_guild(self, cog_name, cog_identifier, guild_id, key, value,
|
||||
clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"guild_id": guild_id}
|
||||
data = {"$set": {key: value}}
|
||||
if self._guild.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the GUILD"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, guild_id))
|
||||
else:
|
||||
if clear:
|
||||
self._guild.delete_one(filter)
|
||||
else:
|
||||
self._guild.update_one(filter, data, upsert=True)
|
||||
|
||||
def set_channel(self, cog_name, cog_identifier, channel_id, key, value,
|
||||
clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"channel_id": channel_id}
|
||||
data = {"$set": {key: value}}
|
||||
if self._channel.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the CHANNEL"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, channel_id))
|
||||
else:
|
||||
if clear:
|
||||
self._channel.delete_one(filter)
|
||||
else:
|
||||
self._channel.update_one(filter, data, upsert=True)
|
||||
|
||||
def set_role(self, cog_name, cog_identifier, role_id, key, value,
|
||||
clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"role_id": role_id}
|
||||
data = {"$set": {key: value}}
|
||||
if self._role.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the ROLE"
|
||||
" level: ({}, {}, {})".format(
|
||||
cog_name, cog_identifier, role_id))
|
||||
else:
|
||||
if clear:
|
||||
self._role.delete_one(filter)
|
||||
else:
|
||||
self._role.update_one(filter, data, upsert=True)
|
||||
|
||||
def set_member(self, cog_name, cog_identifier, user_id, guild_id, key,
|
||||
value, clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"guild_id": guild_id, "user_id": user_id}
|
||||
data = {"$set": {key: value}}
|
||||
if self._member.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the MEMBER"
|
||||
" level: ({}, {}, mid {}, sid {})".format(
|
||||
cog_name, cog_identifier, user_id,
|
||||
guild_id))
|
||||
else:
|
||||
if clear:
|
||||
self._member.delete_one(filter)
|
||||
else:
|
||||
self._member.update_one(filter, data, upsert=True)
|
||||
|
||||
def set_user(self, cog_name, cog_identifier, user_id, key, value,
|
||||
clear=False):
|
||||
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
|
||||
"user_id": user_id}
|
||||
data = {"$set": {key: value}}
|
||||
if self._user.count(filter) > 1:
|
||||
raise MultipleMatches("Too many matching documents at the USER"
|
||||
" level: ({}, {}, mid {})".format(
|
||||
cog_name, cog_identifier, user_id))
|
||||
else:
|
||||
if clear:
|
||||
self._user.delete_one(filter)
|
||||
else:
|
||||
self._user.update_one(filter, data, upsert=True)
|
||||
@ -32,7 +32,7 @@ def init_events(bot, cli_flags):
|
||||
if cli_flags.no_cogs is False:
|
||||
print("Loading packages...")
|
||||
failed = []
|
||||
packages = bot.db.get_global("packages", [])
|
||||
packages = bot.db.packages()
|
||||
|
||||
for package in packages:
|
||||
try:
|
||||
@ -69,7 +69,7 @@ def init_events(bot, cli_flags):
|
||||
print("\nInvite URL: {}\n".format(invite_url))
|
||||
|
||||
@bot.event
|
||||
async def on_command_error(error, ctx):
|
||||
async def on_command_error(ctx, error):
|
||||
if isinstance(error, commands.MissingRequiredArgument):
|
||||
await bot.send_cmd_help(ctx)
|
||||
elif isinstance(error, commands.BadArgument):
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from discord.ext import commands
|
||||
|
||||
|
||||
def init_global_checks(bot):
|
||||
|
||||
@bot.check
|
||||
@ -5,20 +8,19 @@ def init_global_checks(bot):
|
||||
if await bot.is_owner(ctx.author):
|
||||
return True
|
||||
|
||||
if bot.db.get_global("whitelist", []):
|
||||
return ctx.author.id in bot.db.get_global("whitelist", [])
|
||||
if bot.db.whitelist():
|
||||
return ctx.author.id in bot.db.whitelist()
|
||||
|
||||
return ctx.author.id not in bot.db.get_global("blacklist", [])
|
||||
return ctx.author.id not in bot.db.blacklist()
|
||||
|
||||
@bot.check
|
||||
async def local_perms(ctx):
|
||||
async def local_perms(ctx: commands.Context):
|
||||
if await bot.is_owner(ctx.author):
|
||||
return True
|
||||
elif ctx.message.guild is None:
|
||||
return True
|
||||
guild_perms = bot.db.get_all(ctx.guild, {})
|
||||
local_blacklist = guild_perms.get("blacklist", [])
|
||||
local_whitelist = guild_perms.get("whitelist", [])
|
||||
local_blacklist = bot.db.guild(ctx.guild).blacklist()
|
||||
local_whitelist = bot.db.guild(ctx.guild).whitelist()
|
||||
|
||||
if local_whitelist:
|
||||
return ctx.author.id in local_whitelist
|
||||
|
||||
@ -7,7 +7,7 @@ from uuid import uuid4
|
||||
|
||||
# This is basically our old DataIO, except that it's now threadsafe
|
||||
# and just a base for much more elaborate classes
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger("red")
|
||||
|
||||
@ -17,25 +17,33 @@ MINIFIED = {"sort_keys": True, "separators": (',', ':')}
|
||||
|
||||
class JsonIO:
|
||||
"""Basic functions for atomic saving / loading of json files"""
|
||||
_lock = asyncio.Lock()
|
||||
def __init__(self, path: Path=Path.cwd()):
|
||||
"""
|
||||
:param path: Full path to file.
|
||||
"""
|
||||
self._lock = asyncio.Lock()
|
||||
self.path = path
|
||||
|
||||
def _save_json(self, path, data, settings=PRETTY):
|
||||
log.debug("Saving file {}".format(path))
|
||||
filename, _ = os.path.splitext(path)
|
||||
# noinspection PyUnresolvedReferences
|
||||
def _save_json(self, data, settings=PRETTY):
|
||||
log.debug("Saving file {}".format(self.path))
|
||||
filename = self.path.stem
|
||||
tmp_file = "{}-{}.tmp".format(filename, uuid4().fields[0])
|
||||
with open(tmp_file, encoding="utf-8", mode="w") as f:
|
||||
tmp_path = self.path.parent / tmp_file
|
||||
with tmp_path.open(encoding="utf-8", mode="w") as f:
|
||||
json.dump(data, f, **settings)
|
||||
os.replace(tmp_file, path)
|
||||
tmp_path.replace(self.path)
|
||||
|
||||
async def _threadsafe_save_json(self, path, data, settings=PRETTY):
|
||||
async def _threadsafe_save_json(self, data, settings=PRETTY):
|
||||
loop = asyncio.get_event_loop()
|
||||
func = functools.partial(self._save_json, path, data, settings)
|
||||
func = functools.partial(self._save_json, data, settings)
|
||||
with await self._lock:
|
||||
await loop.run_in_executor(None, func)
|
||||
|
||||
def _load_json(self, path):
|
||||
log.debug("Reading file {}".format(path))
|
||||
with open(path, encoding='utf-8', mode="r") as f:
|
||||
# noinspection PyUnresolvedReferences
|
||||
def _load_json(self):
|
||||
log.debug("Reading file {}".format(self.path))
|
||||
with self.path.open(encoding='utf-8', mode="r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
124
core/settings.py
124
core/settings.py
@ -1,124 +0,0 @@
|
||||
from core.utils.helpers import JsonGuildDB
|
||||
import discord
|
||||
import argparse
|
||||
|
||||
|
||||
class CoreDB(JsonGuildDB):
|
||||
"""
|
||||
The central DB used by Red to store a variety
|
||||
of settings, both global and guild specific
|
||||
"""
|
||||
|
||||
def get_admin_role(self, guild):
|
||||
"""Returns the guild's admin role
|
||||
|
||||
Returns None if not set or if the role
|
||||
couldn't be retrieved"""
|
||||
_id = self.get_all(guild, {}).get("admin_role", None)
|
||||
return discord.utils.get(guild.roles, id=_id)
|
||||
|
||||
def get_mod_role(self, guild):
|
||||
"""Returns the guild's mod role
|
||||
|
||||
Returns None if not set or if the role
|
||||
couldn't be retrieved"""
|
||||
_id = self.get_all(guild, {}).get("mod_role", None)
|
||||
return discord.utils.get(guild.roles, id=_id)
|
||||
|
||||
async def set_admin_role(self, role):
|
||||
"""Sets the admin role for the guild"""
|
||||
if not isinstance(role, discord.Role):
|
||||
raise TypeError("A valid Discord role must be passed.")
|
||||
await self.set(role.guild, "admin_role", role.id)
|
||||
|
||||
async def set_mod_role(self, role):
|
||||
"""Sets the mod role for the guild"""
|
||||
if not isinstance(role, discord.Role):
|
||||
raise TypeError("A valid Discord role must be passed.")
|
||||
await self.set(role.guild, "mod_role", role.id)
|
||||
|
||||
def get_global_whitelist(self):
|
||||
"""Returns the global whitelist"""
|
||||
return self.get_global("whitelist", [])
|
||||
|
||||
def get_global_blacklist(self):
|
||||
"""Returns the global whitelist"""
|
||||
return self.get_global("blacklist", [])
|
||||
|
||||
async def set_global_whitelist(self, whitelist):
|
||||
"""Sets the global whitelist"""
|
||||
if not isinstance(list, whitelist):
|
||||
raise TypeError("A list of IDs must be passed.")
|
||||
await self.set_global("whitelist", whitelist)
|
||||
|
||||
async def set_global_blacklist(self, blacklist):
|
||||
"""Sets the global blacklist"""
|
||||
if not isinstance(list, blacklist):
|
||||
raise TypeError("A list of IDs must be passed.")
|
||||
await self.set_global("blacklist", blacklist)
|
||||
|
||||
def get_guild_whitelist(self, guild):
|
||||
"""Returns the guild's whitelist"""
|
||||
return self.get(guild, "whitelist", [])
|
||||
|
||||
def get_guild_blacklist(self, guild):
|
||||
"""Returns the guild's blacklist"""
|
||||
return self.get(guild, "blacklist", [])
|
||||
|
||||
async def set_guild_whitelist(self, guild, whitelist):
|
||||
"""Sets the guild's whitelist"""
|
||||
if not isinstance(guild, discord.Guild) or not isinstance(whitelist, list):
|
||||
raise TypeError("A valid Discord guild and a list of IDs "
|
||||
"must be passed.")
|
||||
await self.set(guild, "whitelist", whitelist)
|
||||
|
||||
async def set_guild_blacklist(self, guild, blacklist):
|
||||
"""Sets the guild's blacklist"""
|
||||
if not isinstance(guild, discord.Guild) or not isinstance(blacklist, list):
|
||||
raise TypeError("A valid Discord guild and a list of IDs "
|
||||
"must be passed.")
|
||||
await self.set(guild, "blacklist", blacklist)
|
||||
|
||||
|
||||
def parse_cli_flags():
|
||||
parser = argparse.ArgumentParser(description="Red - Discord Bot")
|
||||
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
|
||||
"Red should be owner, this has "
|
||||
"security implications")
|
||||
parser.add_argument("--prefix", "-p", action="append",
|
||||
help="Global prefix. Can be multiple")
|
||||
parser.add_argument("--no-prompt",
|
||||
action="store_true",
|
||||
help="Disables console inputs. Features requiring "
|
||||
"console interaction could be disabled as a "
|
||||
"result")
|
||||
parser.add_argument("--no-cogs",
|
||||
action="store_true",
|
||||
help="Starts Red with no cogs loaded, only core")
|
||||
parser.add_argument("--self-bot",
|
||||
action='store_true',
|
||||
help="Specifies if Red should log in as selfbot")
|
||||
parser.add_argument("--not-bot",
|
||||
action='store_true',
|
||||
help="Specifies if the token used belongs to a bot "
|
||||
"account.")
|
||||
parser.add_argument("--dry-run",
|
||||
action="store_true",
|
||||
help="Makes Red quit with code 0 just before the "
|
||||
"login. This is useful for testing the boot "
|
||||
"process.")
|
||||
parser.add_argument("--debug",
|
||||
action="store_true",
|
||||
help="Sets the loggers level as debug")
|
||||
parser.add_argument("--dev",
|
||||
action="store_true",
|
||||
help="Enables developer mode")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.prefix:
|
||||
args.prefix = sorted(args.prefix, reverse=True)
|
||||
else:
|
||||
args.prefix = []
|
||||
|
||||
return args
|
||||
@ -1,216 +0,0 @@
|
||||
import os
|
||||
import discord
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from core.json_io import JsonIO
|
||||
|
||||
|
||||
GLOBAL_KEY = '__global__'
|
||||
SENTINEL = object()
|
||||
|
||||
|
||||
class JsonDB(JsonIO):
|
||||
"""
|
||||
A DB-like helper class to streamline the saving of json files
|
||||
|
||||
Parameters:
|
||||
|
||||
file_path: str
|
||||
The path of the json file you want to create / access
|
||||
create_dirs: bool=True
|
||||
If True, it will create any missing directory leading to
|
||||
the file you want to create
|
||||
relative_path: bool=True
|
||||
The file_path you specified is relative to the path from which
|
||||
you're instantiating this object from
|
||||
i.e. If you're in a package's folder and your file_path is
|
||||
'data/settings.json', these files will be created inside
|
||||
the package's folder and not Red's root folder
|
||||
default_value: Optional=None
|
||||
Same behaviour as a defaultdict
|
||||
"""
|
||||
_caller = ""
|
||||
|
||||
def __init__(self, file_path, **kwargs):
|
||||
local = kwargs.pop("relative_path", True)
|
||||
if local and not self._caller:
|
||||
self._caller = self._get_caller_path()
|
||||
|
||||
create_dirs = kwargs.pop("create_dirs", True)
|
||||
default_value = kwargs.pop("default_value", SENTINEL)
|
||||
self.autosave = kwargs.pop("autosave", False)
|
||||
self.path = os.path.join(self._caller, file_path)
|
||||
|
||||
file_exists = os.path.isfile(self.path)
|
||||
|
||||
if create_dirs and not file_exists:
|
||||
path, _ = os.path.split(self.path)
|
||||
if path:
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except FileExistsError:
|
||||
pass
|
||||
|
||||
if file_exists:
|
||||
# Might be worth looking into threadsafe ways for very large files
|
||||
self._data = self._load_json(self.path)
|
||||
else:
|
||||
self._data = {}
|
||||
self._blocking_save()
|
||||
|
||||
if default_value is not SENTINEL:
|
||||
def _get_default():
|
||||
return default_value
|
||||
self._data = defaultdict(_get_default, self._data)
|
||||
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._task = functools.partial(self._threadsafe_save_json, self._data)
|
||||
|
||||
async def set(self, key, value):
|
||||
"""Sets a DB's entry"""
|
||||
self._data[key] = value
|
||||
await self.save()
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Returns a DB's entry"""
|
||||
return self._data.get(key, default)
|
||||
|
||||
async def remove(self, key):
|
||||
"""Removes a DB's entry"""
|
||||
del self._data[key]
|
||||
await self.save()
|
||||
|
||||
async def pop(self, key, default=None):
|
||||
"""Removes and returns a DB's entry"""
|
||||
value = self._data.pop(key, default)
|
||||
await self.save()
|
||||
return value
|
||||
|
||||
async def wipe(self):
|
||||
"""Wipes DB"""
|
||||
self._data = {}
|
||||
await self.save()
|
||||
|
||||
def all(self):
|
||||
"""Returns all DB's data"""
|
||||
return self._data
|
||||
|
||||
def _blocking_save(self):
|
||||
"""Using this should be avoided. Let's stick to threadsafe saves"""
|
||||
self._save_json(self.path, self._data)
|
||||
|
||||
async def save(self):
|
||||
"""Threadsafe save to file"""
|
||||
await self._threadsafe_save_json(self.path, self._data)
|
||||
|
||||
def _get_caller_path(self):
|
||||
frame = inspect.stack()[2]
|
||||
module = inspect.getmodule(frame[0])
|
||||
abspath = os.path.abspath(module.__file__)
|
||||
return os.path.dirname(abspath)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._data
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __repr__(self):
|
||||
return "<{} {}>".format(self.__class__.__name__, self._data)
|
||||
|
||||
|
||||
class JsonGuildDB(JsonDB):
|
||||
"""
|
||||
A DB-like helper class to streamline the saving of json files
|
||||
This is a variant of JsonDB that allows for guild specific data
|
||||
Global data is still allowed with dedicated methods
|
||||
|
||||
Same parameters as JsonDB
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
local = kwargs.get("relative_path", True)
|
||||
if local and not self._caller:
|
||||
self._caller = self._get_caller_path()
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def set(self, guild, key, value):
|
||||
"""Sets a guild's entry"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only set guild data')
|
||||
if str(guild.id) not in self._data:
|
||||
self._data[str(guild.id)] = {}
|
||||
self._data[str(guild.id)][key] = value
|
||||
await self.save()
|
||||
|
||||
def get(self, guild, key, default=None):
|
||||
"""Returns a guild's entry"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only get guild data')
|
||||
if str(guild.id) not in self._data:
|
||||
return default
|
||||
return self._data[str(guild.id)].get(key, default)
|
||||
|
||||
async def remove(self, guild, key):
|
||||
"""Removes a guild's entry"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only remove guild data')
|
||||
if str(guild.id) not in self._data:
|
||||
raise KeyError('Guild data is not present')
|
||||
del self._data[str(guild.id)][key]
|
||||
await self.save()
|
||||
|
||||
async def pop(self, guild, key, default=None):
|
||||
"""Removes and returns a guild's entry"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only remove guild data')
|
||||
value = self._data.get(str(guild.id), {}).pop(key, default)
|
||||
await self.save()
|
||||
return value
|
||||
|
||||
def get_all(self, guild, default=None):
|
||||
"""Returns all entries of a guild"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only get guild data')
|
||||
return self._data.get(str(guild.id), default)
|
||||
|
||||
async def remove_all(self, guild):
|
||||
"""Removes all entries of a guild"""
|
||||
if not isinstance(guild, discord.Guild):
|
||||
raise TypeError('Can only remove guilds')
|
||||
await super().remove(str(guild.id))
|
||||
|
||||
async def set_global(self, key, value):
|
||||
"""Sets a global value"""
|
||||
if GLOBAL_KEY not in self._data:
|
||||
self._data[GLOBAL_KEY] = {}
|
||||
self._data[GLOBAL_KEY][key] = value
|
||||
await self.save()
|
||||
|
||||
def get_global(self, key, default=None):
|
||||
"""Gets a global value"""
|
||||
if GLOBAL_KEY not in self._data:
|
||||
self._data[GLOBAL_KEY] = {}
|
||||
|
||||
return self._data[GLOBAL_KEY].get(key, default)
|
||||
|
||||
async def remove_global(self, key):
|
||||
"""Removes a global value"""
|
||||
if GLOBAL_KEY not in self._data:
|
||||
self._data[GLOBAL_KEY] = {}
|
||||
del self._data[GLOBAL_KEY][key]
|
||||
await self.save()
|
||||
|
||||
async def pop_global(self, key, default=None):
|
||||
"""Removes and returns a global value"""
|
||||
if GLOBAL_KEY not in self._data:
|
||||
self._data[GLOBAL_KEY] = {}
|
||||
value = self._data[GLOBAL_KEY].pop(key, default)
|
||||
await self.save()
|
||||
return value
|
||||
11
main.py
11
main.py
@ -1,8 +1,7 @@
|
||||
from core.bot import Red, ExitCodes
|
||||
from core.global_checks import init_global_checks
|
||||
from core.events import init_events
|
||||
from core.settings import parse_cli_flags
|
||||
from core.cli import interactive_config, confirm
|
||||
from core.cli import interactive_config, confirm, parse_cli_flags
|
||||
from core.core_commands import Core
|
||||
from core.dev_commands import Dev
|
||||
import asyncio
|
||||
@ -65,8 +64,8 @@ if __name__ == '__main__':
|
||||
if cli_flags.dev:
|
||||
red.add_cog(Dev())
|
||||
|
||||
token = os.environ.get("RED_TOKEN", red.db.get_global("token", None))
|
||||
prefix = cli_flags.prefix or red.db.get_global("prefix", [])
|
||||
token = os.environ.get("RED_TOKEN", red.db.token())
|
||||
prefix = cli_flags.prefix or red.db.prefix()
|
||||
|
||||
if token is None or not prefix:
|
||||
if cli_flags.no_prompt is False:
|
||||
@ -89,11 +88,11 @@ if __name__ == '__main__':
|
||||
"a user account, remember that the --not-bot flag "
|
||||
"must be used. For self-bot functionalities instead, "
|
||||
"--self-bot")
|
||||
db_token = red.db.get_global("token")
|
||||
db_token = red.db.token()
|
||||
if db_token and not cli_flags.no_prompt:
|
||||
print("\nDo you want to reset the token? (y/n)")
|
||||
if confirm("> "):
|
||||
loop.run_until_complete(red.db.remove_global("token"))
|
||||
loop.run_until_complete(red.db.set("token", ""))
|
||||
print("Token has been reset.")
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt detected. Quitting...")
|
||||
|
||||
@ -1,2 +1,5 @@
|
||||
git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py[voice]
|
||||
youtube_dl
|
||||
pytest
|
||||
git+https://github.com/pytest-dev/pytest-asyncio
|
||||
pymongo
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
101
tests/conftest.py
Normal file
101
tests/conftest.py
Normal file
@ -0,0 +1,101 @@
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import random
|
||||
|
||||
from core.bot import Red
|
||||
from core.drivers import red_json
|
||||
from core import Config
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def json_driver(tmpdir_factory):
|
||||
driver = red_json.JSON(
|
||||
"PyTest",
|
||||
data_path_override=Path(str(tmpdir_factory.getbasetemp()))
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config(json_driver):
|
||||
return Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=0,
|
||||
driver_spawn=json_driver)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def config_fr(json_driver):
|
||||
"""
|
||||
Mocked config object with force_register enabled.
|
||||
"""
|
||||
return Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=0,
|
||||
driver_spawn=json_driver,
|
||||
force_registration=True
|
||||
)
|
||||
|
||||
|
||||
#region Dpy Mocks
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_guild():
|
||||
mock_guild = namedtuple("Guild", "id members")
|
||||
return mock_guild(random.randint(1, 999999999), [])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_channel():
|
||||
mock_channel = namedtuple("Channel", "id")
|
||||
return mock_channel(random.randint(1, 999999999))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_role():
|
||||
mock_role = namedtuple("Role", "id")
|
||||
return mock_role(random.randint(1, 999999999))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_member(empty_guild):
|
||||
mock_member = namedtuple("Member", "id guild")
|
||||
return mock_member(random.randint(1, 999999999), empty_guild)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_user():
|
||||
mock_user = namedtuple("User", "id")
|
||||
return mock_user(random.randint(1, 999999999))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_message():
|
||||
mock_msg = namedtuple("Message", "content")
|
||||
return mock_msg("No content.")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ctx(empty_member, empty_channel, red):
|
||||
mock_ctx = namedtuple("Context", "author guild channel message bot")
|
||||
return mock_ctx(empty_member, empty_member.guild, empty_channel,
|
||||
empty_message, red)
|
||||
#endregion
|
||||
|
||||
|
||||
#region Red Mock
|
||||
@pytest.fixture
|
||||
def red(monkeypatch, config_fr, event_loop):
|
||||
from core.cli import parse_cli_flags
|
||||
cli_flags = parse_cli_flags()
|
||||
|
||||
description = "Red v3 - Alpha"
|
||||
|
||||
monkeypatch.setattr("core.config.Config.get_core_conf",
|
||||
lambda *args, **kwargs: config_fr)
|
||||
|
||||
red = Red(cli_flags, description=description, pm_help=None,
|
||||
loop=event_loop)
|
||||
return red
|
||||
#endregion
|
||||
0
tests/core/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
147
tests/core/test_config.py
Normal file
147
tests/core/test_config.py
Normal file
@ -0,0 +1,147 @@
|
||||
import pytest
|
||||
|
||||
|
||||
#region Register Tests
|
||||
def test_config_register_global(config):
|
||||
config.register_global(enabled=False)
|
||||
assert config.defaults["GLOBAL"]["enabled"] is False
|
||||
assert config.enabled() is False
|
||||
|
||||
|
||||
def test_config_register_global_badvalues(config):
|
||||
with pytest.raises(RuntimeError):
|
||||
config.register_global(**{"invalid var name": True})
|
||||
|
||||
|
||||
def test_config_register_guild(config, empty_guild):
|
||||
config.register_guild(enabled=False, some_list=[], some_dict={})
|
||||
assert config.defaults["GUILD"]["enabled"] is False
|
||||
assert config.defaults["GUILD"]["some_list"] == []
|
||||
assert config.defaults["GUILD"]["some_dict"] == {}
|
||||
|
||||
assert config.guild(empty_guild).enabled() is False
|
||||
assert config.guild(empty_guild).some_list() == []
|
||||
assert config.guild(empty_guild).some_dict() == {}
|
||||
|
||||
|
||||
def test_config_register_channel(config, empty_channel):
|
||||
config.register_channel(enabled=False)
|
||||
assert config.defaults["CHANNEL"]["enabled"] is False
|
||||
assert config.channel(empty_channel).enabled() is False
|
||||
|
||||
|
||||
def test_config_register_role(config, empty_role):
|
||||
config.register_role(enabled=False)
|
||||
assert config.defaults["ROLE"]["enabled"] is False
|
||||
assert config.role(empty_role).enabled() is False
|
||||
|
||||
|
||||
def test_config_register_member(config, empty_member):
|
||||
config.register_member(some_number=-1)
|
||||
assert config.defaults["MEMBER"]["some_number"] == -1
|
||||
assert config.member(empty_member).some_number() == -1
|
||||
|
||||
|
||||
def test_config_register_user(config, empty_user):
|
||||
config.register_user(some_value=None)
|
||||
assert config.defaults["USER"]["some_value"] is None
|
||||
assert config.user(empty_user).some_value() is None
|
||||
|
||||
|
||||
def test_config_force_register_global(config_fr):
|
||||
with pytest.raises(AttributeError):
|
||||
config_fr.enabled()
|
||||
|
||||
config_fr.register_global(enabled=True)
|
||||
assert config_fr.enabled() is True
|
||||
#endregion
|
||||
|
||||
|
||||
#region Default Value Overrides
|
||||
def test_global_default_override(config):
|
||||
assert config.enabled(True) is True
|
||||
assert config.get("enabled") is None
|
||||
assert config.get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_global_default_nofr(config):
|
||||
assert config.nofr() is None
|
||||
assert config.nofr(True) is True
|
||||
assert config.get("nofr") is None
|
||||
assert config.get("nofr", default=True) is True
|
||||
|
||||
|
||||
def test_guild_default_override(config, empty_guild):
|
||||
assert config.guild(empty_guild).enabled(True) is True
|
||||
assert config.guild(empty_guild).get("enabled") is None
|
||||
assert config.guild(empty_guild).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_channel_default_override(config, empty_channel):
|
||||
assert config.channel(empty_channel).enabled(True) is True
|
||||
assert config.channel(empty_channel).get("enabled") is None
|
||||
assert config.channel(empty_channel).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_role_default_override(config, empty_role):
|
||||
assert config.role(empty_role).enabled(True) is True
|
||||
assert config.role(empty_role).get("enabled") is None
|
||||
assert config.role(empty_role).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_member_default_override(config, empty_member):
|
||||
assert config.member(empty_member).enabled(True) is True
|
||||
assert config.member(empty_member).get("enabled") is None
|
||||
assert config.member(empty_member).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_user_default_override(config, empty_user):
|
||||
assert config.user(empty_user).some_value(True) is True
|
||||
assert config.user(empty_user).get("some_value") is None
|
||||
assert config.user(empty_user).get("some_value", default=True) is True
|
||||
#endregion
|
||||
|
||||
|
||||
#region Setting Values
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global(config):
|
||||
await config.set("enabled", True)
|
||||
assert config.enabled() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global_badkey(config):
|
||||
with pytest.raises(RuntimeError):
|
||||
await config.set("this is a bad key", True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global_invalidkey(config):
|
||||
with pytest.raises(KeyError):
|
||||
await config.set("uuid", True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_guild(config, empty_guild):
|
||||
await config.guild(empty_guild).set("enabled", True)
|
||||
assert config.guild(empty_guild).enabled() is True
|
||||
|
||||
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
|
||||
assert curr_list == [1, 2, 3]
|
||||
curr_list.append(4)
|
||||
|
||||
await config.guild(empty_guild).set("some_list", curr_list)
|
||||
assert config.guild(empty_guild).some_list() == curr_list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_channel(config, empty_channel):
|
||||
await config.channel(empty_channel).set("enabled", True)
|
||||
assert config.channel(empty_channel).enabled() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_channel_no_register(config, empty_channel):
|
||||
await config.channel(empty_channel).set("no_register", True)
|
||||
assert config.channel(empty_channel).no_register() is True
|
||||
#endregion
|
||||
6
tests/core/test_installation.py
Normal file
6
tests/core/test_installation.py
Normal file
@ -0,0 +1,6 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_init_bot(red):
|
||||
assert red is not None
|
||||
Loading…
x
Reference in New Issue
Block a user