diff --git a/docs/framework_config.rst b/docs/framework_config.rst index d7429bc55..06dd89ef0 100644 --- a/docs/framework_config.rst +++ b/docs/framework_config.rst @@ -191,5 +191,19 @@ Driver Reference **************** .. automodule:: redbot.core.drivers + :members: +Base Driver +^^^^^^^^^^^ .. autoclass:: redbot.core.drivers.red_base.BaseDriver + :members: + +JSON Driver +^^^^^^^^^^^ +.. autoclass:: redbot.core.drivers.red_json.JSON + :members: + +Mongo Driver +^^^^^^^^^^^^ +.. autoclass:: redbot.core.drivers.red_mongo.Mongo + :members: diff --git a/redbot/core/bank.py b/redbot/core/bank.py index 37049c4a2..80a4e920e 100644 --- a/redbot/core/bank.py +++ b/redbot/core/bank.py @@ -435,7 +435,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) - if (await is_global()) is global_: return global_ - if is_global(): + if await is_global(): await _conf.clear_all_users() else: await _conf.clear_all_members() diff --git a/redbot/core/config.py b/redbot/core/config.py index 68700f903..549a89fc1 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -5,7 +5,7 @@ from typing import Callable, Union, Tuple import discord from .data_manager import cog_data_path, core_data_path -from .drivers.red_json import JSON as JSONDriver +from .drivers import get_driver log = logging.getLogger("red.config") @@ -440,7 +440,18 @@ class Config: cog_name = cog_path_override.stem uuid = str(hash(identifier)) - spawner = JSONDriver(cog_name, data_path_override=cog_path_override) + # We have to import this here otherwise we have a circular dependency + from .data_manager import basic_config + + log.debug("Basic config: \n\n{}".format(basic_config)) + + driver_name = basic_config.get('STORAGE_TYPE', 'JSON') + driver_details = basic_config.get('STORAGE_DETAILS', {}) + + log.debug("Using driver: '{}'".format(driver_name)) + + spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override, + **driver_details) return cls(cog_name=cog_name, unique_identifier=uuid, force_registration=force_registration, driver_spawn=spawner) @@ -458,7 +469,16 @@ class Config: See `force_registration`. """ - driver_spawn = JSONDriver("Core", data_path_override=core_data_path()) + core_path = core_data_path() + + # We have to import this here otherwise we have a circular dependency + from .data_manager import basic_config + + driver_name = basic_config.get('STORAGE_TYPE', 'JSON') + driver_details = basic_config.get('STORAGE_DETAILS', {}) + + driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path, + **driver_details) return cls(cog_name="Core", driver_spawn=driver_spawn, unique_identifier='0', force_registration=force_registration) @@ -848,7 +868,7 @@ class Config: """ if not scopes: - group = Group(identifiers=(self.unique_identifier), + group = Group(identifiers=(self.unique_identifier, ), defaults={}, spawner=self.spawner) else: diff --git a/redbot/core/data_manager.py b/redbot/core/data_manager.py index 7bcb423da..a91b845c5 100644 --- a/redbot/core/data_manager.py +++ b/redbot/core/data_manager.py @@ -11,6 +11,8 @@ __all__ = ['load_basic_configuration', 'cog_data_path', 'core_data_path', jsonio = None basic_config = None +instance_name = None + basic_config_default = { "DATA_PATH": None, "COG_PATH_APPEND": "cogs", @@ -21,12 +23,15 @@ config_dir = Path(appdirs.AppDirs("Red-DiscordBot").user_config_dir) config_file = config_dir / 'config.json' -def load_basic_configuration(instance_name: str): +def load_basic_configuration(instance_name_: str): global jsonio global basic_config + global instance_name jsonio = JsonIO(config_file) + instance_name = instance_name_ + try: config = jsonio._load_json() basic_config = config[instance_name] diff --git a/redbot/core/drivers/__init__.py b/redbot/core/drivers/__init__.py index e69de29bb..b93ea0bd5 100644 --- a/redbot/core/drivers/__init__.py +++ b/redbot/core/drivers/__init__.py @@ -0,0 +1,31 @@ +__all__ = ["get_driver"] + + +def get_driver(type, *args, **kwargs): + """ + Selectively import/load driver classes based on the selected type. This + is required so that dependencies can differ between installs (e.g. so that + you don't need to install a mongo dependency if you will just be running a + json data backend). + + .. note:: + + See the respective classes for information on what ``args`` and ``kwargs`` + should be. + + :param str type: + One of: json, mongo + :param args: + Dependent on driver type. + :param kwargs: + Dependent on driver type. + :return: + Subclass of :py:class:`.red_base.BaseDriver`. + """ + if type == "JSON": + from .red_json import JSON + return JSON(*args, **kwargs) + elif type == "MongoDB": + from .red_mongo import Mongo + return Mongo(*args, **kwargs) + raise RuntimeError("Invalid driver type: '{}'".format(type)) diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index 50ec86132..119270cc2 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -3,11 +3,40 @@ from typing import Tuple __all__ = ["BaseDriver"] class BaseDriver: + def __init__(self, cog_name): + self.cog_name = cog_name + def get_driver(self): raise NotImplementedError async def get(self, identifiers: Tuple[str]): + """ + Finds the value indicate by the given identifiers. + + :param identifiers: + A list of identifiers that correspond to nested dict accesses. + :return: + Stored value. + """ + raise NotImplementedError + + def get_config_details(self): + """ + Asks users for additional configuration information necessary + to use this config driver. + + :return: + Dict of configuration details. + """ raise NotImplementedError async def set(self, identifiers: Tuple[str], value): + """ + Sets the value of the key indicated by the given identifiers. + + :param identifiers: + A list of identifiers that correspond to nested dict accesses. + :param value: + Any JSON serializable python object. + """ raise NotImplementedError diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/red_json.py index fbe4186ca..a638e59ec 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/red_json.py @@ -9,10 +9,20 @@ __all__ = ["JSON"] class JSON(BaseDriver): + """ + Subclass of :py:class:`.red_base.BaseDriver`. + + .. py:attribute:: file_name + + The name of the file in which to store JSON data. + + .. py:attribute:: data_path + + The path in which to store the file indicated by :py:attr:`file_name`. + """ def __init__(self, cog_name, *, data_path_override: Path=None, file_name_override: str="settings.json"): - super().__init__() - self.cog_name = cog_name + super().__init__(cog_name) self.file_name = file_name_override if data_path_override: self.data_path = data_path_override diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 38f4816cf..33eed5788 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -1,215 +1,129 @@ -import pymongo as m +from typing import Tuple + +import motor.motor_asyncio from .red_base import BaseDriver - -__all__ = ["Mongo", "RedMongoException", "MultipleMatches", - "MissingCollection"] - - -class RedMongoException(Exception): - """Base Red Mongo Exception class""" - pass - - -class MultipleMatches(RedMongoException): - """Raised when multiple documents match a single cog_name and - cog_identifier pair.""" - pass - - -class MissingCollection(RedMongoException): - """Raised when a collection is missing from the mongo db""" - pass +__all__ = ["Mongo"] class Mongo(BaseDriver): - def __init__(self, host, port=27017, admin_user=None, admin_pass=None, - **kwargs): - self.conn = m.MongoClient(host=host, port=port, **kwargs) + """ + Subclass of :py:class:`.red_base.BaseDriver`. + """ + def __init__(self, cog_name, **kwargs): + super().__init__(cog_name) + self.host = kwargs['HOST'] + self.port = kwargs['PORT'] + admin_user = kwargs['USERNAME'] + admin_pass = kwargs['PASSWORD'] + + from ..data_manager import instance_name + + self.instance_name = instance_name + + self.conn = None self.admin_user = admin_user self.admin_pass = admin_pass - self._db = self.conn.red - if self.admin_user is not None and self.admin_pass is not None: - self._db.authenticate(self.admin_user, self.admin_pass) + async def _authenticate(self): + self.conn = motor.motor_asyncio.AsyncIOMotorClient(host=self.host, port=self.port) - self._global = self._db.GLOBAL - self._guild = self._db.GUILD - self._channel = self._db.CHANNEL - self._role = self._db.ROLE - self._member = self._db.MEMBER - self._user = self._db.USER + if None not in (self.admin_pass, self.admin_user): + await self.db.authenticate(self.admin_user, self.admin_pass) - def get_global(self, cog_name, cog_identifier, _, key, *, default=None): - doc = self._global.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the GLOBAL" - " level: ({}, {})".format(cog_name, - cog_identifier)) - elif doc.count() == 1: - return doc[0].get(key, default) - return default + async def _ensure_connected(self): + if self.conn is None: + await self._authenticate() - def get_guild(self, cog_name, cog_identifier, guild_id, key, *, - default=None): - doc = self._guild.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier, - "guild_id": guild_id}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the GUILD" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, guild_id)) - elif doc.count() == 1: - return doc[0].get(key, default) - return default + @property + def db(self) -> motor.core.Database: + """ + Gets the mongo database for this cog's name. - def get_channel(self, cog_name, cog_identifier, channel_id, key, *, - default=None): - doc = self._channel.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier, - "channel_id": channel_id}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the CHANNEL" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, channel_id)) - elif doc.count() == 1: - return doc[0].get(key, default) - return default + .. warning:: - def get_role(self, cog_name, cog_identifier, role_id, key, *, - default=None): - doc = self._role.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier, - "role_id": role_id}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the ROLE" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, role_id)) - elif doc.count() == 1: - return doc[0].get(key, default) - return default + Right now this will cause a new connection to be made every time the + database is accessed. We will want to create a connection pool down the + line to limit the number of connections. - def get_member(self, cog_name, cog_identifier, user_id, guild_id, key, *, - default=None): - doc = self._member.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier, - "user_id": user_id, "guild_id": guild_id}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the MEMBER" - " level: ({}, {}, mid {}, sid {})".format( - cog_name, cog_identifier, user_id, - guild_id)) - elif doc.count() == 1: - return doc[0].get(key, default) - return default + :return: + PyMongo Database object. + """ + db_name = "RED_{}".format(self.instance_name) + return self.conn[db_name] - def get_user(self, cog_name, cog_identifier, user_id, key, *, - default=None): - doc = self._user.find( - {"cog_name": cog_name, "cog_identifier": cog_identifier, - "user_id": user_id}, - projection=[key, ], batch_size=2) - if doc.count() == 2: - raise MultipleMatches("Too many matching documents at the USER" - " level: ({}, {}, mid {})".format( - cog_name, cog_identifier, user_id)) - elif doc.count() == 1: - return doc[0].get(key, default) - else: - return default + def get_collection(self) -> motor.core.Collection: + """ + Gets a specified collection within the PyMongo database for this cog. - def set_global(self, cog_name, cog_identifier, key, value, clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier} - data = {"$set": {key: value}} - if self._global.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the GLOBAL" - " level: ({}, {})".format(cog_name, - cog_identifier)) - else: - if clear: - self._global.delete_one(filter) - else: - self._global.update_one(filter, data, upsert=True) + Unless you are doing custom stuff ``collection_name`` should be one of the class + attributes of :py:class:`core.config.Config`. - def set_guild(self, cog_name, cog_identifier, guild_id, key, value, - clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, - "guild_id": guild_id} - data = {"$set": {key: value}} - if self._guild.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the GUILD" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, guild_id)) - else: - if clear: - self._guild.delete_one(filter) - else: - self._guild.update_one(filter, data, upsert=True) + :param str collection_name: + :return: + PyMongo collection object. + """ + return self.db[self.cog_name] - def set_channel(self, cog_name, cog_identifier, channel_id, key, value, - clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, - "channel_id": channel_id} - data = {"$set": {key: value}} - if self._channel.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the CHANNEL" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, channel_id)) - else: - if clear: - self._channel.delete_one(filter) - else: - self._channel.update_one(filter, data, upsert=True) + @staticmethod + def _parse_identifiers(identifiers): + uuid, identifiers = identifiers[0], identifiers[1:] + return uuid, identifiers - def set_role(self, cog_name, cog_identifier, role_id, key, value, - clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, - "role_id": role_id} - data = {"$set": {key: value}} - if self._role.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the ROLE" - " level: ({}, {}, {})".format( - cog_name, cog_identifier, role_id)) - else: - if clear: - self._role.delete_one(filter) - else: - self._role.update_one(filter, data, upsert=True) + async def get(self, identifiers: Tuple[str]): + await self._ensure_connected() + uuid, identifiers = self._parse_identifiers(identifiers) - def set_member(self, cog_name, cog_identifier, user_id, guild_id, key, - value, clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, - "guild_id": guild_id, "user_id": user_id} - data = {"$set": {key: value}} - if self._member.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the MEMBER" - " level: ({}, {}, mid {}, sid {})".format( - cog_name, cog_identifier, user_id, - guild_id)) - else: - if clear: - self._member.delete_one(filter) - else: - self._member.update_one(filter, data, upsert=True) + mongo_collection = self.get_collection() - def set_user(self, cog_name, cog_identifier, user_id, key, value, - clear=False): - filter = {"cog_name": cog_name, "cog_identifier": cog_identifier, - "user_id": user_id} - data = {"$set": {key: value}} - if self._user.count(filter) > 1: - raise MultipleMatches("Too many matching documents at the USER" - " level: ({}, {}, mid {})".format( - cog_name, cog_identifier, user_id)) - else: - if clear: - self._user.delete_one(filter) - else: - self._user.update_one(filter, data, upsert=True) + dot_identifiers = '.'.join(identifiers) + + partial = await mongo_collection.find_one( + filter={'_id': uuid}, + projection={dot_identifiers: True} + ) + + if partial is None: + raise KeyError("No matching document was found and Config expects" + " a KeyError.") + + for i in identifiers: + partial = partial[i] + return partial + + async def set(self, identifiers: Tuple[str], value): + await self._ensure_connected() + uuid, identifiers = self._parse_identifiers(identifiers) + + dot_identifiers = '.'.join(identifiers) + + mongo_collection = self.get_collection() + + await mongo_collection.update_one( + {'_id': uuid}, + update={"$set": {dot_identifiers: value}}, + upsert=True + ) + + def get_driver(self): + return self + + +def get_config_details(): + host = input("Enter host address: ") + port = int(input("Enter host port: ")) + + admin_uname = input("Enter login username: ") + admin_password = input("Enter login password: ") + + if admin_uname == "": + admin_uname = admin_password = None + + ret = { + 'HOST': host, + 'PORT': port, + 'USERNAME': admin_uname, + 'PASSWORD': admin_password + } + return ret diff --git a/redbot/setup.py b/redbot/setup.py index 638005fe2..ded53f505 100644 --- a/redbot/setup.py +++ b/redbot/setup.py @@ -89,7 +89,13 @@ def basic_setup(): if storage not in storage_dict: storage = None - default_dirs['STORAGE_TYPE'] = storage_dict[storage] + default_dirs['STORAGE_TYPE'] = storage_dict.get(storage, 1) + + if storage_dict.get(storage, 1) == "MongoDB": + from redbot.core.drivers.red_mongo import get_config_details + default_dirs['STORAGE_DETAILS'] = get_config_details() + else: + default_dirs['STORAGE_DETAILS'] = {} name = "" while len(name) == 0: diff --git a/requirements.txt b/requirements.txt index 1b18025ab..9891ed898 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ appdirs youtube_dl -raven \ No newline at end of file +raven +colorama \ No newline at end of file diff --git a/setup.py b/setup.py index 244d1132c..5cf4c2b42 100644 --- a/setup.py +++ b/setup.py @@ -106,7 +106,7 @@ setup( ], extras_require={ 'test': ['pytest>=3', 'pytest-asyncio'], - 'mongo': ['pymongo', 'motor'], + 'mongo': ['motor'], 'docs': ['sphinx', 'sphinxcontrib-asyncio', 'sphinx_rtd_theme'], 'voice': ['PyNaCl'] }