mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[V3 Config] Adjust functionality of get_attr (#1342)
* Intermediate commit * Add defaulting stuff to config * Remove set_attr in favor of set_raw * Modify get_attr * Fix issue with clearing data
This commit is contained in:
parent
c428982c00
commit
64af7800dc
@ -70,7 +70,7 @@ class CommandObj:
|
|||||||
async def get(self,
|
async def get(self,
|
||||||
message: discord.Message,
|
message: discord.Message,
|
||||||
command: str) -> str:
|
command: str) -> str:
|
||||||
ccinfo = await self.db(message.guild).commands.get_attr(command)
|
ccinfo = await self.db(message.guild).commands.get_raw(command, default=None)
|
||||||
if not ccinfo:
|
if not ccinfo:
|
||||||
raise NotFound
|
raise NotFound
|
||||||
else:
|
else:
|
||||||
@ -82,7 +82,7 @@ class CommandObj:
|
|||||||
response):
|
response):
|
||||||
"""Create a customcommand"""
|
"""Create a customcommand"""
|
||||||
# Check if this command is already registered as a customcommand
|
# Check if this command is already registered as a customcommand
|
||||||
if await self.db(ctx.guild).commands.get_attr(command):
|
if await self.db(ctx.guild).commands.get_raw(command, default=None):
|
||||||
raise AlreadyExists()
|
raise AlreadyExists()
|
||||||
author = ctx.message.author
|
author = ctx.message.author
|
||||||
ccinfo = {
|
ccinfo = {
|
||||||
@ -96,8 +96,8 @@ class CommandObj:
|
|||||||
'response': response
|
'response': response
|
||||||
|
|
||||||
}
|
}
|
||||||
await self.db(ctx.guild).commands.set_attr(command,
|
await self.db(ctx.guild).commands.set_raw(
|
||||||
ccinfo)
|
command, value=ccinfo)
|
||||||
|
|
||||||
async def edit(self,
|
async def edit(self,
|
||||||
ctx: commands.Context,
|
ctx: commands.Context,
|
||||||
@ -105,11 +105,11 @@ class CommandObj:
|
|||||||
response: None):
|
response: None):
|
||||||
"""Edit an already existing custom command"""
|
"""Edit an already existing custom command"""
|
||||||
# Check if this command is registered
|
# Check if this command is registered
|
||||||
if not await self.db(ctx.guild).commands.get_attr(command):
|
if not await self.db(ctx.guild).commands.get_raw(command, default=None):
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
|
|
||||||
author = ctx.message.author
|
author = ctx.message.author
|
||||||
ccinfo = await self.db(ctx.guild).commands.get_attr(command)
|
ccinfo = await self.db(ctx.guild).commands.get_raw(command, default=None)
|
||||||
|
|
||||||
def check(m):
|
def check(m):
|
||||||
return m.channel == ctx.channel and m.author == ctx.message.author
|
return m.channel == ctx.channel and m.author == ctx.message.author
|
||||||
@ -138,18 +138,18 @@ class CommandObj:
|
|||||||
author.id
|
author.id
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.db(ctx.guild).commands.set_attr(command,
|
await self.db(ctx.guild).commands.set_raw(
|
||||||
ccinfo)
|
command, value=ccinfo)
|
||||||
|
|
||||||
async def delete(self,
|
async def delete(self,
|
||||||
ctx: commands.Context,
|
ctx: commands.Context,
|
||||||
command: str):
|
command: str):
|
||||||
"""Delete an already exisiting custom command"""
|
"""Delete an already exisiting custom command"""
|
||||||
# Check if this command is registered
|
# Check if this command is registered
|
||||||
if not await self.db(ctx.guild).commands.get_attr(command):
|
if not await self.db(ctx.guild).commands.get_raw(command, default=None):
|
||||||
raise NotFound()
|
raise NotFound()
|
||||||
await self.db(ctx.guild).commands.set_attr(command,
|
await self.db(ctx.guild).commands.set_raw(
|
||||||
None)
|
command, value=None)
|
||||||
|
|
||||||
|
|
||||||
class CustomCommands:
|
class CustomCommands:
|
||||||
@ -326,7 +326,7 @@ class CustomCommands:
|
|||||||
return
|
return
|
||||||
|
|
||||||
guild = message.guild
|
guild = message.guild
|
||||||
prefixes = await self.bot.db.guild(message.guild).get_attr('prefix')
|
prefixes = await self.bot.db.guild(guild).get_raw('prefix', default=[])
|
||||||
|
|
||||||
if len(prefixes) < 1:
|
if len(prefixes) < 1:
|
||||||
def_prefixes = await self.bot.get_prefix(message)
|
def_prefixes = await self.bot.get_prefix(message)
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class Streams:
|
|||||||
@commands.command()
|
@commands.command()
|
||||||
async def twitch(self, ctx, channel_name: str):
|
async def twitch(self, ctx, channel_name: str):
|
||||||
"""Checks if a Twitch channel is streaming"""
|
"""Checks if a Twitch channel is streaming"""
|
||||||
token = await self.db.tokens.get_attr(TwitchStream.__name__)
|
token = await self.db.tokens.get_raw(TwitchStream.__name__, default=None)
|
||||||
stream = TwitchStream(name=channel_name,
|
stream = TwitchStream(name=channel_name,
|
||||||
token=token)
|
token=token)
|
||||||
await self.check_online(ctx, stream)
|
await self.check_online(ctx, stream)
|
||||||
@ -187,7 +187,7 @@ class Streams:
|
|||||||
async def stream_alert(self, ctx, _class, channel_name):
|
async def stream_alert(self, ctx, _class, channel_name):
|
||||||
stream = self.get_stream(_class, channel_name.lower())
|
stream = self.get_stream(_class, channel_name.lower())
|
||||||
if not stream:
|
if not stream:
|
||||||
token = await self.db.tokens.get_attr(_class.__name__)
|
token = await self.db.tokens.get_raw(_class.__name__, default=None)
|
||||||
stream = _class(name=channel_name,
|
stream = _class(name=channel_name,
|
||||||
token=token)
|
token=token)
|
||||||
try:
|
try:
|
||||||
@ -210,7 +210,7 @@ class Streams:
|
|||||||
async def community_alert(self, ctx, _class, community_name):
|
async def community_alert(self, ctx, _class, community_name):
|
||||||
community = self.get_community(_class, community_name)
|
community = self.get_community(_class, community_name)
|
||||||
if not community:
|
if not community:
|
||||||
token = await self.db.tokens.get_attr(_class.__name__)
|
token = await self.db.tokens.get_raw(_class.__name__, default=None)
|
||||||
community = _class(name=community_name, token=token)
|
community = _class(name=community_name, token=token)
|
||||||
try:
|
try:
|
||||||
await community.get_community_streams()
|
await community.get_community_streams()
|
||||||
@ -477,7 +477,7 @@ class Streams:
|
|||||||
if not _class:
|
if not _class:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
token = await self.db.tokens.get_attr(_class.__name__)
|
token = await self.db.tokens.get_raw(_class.__name__)
|
||||||
streams.append(_class(token=token, **raw_stream))
|
streams.append(_class(token=token, **raw_stream))
|
||||||
|
|
||||||
# issue 1191 extended resolution: Remove this after suitable period
|
# issue 1191 extended resolution: Remove this after suitable period
|
||||||
@ -497,7 +497,7 @@ class Streams:
|
|||||||
if not _class:
|
if not _class:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
token = await self.db.tokens.get_attr(_class.__name__)
|
token = await self.db.tokens.get_raw(_class.__name__, default=None)
|
||||||
communities.append(_class(token=token, **raw_community))
|
communities.append(_class(token=token, **raw_community))
|
||||||
|
|
||||||
# issue 1191 extended resolution: Remove this after suitable period
|
# issue 1191 extended resolution: Remove this after suitable period
|
||||||
|
|||||||
@ -267,17 +267,13 @@ class Group(Value):
|
|||||||
|
|
||||||
return not isinstance(default, dict)
|
return not isinstance(default, dict)
|
||||||
|
|
||||||
def get_attr(self, item: str, default=None, resolve=True):
|
def get_attr(self, item: str):
|
||||||
"""Manually get an attribute of this Group.
|
"""Manually get an attribute of this Group.
|
||||||
|
|
||||||
This is available to use as an alternative to using normal Python
|
This is available to use as an alternative to using normal Python
|
||||||
attribute access. It is required if you find a need for dynamic
|
attribute access. It may be required if you find a need for dynamic
|
||||||
attribute access.
|
attribute access.
|
||||||
|
|
||||||
Note
|
|
||||||
----
|
|
||||||
Use of this method should be avoided wherever possible.
|
|
||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
A possible use case::
|
A possible use case::
|
||||||
@ -287,32 +283,20 @@ class Group(Value):
|
|||||||
user = ctx.author
|
user = ctx.author
|
||||||
|
|
||||||
# Where the value of item is the name of the data field in Config
|
# Where the value of item is the name of the data field in Config
|
||||||
await ctx.send(await self.conf.user(user).get_attr(item))
|
await ctx.send(await self.conf.user(user).get_attr(item).foo())
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
item : str
|
item : str
|
||||||
The name of the data field in `Config`.
|
The name of the data field in `Config`.
|
||||||
default
|
|
||||||
This is an optional override to the registered default for this
|
|
||||||
item.
|
|
||||||
resolve : bool
|
|
||||||
If this is :code:`True` this function will return a coroutine that
|
|
||||||
resolves to a "real" data value when awaited. If :code:`False`,
|
|
||||||
this method acts the same as `__getattr__`.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
`types.coroutine` or `Value` or `Group`
|
`Value` or `Group`
|
||||||
The attribute which was requested, its type depending on the value
|
The attribute which was requested.
|
||||||
of :code:`resolve`.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
value = getattr(self, item)
|
return self.__getattr__(item)
|
||||||
if resolve:
|
|
||||||
return value(default=default)
|
|
||||||
else:
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def get_raw(self, *nested_path: str, default=...):
|
async def get_raw(self, *nested_path: str, default=...):
|
||||||
"""
|
"""
|
||||||
@ -350,6 +334,16 @@ class Group(Value):
|
|||||||
"""
|
"""
|
||||||
path = [str(p) for p in nested_path]
|
path = [str(p) for p in nested_path]
|
||||||
|
|
||||||
|
if default is ...:
|
||||||
|
poss_default = self.defaults
|
||||||
|
for ident in path:
|
||||||
|
try:
|
||||||
|
poss_default = poss_default[ident]
|
||||||
|
except KeyError:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
default = poss_default
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return deepcopy(await self.driver.get(*self.identifiers, *path))
|
return deepcopy(await self.driver.get(*self.identifiers, *path))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -398,27 +392,6 @@ class Group(Value):
|
|||||||
)
|
)
|
||||||
await super().set(value)
|
await super().set(value)
|
||||||
|
|
||||||
async def set_attr(self, item: str, value):
|
|
||||||
"""Set an attribute by its name.
|
|
||||||
|
|
||||||
Similar to `get_attr` in the way it can be used to dynamically set
|
|
||||||
attributes by name.
|
|
||||||
|
|
||||||
Note
|
|
||||||
----
|
|
||||||
Use of this method should be avoided wherever possible.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
item : str
|
|
||||||
The name of the attribute being set.
|
|
||||||
value
|
|
||||||
The raw data value to set the attribute as.
|
|
||||||
|
|
||||||
"""
|
|
||||||
value_obj = getattr(self, item)
|
|
||||||
await value_obj.set(value)
|
|
||||||
|
|
||||||
async def set_raw(self, *nested_path: str, value):
|
async def set_raw(self, *nested_path: str, value):
|
||||||
"""
|
"""
|
||||||
Allows a developer to set data as if it was stored in a standard
|
Allows a developer to set data as if it was stored in a standard
|
||||||
|
|||||||
@ -77,8 +77,8 @@ class Case:
|
|||||||
case_emb = await self.message_content()
|
case_emb = await self.message_content()
|
||||||
await self.message.edit(embed=case_emb)
|
await self.message.edit(embed=case_emb)
|
||||||
|
|
||||||
await _conf.guild(self.guild).cases.set_attr(
|
await _conf.guild(self.guild).cases.set_raw(
|
||||||
str(self.case_number), self.to_json()
|
str(self.case_number), value=self.to_json()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def message_content(self):
|
async def message_content(self):
|
||||||
@ -245,7 +245,7 @@ class CaseType:
|
|||||||
"case_str": self.case_str,
|
"case_str": self.case_str,
|
||||||
"audit_type": self.audit_type
|
"audit_type": self.audit_type
|
||||||
}
|
}
|
||||||
await _conf.casetypes.set_attr(self.name, data)
|
await _conf.casetypes.set_raw(self.name, value=data)
|
||||||
|
|
||||||
async def is_enabled(self) -> bool:
|
async def is_enabled(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -262,8 +262,8 @@ class CaseType:
|
|||||||
"""
|
"""
|
||||||
if not self.guild:
|
if not self.guild:
|
||||||
return False
|
return False
|
||||||
return await _conf.guild(self.guild).casetypes.get_attr(self.name,
|
return await _conf.guild(self.guild).casetypes.get_raw(
|
||||||
self.default_setting)
|
self.name, default=self.default_setting)
|
||||||
|
|
||||||
async def set_enabled(self, enabled: bool):
|
async def set_enabled(self, enabled: bool):
|
||||||
"""
|
"""
|
||||||
@ -275,7 +275,7 @@ class CaseType:
|
|||||||
True if the case should be enabled, otherwise False"""
|
True if the case should be enabled, otherwise False"""
|
||||||
if not self.guild:
|
if not self.guild:
|
||||||
return
|
return
|
||||||
await _conf.guild(self.guild).casetypes.set_attr(self.name, enabled)
|
await _conf.guild(self.guild).casetypes.set_raw(self.name, value=enabled)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_json(cls, data: dict):
|
def from_json(cls, data: dict):
|
||||||
@ -310,7 +310,7 @@ async def get_next_case_number(guild: discord.Guild) -> str:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
cases = sorted(
|
cases = sorted(
|
||||||
(await _conf.guild(guild).get_attr("cases")),
|
(await _conf.guild(guild).get_raw("cases")),
|
||||||
key=lambda x: int(x),
|
key=lambda x: int(x),
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
@ -342,11 +342,12 @@ async def get_case(case_number: int, guild: discord.Guild,
|
|||||||
If there is no case for the specified number
|
If there is no case for the specified number
|
||||||
|
|
||||||
"""
|
"""
|
||||||
case = await _conf.guild(guild).cases.get_attr(str(case_number))
|
try:
|
||||||
if case is None:
|
case = await _conf.guild(guild).cases.get_raw(str(case_number))
|
||||||
|
except KeyError as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"That case does not exist for guild {}".format(guild.name)
|
"That case does not exist for guild {}".format(guild.name)
|
||||||
)
|
) from e
|
||||||
mod_channel = await get_modlog_channel(guild)
|
mod_channel = await get_modlog_channel(guild)
|
||||||
return await Case.from_json(mod_channel, bot, case)
|
return await Case.from_json(mod_channel, bot, case)
|
||||||
|
|
||||||
@ -368,7 +369,7 @@ async def get_all_cases(guild: discord.Guild, bot: Red) -> List[Case]:
|
|||||||
A list of all cases for the guild
|
A list of all cases for the guild
|
||||||
|
|
||||||
"""
|
"""
|
||||||
cases = await _conf.guild(guild).get_attr("cases")
|
cases = await _conf.guild(guild).get_raw("cases")
|
||||||
case_numbers = list(cases.keys())
|
case_numbers = list(cases.keys())
|
||||||
case_list = []
|
case_list = []
|
||||||
for case in case_numbers:
|
for case in case_numbers:
|
||||||
@ -440,7 +441,7 @@ async def create_case(guild: discord.Guild, created_at: datetime, action_type: s
|
|||||||
case_emb = await case.message_content()
|
case_emb = await case.message_content()
|
||||||
msg = await mod_channel.send(embed=case_emb)
|
msg = await mod_channel.send(embed=case_emb)
|
||||||
case.message = msg
|
case.message = msg
|
||||||
await _conf.guild(guild).cases.set_attr(str(next_case_number), case.to_json())
|
await _conf.guild(guild).cases.set_raw(str(next_case_number), value=case.to_json())
|
||||||
return case
|
return case
|
||||||
|
|
||||||
|
|
||||||
@ -459,7 +460,7 @@ async def get_casetype(name: str, guild: discord.Guild=None) -> Union[CaseType,
|
|||||||
-------
|
-------
|
||||||
CaseType or None
|
CaseType or None
|
||||||
"""
|
"""
|
||||||
casetypes = await _conf.get_attr("casetypes")
|
casetypes = await _conf.get_raw("casetypes")
|
||||||
if name in casetypes:
|
if name in casetypes:
|
||||||
data = casetypes[name]
|
data = casetypes[name]
|
||||||
data["name"] = name
|
data["name"] = name
|
||||||
@ -480,7 +481,7 @@ async def get_all_casetypes(guild: discord.Guild=None) -> List[CaseType]:
|
|||||||
A list of case types
|
A list of case types
|
||||||
|
|
||||||
"""
|
"""
|
||||||
casetypes = await _conf.get_attr("casetypes")
|
casetypes = await _conf.get_raw("casetypes", default={})
|
||||||
typelist = []
|
typelist = []
|
||||||
for ct in casetypes.keys():
|
for ct in casetypes.keys():
|
||||||
data = casetypes[ct]
|
data = casetypes[ct]
|
||||||
|
|||||||
@ -219,14 +219,14 @@ async def test_set_channel_no_register(config, empty_channel):
|
|||||||
# Dynamic attribute testing
|
# Dynamic attribute testing
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_dynamic_attr(config):
|
async def test_set_dynamic_attr(config):
|
||||||
await config.set_attr("foobar", True)
|
await config.set_raw("foobar", value=True)
|
||||||
|
|
||||||
assert await config.foobar() is True
|
assert await config.foobar() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_dynamic_attr(config):
|
async def test_get_dynamic_attr(config):
|
||||||
assert await config.get_attr("foobaz", True) is True
|
assert await config.get_raw("foobaz", default=True) is True
|
||||||
|
|
||||||
|
|
||||||
# Member Group testing
|
# Member Group testing
|
||||||
@ -299,13 +299,12 @@ async def test_member_clear_all(config, member_factory):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_clear_value(config_fr):
|
async def test_clear_value(config):
|
||||||
config_fr.register_global(foo=False)
|
await config.foo.set(True)
|
||||||
await config_fr.foo.set(True)
|
await config.foo.clear()
|
||||||
await config_fr.foo.clear()
|
|
||||||
|
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
await config_fr.get_raw('foo')
|
await config.get_raw('foo')
|
||||||
|
|
||||||
|
|
||||||
# Get All testing
|
# Get All testing
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user