[V3 Config] Adjust functionality of get_attr (#1342)

* Intermediate commit

* Add defaulting stuff to config

* Remove set_attr in favor of set_raw

* Modify get_attr

* Fix issue with clearing data
This commit is contained in:
Will 2018-02-26 15:13:01 -05:00 committed by palmtree5
parent c428982c00
commit 64af7800dc
5 changed files with 54 additions and 81 deletions

View File

@ -70,7 +70,7 @@ class CommandObj:
async def get(self, async def get(self,
message: discord.Message, message: discord.Message,
command: str) -> str: command: str) -> str:
ccinfo = await self.db(message.guild).commands.get_attr(command) ccinfo = await self.db(message.guild).commands.get_raw(command, default=None)
if not ccinfo: if not ccinfo:
raise NotFound raise NotFound
else: else:
@ -82,7 +82,7 @@ class CommandObj:
response): response):
"""Create a customcommand""" """Create a customcommand"""
# Check if this command is already registered as a customcommand # Check if this command is already registered as a customcommand
if await self.db(ctx.guild).commands.get_attr(command): if await self.db(ctx.guild).commands.get_raw(command, default=None):
raise AlreadyExists() raise AlreadyExists()
author = ctx.message.author author = ctx.message.author
ccinfo = { ccinfo = {
@ -96,8 +96,8 @@ class CommandObj:
'response': response 'response': response
} }
await self.db(ctx.guild).commands.set_attr(command, await self.db(ctx.guild).commands.set_raw(
ccinfo) command, value=ccinfo)
async def edit(self, async def edit(self,
ctx: commands.Context, ctx: commands.Context,
@ -105,11 +105,11 @@ class CommandObj:
response: None): response: None):
"""Edit an already existing custom command""" """Edit an already existing custom command"""
# Check if this command is registered # Check if this command is registered
if not await self.db(ctx.guild).commands.get_attr(command): if not await self.db(ctx.guild).commands.get_raw(command, default=None):
raise NotFound() raise NotFound()
author = ctx.message.author author = ctx.message.author
ccinfo = await self.db(ctx.guild).commands.get_attr(command) ccinfo = await self.db(ctx.guild).commands.get_raw(command, default=None)
def check(m): def check(m):
return m.channel == ctx.channel and m.author == ctx.message.author return m.channel == ctx.channel and m.author == ctx.message.author
@ -138,18 +138,18 @@ class CommandObj:
author.id author.id
) )
await self.db(ctx.guild).commands.set_attr(command, await self.db(ctx.guild).commands.set_raw(
ccinfo) command, value=ccinfo)
async def delete(self, async def delete(self,
ctx: commands.Context, ctx: commands.Context,
command: str): command: str):
"""Delete an already exisiting custom command""" """Delete an already exisiting custom command"""
# Check if this command is registered # Check if this command is registered
if not await self.db(ctx.guild).commands.get_attr(command): if not await self.db(ctx.guild).commands.get_raw(command, default=None):
raise NotFound() raise NotFound()
await self.db(ctx.guild).commands.set_attr(command, await self.db(ctx.guild).commands.set_raw(
None) command, value=None)
class CustomCommands: class CustomCommands:
@ -326,7 +326,7 @@ class CustomCommands:
return return
guild = message.guild guild = message.guild
prefixes = await self.bot.db.guild(message.guild).get_attr('prefix') prefixes = await self.bot.db.guild(guild).get_raw('prefix', default=[])
if len(prefixes) < 1: if len(prefixes) < 1:
def_prefixes = await self.bot.get_prefix(message) def_prefixes = await self.bot.get_prefix(message)

View File

