mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[Config] Rewrite (#869)
This commit is contained in:
parent
5c2be25dfc
commit
99bfb2fc7a
@ -1,3 +1,4 @@
|
|||||||
|
dist: trusty
|
||||||
language: python
|
language: python
|
||||||
python:
|
python:
|
||||||
- "3.5.3"
|
- "3.5.3"
|
||||||
|
|||||||
@ -81,24 +81,24 @@ class Alias:
|
|||||||
if global_:
|
if global_:
|
||||||
curr_aliases = self._aliases.entries()
|
curr_aliases = self._aliases.entries()
|
||||||
curr_aliases.append(alias.to_json())
|
curr_aliases.append(alias.to_json())
|
||||||
await self._aliases.set("entries", curr_aliases)
|
await self._aliases.entries.set(curr_aliases)
|
||||||
else:
|
else:
|
||||||
curr_aliases = self._aliases.guild(ctx.guild).entries()
|
curr_aliases = self._aliases.guild(ctx.guild).entries()
|
||||||
|
|
||||||
curr_aliases.append(alias.to_json())
|
curr_aliases.append(alias.to_json())
|
||||||
await self._aliases.guild(ctx.guild).set("entries", curr_aliases)
|
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
||||||
|
|
||||||
await self._aliases.guild(ctx.guild).set("enabled", True)
|
await self._aliases.guild(ctx.guild).enabled.set(True)
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||||
global_: bool=False) -> bool:
|
global_: bool=False) -> bool:
|
||||||
if global_:
|
if global_:
|
||||||
aliases = self.unloaded_global_aliases()
|
aliases = self.unloaded_global_aliases()
|
||||||
setter_func = self._aliases.set
|
setter_func = self._aliases.entries.set
|
||||||
else:
|
else:
|
||||||
aliases = self.unloaded_aliases(ctx.guild)
|
aliases = self.unloaded_aliases(ctx.guild)
|
||||||
setter_func = self._aliases.guild(ctx.guild).set
|
setter_func = self._aliases.guild(ctx.guild).entries.set
|
||||||
|
|
||||||
did_delete_alias = False
|
did_delete_alias = False
|
||||||
|
|
||||||
@ -110,7 +110,6 @@ class Alias:
|
|||||||
did_delete_alias = True
|
did_delete_alias = True
|
||||||
|
|
||||||
await setter_func(
|
await setter_func(
|
||||||
"entries",
|
|
||||||
[a.to_json() for a in to_keep]
|
[a.to_json() for a in to_keep]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -355,8 +354,9 @@ class Alias:
|
|||||||
await ctx.send(box("\n".join(names), "diff"))
|
await ctx.send(box("\n".join(names), "diff"))
|
||||||
|
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
aliases = list(self.unloaded_aliases(message.guild)) + \
|
aliases = list(self.unloaded_global_aliases())
|
||||||
list(self.unloaded_global_aliases())
|
if message.guild is not None:
|
||||||
|
aliases = aliases + list(self.unloaded_aliases(message.guild))
|
||||||
|
|
||||||
if len(aliases) == 0:
|
if len(aliases) == 0:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -30,7 +30,7 @@ class Downloader:
|
|||||||
def __init__(self, bot: Red):
|
def __init__(self, bot: Red):
|
||||||
self.bot = bot
|
self.bot = bot
|
||||||
|
|
||||||
self.conf = Config.get_conf(self, unique_identifier=998240343,
|
self.conf = Config.get_conf(self, identifier=998240343,
|
||||||
force_registration=True)
|
force_registration=True)
|
||||||
|
|
||||||
self.conf.register_global(
|
self.conf.register_global(
|
||||||
@ -73,7 +73,7 @@ class Downloader:
|
|||||||
|
|
||||||
if cog_json not in installed:
|
if cog_json not in installed:
|
||||||
installed.append(cog_json)
|
installed.append(cog_json)
|
||||||
await self.conf.set("installed", installed)
|
await self.conf.installed.set(installed)
|
||||||
|
|
||||||
async def _remove_from_installed(self, cog: Installable):
|
async def _remove_from_installed(self, cog: Installable):
|
||||||
"""
|
"""
|
||||||
@ -86,7 +86,7 @@ class Downloader:
|
|||||||
|
|
||||||
if cog_json in installed:
|
if cog_json in installed:
|
||||||
installed.remove(cog_json)
|
installed.remove(cog_json)
|
||||||
await self.conf.set("installed", installed)
|
await self.conf.installed.set(installed)
|
||||||
|
|
||||||
async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]:
|
async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -526,4 +526,4 @@ class RepoManager:
|
|||||||
|
|
||||||
async def _save_repos(self):
|
async def _save_repos(self):
|
||||||
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
|
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
|
||||||
await self.downloader_config.set("repos", repo_json_info)
|
await self.downloader_config.repos.set(repo_json_info)
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class Red(commands.Bot):
|
|||||||
kwargs["owner_id"] = cli_flags.owner
|
kwargs["owner_id"] = cli_flags.owner
|
||||||
|
|
||||||
if "owner_id" not in kwargs:
|
if "owner_id" not in kwargs:
|
||||||
kwargs["owner_id"] = self.db.get("owner")
|
kwargs["owner_id"] = self.db.owner()
|
||||||
|
|
||||||
self.counter = Counter()
|
self.counter = Counter()
|
||||||
self.uptime = None
|
self.uptime = None
|
||||||
@ -89,7 +89,7 @@ class Red(commands.Bot):
|
|||||||
for package in self.extensions:
|
for package in self.extensions:
|
||||||
if package.startswith("cogs."):
|
if package.startswith("cogs."):
|
||||||
loaded.append(package)
|
loaded.append(package)
|
||||||
await self.db.set("packages", loaded)
|
await self.db.packages.set(loaded)
|
||||||
|
|
||||||
|
|
||||||
class ExitCodes(Enum):
|
class ExitCodes(Enum):
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def interactive_config(red, token_set, prefix_set):
|
|||||||
print("That doesn't look like a valid token.")
|
print("That doesn't look like a valid token.")
|
||||||
token = ""
|
token = ""
|
||||||
if token:
|
if token:
|
||||||
loop.run_until_complete(red.db.set("token", token))
|
loop.run_until_complete(red.db.token.set(token))
|
||||||
|
|
||||||
if not prefix_set:
|
if not prefix_set:
|
||||||
prefix = ""
|
prefix = ""
|
||||||
@ -39,7 +39,7 @@ def interactive_config(red, token_set, prefix_set):
|
|||||||
if not confirm("> "):
|
if not confirm("> "):
|
||||||
prefix = ""
|
prefix = ""
|
||||||
if prefix:
|
if prefix:
|
||||||
loop.run_until_complete(red.db.set("prefix", [prefix]))
|
loop.run_until_complete(red.db.prefix.set([prefix]))
|
||||||
|
|
||||||
ask_sentry(red)
|
ask_sentry(red)
|
||||||
|
|
||||||
@ -55,9 +55,9 @@ def ask_sentry(red: Red):
|
|||||||
" found issues in a timely manner. If you wish to opt in\n"
|
" found issues in a timely manner. If you wish to opt in\n"
|
||||||
" the process please type \"yes\":\n")
|
" the process please type \"yes\":\n")
|
||||||
if not confirm("> "):
|
if not confirm("> "):
|
||||||
loop.run_until_complete(red.db.set("enable_sentry", False))
|
loop.run_until_complete(red.db.enable_sentry.set(False))
|
||||||
else:
|
else:
|
||||||
loop.run_until_complete(red.db.set("enable_sentry", True))
|
loop.run_until_complete(red.db.enable_sentry.set(True))
|
||||||
print("\nThank you for helping us with the development process!")
|
print("\nThank you for helping us with the development process!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
791
core/config.py
791
core/config.py
@ -1,521 +1,374 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from core.drivers.red_json import JSON as JSONDriver
|
|
||||||
from core.drivers.red_mongo import Mongo
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Callable
|
from typing import Callable, Union, Tuple
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .drivers.red_json import JSON as JSONDriver
|
||||||
|
|
||||||
log = logging.getLogger("red.config")
|
log = logging.getLogger("red.config")
|
||||||
|
|
||||||
class BaseConfig:
|
|
||||||
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
|
|
||||||
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
|
|
||||||
defaults={}):
|
|
||||||
self.cog_name = cog_name
|
|
||||||
if hash_uuid:
|
|
||||||
self.uuid = str(hash(unique_identifier))
|
|
||||||
else:
|
|
||||||
self.uuid = unique_identifier
|
|
||||||
self.driver_spawn = driver_spawn
|
|
||||||
self._driver = None
|
|
||||||
self.collection = collection
|
|
||||||
self.collection_uuid = collection_uuid
|
|
||||||
|
|
||||||
self.force_registration = force_registration
|
class Value:
|
||||||
|
def __init__(self, identifiers: Tuple[str], default_value, spawner):
|
||||||
|
self._identifiers = identifiers
|
||||||
|
self.default = default_value
|
||||||
|
|
||||||
|
self.spawner = spawner
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifiers(self):
|
||||||
|
return tuple(str(i) for i in self._identifiers)
|
||||||
|
|
||||||
|
def __call__(self, default=None):
|
||||||
|
driver = self.spawner.get_driver()
|
||||||
try:
|
try:
|
||||||
self.driver.maybe_add_ident(self.uuid)
|
ret = driver.get(self.identifiers)
|
||||||
except AttributeError:
|
except KeyError:
|
||||||
pass
|
return default or self.default
|
||||||
|
return ret
|
||||||
|
|
||||||
self.driver_getmap = {
|
async def set(self, value):
|
||||||
"GLOBAL": self.driver.get_global,
|
driver = self.spawner.get_driver()
|
||||||
"GUILD": self.driver.get_guild,
|
await driver.set(self.identifiers, value)
|
||||||
"CHANNEL": self.driver.get_channel,
|
|
||||||
"ROLE": self.driver.get_role,
|
|
||||||
"USER": self.driver.get_user
|
|
||||||
}
|
|
||||||
|
|
||||||
self.driver_setmap = {
|
|
||||||
"GLOBAL": self.driver.set_global,
|
|
||||||
"GUILD": self.driver.set_guild,
|
|
||||||
"CHANNEL": self.driver.set_channel,
|
|
||||||
"ROLE": self.driver.set_role,
|
|
||||||
"USER": self.driver.set_user
|
|
||||||
}
|
|
||||||
|
|
||||||
self.curr_key = None
|
class Group(Value):
|
||||||
|
def __init__(self, identifiers: Tuple[str],
|
||||||
|
defaults: dict,
|
||||||
|
spawner,
|
||||||
|
force_registration: bool=False):
|
||||||
|
self.defaults = defaults
|
||||||
|
self.force_registration = force_registration
|
||||||
|
self.spawner = spawner
|
||||||
|
|
||||||
self.unsettable_keys = ("cog_name", "cog_identifier", "_id",
|
super().__init__(identifiers, {}, self.spawner)
|
||||||
"guild_id", "channel_id", "role_id",
|
|
||||||
"user_id", "uuid")
|
# noinspection PyTypeChecker
|
||||||
self.invalid_keys = (
|
def __getattr__(self, item: str) -> Union["Group", Value]:
|
||||||
"driver_spawn",
|
"""
|
||||||
"_driver", "collection",
|
Takes in the next accessible item. If it's found to be a Group
|
||||||
"collection_uuid", "force_registration"
|
we return another Group object. If it's found to be a Value
|
||||||
|
we return a Value object. If it is not found and
|
||||||
|
force_registration is True then we raise AttributeException,
|
||||||
|
otherwise return a Value object.
|
||||||
|
:param item:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
is_group = self.is_group(item)
|
||||||
|
is_value = not is_group and self.is_value(item)
|
||||||
|
new_identifiers = self.identifiers + (item, )
|
||||||
|
if is_group:
|
||||||
|
return Group(
|
||||||
|
identifiers=new_identifiers,
|
||||||
|
defaults=self.defaults[item],
|
||||||
|
spawner=self.spawner,
|
||||||
|
force_registration=self.force_registration
|
||||||
|
)
|
||||||
|
elif is_value:
|
||||||
|
return Value(
|
||||||
|
identifiers=new_identifiers,
|
||||||
|
default_value=self.defaults[item],
|
||||||
|
spawner=self.spawner
|
||||||
|
)
|
||||||
|
elif self.force_registration:
|
||||||
|
raise AttributeError(
|
||||||
|
"'{}' is not a valid registered Group"
|
||||||
|
"or value.".format(item)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Value(
|
||||||
|
identifiers=new_identifiers,
|
||||||
|
default_value=None,
|
||||||
|
spawner=self.spawner
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _super_group(self) -> 'Group':
|
||||||
|
super_group = Group(
|
||||||
|
self.identifiers[:-1],
|
||||||
|
defaults={},
|
||||||
|
spawner=self.spawner,
|
||||||
|
force_registration=self.force_registration
|
||||||
)
|
)
|
||||||
|
return super_group
|
||||||
|
|
||||||
self.defaults = defaults if defaults else {
|
def is_group(self, item: str) -> bool:
|
||||||
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {},
|
"""
|
||||||
"MEMBER": {}, "USER": {}}
|
Determines if an attribute access is pointing at a registered group.
|
||||||
|
:param item:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
default = self.defaults.get(item)
|
||||||
|
return isinstance(default, dict)
|
||||||
|
|
||||||
|
def is_value(self, item: str) -> bool:
|
||||||
|
"""
|
||||||
|
Determines if an attribute access is pointing at a registered value.
|
||||||
|
:param item:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
default = self.defaults[item]
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return not isinstance(default, dict)
|
||||||
|
|
||||||
|
def get_attr(self, item: str, default=None):
|
||||||
|
"""
|
||||||
|
You should avoid this function whenever possible.
|
||||||
|
:param item:
|
||||||
|
:param default:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
value = getattr(self, item)
|
||||||
|
return value(default=default)
|
||||||
|
|
||||||
|
def all(self) -> dict:
|
||||||
|
"""
|
||||||
|
Gets all entries of the given kind. If this kind is member
|
||||||
|
then this method returns all members from the same
|
||||||
|
server.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return self._super_group()
|
||||||
|
|
||||||
|
async def set(self, value):
|
||||||
|
if not isinstance(value, dict):
|
||||||
|
raise ValueError(
|
||||||
|
"You may only set the value of a group to be a dict."
|
||||||
|
)
|
||||||
|
await super().set(value)
|
||||||
|
|
||||||
|
async def set_attr(self, item: str, value):
|
||||||
|
"""
|
||||||
|
You should avoid this function whenever possible.
|
||||||
|
:param item:
|
||||||
|
:param value:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
value_obj = getattr(self, item)
|
||||||
|
await value_obj.set(value)
|
||||||
|
|
||||||
|
async def clear(self):
|
||||||
|
"""
|
||||||
|
Wipes out data for the given entry in this category
|
||||||
|
e.g. Guild/Role/User
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self.set({})
|
||||||
|
|
||||||
|
async def clear_all(self):
|
||||||
|
"""
|
||||||
|
Removes all data from all entries.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
await self._super_group.set({})
|
||||||
|
|
||||||
|
|
||||||
|
class MemberGroup(Group):
|
||||||
|
@property
|
||||||
|
def _super_group(self) -> Group:
|
||||||
|
new_identifiers = self.identifiers[:2]
|
||||||
|
group_obj = Group(
|
||||||
|
identifiers=new_identifiers,
|
||||||
|
defaults={},
|
||||||
|
spawner=self.spawner
|
||||||
|
)
|
||||||
|
return group_obj
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _guild_group(self) -> Group:
|
||||||
|
new_identifiers = self.identifiers[:3]
|
||||||
|
group_obj = Group(
|
||||||
|
identifiers=new_identifiers,
|
||||||
|
defaults={},
|
||||||
|
spawner=self.spawner
|
||||||
|
)
|
||||||
|
return group_obj
|
||||||
|
|
||||||
|
def all_guilds(self) -> dict:
|
||||||
|
"""
|
||||||
|
Gets a dict of all guilds and members.
|
||||||
|
|
||||||
|
REMEMBER: ID's are stored in these dicts as STRINGS.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return self._super_group()
|
||||||
|
|
||||||
|
def all(self) -> dict:
|
||||||
|
"""
|
||||||
|
Returns the dict of all members in the same guild.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# noinspection PyTypeChecker
|
||||||
|
return self._guild_group()
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
GLOBAL = "GLOBAL"
|
||||||
|
GUILD = "GUILD"
|
||||||
|
CHANNEL = "TEXTCHANNEL"
|
||||||
|
ROLE = "ROLE"
|
||||||
|
USER = "USER"
|
||||||
|
MEMBER = "MEMBER"
|
||||||
|
|
||||||
|
def __init__(self, cog_name: str, unique_identifier: str,
|
||||||
|
driver_spawn: Callable,
|
||||||
|
force_registration: bool=False,
|
||||||
|
defaults: dict=None):
|
||||||
|
self.cog_name = cog_name
|
||||||
|
self.unique_identifier = unique_identifier
|
||||||
|
|
||||||
|
self.spawner = driver_spawn
|
||||||
|
self.force_registration = force_registration
|
||||||
|
self.defaults = defaults or {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_conf(cls, cog_instance: object, unique_identifier: int=0,
|
def get_conf(cls, cog_instance, identifier: int,
|
||||||
force_registration: bool=False):
|
force_registration=False):
|
||||||
"""
|
"""
|
||||||
Gets a config object that cog's can use to safely store data. The
|
Returns a Config instance based on a simplified set of initial
|
||||||
backend to this is totally modular and can easily switch between
|
variables.
|
||||||
JSON and a DB. However, when changed, all data will likely be lost
|
:param cog_instance:
|
||||||
unless cogs write some converters for their data.
|
:param identifier: Any random integer, used to keep your data
|
||||||
|
distinct from any other cog with the same name.
|
||||||
Positional Arguments:
|
:param force_registration: Should config require registration
|
||||||
cog_instance - The cog `self` object, can be passed in from your
|
of data keys before allowing you to get/set values?
|
||||||
cog's __init__ method.
|
:return:
|
||||||
|
|
||||||
Keyword Arguments:
|
|
||||||
unique_identifier - a random integer or string that is used to
|
|
||||||
differentiate your cog from any other named the same. This way we
|
|
||||||
can safely store data for multiple cogs that are named the same.
|
|
||||||
|
|
||||||
YOU SHOULD USE THIS.
|
|
||||||
|
|
||||||
force_registration - A flag which will cause the Config object to
|
|
||||||
throw exceptions if you try to get/set data keys that you have
|
|
||||||
not pre-registered. I highly recommend you ENABLE this as it
|
|
||||||
will help reduce dumb typo errors.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url = None # TODO: get mongo url
|
|
||||||
port = None # TODO: get mongo port
|
|
||||||
|
|
||||||
def spawn_mongo_driver():
|
|
||||||
return Mongo(url, port)
|
|
||||||
|
|
||||||
# TODO: Determine which backend users want, default to JSON
|
|
||||||
|
|
||||||
cog_name = cog_instance.__class__.__name__
|
cog_name = cog_instance.__class__.__name__
|
||||||
|
uuid = str(hash(identifier))
|
||||||
|
|
||||||
driver_spawn = JSONDriver(cog_name)
|
spawner = JSONDriver(cog_name)
|
||||||
|
return cls(cog_name=cog_name, unique_identifier=uuid,
|
||||||
return cls(cog_name=cog_name, unique_identifier=unique_identifier,
|
force_registration=force_registration,
|
||||||
driver_spawn=driver_spawn, force_registration=force_registration)
|
driver_spawn=spawner)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_core_conf(cls, force_registration: bool=False):
|
def get_core_conf(cls, force_registration: bool=False):
|
||||||
core_data_path = Path.cwd() / 'core' / '.data'
|
core_data_path = Path.cwd() / 'core' / '.data'
|
||||||
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
|
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
|
||||||
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
||||||
unique_identifier=0,
|
unique_identifier='0',
|
||||||
force_registration=force_registration)
|
force_registration=force_registration)
|
||||||
|
|
||||||
@property
|
def __getattr__(self, item: str) -> Union[Group, Value]:
|
||||||
def driver(self):
|
|
||||||
if self._driver is None:
|
|
||||||
try:
|
|
||||||
self._driver = self.driver_spawn()
|
|
||||||
except TypeError:
|
|
||||||
return self.driver_spawn
|
|
||||||
|
|
||||||
return self._driver
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
"""This should be used to return config key data as determined by
|
|
||||||
`self.collection` and `self.collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def __setattr__(self, key, value):
|
|
||||||
if 'defaults' in self.__dict__: # Necessary to let the cog load
|
|
||||||
restricted = list(self.defaults[self.collection].keys()) + \
|
|
||||||
list(self.unsettable_keys)
|
|
||||||
if key in restricted:
|
|
||||||
raise ValueError("Not allowed to dynamically set attributes of"
|
|
||||||
" unsettable_keys: {}".format(restricted))
|
|
||||||
else:
|
|
||||||
self.__dict__[key] = value
|
|
||||||
else:
|
|
||||||
self.__dict__[key] = value
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""Clears all values in the current context ONLY."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def set(self, key, value):
|
|
||||||
"""This should set config key with value `value` in the
|
|
||||||
corresponding collection as defined by `self.collection` and
|
|
||||||
`self.collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def guild(self, guild):
|
|
||||||
"""This should return a `BaseConfig` instance with the corresponding
|
|
||||||
`collection` and `collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def channel(self, channel):
|
|
||||||
"""This should return a `BaseConfig` instance with the corresponding
|
|
||||||
`collection` and `collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def role(self, role):
|
|
||||||
"""This should return a `BaseConfig` instance with the corresponding
|
|
||||||
`collection` and `collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def member(self, member):
|
|
||||||
"""This should return a `BaseConfig` instance with the corresponding
|
|
||||||
`collection` and `collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def user(self, user):
|
|
||||||
"""This should return a `BaseConfig` instance with the corresponding
|
|
||||||
`collection` and `collection_uuid`."""
|
|
||||||
raise NotImplemented
|
|
||||||
|
|
||||||
def register_global(self, **global_defaults):
|
|
||||||
"""
|
"""
|
||||||
Registers a new dict of global defaults. This function should
|
This is used to generate Value or Group objects for global
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
values.
|
||||||
__init__)!
|
:param item:
|
||||||
|
|
||||||
:param global_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for k, v in global_defaults.items():
|
global_group = self._get_base_group(self.GLOBAL)
|
||||||
try:
|
return getattr(global_group, item)
|
||||||
self._register_global(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default global key.")
|
|
||||||
|
|
||||||
def _register_global(self, key, default=None):
|
@staticmethod
|
||||||
"""Registers a global config key `key`"""
|
def _get_defaults_dict(key: str, value) -> dict:
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["GLOBAL"][key] = default
|
|
||||||
|
|
||||||
def register_guild(self, **guild_defaults):
|
|
||||||
"""
|
"""
|
||||||
Registers a new dict of guild defaults. This function should
|
Since we're allowing nested config stuff now, not storing the
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
defaults as a flat dict sounds like a good idea. May turn
|
||||||
__init__)!
|
out to be an awful one but we'll see.
|
||||||
|
|
||||||
:param guild_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for k, v in guild_defaults.items():
|
|
||||||
try:
|
|
||||||
self._register_guild(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default guild key.")
|
|
||||||
|
|
||||||
def _register_guild(self, key, default=None):
|
|
||||||
"""Registers a guild config key `key`"""
|
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["GUILD"][key] = default
|
|
||||||
|
|
||||||
def register_channel(self, **channel_defaults):
|
|
||||||
"""
|
|
||||||
Registers a new dict of channel defaults. This function should
|
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
|
||||||
__init__)!
|
|
||||||
|
|
||||||
:param channel_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for k, v in channel_defaults.items():
|
|
||||||
try:
|
|
||||||
self._register_channel(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default channel key.")
|
|
||||||
|
|
||||||
def _register_channel(self, key, default=None):
|
|
||||||
"""Registers a channel config key `key`"""
|
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["CHANNEL"][key] = default
|
|
||||||
|
|
||||||
def register_role(self, **role_defaults):
|
|
||||||
"""
|
|
||||||
Registers a new dict of role defaults. This function should
|
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
|
||||||
__init__)!
|
|
||||||
|
|
||||||
:param role_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for k, v in role_defaults.items():
|
|
||||||
try:
|
|
||||||
self._register_role(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default role key.")
|
|
||||||
|
|
||||||
def _register_role(self, key, default=None):
|
|
||||||
"""Registers a role config key `key`"""
|
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["ROLE"][key] = default
|
|
||||||
|
|
||||||
def register_member(self, **member_defaults):
|
|
||||||
"""
|
|
||||||
Registers a new dict of member defaults. This function should
|
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
|
||||||
__init__)!
|
|
||||||
|
|
||||||
:param member_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for k, v in member_defaults.items():
|
|
||||||
try:
|
|
||||||
self._register_member(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default member key.")
|
|
||||||
|
|
||||||
def _register_member(self, key, default=None):
|
|
||||||
"""Registers a member config key `key`"""
|
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["MEMBER"][key] = default
|
|
||||||
|
|
||||||
def register_user(self, **user_defaults):
|
|
||||||
"""
|
|
||||||
Registers a new dict of user defaults. This function should
|
|
||||||
be called EVERY TIME the cog loads (aka just do it in
|
|
||||||
__init__)!
|
|
||||||
|
|
||||||
:param user_defaults: Each key should be the key you want to
|
|
||||||
access data by and the value is the default value of that
|
|
||||||
key.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for k, v in user_defaults.items():
|
|
||||||
try:
|
|
||||||
self._register_user(k, v)
|
|
||||||
except KeyError:
|
|
||||||
log.exception("Bad default user key.")
|
|
||||||
|
|
||||||
def _register_user(self, key, default=None):
|
|
||||||
"""Registers a user config key `key`"""
|
|
||||||
if key in self.unsettable_keys:
|
|
||||||
raise KeyError("Attempt to use restricted key: '{}'".format(key))
|
|
||||||
elif not key.isidentifier():
|
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
|
||||||
" name.")
|
|
||||||
self.defaults["USER"][key] = default
|
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseConfig):
|
|
||||||
"""
|
|
||||||
Config object created by `Config.get_conf()`
|
|
||||||
|
|
||||||
This configuration object is designed to make backend data
|
|
||||||
storage mechanisms pluggable. It also is designed to
|
|
||||||
help a cog developer make fewer mistakes (such as
|
|
||||||
typos) when dealing with cog data and to make those mistakes
|
|
||||||
apparent much faster in the design process.
|
|
||||||
|
|
||||||
It also has the capability to safely store data between cogs
|
|
||||||
that share the same name.
|
|
||||||
|
|
||||||
There are two main components to this config object. First,
|
|
||||||
you have the ability to get data on a level specific basis.
|
|
||||||
The seven levels available are: global, guild, channel, role,
|
|
||||||
member, user, and misc.
|
|
||||||
|
|
||||||
The second main component is registering default values for
|
|
||||||
data in each of the levels. This functionality is OPTIONAL
|
|
||||||
and must be explicitly enabled when creating the Config object
|
|
||||||
using the kwarg `force_registration=True`.
|
|
||||||
|
|
||||||
Basic Usage:
|
|
||||||
Creating a Config object:
|
|
||||||
Use the `Config.get_conf()` class method to create new
|
|
||||||
Config objects.
|
|
||||||
|
|
||||||
See the `Config.get_conf()` documentation for more
|
|
||||||
information.
|
|
||||||
|
|
||||||
Registering Default Values (optional):
|
|
||||||
You can register default values for data at all levels
|
|
||||||
EXCEPT misc.
|
|
||||||
|
|
||||||
Simply pass in the key/value pairs as keyword arguments to
|
|
||||||
the respective function.
|
|
||||||
|
|
||||||
e.g.: conf_obj.register_global(enabled=True)
|
|
||||||
conf_obj.register_guild(likes_red=True)
|
|
||||||
|
|
||||||
Retrieving data by attributes:
|
|
||||||
Since I registered the "enabled" key in the previous example
|
|
||||||
at the global level I can now do:
|
|
||||||
|
|
||||||
conf_obj.enabled()
|
|
||||||
|
|
||||||
which will retrieve the current value of the "enabled"
|
|
||||||
key, making use of the default of "True". I can also do
|
|
||||||
the same for the guild key "likes_red":
|
|
||||||
|
|
||||||
conf_obj.guild(guild_obj).likes_red()
|
|
||||||
|
|
||||||
If I elected to not register default values, you can provide them
|
|
||||||
when you try to access the key:
|
|
||||||
|
|
||||||
conf_obj.no_default(default=True)
|
|
||||||
|
|
||||||
However if you do not provide a default and you do not register
|
|
||||||
defaults, accessing the attribute will return "None".
|
|
||||||
|
|
||||||
Saving data:
|
|
||||||
This is accomplished by using the `set` function available at
|
|
||||||
every level.
|
|
||||||
|
|
||||||
e.g.: conf_obj.set("enabled", False)
|
|
||||||
conf_obj.guild(guild_obj).set("likes_red", False)
|
|
||||||
|
|
||||||
If `force_registration` was enabled when the config object
|
|
||||||
was created you will only be allowed to save keys that you
|
|
||||||
have registered.
|
|
||||||
|
|
||||||
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
|
|
||||||
respectively.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __getattr__(self, key) -> Callable:
|
|
||||||
"""
|
|
||||||
Until I've got a better way to do this I'm just gonna fake __call__
|
|
||||||
|
|
||||||
:param key:
|
:param key:
|
||||||
:return: lambda function with kwarg
|
:param value:
|
||||||
|
:return:
|
||||||
"""
|
"""
|
||||||
return self._get_value_from_key(key)
|
ret = {}
|
||||||
|
partial = ret
|
||||||
def _get_value_from_key(self, key) -> Callable:
|
splitted = key.split('__')
|
||||||
try:
|
for i, k in enumerate(splitted, start=1):
|
||||||
default = self.defaults[self.collection][key]
|
if not k.isidentifier():
|
||||||
except KeyError as e:
|
raise RuntimeError("'{}' is an invalid config key.".format(k))
|
||||||
if self.force_registration:
|
if i == len(splitted):
|
||||||
raise AttributeError("Key '{}' not registered!".format(key)) from e
|
partial[k] = value
|
||||||
default = None
|
else:
|
||||||
|
partial[k] = {}
|
||||||
self.curr_key = key
|
partial = partial[k]
|
||||||
|
|
||||||
if self.collection != "MEMBER":
|
|
||||||
ret = lambda default=default: self.driver_getmap[self.collection](
|
|
||||||
self.cog_name, self.uuid, self.collection_uuid, key,
|
|
||||||
default=default)
|
|
||||||
else:
|
|
||||||
mid, sid = self.collection_uuid
|
|
||||||
ret = lambda default=default: self.driver.get_member(
|
|
||||||
self.cog_name, self.uuid, mid, sid, key,
|
|
||||||
default=default)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def get(self, key, default=None):
|
@staticmethod
|
||||||
|
def _update_defaults(to_add: dict, _partial: dict):
|
||||||
"""
|
"""
|
||||||
Included as an alternative to registering defaults.
|
This tries to update the defaults dictionary with the nested
|
||||||
|
partial dict generated by _get_defaults_dict. This WILL
|
||||||
:param key:
|
throw an error if you try to have both a value and a group
|
||||||
:param default:
|
registered under the same name.
|
||||||
|
:param to_add:
|
||||||
|
:param _partial:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
for k, v in to_add.items():
|
||||||
|
val_is_dict = isinstance(v, dict)
|
||||||
|
if k in _partial:
|
||||||
|
existing_is_dict = isinstance(_partial[k], dict)
|
||||||
|
if val_is_dict != existing_is_dict:
|
||||||
|
# != is XOR
|
||||||
|
raise KeyError("You cannot register a Group and a Value under"
|
||||||
|
" the same name.")
|
||||||
|
if val_is_dict:
|
||||||
|
Config._update_defaults(v, _partial=_partial[k])
|
||||||
|
else:
|
||||||
|
_partial[k] = v
|
||||||
|
else:
|
||||||
|
_partial[k] = v
|
||||||
|
|
||||||
if default is not None:
|
def _register_default(self, key: str, **kwargs):
|
||||||
return self._get_value_from_key(key)(default)
|
if key not in self.defaults:
|
||||||
else:
|
self.defaults[key] = {}
|
||||||
return self._get_value_from_key(key)()
|
|
||||||
|
|
||||||
async def set(self, key, value):
|
data = deepcopy(kwargs)
|
||||||
# Notice to future developers:
|
|
||||||
# This code was commented to allow users to set keys without having to register them.
|
|
||||||
# That being said, if they try to get keys without registering them
|
|
||||||
# things will blow up. I do highly recommend enforcing the key registration.
|
|
||||||
|
|
||||||
if key in self.unsettable_keys or key in self.invalid_keys:
|
for k, v in data.items():
|
||||||
raise KeyError("Restricted key name, please use another.")
|
to_add = self._get_defaults_dict(k, v)
|
||||||
|
self._update_defaults(to_add, self.defaults[key])
|
||||||
|
|
||||||
if self.force_registration and key not in self.defaults[self.collection]:
|
def register_global(self, **kwargs):
|
||||||
raise AttributeError("Key '{}' not registered!".format(key))
|
self._register_default(self.GLOBAL, **kwargs)
|
||||||
|
|
||||||
if not key.isidentifier():
|
def register_guild(self, **kwargs):
|
||||||
raise RuntimeError("Invalid key name, must be a valid python variable"
|
self._register_default(self.GUILD, **kwargs)
|
||||||
" name.")
|
|
||||||
|
|
||||||
if self.collection == "GLOBAL":
|
def register_channel(self, **kwargs):
|
||||||
await self.driver.set_global(self.cog_name, self.uuid, key, value)
|
# We may need to add a voice channel category later
|
||||||
elif self.collection == "MEMBER":
|
self._register_default(self.CHANNEL, **kwargs)
|
||||||
mid, sid = self.collection_uuid
|
|
||||||
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
|
|
||||||
key, value)
|
|
||||||
elif self.collection in self.driver_setmap:
|
|
||||||
func = self.driver_setmap[self.collection]
|
|
||||||
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
|
|
||||||
|
|
||||||
async def clear(self):
|
def register_role(self, **kwargs):
|
||||||
await self.driver_setmap[self.collection](
|
self._register_default(self.ROLE, **kwargs)
|
||||||
self.cog_name, self.uuid, self.collection_uuid, None, None,
|
|
||||||
clear=True)
|
|
||||||
|
|
||||||
def guild(self, guild):
|
def register_user(self, **kwargs):
|
||||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
self._register_default(self.USER, **kwargs)
|
||||||
hash_uuid=False, defaults=self.defaults)
|
|
||||||
new.collection = "GUILD"
|
|
||||||
new.collection_uuid = guild.id
|
|
||||||
new._driver = None
|
|
||||||
return new
|
|
||||||
|
|
||||||
def channel(self, channel):
|
def register_member(self, **kwargs):
|
||||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
self._register_default(self.MEMBER, **kwargs)
|
||||||
hash_uuid=False, defaults=self.defaults)
|
|
||||||
new.collection = "CHANNEL"
|
|
||||||
new.collection_uuid = channel.id
|
|
||||||
new._driver = None
|
|
||||||
return new
|
|
||||||
|
|
||||||
def role(self, role):
|
def _get_base_group(self, key: str, *identifiers: str,
|
||||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
group_class=Group) -> Group:
|
||||||
hash_uuid=False, defaults=self.defaults)
|
# noinspection PyTypeChecker
|
||||||
new.collection = "ROLE"
|
return group_class(
|
||||||
new.collection_uuid = role.id
|
identifiers=(self.unique_identifier, key) + identifiers,
|
||||||
new._driver = None
|
defaults=self.defaults.get(key, {}),
|
||||||
return new
|
spawner=self.spawner,
|
||||||
|
force_registration=self.force_registration
|
||||||
|
)
|
||||||
|
|
||||||
def member(self, member):
|
def guild(self, guild: discord.Guild) -> Group:
|
||||||
guild = member.guild
|
return self._get_base_group(self.GUILD, guild.id)
|
||||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
|
||||||
hash_uuid=False, defaults=self.defaults)
|
def channel(self, channel: discord.TextChannel) -> Group:
|
||||||
new.collection = "MEMBER"
|
return self._get_base_group(self.CHANNEL, channel.id)
|
||||||
new.collection_uuid = (member.id, guild.id)
|
|
||||||
new._driver = None
|
def role(self, role: discord.Role) -> Group:
|
||||||
return new
|
return self._get_base_group(self.ROLE, role.id)
|
||||||
|
|
||||||
|
def user(self, user: discord.User) -> Group:
|
||||||
|
return self._get_base_group(self.USER, user.id)
|
||||||
|
|
||||||
|
def member(self, member: discord.Member) -> MemberGroup:
|
||||||
|
return self._get_base_group(self.MEMBER, member.guild.id, member.id,
|
||||||
|
group_class=MemberGroup)
|
||||||
|
|
||||||
def user(self, user):
|
|
||||||
new = type(self)(self.cog_name, self.uuid, self.driver,
|
|
||||||
hash_uuid=False, defaults=self.defaults)
|
|
||||||
new.collection = "USER"
|
|
||||||
new.collection_uuid = user.id
|
|
||||||
new._driver = None
|
|
||||||
return new
|
|
||||||
|
|||||||
@ -98,7 +98,7 @@ class Core:
|
|||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def adminrole(self, ctx, *, role: discord.Role):
|
async def adminrole(self, ctx, *, role: discord.Role):
|
||||||
"""Sets the admin role for this server"""
|
"""Sets the admin role for this server"""
|
||||||
await ctx.bot.db.guild(ctx.guild).set("admin_role", role.id)
|
await ctx.bot.db.guild(ctx.guild).admin_role.set(role.id)
|
||||||
await ctx.send("The admin role for this server has been set.")
|
await ctx.send("The admin role for this server has been set.")
|
||||||
|
|
||||||
@_set.command()
|
@_set.command()
|
||||||
@ -106,7 +106,7 @@ class Core:
|
|||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def modrole(self, ctx, *, role: discord.Role):
|
async def modrole(self, ctx, *, role: discord.Role):
|
||||||
"""Sets the mod role for this server"""
|
"""Sets the mod role for this server"""
|
||||||
await ctx.bot.db.guild(ctx.guild).set("mod_role", role.id)
|
await ctx.bot.db.guild(ctx.guild).mod_role.set(role.id)
|
||||||
await ctx.send("The mod role for this server has been set.")
|
await ctx.send("The mod role for this server has been set.")
|
||||||
|
|
||||||
@_set.command()
|
@_set.command()
|
||||||
@ -225,7 +225,7 @@ class Core:
|
|||||||
await ctx.bot.send_cmd_help(ctx)
|
await ctx.bot.send_cmd_help(ctx)
|
||||||
return
|
return
|
||||||
prefixes = sorted(prefixes, reverse=True)
|
prefixes = sorted(prefixes, reverse=True)
|
||||||
await ctx.bot.db.set("prefix", prefixes)
|
await ctx.bot.db.prefix.set(prefixes)
|
||||||
await ctx.send("Prefix set.")
|
await ctx.send("Prefix set.")
|
||||||
|
|
||||||
@_set.command(aliases=["serverprefixes"])
|
@_set.command(aliases=["serverprefixes"])
|
||||||
@ -234,11 +234,11 @@ class Core:
|
|||||||
async def serverprefix(self, ctx, *prefixes):
|
async def serverprefix(self, ctx, *prefixes):
|
||||||
"""Sets Red's server prefix(es)"""
|
"""Sets Red's server prefix(es)"""
|
||||||
if not prefixes:
|
if not prefixes:
|
||||||
await ctx.bot.db.guild(ctx.guild).set("prefix", [])
|
await ctx.bot.db.guild(ctx.guild).prefix.set([])
|
||||||
await ctx.send("Server prefixes have been reset.")
|
await ctx.send("Server prefixes have been reset.")
|
||||||
return
|
return
|
||||||
prefixes = sorted(prefixes, reverse=True)
|
prefixes = sorted(prefixes, reverse=True)
|
||||||
await ctx.bot.db.guild(ctx.guild).set("prefix", prefixes)
|
await ctx.bot.db.guild(ctx.guild).prefix.set(prefixes)
|
||||||
await ctx.send("Prefix set.")
|
await ctx.send("Prefix set.")
|
||||||
|
|
||||||
@_set.command()
|
@_set.command()
|
||||||
@ -276,7 +276,7 @@ class Core:
|
|||||||
else:
|
else:
|
||||||
if message.content.strip() == token:
|
if message.content.strip() == token:
|
||||||
self.owner.reset_cooldown(ctx)
|
self.owner.reset_cooldown(ctx)
|
||||||
await ctx.bot.db.set("owner", ctx.author.id)
|
await ctx.bot.db.owner.set(ctx.author.id)
|
||||||
ctx.bot.owner_id = ctx.author.id
|
ctx.bot.owner_id = ctx.author.id
|
||||||
await ctx.send("You have been set as owner.")
|
await ctx.send("You have been set as owner.")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -1,45 +1,12 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
class BaseDriver:
|
class BaseDriver:
|
||||||
def get_global(self, cog_name, ident, collection_id, key, *, default=None):
|
def get_driver(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
def get(self, identifiers: Tuple[str]):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
async def set(self, identifiers: Tuple[str], value):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
|
||||||
default=None):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def get_misc(self, cog_name, ident, *, default=None):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_channel(self, cog_name, ident, channel_id, key, value,
|
|
||||||
clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
|
|
||||||
clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def set_misc(self, cog_name, ident, value, clear=False):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|||||||
@ -1,13 +1,15 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from core.drivers.red_base import BaseDriver
|
||||||
from core.json_io import JsonIO
|
from core.json_io import JsonIO
|
||||||
import os
|
|
||||||
from .red_base import BaseDriver
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
class JSON(BaseDriver):
|
class JSON(BaseDriver):
|
||||||
def __init__(self, cog_name, *args, data_path_override: Path=None,
|
def __init__(self, cog_name, *, data_path_override: Path=None,
|
||||||
file_name_override: str="settings.json", **kwargs):
|
file_name_override: str="settings.json"):
|
||||||
|
super().__init__()
|
||||||
self.cog_name = cog_name
|
self.cog_name = cog_name
|
||||||
self.file_name = file_name_override
|
self.file_name = file_name_override
|
||||||
if data_path_override:
|
if data_path_override:
|
||||||
@ -25,111 +27,23 @@ class JSON(BaseDriver):
|
|||||||
self.data = self.jsonIO._load_json()
|
self.data = self.jsonIO._load_json()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
self.data = {}
|
self.data = {}
|
||||||
|
|
||||||
def maybe_add_ident(self, ident: str):
|
|
||||||
if ident in self.data:
|
|
||||||
return
|
|
||||||
|
|
||||||
self.data[ident] = {}
|
|
||||||
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
|
|
||||||
if k not in self.data[ident]:
|
|
||||||
self.data[ident][k] = {}
|
|
||||||
|
|
||||||
self.jsonIO._save_json(self.data)
|
self.jsonIO._save_json(self.data)
|
||||||
|
|
||||||
def get_global(self, cog_name, ident, _, key, *, default=None):
|
def get_driver(self):
|
||||||
return self.data[ident]["GLOBAL"].get(key, default)
|
return self
|
||||||
|
|
||||||
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
|
def get(self, identifiers: Tuple[str]):
|
||||||
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {})
|
partial = self.data
|
||||||
return guilddata.get(key, default)
|
for i in identifiers:
|
||||||
|
partial = partial[i]
|
||||||
|
return partial
|
||||||
|
|
||||||
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
|
async def set(self, identifiers, value):
|
||||||
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {})
|
partial = self.data
|
||||||
return channeldata.get(key, default)
|
for i in identifiers[:-1]:
|
||||||
|
if i not in partial:
|
||||||
|
partial[i] = {}
|
||||||
|
partial = partial[i]
|
||||||
|
|
||||||
def get_role(self, cog_name, ident, role_id, key, *, default=None):
|
partial[identifiers[-1]] = value
|
||||||
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
|
|
||||||
return roledata.get(key, default)
|
|
||||||
|
|
||||||
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
|
|
||||||
default=None):
|
|
||||||
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
|
|
||||||
guilddata = userdata.get(str(guild_id), {})
|
|
||||||
return guilddata.get(key, default)
|
|
||||||
|
|
||||||
def get_user(self, cog_name, ident, user_id, key, *, default=None):
|
|
||||||
userdata = self.data[ident]["USER"].get(str(user_id), {})
|
|
||||||
return userdata.get(key, default)
|
|
||||||
|
|
||||||
async def set_global(self, cog_name, ident, key, value, clear=False):
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["GLOBAL"] = {}
|
|
||||||
else:
|
|
||||||
self.data[ident]["GLOBAL"][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
|
||||||
|
|
||||||
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
|
|
||||||
guild_id = str(guild_id)
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["GUILD"][guild_id] = {}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.data[ident]["GUILD"][guild_id][key] = value
|
|
||||||
except KeyError:
|
|
||||||
self.data[ident]["GUILD"][guild_id] = {}
|
|
||||||
self.data[ident]["GUILD"][guild_id][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
|
||||||
|
|
||||||
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
|
|
||||||
channel_id = str(channel_id)
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
|
||||||
except KeyError:
|
|
||||||
self.data[ident]["CHANNEL"][channel_id] = {}
|
|
||||||
self.data[ident]["CHANNEL"][channel_id][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
|
||||||
|
|
||||||
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
|
|
||||||
role_id = str(role_id)
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["ROLE"][role_id] = {}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.data[ident]["ROLE"][role_id][key] = value
|
|
||||||
except KeyError:
|
|
||||||
self.data[ident]["ROLE"][role_id] = {}
|
|
||||||
self.data[ident]["ROLE"][role_id][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
|
||||||
|
|
||||||
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
|
|
||||||
user_id = str(user_id)
|
|
||||||
guild_id = str(guild_id)
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["MEMBER"][user_id] = {}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
|
||||||
except KeyError:
|
|
||||||
if user_id not in self.data[ident]["MEMBER"]:
|
|
||||||
self.data[ident]["MEMBER"][user_id] = {}
|
|
||||||
if guild_id not in self.data[ident]["MEMBER"][user_id]:
|
|
||||||
self.data[ident]["MEMBER"][user_id][guild_id] = {}
|
|
||||||
|
|
||||||
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
|
||||||
|
|
||||||
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
|
|
||||||
user_id = str(user_id)
|
|
||||||
if clear:
|
|
||||||
self.data[ident]["USER"][user_id] = {}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
self.data[ident]["USER"][user_id][key] = value
|
|
||||||
except KeyError:
|
|
||||||
self.data[ident]["USER"][user_id] = {}
|
|
||||||
self.data[ident]["USER"][user_id][key] = value
|
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
await self.jsonIO._threadsafe_save_json(self.data)
|
||||||
|
|||||||
2
main.py
2
main.py
@ -113,7 +113,7 @@ if __name__ == '__main__':
|
|||||||
if db_token and not cli_flags.no_prompt:
|
if db_token and not cli_flags.no_prompt:
|
||||||
print("\nDo you want to reset the token? (y/n)")
|
print("\nDo you want to reset the token? (y/n)")
|
||||||
if confirm("> "):
|
if confirm("> "):
|
||||||
loop.run_until_complete(red.db.set("token", ""))
|
loop.run_until_complete(red.db.token.set(""))
|
||||||
print("Token has been reset.")
|
print("Token has been reset.")
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
log.info("Keyboard interrupt detected. Quitting...")
|
log.info("Keyboard interrupt detected. Quitting...")
|
||||||
|
|||||||
@ -2,12 +2,11 @@ from cogs.alias import Alias
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def alias(monkeysession, config):
|
def alias(config):
|
||||||
def get_mock_conf(*args, **kwargs):
|
import cogs.alias.alias
|
||||||
return config
|
|
||||||
|
|
||||||
monkeysession.setattr("core.config.Config.get_conf", get_mock_conf)
|
cogs.alias.alias.Config.get_conf = lambda *args, **kwargs: config
|
||||||
|
|
||||||
return Alias(None)
|
return Alias(None)
|
||||||
|
|
||||||
@ -25,9 +24,17 @@ def test_empty_global_aliases(alias):
|
|||||||
assert list(alias.unloaded_global_aliases()) == []
|
assert list(alias.unloaded_global_aliases()) == []
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_guild_alias(alias, ctx):
|
||||||
|
await alias.add_alias(ctx, "test", "ping", global_=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_test_global_alias(alias, ctx):
|
||||||
|
await alias.add_alias(ctx, "test", "ping", global_=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_guild_alias(alias, ctx):
|
async def test_add_guild_alias(alias, ctx):
|
||||||
await alias.add_alias(ctx, "test", "ping", global_=False)
|
await create_test_guild_alias(alias, ctx)
|
||||||
|
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
@ -36,6 +43,7 @@ async def test_add_guild_alias(alias, ctx):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_guild_alias(alias, ctx):
|
async def test_delete_guild_alias(alias, ctx):
|
||||||
|
await create_test_guild_alias(alias, ctx)
|
||||||
is_alias, _ = alias.is_alias(ctx.guild, "test")
|
is_alias, _ = alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
|
|
||||||
@ -47,7 +55,7 @@ async def test_delete_guild_alias(alias, ctx):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_global_alias(alias, ctx):
|
async def test_add_global_alias(alias, ctx):
|
||||||
await alias.add_alias(ctx, "test", "ping", global_=True)
|
await create_test_global_alias(alias, ctx)
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||||
|
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
@ -56,6 +64,7 @@ async def test_add_global_alias(alias, ctx):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_global_alias(alias, ctx):
|
async def test_delete_global_alias(alias, ctx):
|
||||||
|
await create_test_global_alias(alias, ctx)
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
assert alias_obj.global_ is True
|
assert alias_obj.global_ is True
|
||||||
|
|||||||
@ -17,21 +17,27 @@ def monkeysession(request):
|
|||||||
mpatch.undo()
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def json_driver(tmpdir_factory):
|
def json_driver(tmpdir_factory):
|
||||||
|
import uuid
|
||||||
|
rand = str(uuid.uuid4())
|
||||||
|
path = Path(str(tmpdir_factory.mktemp(rand)))
|
||||||
driver = red_json.JSON(
|
driver = red_json.JSON(
|
||||||
"PyTest",
|
"PyTest",
|
||||||
data_path_override=Path(str(tmpdir_factory.getbasetemp()))
|
data_path_override=path
|
||||||
)
|
)
|
||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def config(json_driver):
|
def config(json_driver):
|
||||||
return Config(
|
import uuid
|
||||||
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=0,
|
unique_identifier=str(uuid.uuid4()),
|
||||||
driver_spawn=json_driver)
|
driver_spawn=json_driver)
|
||||||
|
yield conf
|
||||||
|
conf.defaults = {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -39,19 +45,32 @@ def config_fr(json_driver):
|
|||||||
"""
|
"""
|
||||||
Mocked config object with force_register enabled.
|
Mocked config object with force_register enabled.
|
||||||
"""
|
"""
|
||||||
return Config(
|
import uuid
|
||||||
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=0,
|
unique_identifier=str(uuid.uuid4()),
|
||||||
driver_spawn=json_driver,
|
driver_spawn=json_driver,
|
||||||
force_registration=True
|
force_registration=True
|
||||||
)
|
)
|
||||||
|
yield conf
|
||||||
|
conf.defaults = {}
|
||||||
|
|
||||||
|
|
||||||
#region Dpy Mocks
|
#region Dpy Mocks
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def empty_guild():
|
def guild_factory():
|
||||||
mock_guild = namedtuple("Guild", "id members")
|
mock_guild = namedtuple("Guild", "id members")
|
||||||
return mock_guild(random.randint(1, 999999999), [])
|
|
||||||
|
class GuildFactory:
|
||||||
|
def get(self):
|
||||||
|
return mock_guild(random.randint(1, 999999999), [])
|
||||||
|
|
||||||
|
return GuildFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def empty_guild(guild_factory):
|
||||||
|
return guild_factory.get()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -66,16 +85,39 @@ def empty_role():
|
|||||||
return mock_role(random.randint(1, 999999999))
|
return mock_role(random.randint(1, 999999999))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def empty_member(empty_guild):
|
def member_factory(guild_factory):
|
||||||
mock_member = namedtuple("Member", "id guild")
|
mock_member = namedtuple("Member", "id guild")
|
||||||
return mock_member(random.randint(1, 999999999), empty_guild)
|
|
||||||
|
class MemberFactory:
|
||||||
|
def get(self):
|
||||||
|
return mock_member(
|
||||||
|
random.randint(1, 999999999),
|
||||||
|
guild_factory.get())
|
||||||
|
|
||||||
|
return MemberFactory()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture()
|
||||||
def empty_user():
|
def empty_member(member_factory):
|
||||||
|
return member_factory.get()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def user_factory():
|
||||||
mock_user = namedtuple("User", "id")
|
mock_user = namedtuple("User", "id")
|
||||||
return mock_user(random.randint(1, 999999999))
|
|
||||||
|
class UserFactory:
|
||||||
|
def get(self):
|
||||||
|
return mock_user(
|
||||||
|
random.randint(1, 999999999))
|
||||||
|
|
||||||
|
return UserFactory()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def empty_user(user_factory):
|
||||||
|
return user_factory.get()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -84,7 +126,7 @@ def empty_message():
|
|||||||
return mock_msg("No content.")
|
return mock_msg("No content.")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def ctx(empty_member, empty_channel, red):
|
def ctx(empty_member, empty_channel, red):
|
||||||
mock_ctx = namedtuple("Context", "author guild channel message bot")
|
mock_ctx = namedtuple("Context", "author guild channel message bot")
|
||||||
return mock_ctx(empty_member, empty_member.guild, empty_channel,
|
return mock_ctx(empty_member, empty_member.guild, empty_channel,
|
||||||
@ -93,15 +135,14 @@ def ctx(empty_member, empty_channel, red):
|
|||||||
|
|
||||||
|
|
||||||
#region Red Mock
|
#region Red Mock
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def red(monkeysession, config_fr):
|
def red(config_fr):
|
||||||
from core.cli import parse_cli_flags
|
from core.cli import parse_cli_flags
|
||||||
cli_flags = parse_cli_flags()
|
cli_flags = parse_cli_flags()
|
||||||
|
|
||||||
description = "Red v3 - Alpha"
|
description = "Red v3 - Alpha"
|
||||||
|
|
||||||
monkeysession.setattr("core.config.Config.get_core_conf",
|
Config.get_core_conf = (lambda *args, **kwargs: config_fr)
|
||||||
lambda *args, **kwargs: config_fr)
|
|
||||||
|
|
||||||
red = Red(cli_flags, description=description, pm_help=None)
|
red = Red(cli_flags, description=description, pm_help=None)
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,9 @@ def test_config_register_global_badvalues(config):
|
|||||||
|
|
||||||
def test_config_register_guild(config, empty_guild):
|
def test_config_register_guild(config, empty_guild):
|
||||||
config.register_guild(enabled=False, some_list=[], some_dict={})
|
config.register_guild(enabled=False, some_list=[], some_dict={})
|
||||||
assert config.defaults["GUILD"]["enabled"] is False
|
assert config.defaults[config.GUILD]["enabled"] is False
|
||||||
assert config.defaults["GUILD"]["some_list"] == []
|
assert config.defaults[config.GUILD]["some_list"] == []
|
||||||
assert config.defaults["GUILD"]["some_dict"] == {}
|
assert config.defaults[config.GUILD]["some_dict"] == {}
|
||||||
|
|
||||||
assert config.guild(empty_guild).enabled() is False
|
assert config.guild(empty_guild).enabled() is False
|
||||||
assert config.guild(empty_guild).some_list() == []
|
assert config.guild(empty_guild).some_list() == []
|
||||||
@ -26,25 +26,25 @@ def test_config_register_guild(config, empty_guild):
|
|||||||
|
|
||||||
def test_config_register_channel(config, empty_channel):
|
def test_config_register_channel(config, empty_channel):
|
||||||
config.register_channel(enabled=False)
|
config.register_channel(enabled=False)
|
||||||
assert config.defaults["CHANNEL"]["enabled"] is False
|
assert config.defaults[config.CHANNEL]["enabled"] is False
|
||||||
assert config.channel(empty_channel).enabled() is False
|
assert config.channel(empty_channel).enabled() is False
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_role(config, empty_role):
|
def test_config_register_role(config, empty_role):
|
||||||
config.register_role(enabled=False)
|
config.register_role(enabled=False)
|
||||||
assert config.defaults["ROLE"]["enabled"] is False
|
assert config.defaults[config.ROLE]["enabled"] is False
|
||||||
assert config.role(empty_role).enabled() is False
|
assert config.role(empty_role).enabled() is False
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_member(config, empty_member):
|
def test_config_register_member(config, empty_member):
|
||||||
config.register_member(some_number=-1)
|
config.register_member(some_number=-1)
|
||||||
assert config.defaults["MEMBER"]["some_number"] == -1
|
assert config.defaults[config.MEMBER]["some_number"] == -1
|
||||||
assert config.member(empty_member).some_number() == -1
|
assert config.member(empty_member).some_number() == -1
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_user(config, empty_user):
|
def test_config_register_user(config, empty_user):
|
||||||
config.register_user(some_value=None)
|
config.register_user(some_value=None)
|
||||||
assert config.defaults["USER"]["some_value"] is None
|
assert config.defaults[config.USER]["some_value"] is None
|
||||||
assert config.user(empty_user).some_value() is None
|
assert config.user(empty_user).some_value() is None
|
||||||
|
|
||||||
|
|
||||||
@ -57,106 +57,233 @@ def test_config_force_register_global(config_fr):
|
|||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
|
# Test nested registration
|
||||||
|
def test_nested_registration(config):
|
||||||
|
config.register_global(foo__bar__baz=False)
|
||||||
|
assert config.foo.bar.baz() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_registration_asdict(config):
|
||||||
|
defaults = {'bar': {'baz': False}}
|
||||||
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
|
assert config.foo.bar.baz() is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_registration_and_changing(config):
|
||||||
|
defaults = {'bar': {'baz': False}}
|
||||||
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
|
assert config.foo.bar.baz() is False
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await config.foo.set(True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_doubleset_default(config):
|
||||||
|
config.register_global(foo=True)
|
||||||
|
config.register_global(foo=False)
|
||||||
|
|
||||||
|
assert config.foo() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_registration_multidict(config):
|
||||||
|
defaults = {
|
||||||
|
"foo": {
|
||||||
|
"bar": {
|
||||||
|
"baz": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"blah": True
|
||||||
|
}
|
||||||
|
config.register_global(**defaults)
|
||||||
|
|
||||||
|
assert config.foo.bar.baz() is True
|
||||||
|
assert config.blah() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_group_value_badreg(config):
|
||||||
|
config.register_global(foo=True)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
config.register_global(foo__bar=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_toplevel_reg(config):
|
||||||
|
defaults = {'bar': True, 'baz': False}
|
||||||
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
|
assert config.foo.bar() is True
|
||||||
|
assert config.foo.baz() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_overlapping(config):
|
||||||
|
config.register_global(foo__bar=True)
|
||||||
|
config.register_global(foo__baz=False)
|
||||||
|
|
||||||
|
assert config.foo.bar() is True
|
||||||
|
assert config.foo.baz() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_nesting_nofr(config):
|
||||||
|
config.register_global(foo={})
|
||||||
|
assert config.foo.bar() is None
|
||||||
|
assert config.foo() == {}
|
||||||
|
|
||||||
|
|
||||||
#region Default Value Overrides
|
#region Default Value Overrides
|
||||||
def test_global_default_override(config):
|
def test_global_default_override(config):
|
||||||
assert config.enabled(True) is True
|
assert config.enabled(True) is True
|
||||||
assert config.get("enabled") is None
|
|
||||||
assert config.get("enabled", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_global_default_nofr(config):
|
def test_global_default_nofr(config):
|
||||||
assert config.nofr() is None
|
assert config.nofr() is None
|
||||||
assert config.nofr(True) is True
|
assert config.nofr(True) is True
|
||||||
assert config.get("nofr") is None
|
|
||||||
assert config.get("nofr", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_guild_default_override(config, empty_guild):
|
def test_guild_default_override(config, empty_guild):
|
||||||
assert config.guild(empty_guild).enabled(True) is True
|
assert config.guild(empty_guild).enabled(True) is True
|
||||||
assert config.guild(empty_guild).get("enabled") is None
|
|
||||||
assert config.guild(empty_guild).get("enabled", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_channel_default_override(config, empty_channel):
|
def test_channel_default_override(config, empty_channel):
|
||||||
assert config.channel(empty_channel).enabled(True) is True
|
assert config.channel(empty_channel).enabled(True) is True
|
||||||
assert config.channel(empty_channel).get("enabled") is None
|
|
||||||
assert config.channel(empty_channel).get("enabled", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_role_default_override(config, empty_role):
|
def test_role_default_override(config, empty_role):
|
||||||
assert config.role(empty_role).enabled(True) is True
|
assert config.role(empty_role).enabled(True) is True
|
||||||
assert config.role(empty_role).get("enabled") is None
|
|
||||||
assert config.role(empty_role).get("enabled", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_member_default_override(config, empty_member):
|
def test_member_default_override(config, empty_member):
|
||||||
assert config.member(empty_member).enabled(True) is True
|
assert config.member(empty_member).enabled(True) is True
|
||||||
assert config.member(empty_member).get("enabled") is None
|
|
||||||
assert config.member(empty_member).get("enabled", default=True) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_default_override(config, empty_user):
|
def test_user_default_override(config, empty_user):
|
||||||
assert config.user(empty_user).some_value(True) is True
|
assert config.user(empty_user).some_value(True) is True
|
||||||
assert config.user(empty_user).get("some_value") is None
|
|
||||||
assert config.user(empty_user).get("some_value", default=True) is True
|
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
#region Setting Values
|
#region Setting Values
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_global(config):
|
async def test_set_global(config):
|
||||||
await config.set("enabled", True)
|
await config.enabled.set(True)
|
||||||
assert config.enabled() is True
|
assert config.enabled() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_set_global_badkey(config):
|
|
||||||
with pytest.raises(RuntimeError):
|
|
||||||
await config.set("this is a bad key", True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_set_global_invalidkey(config):
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await config.set("uuid", True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_guild(config, empty_guild):
|
async def test_set_guild(config, empty_guild):
|
||||||
await config.guild(empty_guild).set("enabled", True)
|
await config.guild(empty_guild).enabled.set(True)
|
||||||
assert config.guild(empty_guild).enabled() is True
|
assert config.guild(empty_guild).enabled() is True
|
||||||
|
|
||||||
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
|
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
|
||||||
assert curr_list == [1, 2, 3]
|
assert curr_list == [1, 2, 3]
|
||||||
curr_list.append(4)
|
curr_list.append(4)
|
||||||
|
|
||||||
await config.guild(empty_guild).set("some_list", curr_list)
|
await config.guild(empty_guild).some_list.set(curr_list)
|
||||||
assert config.guild(empty_guild).some_list() == curr_list
|
assert config.guild(empty_guild).some_list() == curr_list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_channel(config, empty_channel):
|
async def test_set_channel(config, empty_channel):
|
||||||
await config.channel(empty_channel).set("enabled", True)
|
await config.channel(empty_channel).enabled.set(True)
|
||||||
assert config.channel(empty_channel).enabled() is True
|
assert config.channel(empty_channel).enabled() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_channel_no_register(config, empty_channel):
|
async def test_set_channel_no_register(config, empty_channel):
|
||||||
await config.channel(empty_channel).set("no_register", True)
|
await config.channel(empty_channel).no_register.set(True)
|
||||||
assert config.channel(empty_channel).no_register() is True
|
assert config.channel(empty_channel).no_register() is True
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
# region Getting Values
|
# Dynamic attribute testing
|
||||||
def test_get_func_w_reg(config):
|
@pytest.mark.asyncio
|
||||||
config.register_global(
|
async def test_set_dynamic_attr(config):
|
||||||
thing=True
|
await config.set_attr("foobar", True)
|
||||||
)
|
|
||||||
assert config.get("thing") is True
|
assert config.foobar() is True
|
||||||
assert config.get("thing", False) is False
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_func_wo_reg(config):
|
def test_get_dynamic_attr(config):
|
||||||
assert config.get("thing") is None
|
assert config.get_attr("foobaz", True) is True
|
||||||
assert config.get("thing", True) is True
|
|
||||||
# endregion
|
|
||||||
|
# Member Group testing
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_membergroup_allguilds(config, empty_member):
|
||||||
|
await config.member(empty_member).foo.set(False)
|
||||||
|
|
||||||
|
all_servers = config.member(empty_member).all_guilds()
|
||||||
|
assert str(empty_member.guild.id) in all_servers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_membergroup_allmembers(config, empty_member):
|
||||||
|
await config.member(empty_member).foo.set(False)
|
||||||
|
|
||||||
|
all_members = config.member(empty_member).all()
|
||||||
|
assert str(empty_member.id) in all_members
|
||||||
|
|
||||||
|
|
||||||
|
# Clearing testing
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_global_clear(config):
|
||||||
|
config.register_global(foo=True, bar=False)
|
||||||
|
|
||||||
|
await config.foo.set(False)
|
||||||
|
await config.bar.set(True)
|
||||||
|
|
||||||
|
assert config.foo() is False
|
||||||
|
assert config.bar() is True
|
||||||
|
|
||||||
|
await config.clear()
|
||||||
|
|
||||||
|
assert config.foo() is True
|
||||||
|
assert config.bar() is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_member_clear(config, member_factory):
|
||||||
|
config.register_member(foo=True)
|
||||||
|
|
||||||
|
m1 = member_factory.get()
|
||||||
|
await config.member(m1).foo.set(False)
|
||||||
|
assert config.member(m1).foo() is False
|
||||||
|
|
||||||
|
m2 = member_factory.get()
|
||||||
|
await config.member(m2).foo.set(False)
|
||||||
|
assert config.member(m2).foo() is False
|
||||||
|
|
||||||
|
assert m1.guild.id != m2.guild.id
|
||||||
|
|
||||||
|
await config.member(m1).clear()
|
||||||
|
assert config.member(m1).foo() is True
|
||||||
|
assert config.member(m2).foo() is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_member_clear_all(config, member_factory):
|
||||||
|
server_ids = []
|
||||||
|
for _ in range(5):
|
||||||
|
member = member_factory.get()
|
||||||
|
await config.member(member).foo.set(True)
|
||||||
|
server_ids.append(member.guild.id)
|
||||||
|
|
||||||
|
member = member_factory.get()
|
||||||
|
assert len(config.member(member).all_guilds()) == len(server_ids)
|
||||||
|
|
||||||
|
await config.member(member).clear_all()
|
||||||
|
|
||||||
|
assert len(config.member(member).all_guilds()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Get All testing
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_get_all(config, user_factory):
|
||||||
|
for _ in range(5):
|
||||||
|
user = user_factory.get()
|
||||||
|
await config.user(user).foo.set(True)
|
||||||
|
|
||||||
|
user = user_factory.get()
|
||||||
|
all_data = config.user(user).all()
|
||||||
|
|
||||||
|
assert len(all_data) == 5
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user