diff --git a/redbot/core/config.py b/redbot/core/config.py index 79625ca68..012f10474 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -146,6 +146,12 @@ class Value: """ await self.driver.set(*self.identifiers, value=value) + async def clear(self): + """ + Clears the value from record for the data element pointed to by `identifiers`. + """ + await self.driver.clear(*self.identifiers) + class Group(Value): """ @@ -438,14 +444,6 @@ class Group(Value): path = [str(p) for p in nested_path] await self.driver.set(*self.identifiers, *path, value=value) - async def clear(self): - """Wipe all data from this group. - - If used on a global group, it will wipe all global data, but not - local data. - """ - await self.set({}) - _config_cogrefs = {} _config_coreref = None @@ -1037,7 +1035,7 @@ class Config: driver=self.driver) else: group = self._get_base_group(*scopes) - await group.set({}) + await group.clear() async def clear_all(self): """Clear all data from this Config instance. diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index a94f86d0c..c92d7dabb 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -1,5 +1,3 @@ -from typing import Tuple - __all__ = ["BaseDriver"] @@ -8,13 +6,18 @@ class BaseDriver: self.cog_name = cog_name self.unique_cog_identifier = None # This is set by Config's init method - async def get(self, *identifiers: Tuple[str]): + async def get(self, *identifiers: str): """ Finds the value indicate by the given identifiers. - :param identifiers: + Parameters + ---------- + identifiers A list of identifiers that correspond to nested dict accesses. - :return: + + Returns + ------- + Any Stored value. """ raise NotImplementedError @@ -24,18 +27,34 @@ class BaseDriver: Asks users for additional configuration information necessary to use this config driver. - :return: + Returns + ------- Dict of configuration details. """ raise NotImplementedError - async def set(self, *identifiers: Tuple[str], value=None): + async def set(self, *identifiers: str, value=None): """ Sets the value of the key indicated by the given identifiers. - :param identifiers: + Parameters + ---------- + identifiers A list of identifiers that correspond to nested dict accesses. - :param value: + value Any JSON serializable python object. """ raise NotImplementedError + + async def clear(self, *identifiers: str): + """ + Clears out the value specified by the given identifiers. + + Equivalent to using ``del`` on a dict. + + Parameters + ---------- + identifiers + A list of identifiers that correspond to nested dict accesses. + """ + raise NotImplementedError diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/red_json.py index fc2a4ec69..dec2a5f58 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/red_json.py @@ -58,3 +58,14 @@ class JSON(BaseDriver): partial[full_identifiers[-1]] = value await self.jsonIO._threadsafe_save_json(self.data) + + async def clear(self, *identifiers: str): + partial = self.data + full_identifiers = (self.unique_cog_identifier, *identifiers) + for i in full_identifiers[:-1]: + if i not in partial: + break + partial = partial[i] + else: + del partial[identifiers[-1]] + await self.jsonIO._threadsafe_save_json(self.data) diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 73046680a..de431c3cc 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -71,7 +71,7 @@ class Mongo(BaseDriver): uuid, identifiers = identifiers[0], identifiers[1:] return uuid, identifiers - async def get(self, *identifiers: Tuple[str]): + async def get(self, *identifiers: str): await self._ensure_connected() mongo_collection = self.get_collection() @@ -104,6 +104,17 @@ class Mongo(BaseDriver): upsert=True ) + async def clear(self, *identifiers: str): + await self._ensure_connected() + + dot_identifiers = '.'.join(identifiers) + mongo_collection = self.get_collection() + + await mongo_collection.update_one( + {'_id': self.unique_cog_identifier}, + update={"$unset": {dot_identifiers: 1}} + ) + def get_config_details(): host = input("Enter host address: ") diff --git a/tests/core/test_config.py b/tests/core/test_config.py index d333b7735..92152f07a 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -298,6 +298,16 @@ async def test_member_clear_all(config, member_factory): assert len(await config.all_members()) == 0 +@pytest.mark.asyncio +async def test_clear_value(config_fr): + config_fr.register_global(foo=False) + await config_fr.foo.set(True) + await config_fr.foo.clear() + + with pytest.raises(KeyError): + await config_fr.get_raw('foo') + + # Get All testing @pytest.mark.asyncio async def test_user_get_all_from_kind(config, user_factory):