@ -53,7 +53,7 @@ class Streams:
@commands.command() @commands.command()
async def twitch(self, ctx, channel_name: str): async def twitch(self, ctx, channel_name: str):
"""Checks if a Twitch channel is streaming""" """Checks if a Twitch channel is streaming"""
token = await self.db.tokens.get_attr(TwitchStream.__name__) token = await self.db.tokens.get_raw(TwitchStream.__name__, default=None)
stream = TwitchStream(name=channel_name, stream = TwitchStream(name=channel_name,
token=token) token=token)
await self.check_online(ctx, stream) await self.check_online(ctx, stream)
@ -187,7 +187,7 @@ class Streams:
async def stream_alert(self, ctx, _class, channel_name): async def stream_alert(self, ctx, _class, channel_name):
stream = self.get_stream(_class, channel_name.lower()) stream = self.get_stream(_class, channel_name.lower())
if not stream: if not stream:
token = await self.db.tokens.get_attr(_class.__name__) token = await self.db.tokens.get_raw(_class.__name__, default=None)
stream = _class(name=channel_name, stream = _class(name=channel_name,
token=token) token=token)
try: try:
@ -210,7 +210,7 @@ class Streams:
async def community_alert(self, ctx, _class, community_name): async def community_alert(self, ctx, _class, community_name):
community = self.get_community(_class, community_name) community = self.get_community(_class, community_name)
if not community: if not community:
token = await self.db.tokens.get_attr(_class.__name__) token = await self.db.tokens.get_raw(_class.__name__, default=None)
community = _class(name=community_name, token=token) community = _class(name=community_name, token=token)
try: try:
await community.get_community_streams() await community.get_community_streams()
@ -477,7 +477,7 @@ class Streams:
if not _class: if not _class:
continue continue
token = await self.db.tokens.get_attr(_class.__name__) token = await self.db.tokens.get_raw(_class.__name__)
streams.append(_class(token=token, **raw_stream)) streams.append(_class(token=token, **raw_stream))
# issue 1191 extended resolution: Remove this after suitable period # issue 1191 extended resolution: Remove this after suitable period
@ -497,7 +497,7 @@ class Streams:
if not _class: if not _class:
continue continue
token = await self.db.tokens.get_attr(_class.__name__) token = await self.db.tokens.get_raw(_class.__name__, default=None)
communities.append(_class(token=token, **raw_community)) communities.append(_class(token=token, **raw_community))
# issue 1191 extended resolution: Remove this after suitable period # issue 1191 extended resolution: Remove this after suitable period

View File

@ -267,17 +267,13 @@ class Group(Value):
return not isinstance(default, dict) return not isinstance(default, dict)
def get_attr(self, item: str, default=None, resolve=True): def get_attr(self, item: str):
"""Manually get an attribute of this Group. """Manually get an attribute of this Group.
This is available to use as an alternative to using normal Python This is available to use as an alternative to using normal Python
attribute access. It is required if you find a need for dynamic attribute access. It may be required if you find a need for dynamic
attribute access. attribute access.
Note
----
Use of this method should be avoided wherever possible.
Example Example
------- -------
A possible use case:: A possible use case::
@ -287,32 +283,20 @@ class Group(Value):
user = ctx.author user = ctx.author
# Where the value of item is the name of the data field in Config # Where the value of item is the name of the data field in Config
await ctx.send(await self.conf.user(user).get_attr(item)) await ctx.send(await self.conf.user(user).get_attr(item).foo())
Parameters Parameters
---------- ----------
item : str item : str
The name of the data field in `Config`. The name of the data field in `Config`.
default
This is an optional override to the registered default for this
item.
resolve : bool
If this is :code:`True` this function will return a coroutine that
resolves to a "real" data value when awaited. If :code:`False`,
this method acts the same as `__getattr__`.
Returns Returns
------- -------
`types.coroutine` or `Value` or `Group` `Value` or `Group`
The attribute which was requested, its type depending on the value The attribute which was requested.
of :code:`resolve`.
""" """
value = getattr(self, item) return self.__getattr__(item)
if resolve:
return value(default=default)
else:
return value
async def get_raw(self, *nested_path: str, default=...): async def get_raw(self, *nested_path: str, default=...):
""" """
@ -350,6 +334,16 @@ class Group(Value):
""" """
path = [str(p) for p in nested_path] path = [str(p) for p in nested_path]
if default is ...:
poss_default = self.defaults
for ident in path:
try:
poss_default = poss_default[ident]
except KeyError:
break
else:
default = poss_default
try: try:
return deepcopy(await self.driver.get(*self.identifiers, *path)) return deepcopy(await self.driver.get(*self.identifiers, *path))
except KeyError: except KeyError:
@ -398,27 +392,6 @@ class Group(Value):
) )
await super().set(value) await super().set(value)
async def set_attr(self, item: str, value):
"""Set an attribute by its name.
Similar to `get_attr` in the way it can be used to dynamically set
attributes by name.
Note
----
Use of this method should be avoided wherever possible.
Parameters
----------
item : str
The name of the attribute being set.
value
The raw data value to set the attribute as.
"""
value_obj = getattr(self, item)
await value_obj.set(value)
async def set_raw(self, *nested_path: str, value): async def set_raw(self, *nested_path: str, value):
""" """
Allows a developer to set data as if it was stored in a standard Allows a developer to set data as if it was stored in a standard

View File

