From b493103dcb84633550865a27d867ae70c65f3ccb Mon Sep 17 00:00:00 2001 From: Vexed Date: Mon, 2 Jan 2023 04:24:27 +0000 Subject: [PATCH] Improve validation in trivia (#5947) Co-authored-by: Jakub Kuczys --- redbot/cogs/trivia/schema.py | 121 +++++++++++++++++++++++++++++++++++ redbot/cogs/trivia/trivia.py | 31 ++------- tests/cogs/test_trivia.py | 53 ++++++++++++++- 3 files changed, 180 insertions(+), 25 deletions(-) create mode 100644 redbot/cogs/trivia/schema.py diff --git a/redbot/cogs/trivia/schema.py b/redbot/cogs/trivia/schema.py new file mode 100644 index 000000000..bc1af8877 --- /dev/null +++ b/redbot/cogs/trivia/schema.py @@ -0,0 +1,121 @@ +import itertools +import re +from typing import Any, NoReturn + +from schema import And, Const, Optional, Schema, SchemaError, SchemaMissingKeyError, Use + +from redbot.core.i18n import Translator + +__all__ = ("TRIVIA_LIST_SCHEMA", "format_schema_error") + +T_ = Translator("Trivia", __file__) +KEY_ERROR_MSG_RE = re.compile(r"Key '(.+)' error:") + + +class SchemaErrorMessage(str): + def format(self, *args: Any, **kwargs: Any) -> str: + return T_(str(self)) + + +def int_or_float(value: Any) -> float: + if not isinstance(value, (float, int)): + raise TypeError("Value needs to be an integer or a float.") + return float(value) + + +def not_str(value: Any) -> float: + if isinstance(value, str): + raise TypeError("Value needs to not be a string.") + return value + + +_ = SchemaErrorMessage +NO_QUESTIONS_ERROR_MSG = _("The trivia list does not contain any questions.") +ALWAYS_MATCH = Optional(Use(lambda x: x)) +MATCH_ALL_BUT_STR = Optional(Use(not_str)) +TRIVIA_LIST_SCHEMA = Schema( + { + Optional("AUTHOR"): And(str, error=_("{key} key must be a text value.")), + Optional("CONFIG"): And( + { + Optional("max_score"): And( + int, + lambda n: n >= 1, + error=_("{key} key in {parent_key} must be a positive integer."), + ), + Optional("timeout"): And( + Use(int_or_float), + lambda n: n > 0.0, + error=_("{key} key in {parent_key} must be a positive number."), + ), + Optional("delay"): And( + Use(int_or_float), + lambda n: n >= 4.0, + error=_( + "{key} key in {parent_key} must be a positive number" + " greater than or equal to 4." + ), + ), + Optional("bot_plays"): Const( + bool, error=_("{key} key in {parent_key} must be either true or false.") + ), + Optional("reveal_answer"): Const( + bool, error=_("{key} key in {parent_key} must be either true or false.") + ), + Optional("payout_multiplier"): And( + Use(int_or_float), + lambda n: n >= 0.0, + error=_("{key} key in {parent_key} must be a non-negative number."), + ), + Optional("use_spoilers"): Const( + bool, error=_("{key} key in {parent_key} must be either true or false.") + ), + # This matches any extra key and always fails validation + # for the purpose of better error messages. + ALWAYS_MATCH: And( + lambda __: False, + error=_("{key} is not a key that can be specified in {parent_key}."), + ), + }, + error=_("{key} should be a 'key: value' mapping."), + ), + str: And( + [str, int, bool, float], + error=_("Value of question {key} is not a list of text values (answers)."), + ), + # This matches any extra key and always fails validation + # for the purpose of better error messages. + MATCH_ALL_BUT_STR: And( + lambda __: False, + error=_("A key of question {key} is not a text value."), + ), + }, + error=_("A trivia list should be a 'key: value' mapping."), +) + + +def format_schema_error(exc: SchemaError) -> str: + if isinstance(exc, SchemaMissingKeyError): + return NO_QUESTIONS_ERROR_MSG.format() + + # dict.fromkeys is used for de-duplication with order preservation + errors = {idx: msg for idx, msg in enumerate(exc.errors) if msg is not None} + if not errors: + return str(exc) + error_idx, error_msg_fmt = errors.popitem() + + autos = dict.fromkeys(msg for msg in itertools.islice(exc.autos, error_idx) if msg is not None) + keys = [match[1] for msg in autos if (match := KEY_ERROR_MSG_RE.fullmatch(msg)) is not None] + key_count = len(keys) + if key_count == 2: + key = keys[-1] + parent_key = keys[-2] + elif key_count == 1: + key = keys[-1] + # should only happen for messages where this field isn't used + parent_key = "UNKNOWN" + else: + # should only happen for messages where neither of the fields are used + key = parent_key = "UNKNOWN" + + return error_msg_fmt.format(key=repr(key), parent_key=repr(parent_key)) diff --git a/redbot/cogs/trivia/trivia.py b/redbot/cogs/trivia/trivia.py index 5b6e23e16..ca2c8c47d 100644 --- a/redbot/cogs/trivia/trivia.py +++ b/redbot/cogs/trivia/trivia.py @@ -4,7 +4,7 @@ import math import pathlib from collections import Counter from typing import Any, Dict, List, Literal, Union -from schema import Schema, Optional, Or, SchemaError +import schema import io import yaml @@ -23,26 +23,11 @@ from .checks import trivia_stop_check from .converters import finite_float from .log import LOG from .session import TriviaSession +from .schema import TRIVIA_LIST_SCHEMA, format_schema_error __all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list") UNIQUE_ID = 0xB3C0E453 -TRIVIA_LIST_SCHEMA = Schema( - { - Optional("AUTHOR"): str, - Optional("CONFIG"): { - Optional("max_score"): int, - Optional("timeout"): Or(int, float), - Optional("delay"): Or(int, float), - Optional("bot_plays"): bool, - Optional("reveal_answer"): bool, - Optional("payout_multiplier"): Or(int, float), - Optional("use_spoilers"): bool, - }, - str: [str, int, bool, float], - } -) - _ = Translator("Trivia", __file__) @@ -120,7 +105,7 @@ class Trivia(commands.Cog): @triviaset.command(name="maxscore") async def triviaset_max_score(self, ctx: commands.Context, score: int): """Set the total points required to win.""" - if score < 0: + if score <= 0: await ctx.send(_("Score must be greater than 0.")) return settings = self.config.guild(ctx.guild) @@ -293,18 +278,18 @@ class Trivia(commands.Cog): try: await self._save_trivia_list(ctx=ctx, attachment=parsedfile) except yaml.error.MarkedYAMLError as exc: - await ctx.send(_("Invalid syntax: ") + str(exc)) + await ctx.send(_("Invalid syntax:\n") + box(str(exc))) except yaml.error.YAMLError: await ctx.send( _("There was an error parsing the trivia list. See logs for more info.") ) LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename) - except SchemaError as e: + except schema.SchemaError as exc: await ctx.send( _( "The custom trivia list was not saved." " The file does not follow the proper data format.\n{schema_error}" - ).format(schema_error=box(e)) + ).format(schema_error=box(format_schema_error(exc))) ) @commands.is_owner() @@ -740,8 +725,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]: ------ InvalidListError Parsing of list's YAML file failed. - SchemaError - The list does not adhere to the schema. """ with path.open(encoding="utf-8") as file: try: @@ -751,6 +734,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]: try: TRIVIA_LIST_SCHEMA.validate(trivia_dict) - except SchemaError as exc: + except schema.SchemaError as exc: raise InvalidListError("The list does not adhere to the schema.") from exc return trivia_dict diff --git a/tests/cogs/test_trivia.py b/tests/cogs/test_trivia.py index 8d8d9baae..a5795495e 100644 --- a/tests/cogs/test_trivia.py +++ b/tests/cogs/test_trivia.py @@ -1,7 +1,17 @@ import textwrap +from typing import Any +import pytest import yaml -from schema import SchemaError +from schema import And, Optional, SchemaError + +from redbot.cogs.trivia.schema import ( + ALWAYS_MATCH, + MATCH_ALL_BUT_STR, + NO_QUESTIONS_ERROR_MSG, + TRIVIA_LIST_SCHEMA, + format_schema_error, +) def test_trivia_lists(): @@ -25,3 +35,44 @@ def test_trivia_lists(): f"- {name}:\n{textwrap.indent(error, ' ')}" for name, error in problem_lists ) raise TypeError("The following lists contain errors:\n" + msg) + + +def _get_error_message(*keys: Any, key: str = "UNKNOWN", parent_key: str = "UNKNOWN") -> str: + if not keys: + return TRIVIA_LIST_SCHEMA._error + + current = TRIVIA_LIST_SCHEMA.schema + for key_name in keys: + if isinstance(current, And): + current = current.args[0] + current = current[key_name] + return str(current._error).format(key=repr(key), parent_key=repr(parent_key)) + + +@pytest.mark.parametrize( + "data,error_msg", + ( + ("text", _get_error_message()), + ({"AUTHOR": 123}, _get_error_message(Optional("AUTHOR"), key="AUTHOR")), + ({"CONFIG": 123}, _get_error_message(Optional("CONFIG"), key="CONFIG")), + ( + {"CONFIG": {"key": "value"}}, + _get_error_message(Optional("CONFIG"), ALWAYS_MATCH, key="key", parent_key="CONFIG"), + ), + ( + {"CONFIG": {"bot_plays": "wrong type"}}, + _get_error_message( + Optional("CONFIG"), Optional("bot_plays"), key="bot_plays", parent_key="CONFIG" + ), + ), + ({"AUTHOR": "Correct type but no questions."}, NO_QUESTIONS_ERROR_MSG), + ({"Question": "wrong type"}, _get_error_message(str, key="Question")), + ({"Question": [{"wrong": "type"}]}, _get_error_message(str, key="Question")), + ({123: "wrong key type"}, _get_error_message(MATCH_ALL_BUT_STR, key="123")), + ), +) +def test_trivia_schema_error_messages(data: Any, error_msg: str): + with pytest.raises(SchemaError) as exc: + TRIVIA_LIST_SCHEMA.validate(data) + + assert format_schema_error(exc.value) == error_msg