mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
* PostgreSQL driver and general drivers cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Make tests pass Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add black --target-version flag in make.bat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Rewrite postgres driver Most of the logic is now in PL/pgSQL. This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse. Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Reformat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix PostgresDriver.delete_all_data() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Clean up PL/pgSQL code Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More PL/pgSQL cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * PL/pgSQL function optimisations Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Ensure compatibility with PostgreSQL 10 and below Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More/better docstrings for PG functions Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix typo in docstring Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Return correct value on toggle() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Use composite type for PG function parameters Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix JSON driver's Config.clear_all() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Correct description for Mongo tox recipe Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix linting errors Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update dep specification after merging bumpdeps Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entries Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update from merge Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Mention [postgres] extra in install docs Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Support more connection options and use better defaults Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Actually pass PG env vars in tox Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Replace event trigger with manual DELETE queries Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
256 lines
8.7 KiB
Python
256 lines
8.7 KiB
Python
import getpass
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List
|
|
|
|
try:
|
|
# pylint: disable=import-error
|
|
import asyncpg
|
|
except ModuleNotFoundError:
|
|
asyncpg = None
|
|
|
|
from ... import data_manager, errors
|
|
from ..base import BaseDriver, IdentifierData, ConfigCategory
|
|
from ..log import log
|
|
|
|
__all__ = ["PostgresDriver"]
|
|
|
|
_PKG_PATH = Path(__file__).parent
|
|
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
|
|
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"
|
|
|
|
|
|
def encode_identifier_data(
|
|
id_data: IdentifierData
|
|
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
|
|
return (
|
|
id_data.cog_name,
|
|
id_data.uuid,
|
|
id_data.category,
|
|
["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
|
|
list(id_data.identifiers),
|
|
1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
|
|
id_data.is_custom,
|
|
)
|
|
|
|
|
|
class PostgresDriver(BaseDriver):
|
|
|
|
_pool: Optional["asyncpg.pool.Pool"] = None
|
|
|
|
@classmethod
|
|
async def initialize(cls, **storage_details) -> None:
|
|
if asyncpg is None:
|
|
raise errors.MissingExtraRequirements(
|
|
"Red must be installed with the [postgres] extra to use the PostgreSQL driver"
|
|
)
|
|
cls._pool = await asyncpg.create_pool(**storage_details)
|
|
with DDL_SCRIPT_PATH.open() as fs:
|
|
await cls._pool.execute(fs.read())
|
|
|
|
@classmethod
|
|
async def teardown(cls) -> None:
|
|
if cls._pool is not None:
|
|
await cls._pool.close()
|
|
|
|
@staticmethod
|
|
def get_config_details():
|
|
unixmsg = (
|
|
""
|
|
if sys.platform != "win32"
|
|
else (
|
|
" - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
|
|
"/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
|
|
)
|
|
)
|
|
host = (
|
|
input(
|
|
f"Enter the PostgreSQL server's address.\n"
|
|
f"If left blank, Red will try the following, in order:\n"
|
|
f" - The PGHOST environment variable,\n{unixmsg}"
|
|
f" - localhost.\n"
|
|
f"> "
|
|
)
|
|
or None
|
|
)
|
|
|
|
print(
|
|
"Enter the PostgreSQL server port.\n"
|
|
"If left blank, this will default to either:\n"
|
|
" - The PGPORT environment variable,\n"
|
|
" - 5432."
|
|
)
|
|
while True:
|
|
port = input("> ") or None
|
|
if port is None:
|
|
break
|
|
|
|
try:
|
|
port = int(port)
|
|
except ValueError:
|
|
print("Port must be a number")
|
|
else:
|
|
break
|
|
|
|
user = (
|
|
input(
|
|
"Enter the PostgreSQL server username.\n"
|
|
"If left blank, this will default to either:\n"
|
|
" - The PGUSER environment variable,\n"
|
|
" - The OS name of the user running Red (ident/peer authentication).\n"
|
|
"> "
|
|
)
|
|
or None
|
|
)
|
|
|
|
passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass"
|
|
password = getpass.getpass(
|
|
f"Enter the PostgreSQL server password. The input will be hidden.\n"
|
|
f" NOTE: If using ident/peer authentication (no password), enter NONE.\n"
|
|
f"When NONE is entered, this will default to:\n"
|
|
f" - The PGPASSWORD environment variable,\n"
|
|
f" - Looking up the password in the {passfile} passfile,\n"
|
|
f" - No password.\n"
|
|
f"> "
|
|
)
|
|
if password == "NONE":
|
|
password = None
|
|
|
|
database = (
|
|
input(
|
|
"Enter the PostgreSQL database's name.\n"
|
|
"If left blank, this will default to either:\n"
|
|
" - The PGDATABASE environment variable,\n"
|
|
" - The OS name of the user running Red.\n"
|
|
"> "
|
|
)
|
|
or None
|
|
)
|
|
|
|
return {
|
|
"host": host,
|
|
"port": port,
|
|
"user": user,
|
|
"password": password,
|
|
"database": database,
|
|
}
|
|
|
|
async def get(self, identifier_data: IdentifierData):
|
|
try:
|
|
result = await self._execute(
|
|
"SELECT red_config.get($1)",
|
|
encode_identifier_data(identifier_data),
|
|
method=self._pool.fetchval,
|
|
)
|
|
except asyncpg.UndefinedTableError:
|
|
raise KeyError from None
|
|
|
|
if result is None:
|
|
# The result is None both when postgres yields no results, or when it yields a NULL row
|
|
# A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
|
|
raise KeyError
|
|
return json.loads(result)
|
|
|
|
async def set(self, identifier_data: IdentifierData, value=None):
|
|
try:
|
|
await self._execute(
|
|
"SELECT red_config.set($1, $2::jsonb)",
|
|
encode_identifier_data(identifier_data),
|
|
json.dumps(value),
|
|
)
|
|
except asyncpg.ErrorInAssignmentError:
|
|
raise errors.CannotSetSubfield
|
|
|
|
async def clear(self, identifier_data: IdentifierData):
|
|
try:
|
|
await self._execute(
|
|
"SELECT red_config.clear($1)", encode_identifier_data(identifier_data)
|
|
)
|
|
except asyncpg.UndefinedTableError:
|
|
pass
|
|
|
|
async def inc(
|
|
self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
|
|
) -> Union[int, float]:
|
|
try:
|
|
return await self._execute(
|
|
f"SELECT red_config.inc($1, $2, $3)",
|
|
encode_identifier_data(identifier_data),
|
|
value,
|
|
default,
|
|
method=self._pool.fetchval,
|
|
)
|
|
except asyncpg.WrongObjectTypeError as exc:
|
|
raise errors.StoredTypeError(*exc.args)
|
|
|
|
async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
|
|
try:
|
|
return await self._execute(
|
|
"SELECT red_config.inc($1, $2)",
|
|
encode_identifier_data(identifier_data),
|
|
default,
|
|
method=self._pool.fetchval,
|
|
)
|
|
except asyncpg.WrongObjectTypeError as exc:
|
|
raise errors.StoredTypeError(*exc.args)
|
|
|
|
@classmethod
|
|
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
|
query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
|
|
log.invisible(query)
|
|
async with cls._pool.acquire() as conn, conn.transaction():
|
|
async for row in conn.cursor(query):
|
|
yield row["cog_name"], row["cog_id"]
|
|
|
|
@classmethod
|
|
async def delete_all_data(
|
|
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
|
|
) -> None:
|
|
"""Delete all data being stored by this driver.
|
|
|
|
Parameters
|
|
----------
|
|
interactive : bool
|
|
Set to ``True`` to allow the method to ask the user for
|
|
input from the console, regarding the other unset parameters
|
|
for this method.
|
|
drop_db : Optional[bool]
|
|
Set to ``True`` to drop the entire database for the current
|
|
bot's instance. Otherwise, schemas within the database which
|
|
store bot data will be dropped, as well as functions,
|
|
aggregates, event triggers, and meta-tables.
|
|
|
|
"""
|
|
if interactive is True and drop_db is None:
|
|
print(
|
|
"Please choose from one of the following options:\n"
|
|
" 1. Drop the entire PostgreSQL database for this instance, or\n"
|
|
" 2. Delete all of Red's data within this database, without dropping the database "
|
|
"itself."
|
|
)
|
|
options = ("1", "2")
|
|
while True:
|
|
resp = input("> ")
|
|
try:
|
|
drop_db = bool(options.index(resp))
|
|
except ValueError:
|
|
print("Please type a number corresponding to one of the options.")
|
|
else:
|
|
break
|
|
if drop_db is True:
|
|
storage_details = data_manager.storage_details()
|
|
await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"])
|
|
else:
|
|
with DROP_DDL_SCRIPT_PATH.open() as fs:
|
|
await cls._pool.execute(fs.read())
|
|
|
|
@classmethod
|
|
async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
|
|
if method is None:
|
|
method = cls._pool.execute
|
|
log.invisible("Query: %s", query)
|
|
if args:
|
|
log.invisible("Args: %s", args)
|
|
return await method(query, *args)
|