@ -77,8 +77,8 @@ class Case:
case_emb = await self.message_content() case_emb = await self.message_content()
await self.message.edit(embed=case_emb) await self.message.edit(embed=case_emb)
await _conf.guild(self.guild).cases.set_attr( await _conf.guild(self.guild).cases.set_raw(
str(self.case_number), self.to_json() str(self.case_number), value=self.to_json()
) )
async def message_content(self): async def message_content(self):
@ -245,7 +245,7 @@ class CaseType:
"case_str": self.case_str, "case_str": self.case_str,
"audit_type": self.audit_type "audit_type": self.audit_type
} }
await _conf.casetypes.set_attr(self.name, data) await _conf.casetypes.set_raw(self.name, value=data)
async def is_enabled(self) -> bool: async def is_enabled(self) -> bool:
""" """
@ -262,8 +262,8 @@ class CaseType:
""" """
if not self.guild: if not self.guild:
return False return False
return await _conf.guild(self.guild).casetypes.get_attr(self.name, return await _conf.guild(self.guild).casetypes.get_raw(
self.default_setting) self.name, default=self.default_setting)
async def set_enabled(self, enabled: bool): async def set_enabled(self, enabled: bool):
""" """
@ -275,7 +275,7 @@ class CaseType:
True if the case should be enabled, otherwise False""" True if the case should be enabled, otherwise False"""
if not self.guild: if not self.guild:
return return
await _conf.guild(self.guild).casetypes.set_attr(self.name, enabled) await _conf.guild(self.guild).casetypes.set_raw(self.name, value=enabled)
@classmethod @classmethod
def from_json(cls, data: dict): def from_json(cls, data: dict):
@ -310,7 +310,7 @@ async def get_next_case_number(guild: discord.Guild) -> str:
""" """
cases = sorted( cases = sorted(
(await _conf.guild(guild).get_attr("cases")), (await _conf.guild(guild).get_raw("cases")),
key=lambda x: int(x), key=lambda x: int(x),
reverse=True reverse=True
) )
@ -342,11 +342,12 @@ async def get_case(case_number: int, guild: discord.Guild,
If there is no case for the specified number If there is no case for the specified number
""" """
case = await _conf.guild(guild).cases.get_attr(str(case_number)) try:
if case is None: case = await _conf.guild(guild).cases.get_raw(str(case_number))
except KeyError as e:
raise RuntimeError( raise RuntimeError(
"That case does not exist for guild {}".format(guild.name) "That case does not exist for guild {}".format(guild.name)
) ) from e
mod_channel = await get_modlog_channel(guild) mod_channel = await get_modlog_channel(guild)
return await Case.from_json(mod_channel, bot, case) return await Case.from_json(mod_channel, bot, case)
@ -368,7 +369,7 @@ async def get_all_cases(guild: discord.Guild, bot: Red) -> List[Case]:
A list of all cases for the guild A list of all cases for the guild
""" """
cases = await _conf.guild(guild).get_attr("cases") cases = await _conf.guild(guild).get_raw("cases")
case_numbers = list(cases.keys()) case_numbers = list(cases.keys())
case_list = [] case_list = []
for case in case_numbers: for case in case_numbers:
@ -440,7 +441,7 @@ async def create_case(guild: discord.Guild, created_at: datetime, action_type: s
case_emb = await case.message_content() case_emb = await case.message_content()
msg = await mod_channel.send(embed=case_emb) msg = await mod_channel.send(embed=case_emb)
case.message = msg case.message = msg
await _conf.guild(guild).cases.set_attr(str(next_case_number), case.to_json()) await _conf.guild(guild).cases.set_raw(str(next_case_number), value=case.to_json())
return case return case
@ -459,7 +460,7 @@ async def get_casetype(name: str, guild: discord.Guild=None) -> Union[CaseType,
------- -------
CaseType or None CaseType or None
""" """
casetypes = await _conf.get_attr("casetypes") casetypes = await _conf.get_raw("casetypes")
if name in casetypes: if name in casetypes:
data = casetypes[name] data = casetypes[name]
data["name"] = name data["name"] = name
@ -480,7 +481,7 @@ async def get_all_casetypes(guild: discord.Guild=None) -> List[CaseType]:
A list of case types A list of case types
""" """
casetypes = await _conf.get_attr("casetypes") casetypes = await _conf.get_raw("casetypes", default={})
typelist = [] typelist = []
for ct in casetypes.keys(): for ct in casetypes.keys():
data = casetypes[ct] data = casetypes[ct]

View File

@ -219,14 +219,14 @@ async def test_set_channel_no_register(config, empty_channel):
# Dynamic attribute testing # Dynamic attribute testing
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_dynamic_attr(config): async def test_set_dynamic_attr(config):
await config.set_attr("foobar", True) await config.set_raw("foobar", value=True)
assert await config.foobar() is True assert await config.foobar() is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_dynamic_attr(config): async def test_get_dynamic_attr(config):
assert await config.get_attr("foobaz", True) is True assert await config.get_raw("foobaz", default=True) is True
# Member Group testing # Member Group testing
@ -299,13 +299,12 @@ async def test_member_clear_all(config, member_factory):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_value(config_fr): async def test_clear_value(config):
config_fr.register_global(foo=False) await config.foo.set(True)
await config_fr.foo.set(True) await config.foo.clear()
await config_fr.foo.clear()
with pytest.raises(KeyError): with pytest.raises(KeyError):
await config_fr.get_raw('foo') await config.get_raw('foo')
# Get All testing # Get All testing