mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[Config] Rewrite (#869)
This commit is contained in:
parent
5c2be25dfc
commit
99bfb2fc7a
@ -1,3 +1,4 @@
|
||||
dist: trusty
|
||||
language: python
|
||||
python:
|
||||
- "3.5.3"
|
||||
|
||||
@ -81,24 +81,24 @@ class Alias:
|
||||
if global_:
|
||||
curr_aliases = self._aliases.entries()
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.set("entries", curr_aliases)
|
||||
await self._aliases.entries.set(curr_aliases)
|
||||
else:
|
||||
curr_aliases = self._aliases.guild(ctx.guild).entries()
|
||||
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.guild(ctx.guild).set("entries", curr_aliases)
|
||||
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
||||
|
||||
await self._aliases.guild(ctx.guild).set("enabled", True)
|
||||
await self._aliases.guild(ctx.guild).enabled.set(True)
|
||||
return alias
|
||||
|
||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||
global_: bool=False) -> bool:
|
||||
if global_:
|
||||
aliases = self.unloaded_global_aliases()
|
||||
setter_func = self._aliases.set
|
||||
setter_func = self._aliases.entries.set
|
||||
else:
|
||||
aliases = self.unloaded_aliases(ctx.guild)
|
||||
setter_func = self._aliases.guild(ctx.guild).set
|
||||
setter_func = self._aliases.guild(ctx.guild).entries.set
|
||||
|
||||
did_delete_alias = False
|
||||
|
||||
@ -110,7 +110,6 @@ class Alias:
|
||||
did_delete_alias = True
|
||||
|
||||
await setter_func(
|
||||
"entries",
|
||||
[a.to_json() for a in to_keep]
|
||||
)
|
||||
|
||||
@ -355,8 +354,9 @@ class Alias:
|
||||
await ctx.send(box("\n".join(names), "diff"))
|
||||
|
||||
async def on_message(self, message: discord.Message):
|
||||
aliases = list(self.unloaded_aliases(message.guild)) + \
|
||||
list(self.unloaded_global_aliases())
|
||||
aliases = list(self.unloaded_global_aliases())
|
||||
if message.guild is not None:
|
||||
aliases = aliases + list(self.unloaded_aliases(message.guild))
|
||||
|
||||
if len(aliases) == 0:
|
||||
return
|
||||
|
||||
@ -30,7 +30,7 @@ class Downloader:
|
||||
def __init__(self, bot: Red):
|
||||
self.bot = bot
|
||||
|
||||
self.conf = Config.get_conf(self, unique_identifier=998240343,
|
||||
self.conf = Config.get_conf(self, identifier=998240343,
|
||||
force_registration=True)
|
||||
|
||||
self.conf.register_global(
|
||||
@ -73,7 +73,7 @@ class Downloader:
|
||||
|
||||
if cog_json not in installed:
|
||||
installed.append(cog_json)
|
||||
await self.conf.set("installed", installed)
|
||||
await self.conf.installed.set(installed)
|
||||
|
||||
async def _remove_from_installed(self, cog: Installable):
|
||||
"""
|
||||
@ -86,7 +86,7 @@ class Downloader:
|
||||
|
||||
if cog_json in installed:
|
||||
installed.remove(cog_json)
|
||||
await self.conf.set("installed", installed)
|
||||
await self.conf.installed.set(installed)
|
||||
|
||||
async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]:
|
||||
"""
|
||||
|
||||
@ -526,4 +526,4 @@ class RepoManager:
|
||||
|
||||
async def _save_repos(self):
|
||||
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
|
||||
await self.downloader_config.set("repos", repo_json_info)
|
||||
await self.downloader_config.repos.set(repo_json_info)
|
||||
|
||||
@ -47,7 +47,7 @@ class Red(commands.Bot):
|
||||
kwargs["owner_id"] = cli_flags.owner
|
||||
|
||||
if "owner_id" not in kwargs:
|
||||
kwargs["owner_id"] = self.db.get("owner")
|
||||
kwargs["owner_id"] = self.db.owner()
|
||||
|
||||
self.counter = Counter()
|
||||
self.uptime = None
|
||||
@ -89,7 +89,7 @@ class Red(commands.Bot):
|
||||
for package in self.extensions:
|
||||
if package.startswith("cogs."):
|
||||
loaded.append(package)
|
||||
await self.db.set("packages", loaded)
|
||||
await self.db.packages.set(loaded)
|
||||
|
||||
|
||||
class ExitCodes(Enum):
|
||||
|
||||
@ -22,7 +22,7 @@ def interactive_config(red, token_set, prefix_set):
|
||||
print("That doesn't look like a valid token.")
|
||||
token = ""
|
||||
if token:
|
||||
loop.run_until_complete(red.db.set("token", token))
|
||||
loop.run_until_complete(red.db.token.set(token))
|
||||
|
||||
if not prefix_set:
|
||||
prefix = ""
|
||||
@ -39,7 +39,7 @@ def interactive_config(red, token_set, prefix_set):
|
||||
if not confirm("> "):
|
||||
prefix = ""
|
||||
if prefix:
|
||||
loop.run_until_complete(red.db.set("prefix", [prefix]))
|
||||
loop.run_until_complete(red.db.prefix.set([prefix]))
|
||||
|
||||
ask_sentry(red)
|
||||
|
||||
@ -55,9 +55,9 @@ def ask_sentry(red: Red):
|
||||
" found issues in a timely manner. If you wish to opt in\n"
|
||||
" the process please type \"yes\":\n")
|
||||
if not confirm("> "):
|
||||
loop.run_until_complete(red.db.set("enable_sentry", False))
|
||||
loop.run_until_complete(red.db.enable_sentry.set(False))
|
||||
else:
|
||||
loop.run_until_complete(red.db.set("enable_sentry", True))
|
||||
loop.run_until_complete(red.db.enable_sentry.set(True))
|
||||
print("\nThank you for helping us with the development process!")
|
||||
|
||||
|
||||
|
||||
803
core/config.py
803
core/config.py
@ -1,521 +1,374 @@
|
||||
from pathlib import Path
|
||||
|
||||
from core.drivers.red_json import JSON as JSONDriver
|
||||
from core.drivers.red_mongo import Mongo
|
||||
import logging
|
||||
|
||||
from typing import Callable
|
||||
from typing import Callable, Union, Tuple
|
||||
|
||||
import discord
|
||||
from copy import deepcopy
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .drivers.red_json import JSON as JSONDriver
|
||||
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
class BaseConfig:
|
||||
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
|
||||
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
|
||||
defaults={}):
|
||||
self.cog_name = cog_name
|
||||
if hash_uuid:
|
||||
self.uuid = str(hash(unique_identifier))
|
||||
else:
|
||||
self.uuid = unique_identifier
|
||||
self.driver_spawn = driver_spawn
|
||||
self._driver = None
|
||||
self.collection = collection
|
||||
self.collection_uuid = collection_uuid
|
||||
|
||||
self.force_registration = force_registration
|
||||
class Value:
|
||||
def __init__(self, identifiers: Tuple[str], default_value, spawner):
|
||||
self._identifiers = identifiers
|
||||
self.default = default_value
|
||||
|
||||
self.spawner = spawner
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return tuple(str(i) for i in self._identifiers)
|
||||
|
||||
def __call__(self, default=None):
|
||||
driver = self.spawner.get_driver()
|
||||
try:
|
||||
self.driver.maybe_add_ident(self.uuid)
|
||||
except AttributeError:
|
||||
pass
|
||||
ret = driver.get(self.identifiers)
|
||||
except KeyError:
|
||||
return default or self.default
|
||||
return ret
|
||||
|
||||
self.driver_getmap = {
|
||||
"GLOBAL": self.driver.get_global,
|
||||
"GUILD": self.driver.get_guild,
|
||||
"CHANNEL": self.driver.get_channel,
|
||||
"ROLE": self.driver.get_role,
|
||||
"USER": self.driver.get_user
|
||||
}
|
||||
async def set(self, value):
|
||||
driver = self.spawner.get_driver()
|
||||
await driver.set(self.identifiers, value)
|
||||
|
||||
self.driver_setmap = {
|
||||
"GLOBAL": self.driver.set_global,
|
||||
"GUILD": self.driver.set_guild,
|
||||
"CHANNEL": self.driver.set_channel,
|
||||
"ROLE": self.driver.set_role,
|
||||
"USER": self.driver.set_user
|
||||
}
|
||||
|
||||
self.curr_key = None
|
||||
class Group(Value):
|
||||
def __init__(self, identifiers: Tuple[str],
|
||||
defaults: dict,
|
||||
spawner,
|
||||
force_registration: bool=False):
|
||||
self.defaults = defaults
|
||||
self.force_registration = force_registration
|
||||
self.spawner = spawner
|
||||
|
||||
self.unsettable_keys = ("cog_name", "cog_identifier", "_id",
|
||||
"guild_id", "channel_id", "role_id",
|
||||
"user_id", "uuid")
|
||||
self.invalid_keys = (
|
||||
"driver_spawn",
|
||||
"_driver", "collection",
|
||||
"collection_uuid", "force_registration"
|
||||
super().__init__(identifiers, {}, self.spawner)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def __getattr__(self, item: str) -> Union["Group", Value]:
|
||||
"""
|
||||
Takes in the next accessible item. If it's found to be a Group
|
||||
we return another Group object. If it's found to be a Value
|
||||
we return a Value object. If it is not found and
|
||||
force_registration is True then we raise AttributeException,
|
||||
otherwise return a Value object.
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
is_group = self.is_group(item)
|
||||
is_value = not is_group and self.is_value(item)
|
||||
new_identifiers = self.identifiers + (item, )
|
||||
if is_group:
|
||||
return Group(
|
||||
identifiers=new_identifiers,
|
||||
defaults=self.defaults[item],
|
||||
spawner=self.spawner,
|
||||
force_registration=self.force_registration
|
||||
)
|
||||
elif is_value:
|
||||
return Value(
|
||||
identifiers=new_identifiers,
|
||||
default_value=self.defaults[item],
|
||||
spawner=self.spawner
|
||||
)
|
||||
elif self.force_registration:
|
||||
raise AttributeError(
|
||||
"'{}' is not a valid registered Group"
|
||||
"or value.".format(item)
|
||||
)
|
||||
else:
|
||||
return Value(
|
||||
identifiers=new_identifiers,
|
||||
default_value=None,
|
||||
spawner=self.spawner
|
||||
)
|
||||
|
||||
@property
|
||||
def _super_group(self) -> 'Group':
|
||||
super_group = Group(
|
||||
self.identifiers[:-1],
|
||||
defaults={},
|
||||
spawner=self.spawner,
|
||||
force_registration=self.force_registration
|
||||
)
|
||||
return super_group
|
||||
|
||||
self.defaults = defaults if defaults else {
|
||||
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {},
|
||||
"MEMBER": {}, "USER": {}}
|
||||
def is_group(self, item: str) -> bool:
|
||||
"""
|
||||
Determines if an attribute access is pointing at a registered group.
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
default = self.defaults.get(item)
|
||||
return isinstance(default, dict)
|
||||
|
||||
def is_value(self, item: str) -> bool:
|
||||
"""
|
||||
Determines if an attribute access is pointing at a registered value.
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
default = self.defaults[item]
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
return not isinstance(default, dict)
|
||||
|
||||
def get_attr(self, item: str, default=None):
|
||||
"""
|
||||
You should avoid this function whenever possible.
|
||||
:param item:
|
||||
:param default:
|
||||
:return:
|
||||
"""
|
||||
value = getattr(self, item)
|
||||
return value(default=default)
|
||||
|
||||
def all(self) -> dict:
|
||||
"""
|
||||
Gets all entries of the given kind. If this kind is member
|
||||
then this method returns all members from the same
|
||||
server.
|
||||
:return:
|
||||
"""
|
||||
# noinspection PyTypeChecker
|
||||
return self._super_group()
|
||||
|
||||
async def set(self, value):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError(
|
||||
"You may only set the value of a group to be a dict."
|
||||
)
|
||||
await super().set(value)
|
||||
|
||||
async def set_attr(self, item: str, value):
|
||||
"""
|
||||
You should avoid this function whenever possible.
|
||||
:param item:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
value_obj = getattr(self, item)
|
||||
await value_obj.set(value)
|
||||
|
||||
async def clear(self):
|
||||
"""
|
||||
Wipes out data for the given entry in this category
|
||||
e.g. Guild/Role/User
|
||||
:return:
|
||||
"""
|
||||
await self.set({})
|
||||
|
||||
async def clear_all(self):
|
||||
"""
|
||||
Removes all data from all entries.
|
||||
:return:
|
||||
"""
|
||||
await self._super_group.set({})
|
||||
|
||||
|
||||
class MemberGroup(Group):
|
||||
@property
|
||||
def _super_group(self) -> Group:
|
||||
new_identifiers = self.identifiers[:2]
|
||||
group_obj = Group(
|
||||
identifiers=new_identifiers,
|
||||
defaults={},
|
||||
spawner=self.spawner
|
||||
)
|
||||
return group_obj
|
||||
|
||||
@property
|
||||
def _guild_group(self) -> Group:
|
||||
new_identifiers = self.identifiers[:3]
|
||||
group_obj = Group(
|
||||
identifiers=new_identifiers,
|
||||
defaults={},
|
||||
spawner=self.spawner
|
||||
)
|
||||
return group_obj
|
||||
|
||||
def all_guilds(self) -> dict:
|
||||
"""
|
||||
Gets a dict of all guilds and members.
|
||||
|
||||
REMEMBER: ID's are stored in these dicts as STRINGS.
|
||||
:return:
|
||||
"""
|
||||
# noinspection PyTypeChecker
|
||||
return self._super_group()
|
||||
|
||||
def all(self) -> dict:
|
||||
"""
|
||||
Returns the dict of all members in the same guild.
|
||||
:return:
|
||||
"""
|
||||
# noinspection PyTypeChecker
|
||||
return self._guild_group()
|
||||
|
||||
class Config:
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
def __init__(self, cog_name: str, unique_identifier: str,
|
||||
driver_spawn: Callable,
|
||||
force_registration: bool=False,
|
||||
defaults: dict=None):
|
||||
self.cog_name = cog_name
|
||||
self.unique_identifier = unique_identifier
|
||||
|
||||
self.spawner = driver_spawn
|
||||
self.force_registration = force_registration
|
||||
self.defaults = defaults or {}
|
||||
|
||||
@classmethod
|
||||
def get_conf(cls, cog_instance: object, unique_identifier: int=0,
|
||||
force_registration: bool=False):
|
||||
def get_conf(cls, cog_instance, identifier: int,
|
||||
force_registration=False):
|
||||
"""
|
||||
Gets a config object that cog's can use to safely store data. The
|
||||
backend to this is totally modular and can easily switch between
|
||||
JSON and a DB. However, when changed, all data will likely be lost
|
||||
unless cogs write some converters for their data.
|
||||
|
||||
Positional Arguments:
|
||||
cog_instance - The cog `self` object, can be passed in from your
|
||||
cog's __init__ method.
|
||||
|
||||
Keyword Arguments:
|
||||
unique_identifier - a random integer or string that is used to
|
||||
differentiate your cog from any other named the same. This way we
|
||||
can safely store data for multiple cogs that are named the same.
|
||||
|
||||
YOU SHOULD USE THIS.
|
||||
|
||||
force_registration - A flag which will cause the Config object to
|
||||
throw exceptions if you try to get/set data keys that you have
|
||||
not pre-registered. I highly recommend you ENABLE this as it
|
||||
will help reduce dumb typo errors.
|
||||
Returns a Config instance based on a simplified set of initial
|
||||
variables.
|
||||
:param cog_instance:
|
||||
:param identifier: Any random integer, used to keep your data
|
||||
distinct from any other cog with the same name.
|
||||
:param force_registration: Should config require registration
|
||||
of data keys before allowing you to get/set values?
|
||||
:return:
|
||||
"""
|
||||
|
||||
url = None # TODO: get mongo url
|
||||
port = None # TODO: get mongo port
|
||||
|
||||
def spawn_mongo_driver():
|
||||
return Mongo(url, port)
|
||||
|
||||
# TODO: Determine which backend users want, default to JSON
|
||||
|
||||
cog_name = cog_instance.__class__.__name__
|
||||
uuid = str(hash(identifier))
|
||||
|
||||
driver_spawn = JSONDriver(cog_name)
|
||||
|
||||
return cls(cog_name=cog_name, unique_identifier=unique_identifier,
|
||||
driver_spawn=driver_spawn, force_registration=force_registration)
|
||||
spawner = JSONDriver(cog_name)
|
||||
return cls(cog_name=cog_name, unique_identifier=uuid,
|
||||
force_registration=force_registration,
|
||||
driver_spawn=spawner)
|
||||
|
||||
@classmethod
|
||||
def get_core_conf(cls, force_registration: bool=False):
|
||||
core_data_path = Path.cwd() / 'core' / '.data'
|
||||
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
|
||||
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
||||
unique_identifier=0,
|
||||
unique_identifier='0',
|
||||
force_registration=force_registration)
|
||||
|
||||
@property
|
||||
def driver(self):
|
||||
if self._driver is None:
|
||||
try:
|
||||
self._driver = self.driver_spawn()
|
||||
except TypeError:
|
||||
return self.driver_spawn
|
||||
def __getattr__(self, item: str) -> Union[Group, Value]:
|
||||
"""
|
||||
This is used to generate Value or Group objects for global
|
||||
values.
|
||||
:param item:
|
||||
:return:
|
||||
"""
|
||||
global_group = self._get_base_group(self.GLOBAL)
|
||||
return getattr(global_group, item)
|
||||
|
||||
return self._driver
|
||||
|
||||
def __getattr__(self, key):
|
||||
"""This should be used to return config key data as determined by
|
||||
`self.collection` and `self.collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if 'defaults' in self.__dict__: # Necessary to let the cog load
|
||||
restricted = list(self.defaults[self.collection].keys()) + \
|
||||
list(self.unsettable_keys)
|
||||
if key in restricted:
|
||||
raise ValueError("Not allowed to dynamically set attributes of"
|
||||
" unsettable_keys: {}".format(restricted))
|
||||
@staticmethod
|
||||
def _get_defaults_dict(key: str, value) -> dict:
|
||||
"""
|
||||
Since we're allowing nested config stuff now, not storing the
|
||||
defaults as a flat dict sounds like a good idea. May turn
|
||||
out to be an awful one but we'll see.
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
"""
|
||||
ret = {}
|
||||
partial = ret
|
||||
splitted = key.split('__')
|
||||
for i, k in enumerate(splitted, start=1):
|
||||
if not k.isidentifier():
|
||||
raise RuntimeError("'{}' is an invalid config key.".format(k))
|
||||
if i == len(splitted):
|
||||
partial[k] = value
|
||||
else:
|
||||
self.__dict__[key] = value
|
||||
else:
|
||||
self.__dict__[key] = value
|
||||
|
||||
def clear(self):
|
||||
"""Clears all values in the current context ONLY."""
|
||||
raise NotImplemented
|
||||
|
||||
def set(self, key, value):
|
||||
"""This should set config key with value `value` in the
|
||||
corresponding collection as defined by `self.collection` and
|
||||
`self.collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def guild(self, guild):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def channel(self, channel):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def role(self, role):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def member(self, member):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def user(self, user):
|
||||
"""This should return a `BaseConfig` instance with the corresponding
|
||||
`collection` and `collection_uuid`."""
|
||||
raise NotImplemented
|
||||
|
||||
def register_global(self, **global_defaults):
|
||||
"""
|
||||
Registers a new dict of global defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param global_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in global_defaults.items():
|
||||
try:
|
||||
self._register_global(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default global key.")
|
||||
|
||||
def _register_global(self, key, default=None):
|
||||
"""Registers a global config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["GLOBAL"][key] = default
|
||||
|
||||
def register_guild(self, **guild_defaults):
|
||||
"""
|
||||
Registers a new dict of guild defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param guild_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in guild_defaults.items():
|
||||
try:
|
||||
self._register_guild(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default guild key.")
|
||||
|
||||
def _register_guild(self, key, default=None):
|
||||
"""Registers a guild config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["GUILD"][key] = default
|
||||
|
||||
def register_channel(self, **channel_defaults):
|
||||
"""
|
||||
Registers a new dict of channel defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param channel_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in channel_defaults.items():
|
||||
try:
|
||||
self._register_channel(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default channel key.")
|
||||
|
||||
def _register_channel(self, key, default=None):
|
||||
"""Registers a channel config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["CHANNEL"][key] = default
|
||||
|
||||
def register_role(self, **role_defaults):
|
||||
"""
|
||||
Registers a new dict of role defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param role_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in role_defaults.items():
|
||||
try:
|
||||
self._register_role(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default role key.")
|
||||
|
||||
def _register_role(self, key, default=None):
|
||||
"""Registers a role config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["ROLE"][key] = default
|
||||
|
||||
def register_member(self, **member_defaults):
|
||||
"""
|
||||
Registers a new dict of member defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param member_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in member_defaults.items():
|
||||
try:
|
||||
self._register_member(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default member key.")
|
||||
|
||||
def _register_member(self, key, default=None):
|
||||
"""Registers a member config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["MEMBER"][key] = default
|
||||
|
||||
def register_user(self, **user_defaults):
|
||||
"""
|
||||
Registers a new dict of user defaults. This function should
|
||||
be called EVERY TIME the cog loads (aka just do it in
|
||||
__init__)!
|
||||
|
||||
:param user_defaults: Each key should be the key you want to
|
||||
access data by and the value is the default value of that
|
||||
key.
|
||||
:return:
|
||||
"""
|
||||
for k, v in user_defaults.items():
|
||||
try:
|
||||
self._register_user(k, v)
|
||||
except KeyError:
|
||||
log.exception("Bad default user key.")
|
||||
|
||||
def _register_user(self, key, default=None):
|
||||
"""Registers a user config key `key`"""
|
||||
if key in self.unsettable_keys:
|
||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
||||
elif not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
self.defaults["USER"][key] = default
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
"""
|
||||
Config object created by `Config.get_conf()`
|
||||
|
||||
This configuration object is designed to make backend data
|
||||
storage mechanisms pluggable. It also is designed to
|
||||
help a cog developer make fewer mistakes (such as
|
||||
typos) when dealing with cog data and to make those mistakes
|
||||
apparent much faster in the design process.
|
||||
|
||||
It also has the capability to safely store data between cogs
|
||||
that share the same name.
|
||||
|
||||
There are two main components to this config object. First,
|
||||
you have the ability to get data on a level specific basis.
|
||||
The seven levels available are: global, guild, channel, role,
|
||||
member, user, and misc.
|
||||
|
||||
The second main component is registering default values for
|
||||
data in each of the levels. This functionality is OPTIONAL
|
||||
and must be explicitly enabled when creating the Config object
|
||||
using the kwarg `force_registration=True`.
|
||||
|
||||
Basic Usage:
|
||||
Creating a Config object:
|
||||
Use the `Config.get_conf()` class method to create new
|
||||
Config objects.
|
||||
|
||||
See the `Config.get_conf()` documentation for more
|
||||
information.
|
||||
|
||||
Registering Default Values (optional):
|
||||
You can register default values for data at all levels
|
||||
EXCEPT misc.
|
||||
|
||||
Simply pass in the key/value pairs as keyword arguments to
|
||||
the respective function.
|
||||
|
||||
e.g.: conf_obj.register_global(enabled=True)
|
||||
conf_obj.register_guild(likes_red=True)
|
||||
|
||||
Retrieving data by attributes:
|
||||
Since I registered the "enabled" key in the previous example
|
||||
at the global level I can now do:
|
||||
|
||||
conf_obj.enabled()
|
||||
|
||||
which will retrieve the current value of the "enabled"
|
||||
key, making use of the default of "True". I can also do
|
||||
the same for the guild key "likes_red":
|
||||
|
||||
conf_obj.guild(guild_obj).likes_red()
|
||||
|
||||
If I elected to not register default values, you can provide them
|
||||
when you try to access the key:
|
||||
|
||||
conf_obj.no_default(default=True)
|
||||
|
||||
However if you do not provide a default and you do not register
|
||||
defaults, accessing the attribute will return "None".
|
||||
|
||||
Saving data:
|
||||
This is accomplished by using the `set` function available at
|
||||
every level.
|
||||
|
||||
e.g.: conf_obj.set("enabled", False)
|
||||
conf_obj.guild(guild_obj).set("likes_red", False)
|
||||
|
||||
If `force_registration` was enabled when the config object
|
||||
was created you will only be allowed to save keys that you
|
||||
have registered.
|
||||
|
||||
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
|
||||
respectively.
|
||||
"""
|
||||
|
||||
def __getattr__(self, key) -> Callable:
|
||||
"""
|
||||
Until I've got a better way to do this I'm just gonna fake __call__
|
||||
|
||||
:param key:
|
||||
:return: lambda function with kwarg
|
||||
"""
|
||||
return self._get_value_from_key(key)
|
||||
|
||||
def _get_value_from_key(self, key) -> Callable:
|
||||
try:
|
||||
default = self.defaults[self.collection][key]
|
||||
except KeyError as e:
|
||||
if self.force_registration:
|
||||
raise AttributeError("Key '{}' not registered!".format(key)) from e
|
||||
default = None
|
||||
|
||||
self.curr_key = key
|
||||
|
||||
if self.collection != "MEMBER":
|
||||
ret = lambda default=default: self.driver_getmap[self.collection](
|
||||
self.cog_name, self.uuid, self.collection_uuid, key,
|
||||
default=default)
|
||||
else:
|
||||
mid, sid = self.collection_uuid
|
||||
ret = lambda default=default: self.driver.get_member(
|
||||
self.cog_name, self.uuid, mid, sid, key,
|
||||
default=default)
|
||||
partial[k] = {}
|
||||
partial = partial[k]
|
||||
return ret
|
||||
|
||||
def get(self, key, default=None):
|
||||
@staticmethod
|
||||
def _update_defaults(to_add: dict, _partial: dict):
|
||||
"""
|
||||
Included as an alternative to registering defaults.
|
||||
|
||||
:param key:
|
||||
:param default:
|
||||
:return:
|
||||
This tries to update the defaults dictionary with the nested
|
||||
partial dict generated by _get_defaults_dict. This WILL
|
||||
throw an error if you try to have both a value and a group
|
||||
registered under the same name.
|
||||
:param to_add:
|
||||
:param _partial:
|
||||
:return:
|
||||
"""
|
||||
for k, v in to_add.items():
|
||||
val_is_dict = isinstance(v, dict)
|
||||
if k in _partial:
|
||||
existing_is_dict = isinstance(_partial[k], dict)
|
||||
if val_is_dict != existing_is_dict:
|
||||
# != is XOR
|
||||
raise KeyError("You cannot register a Group and a Value under"
|
||||
" the same name.")
|
||||
if val_is_dict:
|
||||
Config._update_defaults(v, _partial=_partial[k])
|
||||
else:
|
||||
_partial[k] = v
|
||||
else:
|
||||
_partial[k] = v
|
||||
|
||||
if default is not None:
|
||||
return self._get_value_from_key(key)(default)
|
||||
else:
|
||||
return self._get_value_from_key(key)()
|
||||
def _register_default(self, key: str, **kwargs):
|
||||
if key not in self.defaults:
|
||||
self.defaults[key] = {}
|
||||
|
||||
async def set(self, key, value):
|
||||
# Notice to future developers:
|
||||
# This code was commented to allow users to set keys without having to register them.
|
||||
# That being said, if they try to get keys without registering them
|
||||
# things will blow up. I do highly recommend enforcing the key registration.
|
||||
data = deepcopy(kwargs)
|
||||
|
||||
if key in self.unsettable_keys or key in self.invalid_keys:
|
||||
raise KeyError("Restricted key name, please use another.")
|
||||
for k, v in data.items():
|
||||
to_add = self._get_defaults_dict(k, v)
|
||||
self._update_defaults(to_add, self.defaults[key])
|
||||
|
||||
if self.force_registration and key not in self.defaults[self.collection]:
|
||||
raise AttributeError("Key '{}' not registered!".format(key))
|
||||
def register_global(self, **kwargs):
|
||||
self._register_default(self.GLOBAL, **kwargs)
|
||||
|
||||
if not key.isidentifier():
|
||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
||||
" name.")
|
||||
def register_guild(self, **kwargs):
|
||||
self._register_default(self.GUILD, **kwargs)
|
||||
|
||||
if self.collection == "GLOBAL":
|
||||
await self.driver.set_global(self.cog_name, self.uuid, key, value)
|
||||
elif self.collection == "MEMBER":
|
||||
mid, sid = self.collection_uuid
|
||||
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
|
||||
key, value)
|
||||
elif self.collection in self.driver_setmap:
|
||||
func = self.driver_setmap[self.collection]
|
||||
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
|
||||
def register_channel(self, **kwargs):
|
||||
# We may need to add a voice channel category later
|
||||
self._register_default(self.CHANNEL, **kwargs)
|
||||
|
||||
async def clear(self):
|
||||
await self.driver_setmap[self.collection](
|
||||
self.cog_name, self.uuid, self.collection_uuid, None, None,
|
||||
clear=True)
|
||||
def register_role(self, **kwargs):
|
||||
self._register_default(self.ROLE, **kwargs)
|
||||
|
||||
def guild(self, guild):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "GUILD"
|
||||
new.collection_uuid = guild.id
|
||||
new._driver = None
|
||||
return new
|
||||
def register_user(self, **kwargs):
|
||||
self._register_default(self.USER, **kwargs)
|
||||
|
||||
def channel(self, channel):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "CHANNEL"
|
||||
new.collection_uuid = channel.id
|
||||
new._driver = None
|
||||
return new
|
||||
def register_member(self, **kwargs):
|
||||
self._register_default(self.MEMBER, **kwargs)
|
||||
|
||||
def role(self, role):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "ROLE"
|
||||
new.collection_uuid = role.id
|
||||
new._driver = None
|
||||
return new
|
||||
def _get_base_group(self, key: str, *identifiers: str,
|
||||
group_class=Group) -> Group:
|
||||
# noinspection PyTypeChecker
|
||||
return group_class(
|
||||
identifiers=(self.unique_identifier, key) + identifiers,
|
||||
defaults=self.defaults.get(key, {}),
|
||||
spawner=self.spawner,
|
||||
force_registration=self.force_registration
|
||||
)
|
||||
|
||||
def member(self, member):
|
||||
guild = member.guild
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "MEMBER"
|
||||
new.collection_uuid = (member.id, guild.id)
|
||||
new._driver = None
|
||||
return new
|
||||
def guild(self, guild: discord.Guild) -> Group:
|
||||
return self._get_base_group(self.GUILD, guild.id)
|
||||
|
||||
def channel(self, channel: discord.TextChannel) -> Group:
|
||||
return self._get_base_group(self.CHANNEL, channel.id)
|
||||
|
||||
def role(self, role: discord.Role) -> Group:
|
||||
return self._get_base_group(self.ROLE, role.id)
|
||||
|
||||
def user(self, user: discord.User) -> Group:
|
||||
return self._get_base_group(self.USER, user.id)
|
||||
|
||||
def member(self, member: discord.Member) -> MemberGroup:
|
||||
return self._get_base_group(self.MEMBER, member.guild.id, member.id,
|
||||
group_class=MemberGroup)
|
||||
|
||||
def user(self, user):
|
||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
||||
hash_uuid=False, defaults=self.defaults)
|
||||
new.collection = "USER"
|
||||
new.collection_uuid = user.id
|
||||
new._driver = None
|
||||
return new
|
||||
|
||||
@ -98,7 +98,7 @@ class Core:
|
||||
@commands.guild_only()
|
||||
async def adminrole(self, ctx, *, role: discord.Role):
|
||||
"""Sets the admin role for this server"""
|
||||
await ctx.bot.db.guild(ctx.guild).set("admin_role", role.id)
|
||||
await ctx.bot.db.guild(ctx.guild).admin_role.set(role.id)
|
||||
await ctx.send("The admin role for this server has been set.")
|
||||
|
||||
@_set.command()
|
||||
@ -106,7 +106,7 @@ class Core:
|
||||
@commands.guild_only()
|
||||
async def modrole(self, ctx, *, role: discord.Role):
|
||||
"""Sets the mod role for this server"""
|
||||
await ctx.bot.db.guild(ctx.guild).set("mod_role", role.id)
|
||||
await ctx.bot.db.guild(ctx.guild).mod_role.set(role.id)
|
||||
await ctx.send("The mod role for this server has been set.")
|
||||
|
||||
@_set.command()
|
||||
@ -225,7 +225,7 @@ class Core:
|
||||
await ctx.bot.send_cmd_help(ctx)
|
||||
return
|
||||
prefixes = sorted(prefixes, reverse=True)
|
||||
await ctx.bot.db.set("prefix", prefixes)
|
||||
await ctx.bot.db.prefix.set(prefixes)
|
||||
await ctx.send("Prefix set.")
|
||||
|
||||
@_set.command(aliases=["serverprefixes"])
|
||||
@ -234,11 +234,11 @@ class Core:
|
||||
async def serverprefix(self, ctx, *prefixes):
|
||||
"""Sets Red's server prefix(es)"""
|
||||
if not prefixes:
|
||||
await ctx.bot.db.guild(ctx.guild).set("prefix", [])
|
||||
await ctx.bot.db.guild(ctx.guild).prefix.set([])
|
||||
await ctx.send("Server prefixes have been reset.")
|
||||
return
|
||||
prefixes = sorted(prefixes, reverse=True)
|
||||
await ctx.bot.db.guild(ctx.guild).set("prefix", prefixes)
|
||||
await ctx.bot.db.guild(ctx.guild).prefix.set(prefixes)
|
||||
await ctx.send("Prefix set.")
|
||||
|
||||
@_set.command()
|
||||
@ -276,7 +276,7 @@ class Core:
|
||||
else:
|
||||
if message.content.strip() == token:
|
||||
self.owner.reset_cooldown(ctx)
|
||||
await ctx.bot.db.set("owner", ctx.author.id)
|
||||
await ctx.bot.db.owner.set(ctx.author.id)
|
||||
ctx.bot.owner_id = ctx.author.id
|
||||
await ctx.send("You have been set as owner.")
|
||||
else:
|
||||
|
||||
@ -1,45 +1,12 @@
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class BaseDriver:
|
||||
def get_global(self, cog_name, ident, collection_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
def get_driver(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
def get(self, identifiers: Tuple[str]):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
||||
default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_misc(self, cog_name, ident, *, default=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_channel(self, cog_name, ident, channel_id, key, value,
|
||||
clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
|
||||
clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
async def set_misc(self, cog_name, ident, value, clear=False):
|
||||
raise NotImplementedError()
|
||||
async def set(self, identifiers: Tuple[str], value):
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from typing import Tuple
|
||||
|
||||
from core.drivers.red_base import BaseDriver
|
||||
from core.json_io import JsonIO
|
||||
import os
|
||||
from .red_base import BaseDriver
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class JSON(BaseDriver):
|
||||
def __init__(self, cog_name, *args, data_path_override: Path=None,
|
||||
file_name_override: str="settings.json", **kwargs):
|
||||
def __init__(self, cog_name, *, data_path_override: Path=None,
|
||||
file_name_override: str="settings.json"):
|
||||
super().__init__()
|
||||
self.cog_name = cog_name
|
||||
self.file_name = file_name_override
|
||||
if data_path_override:
|
||||
@ -25,111 +27,23 @@ class JSON(BaseDriver):
|
||||
self.data = self.jsonIO._load_json()
|
||||
except FileNotFoundError:
|
||||
self.data = {}
|
||||
|
||||
def maybe_add_ident(self, ident: str):
|
||||
if ident in self.data:
|
||||
return
|
||||
|
||||
self.data[ident] = {}
|
||||
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
|
||||
if k not in self.data[ident]:
|
||||
self.data[ident][k] = {}
|
||||
|
||||
self.jsonIO._save_json(self.data)
|
||||
|
||||
def get_global(self, cog_name, ident, _, key, *, default=None):
|
||||
return self.data[ident]["GLOBAL"].get(key, default)
|
||||
def get_driver(self):
|
||||
return self
|
||||
|
||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
||||
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {})
|
||||
return guilddata.get(key, default)
|
||||
def get(self, identifiers: Tuple[str]):
|
||||
partial = self.data
|
||||
for i in identifiers:
|
||||
partial = partial[i]
|
||||
return partial
|
||||
|
||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
||||
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {})
|
||||
return channeldata.get(key, default)
|
||||
async def set(self, identifiers, value):
|
||||
partial = self.data
|
||||
for i in identifiers[:-1]:
|
||||
if i not in partial:
|
||||
partial[i] = {}
|
||||
partial = partial[i]
|
||||
|
||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
||||
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
|
||||
return roledata.get(key, default)
|
||||
|
||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
||||
default=None):
|
||||
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
|
||||
guilddata = userdata.get(str(guild_id), {})
|
||||
return guilddata.get(key, default)
|
||||
|
||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
||||
userdata = self.data[ident]["USER"].get(str(user_id), {})
|
||||
return userdata.get(key, default)
|
||||
|
||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
||||
if clear:
|
||||
self.data[ident]["GLOBAL"] = {}
|
||||
else:
|
||||
self.data[ident]["GLOBAL"][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
||||
guild_id = str(guild_id)
|
||||
if clear:
|
||||
self.data[ident]["GUILD"][guild_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["GUILD"][guild_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["GUILD"][guild_id] = {}
|
||||
self.data[ident]["GUILD"][guild_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
|
||||
channel_id = str(channel_id)
|
||||
if clear:
|
||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
||||
role_id = str(role_id)
|
||||
if clear:
|
||||
self.data[ident]["ROLE"][role_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["ROLE"][role_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["ROLE"][role_id] = {}
|
||||
self.data[ident]["ROLE"][role_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
|
||||
user_id = str(user_id)
|
||||
guild_id = str(guild_id)
|
||||
if clear:
|
||||
self.data[ident]["MEMBER"][user_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
||||
except KeyError:
|
||||
if user_id not in self.data[ident]["MEMBER"]:
|
||||
self.data[ident]["MEMBER"][user_id] = {}
|
||||
if guild_id not in self.data[ident]["MEMBER"][user_id]:
|
||||
self.data[ident]["MEMBER"][user_id][guild_id] = {}
|
||||
|
||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
||||
user_id = str(user_id)
|
||||
if clear:
|
||||
self.data[ident]["USER"][user_id] = {}
|
||||
else:
|
||||
try:
|
||||
self.data[ident]["USER"][user_id][key] = value
|
||||
except KeyError:
|
||||
self.data[ident]["USER"][user_id] = {}
|
||||
self.data[ident]["USER"][user_id][key] = value
|
||||
partial[identifiers[-1]] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
2
main.py
2
main.py
@ -113,7 +113,7 @@ if __name__ == '__main__':
|
||||
if db_token and not cli_flags.no_prompt:
|
||||
print("\nDo you want to reset the token? (y/n)")
|
||||
if confirm("> "):
|
||||
loop.run_until_complete(red.db.set("token", ""))
|
||||
loop.run_until_complete(red.db.token.set(""))
|
||||
print("Token has been reset.")
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt detected. Quitting...")
|
||||
|
||||
@ -2,12 +2,11 @@ from cogs.alias import Alias
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def alias(monkeysession, config):
|
||||
def get_mock_conf(*args, **kwargs):
|
||||
return config
|
||||
@pytest.fixture()
|
||||
def alias(config):
|
||||
import cogs.alias.alias
|
||||
|
||||
monkeysession.setattr("core.config.Config.get_conf", get_mock_conf)
|
||||
cogs.alias.alias.Config.get_conf = lambda *args, **kwargs: config
|
||||
|
||||
return Alias(None)
|
||||
|
||||
@ -25,9 +24,17 @@ def test_empty_global_aliases(alias):
|
||||
assert list(alias.unloaded_global_aliases()) == []
|
||||
|
||||
|
||||
async def create_test_guild_alias(alias, ctx):
|
||||
await alias.add_alias(ctx, "test", "ping", global_=False)
|
||||
|
||||
|
||||
async def create_test_global_alias(alias, ctx):
|
||||
await alias.add_alias(ctx, "test", "ping", global_=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_guild_alias(alias, ctx):
|
||||
await alias.add_alias(ctx, "test", "ping", global_=False)
|
||||
await create_test_guild_alias(alias, ctx)
|
||||
|
||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||
assert is_alias is True
|
||||
@ -36,6 +43,7 @@ async def test_add_guild_alias(alias, ctx):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_guild_alias(alias, ctx):
|
||||
await create_test_guild_alias(alias, ctx)
|
||||
is_alias, _ = alias.is_alias(ctx.guild, "test")
|
||||
assert is_alias is True
|
||||
|
||||
@ -47,7 +55,7 @@ async def test_delete_guild_alias(alias, ctx):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_global_alias(alias, ctx):
|
||||
await alias.add_alias(ctx, "test", "ping", global_=True)
|
||||
await create_test_global_alias(alias, ctx)
|
||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||
|
||||
assert is_alias is True
|
||||
@ -56,6 +64,7 @@ async def test_add_global_alias(alias, ctx):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_global_alias(alias, ctx):
|
||||
await create_test_global_alias(alias, ctx)
|
||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||
assert is_alias is True
|
||||
assert alias_obj.global_ is True
|
||||
|
||||
@ -17,21 +17,27 @@ def monkeysession(request):
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@pytest.fixture()
|
||||
def json_driver(tmpdir_factory):
|
||||
import uuid
|
||||
rand = str(uuid.uuid4())
|
||||
path = Path(str(tmpdir_factory.mktemp(rand)))
|
||||
driver = red_json.JSON(
|
||||
"PyTest",
|
||||
data_path_override=Path(str(tmpdir_factory.getbasetemp()))
|
||||
data_path_override=path
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config(json_driver):
|
||||
return Config(
|
||||
import uuid
|
||||
conf = Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=0,
|
||||
unique_identifier=str(uuid.uuid4()),
|
||||
driver_spawn=json_driver)
|
||||
yield conf
|
||||
conf.defaults = {}
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -39,19 +45,32 @@ def config_fr(json_driver):
|
||||
"""
|
||||
Mocked config object with force_register enabled.
|
||||
"""
|
||||
return Config(
|
||||
import uuid
|
||||
conf = Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=0,
|
||||
unique_identifier=str(uuid.uuid4()),
|
||||
driver_spawn=json_driver,
|
||||
force_registration=True
|
||||
)
|
||||
yield conf
|
||||
conf.defaults = {}
|
||||
|
||||
|
||||
#region Dpy Mocks
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_guild():
|
||||
@pytest.fixture()
|
||||
def guild_factory():
|
||||
mock_guild = namedtuple("Guild", "id members")
|
||||
return mock_guild(random.randint(1, 999999999), [])
|
||||
|
||||
class GuildFactory:
|
||||
def get(self):
|
||||
return mock_guild(random.randint(1, 999999999), [])
|
||||
|
||||
return GuildFactory()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def empty_guild(guild_factory):
|
||||
return guild_factory.get()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -66,16 +85,39 @@ def empty_role():
|
||||
return mock_role(random.randint(1, 999999999))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_member(empty_guild):
|
||||
@pytest.fixture()
|
||||
def member_factory(guild_factory):
|
||||
mock_member = namedtuple("Member", "id guild")
|
||||
return mock_member(random.randint(1, 999999999), empty_guild)
|
||||
|
||||
class MemberFactory:
|
||||
def get(self):
|
||||
return mock_member(
|
||||
random.randint(1, 999999999),
|
||||
guild_factory.get())
|
||||
|
||||
return MemberFactory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def empty_user():
|
||||
@pytest.fixture()
|
||||
def empty_member(member_factory):
|
||||
return member_factory.get()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def user_factory():
|
||||
mock_user = namedtuple("User", "id")
|
||||
return mock_user(random.randint(1, 999999999))
|
||||
|
||||
class UserFactory:
|
||||
def get(self):
|
||||
return mock_user(
|
||||
random.randint(1, 999999999))
|
||||
|
||||
return UserFactory()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def empty_user(user_factory):
|
||||
return user_factory.get()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -84,7 +126,7 @@ def empty_message():
|
||||
return mock_msg("No content.")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture()
|
||||
def ctx(empty_member, empty_channel, red):
|
||||
mock_ctx = namedtuple("Context", "author guild channel message bot")
|
||||
return mock_ctx(empty_member, empty_member.guild, empty_channel,
|
||||
@ -93,15 +135,14 @@ def ctx(empty_member, empty_channel, red):
|
||||
|
||||
|
||||
#region Red Mock
|
||||
@pytest.fixture
|
||||
def red(monkeysession, config_fr):
|
||||
@pytest.fixture()
|
||||
def red(config_fr):
|
||||
from core.cli import parse_cli_flags
|
||||
cli_flags = parse_cli_flags()
|
||||
|
||||
description = "Red v3 - Alpha"
|
||||
|
||||
monkeysession.setattr("core.config.Config.get_core_conf",
|
||||
lambda *args, **kwargs: config_fr)
|
||||
Config.get_core_conf = (lambda *args, **kwargs: config_fr)
|
||||
|
||||
red = Red(cli_flags, description=description, pm_help=None)
|
||||
|
||||
|
||||
@ -15,9 +15,9 @@ def test_config_register_global_badvalues(config):
|
||||
|
||||
def test_config_register_guild(config, empty_guild):
|
||||
config.register_guild(enabled=False, some_list=[], some_dict={})
|
||||
assert config.defaults["GUILD"]["enabled"] is False
|
||||
assert config.defaults["GUILD"]["some_list"] == []
|
||||
assert config.defaults["GUILD"]["some_dict"] == {}
|
||||
assert config.defaults[config.GUILD]["enabled"] is False
|
||||
assert config.defaults[config.GUILD]["some_list"] == []
|
||||
assert config.defaults[config.GUILD]["some_dict"] == {}
|
||||
|
||||
assert config.guild(empty_guild).enabled() is False
|
||||
assert config.guild(empty_guild).some_list() == []
|
||||
@ -26,25 +26,25 @@ def test_config_register_guild(config, empty_guild):
|
||||
|
||||
def test_config_register_channel(config, empty_channel):
|
||||
config.register_channel(enabled=False)
|
||||
assert config.defaults["CHANNEL"]["enabled"] is False
|
||||
assert config.defaults[config.CHANNEL]["enabled"] is False
|
||||
assert config.channel(empty_channel).enabled() is False
|
||||
|
||||
|
||||
def test_config_register_role(config, empty_role):
|
||||
config.register_role(enabled=False)
|
||||
assert config.defaults["ROLE"]["enabled"] is False
|
||||
assert config.defaults[config.ROLE]["enabled"] is False
|
||||
assert config.role(empty_role).enabled() is False
|
||||
|
||||
|
||||
def test_config_register_member(config, empty_member):
|
||||
config.register_member(some_number=-1)
|
||||
assert config.defaults["MEMBER"]["some_number"] == -1
|
||||
assert config.defaults[config.MEMBER]["some_number"] == -1
|
||||
assert config.member(empty_member).some_number() == -1
|
||||
|
||||
|
||||
def test_config_register_user(config, empty_user):
|
||||
config.register_user(some_value=None)
|
||||
assert config.defaults["USER"]["some_value"] is None
|
||||
assert config.defaults[config.USER]["some_value"] is None
|
||||
assert config.user(empty_user).some_value() is None
|
||||
|
||||
|
||||
@ -57,106 +57,233 @@ def test_config_force_register_global(config_fr):
|
||||
#endregion
|
||||
|
||||
|
||||
# Test nested registration
|
||||
def test_nested_registration(config):
|
||||
config.register_global(foo__bar__baz=False)
|
||||
assert config.foo.bar.baz() is False
|
||||
|
||||
|
||||
def test_nested_registration_asdict(config):
|
||||
defaults = {'bar': {'baz': False}}
|
||||
config.register_global(foo=defaults)
|
||||
|
||||
assert config.foo.bar.baz() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_registration_and_changing(config):
|
||||
defaults = {'bar': {'baz': False}}
|
||||
config.register_global(foo=defaults)
|
||||
|
||||
assert config.foo.bar.baz() is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await config.foo.set(True)
|
||||
|
||||
|
||||
def test_doubleset_default(config):
|
||||
config.register_global(foo=True)
|
||||
config.register_global(foo=False)
|
||||
|
||||
assert config.foo() is False
|
||||
|
||||
|
||||
def test_nested_registration_multidict(config):
|
||||
defaults = {
|
||||
"foo": {
|
||||
"bar": {
|
||||
"baz": True
|
||||
}
|
||||
},
|
||||
"blah": True
|
||||
}
|
||||
config.register_global(**defaults)
|
||||
|
||||
assert config.foo.bar.baz() is True
|
||||
assert config.blah() is True
|
||||
|
||||
|
||||
def test_nested_group_value_badreg(config):
|
||||
config.register_global(foo=True)
|
||||
with pytest.raises(KeyError):
|
||||
config.register_global(foo__bar=False)
|
||||
|
||||
|
||||
def test_nested_toplevel_reg(config):
|
||||
defaults = {'bar': True, 'baz': False}
|
||||
config.register_global(foo=defaults)
|
||||
|
||||
assert config.foo.bar() is True
|
||||
assert config.foo.baz() is False
|
||||
|
||||
|
||||
def test_nested_overlapping(config):
|
||||
config.register_global(foo__bar=True)
|
||||
config.register_global(foo__baz=False)
|
||||
|
||||
assert config.foo.bar() is True
|
||||
assert config.foo.baz() is False
|
||||
|
||||
|
||||
def test_nesting_nofr(config):
|
||||
config.register_global(foo={})
|
||||
assert config.foo.bar() is None
|
||||
assert config.foo() == {}
|
||||
|
||||
|
||||
#region Default Value Overrides
|
||||
def test_global_default_override(config):
|
||||
assert config.enabled(True) is True
|
||||
assert config.get("enabled") is None
|
||||
assert config.get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_global_default_nofr(config):
|
||||
assert config.nofr() is None
|
||||
assert config.nofr(True) is True
|
||||
assert config.get("nofr") is None
|
||||
assert config.get("nofr", default=True) is True
|
||||
|
||||
|
||||
def test_guild_default_override(config, empty_guild):
|
||||
assert config.guild(empty_guild).enabled(True) is True
|
||||
assert config.guild(empty_guild).get("enabled") is None
|
||||
assert config.guild(empty_guild).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_channel_default_override(config, empty_channel):
|
||||
assert config.channel(empty_channel).enabled(True) is True
|
||||
assert config.channel(empty_channel).get("enabled") is None
|
||||
assert config.channel(empty_channel).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_role_default_override(config, empty_role):
|
||||
assert config.role(empty_role).enabled(True) is True
|
||||
assert config.role(empty_role).get("enabled") is None
|
||||
assert config.role(empty_role).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_member_default_override(config, empty_member):
|
||||
assert config.member(empty_member).enabled(True) is True
|
||||
assert config.member(empty_member).get("enabled") is None
|
||||
assert config.member(empty_member).get("enabled", default=True) is True
|
||||
|
||||
|
||||
def test_user_default_override(config, empty_user):
|
||||
assert config.user(empty_user).some_value(True) is True
|
||||
assert config.user(empty_user).get("some_value") is None
|
||||
assert config.user(empty_user).get("some_value", default=True) is True
|
||||
#endregion
|
||||
|
||||
|
||||
#region Setting Values
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global(config):
|
||||
await config.set("enabled", True)
|
||||
await config.enabled.set(True)
|
||||
assert config.enabled() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global_badkey(config):
|
||||
with pytest.raises(RuntimeError):
|
||||
await config.set("this is a bad key", True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_global_invalidkey(config):
|
||||
with pytest.raises(KeyError):
|
||||
await config.set("uuid", True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_guild(config, empty_guild):
|
||||
await config.guild(empty_guild).set("enabled", True)
|
||||
await config.guild(empty_guild).enabled.set(True)
|
||||
assert config.guild(empty_guild).enabled() is True
|
||||
|
||||
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
|
||||
assert curr_list == [1, 2, 3]
|
||||
curr_list.append(4)
|
||||
|
||||
await config.guild(empty_guild).set("some_list", curr_list)
|
||||
await config.guild(empty_guild).some_list.set(curr_list)
|
||||
assert config.guild(empty_guild).some_list() == curr_list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_channel(config, empty_channel):
|
||||
await config.channel(empty_channel).set("enabled", True)
|
||||
await config.channel(empty_channel).enabled.set(True)
|
||||
assert config.channel(empty_channel).enabled() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_channel_no_register(config, empty_channel):
|
||||
await config.channel(empty_channel).set("no_register", True)
|
||||
await config.channel(empty_channel).no_register.set(True)
|
||||
assert config.channel(empty_channel).no_register() is True
|
||||
#endregion
|
||||
|
||||
|
||||
# region Getting Values
|
||||
def test_get_func_w_reg(config):
|
||||
config.register_global(
|
||||
thing=True
|
||||
)
|
||||
assert config.get("thing") is True
|
||||
assert config.get("thing", False) is False
|
||||
# Dynamic attribute testing
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_dynamic_attr(config):
|
||||
await config.set_attr("foobar", True)
|
||||
|
||||
assert config.foobar() is True
|
||||
|
||||
|
||||
def test_get_func_wo_reg(config):
|
||||
assert config.get("thing") is None
|
||||
assert config.get("thing", True) is True
|
||||
# endregion
|
||||
def test_get_dynamic_attr(config):
|
||||
assert config.get_attr("foobaz", True) is True
|
||||
|
||||
|
||||
# Member Group testing
|
||||
@pytest.mark.asyncio
|
||||
async def test_membergroup_allguilds(config, empty_member):
|
||||
await config.member(empty_member).foo.set(False)
|
||||
|
||||
all_servers = config.member(empty_member).all_guilds()
|
||||
assert str(empty_member.guild.id) in all_servers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_membergroup_allmembers(config, empty_member):
|
||||
await config.member(empty_member).foo.set(False)
|
||||
|
||||
all_members = config.member(empty_member).all()
|
||||
assert str(empty_member.id) in all_members
|
||||
|
||||
|
||||
# Clearing testing
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_clear(config):
|
||||
config.register_global(foo=True, bar=False)
|
||||
|
||||
await config.foo.set(False)
|
||||
await config.bar.set(True)
|
||||
|
||||
assert config.foo() is False
|
||||
assert config.bar() is True
|
||||
|
||||
await config.clear()
|
||||
|
||||
assert config.foo() is True
|
||||
assert config.bar() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_clear(config, member_factory):
|
||||
config.register_member(foo=True)
|
||||
|
||||
m1 = member_factory.get()
|
||||
await config.member(m1).foo.set(False)
|
||||
assert config.member(m1).foo() is False
|
||||
|
||||
m2 = member_factory.get()
|
||||
await config.member(m2).foo.set(False)
|
||||
assert config.member(m2).foo() is False
|
||||
|
||||
assert m1.guild.id != m2.guild.id
|
||||
|
||||
await config.member(m1).clear()
|
||||
assert config.member(m1).foo() is True
|
||||
assert config.member(m2).foo() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_clear_all(config, member_factory):
|
||||
server_ids = []
|
||||
for _ in range(5):
|
||||
member = member_factory.get()
|
||||
await config.member(member).foo.set(True)
|
||||
server_ids.append(member.guild.id)
|
||||
|
||||
member = member_factory.get()
|
||||
assert len(config.member(member).all_guilds()) == len(server_ids)
|
||||
|
||||
await config.member(member).clear_all()
|
||||
|
||||
assert len(config.member(member).all_guilds()) == 0
|
||||
|
||||
|
||||
# Get All testing
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_get_all(config, user_factory):
|
||||
for _ in range(5):
|
||||
user = user_factory.get()
|
||||
await config.user(user).foo.set(True)
|
||||
|
||||
user = user_factory.get()
|
||||
all_data = config.user(user).all()
|
||||
|
||||
assert len(all_data) == 5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user