diff --git a/redbot/cogs/trivia/trivia.py b/redbot/cogs/trivia/trivia.py index 3f7d3c431..1d6eebe26 100644 --- a/redbot/cogs/trivia/trivia.py +++ b/redbot/cogs/trivia/trivia.py @@ -3,7 +3,8 @@ import asyncio import math import pathlib from collections import Counter -from typing import List, Literal +from typing import Any, Dict, List, Literal +from schema import Schema, Optional, Or, SchemaError import io import yaml @@ -23,9 +24,23 @@ from .converters import finite_float from .log import LOG from .session import TriviaSession -__all__ = ["Trivia", "UNIQUE_ID", "get_core_lists"] +__all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list") UNIQUE_ID = 0xB3C0E453 +TRIVIA_LIST_SCHEMA = Schema( + { + Optional("AUTHOR"): str, + Optional("CONFIG"): { + Optional("max_score"): int, + Optional("timeout"): Or(int, float), + Optional("delay"): Or(int, float), + Optional("bot_plays"): bool, + Optional("reveal_answer"): bool, + Optional("payout_multiplier"): Or(int, float), + }, + str: [str, int, bool, float], + } +) _ = Translator("Trivia", __file__) @@ -282,6 +297,13 @@ class Trivia(commands.Cog): _("There was an error parsing the trivia list. See logs for more info.") ) LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename) + except SchemaError as e: + await ctx.send( + _( + "The custom trivia list was not saved." + " The file does not follow the proper data format.\n{schema_error}" + ).format(schema_error=box(e)) + ) @commands.is_owner() @triviaset_custom.command(name="delete", aliases=["remove"]) @@ -615,13 +637,7 @@ class Trivia(commands.Cog): except StopIteration: raise FileNotFoundError("Could not find the `{}` category.".format(category)) - with path.open(encoding="utf-8") as file: - try: - dict_ = yaml.safe_load(file) - except yaml.error.YAMLError as exc: - raise InvalidListError("YAML parsing failed.") from exc - else: - return dict_ + return get_list(path) async def _save_trivia_list( self, ctx: commands.Context, attachment: discord.Attachment @@ -683,9 +699,10 @@ class Trivia(commands.Cog): return buffer = io.BytesIO(await attachment.read()) - yaml.safe_load(buffer) - buffer.seek(0) + trivia_dict = yaml.safe_load(buffer) + TRIVIA_LIST_SCHEMA.validate(trivia_dict) + buffer.seek(0) with file.open("wb") as fp: fp.write(buffer.read()) await ctx.send(_("Saved Trivia list as {filename}.").format(filename=filename)) @@ -709,3 +726,27 @@ def get_core_lists() -> List[pathlib.Path]: """Return a list of paths for all trivia lists packaged with the bot.""" core_lists_path = pathlib.Path(__file__).parent.resolve() / "data/lists" return list(core_lists_path.glob("*.yaml")) + + +def get_list(path: pathlib.Path) -> Dict[str, Any]: + """ + Returns a trivia list dictionary from the given path. + + Raises + ------ + InvalidListError + Parsing of list's YAML file failed. + SchemaError + The list does not adhere to the schema. + """ + with path.open(encoding="utf-8") as file: + try: + trivia_dict = yaml.safe_load(file) + except yaml.error.YAMLError as exc: + raise InvalidListError("YAML parsing failed.") from exc + + try: + TRIVIA_LIST_SCHEMA.validate(trivia_dict) + except SchemaError as exc: + raise InvalidListError("The list does not adhere to the schema.") from exc + return trivia_dict diff --git a/tests/cogs/test_trivia.py b/tests/cogs/test_trivia.py index 3c0fef1a0..9aaed0275 100644 --- a/tests/cogs/test_trivia.py +++ b/tests/cogs/test_trivia.py @@ -1,33 +1,27 @@ +import textwrap + import yaml +from schema import SchemaError def test_trivia_lists(): - from redbot.cogs.trivia import get_core_lists + from redbot.cogs.trivia import InvalidListError, get_core_lists, get_list list_names = get_core_lists() assert list_names problem_lists = [] for l in list_names: - with l.open(encoding="utf-8") as f: - try: - dict_ = yaml.safe_load(f) - except yaml.error.YAMLError as e: - problem_lists.append((l.stem, "YAML error:\n{!s}".format(e))) + try: + get_list(l) + except InvalidListError as exc: + e = exc.__cause__ + if isinstance(e, SchemaError): + problem_lists.append((l.stem, f"SCHEMA error:\n{e!s}")) else: - for key in list(dict_.keys()): - if key == "CONFIG": - if not isinstance(dict_[key], dict): - problem_lists.append((l.stem, "CONFIG is not a dict")) - elif key == "AUTHOR": - if not isinstance(dict_[key], str): - problem_lists.append((l.stem, "AUTHOR is not a string")) - else: - if not isinstance(dict_[key], list): - problem_lists.append( - (l.stem, "The answers for '{}' are not a list".format(key)) - ) + problem_lists.append((l.stem, f"YAML error:\n{e!s}")) + if problem_lists: msg = "" - for l in problem_lists: - msg += "{}: {}\n".format(l[0], l[1]) + for name, error in problem_lists: + msg += f"- {name}:\n{textwrap.indent(error, ' ')}" raise TypeError("The following lists contain errors:\n" + msg)