From b6c8be5f43438edcf5f577bc9f7e2e11762df96d Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 1 Oct 2018 01:49:29 -0500 Subject: [PATCH] [MongoDB] Support mongodb+srv protocol (#2159) --- Pipfile.lock | 6 ++++++ redbot/core/drivers/red_mongo.py | 29 +++++++++++++++++++++++++---- setup.py | 2 +- 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/Pipfile.lock b/Pipfile.lock index c27d841ce..b48ed16c6 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -89,6 +89,12 @@ ], "version": "==1.3.0" }, + "dnspython": { + "hashes": [ + "sha256:861e6e58faa730f9845aaaa9c6c832851fbf89382ac52915a51f89c71accdd31" + ], + "version": "==1.15.0" + }, "e1839a8": { "editable": true, "extras": [ diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 18a6382e4..664b691e7 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -9,18 +9,24 @@ _conn = None def _initialize(**kwargs): + kwargs.get("URI", "mongodb") host = kwargs["HOST"] port = kwargs["PORT"] admin_user = kwargs["USERNAME"] admin_pass = kwargs["PASSWORD"] db_name = kwargs.get("DB_NAME", "default_db") + if port is 0: + ports = "" + else: + ports = ":{}".format(port) + if admin_user is not None and admin_pass is not None: - url = "mongodb://{}:{}@{}:{}/{}".format( - quote_plus(admin_user), quote_plus(admin_pass), host, port, db_name + url = "{}://{}:{}@{}{}/{}".format( + uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name ) else: - url = "mongodb://{}:{}/{}".format(host, port, db_name) + url = "{}://{}{}/{}".format(uri, host, ports, db_name) global _conn _conn = motor.motor_asyncio.AsyncIOMotorClient(url) @@ -111,8 +117,22 @@ class Mongo(BaseDriver): def get_config_details(): + uri = None + while True: + uri = input("Enter URI scheme (mongodb or mongodb+srv): ") + if uri is "": + uri = "mongodb" + + if uri in ["mongodb", "mongodb+srv"]: + break + else: + print("Invalid URI scheme") + host = input("Enter host address: ") - port = int(input("Enter host port: ")) + if uri is "mongodb": + port = int(input("Enter host port: ")) + else: + port = 0 admin_uname = input("Enter login username: ") admin_password = input("Enter login password: ") @@ -128,5 +148,6 @@ def get_config_details(): "USERNAME": admin_uname, "PASSWORD": admin_password, "DB_NAME": db_name, + "URI": uri, } return ret diff --git a/setup.py b/setup.py index 63b45257a..56e97a3e9 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ if __name__ == "__main__": "pytest-asyncio==0.9.0", "six==1.11.0", ], - "mongo": ["motor==2.0.0", "pymongo==3.7.1"], + "mongo": ["motor==2.0.0", "pymongo==3.7.1", "dnspython==1.15.0"], "docs": [ "alabaster==0.7.11", "babel==2.6.0",