Improve validation in trivia (#5947)

Co-authored-by: Jakub Kuczys <me@jacken.men>
This commit is contained in:
Vexed 2023-01-02 04:24:27 +00:00 committed by GitHub
parent 7db635a05b
commit b493103dcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 180 additions and 25 deletions

View File

@ -0,0 +1,121 @@
import itertools
import re
from typing import Any, NoReturn
from schema import And, Const, Optional, Schema, SchemaError, SchemaMissingKeyError, Use
from redbot.core.i18n import Translator
__all__ = ("TRIVIA_LIST_SCHEMA", "format_schema_error")
T_ = Translator("Trivia", __file__)
KEY_ERROR_MSG_RE = re.compile(r"Key '(.+)' error:")
class SchemaErrorMessage(str):
def format(self, *args: Any, **kwargs: Any) -> str:
return T_(str(self))
def int_or_float(value: Any) -> float:
if not isinstance(value, (float, int)):
raise TypeError("Value needs to be an integer or a float.")
return float(value)
def not_str(value: Any) -> float:
if isinstance(value, str):
raise TypeError("Value needs to not be a string.")
return value
_ = SchemaErrorMessage
NO_QUESTIONS_ERROR_MSG = _("The trivia list does not contain any questions.")
ALWAYS_MATCH = Optional(Use(lambda x: x))
MATCH_ALL_BUT_STR = Optional(Use(not_str))
TRIVIA_LIST_SCHEMA = Schema(
{
Optional("AUTHOR"): And(str, error=_("{key} key must be a text value.")),
Optional("CONFIG"): And(
{
Optional("max_score"): And(
int,
lambda n: n >= 1,
error=_("{key} key in {parent_key} must be a positive integer."),
),
Optional("timeout"): And(
Use(int_or_float),
lambda n: n > 0.0,
error=_("{key} key in {parent_key} must be a positive number."),
),
Optional("delay"): And(
Use(int_or_float),
lambda n: n >= 4.0,
error=_(
"{key} key in {parent_key} must be a positive number"
" greater than or equal to 4."
),
),
Optional("bot_plays"): Const(
bool, error=_("{key} key in {parent_key} must be either true or false.")
),
Optional("reveal_answer"): Const(
bool, error=_("{key} key in {parent_key} must be either true or false.")
),
Optional("payout_multiplier"): And(
Use(int_or_float),
lambda n: n >= 0.0,
error=_("{key} key in {parent_key} must be a non-negative number."),
),
Optional("use_spoilers"): Const(
bool, error=_("{key} key in {parent_key} must be either true or false.")
),
# This matches any extra key and always fails validation
# for the purpose of better error messages.
ALWAYS_MATCH: And(
lambda __: False,
error=_("{key} is not a key that can be specified in {parent_key}."),
),
},
error=_("{key} should be a 'key: value' mapping."),
),
str: And(
[str, int, bool, float],
error=_("Value of question {key} is not a list of text values (answers)."),
),
# This matches any extra key and always fails validation
# for the purpose of better error messages.
MATCH_ALL_BUT_STR: And(
lambda __: False,
error=_("A key of question {key} is not a text value."),
),
},
error=_("A trivia list should be a 'key: value' mapping."),
)
def format_schema_error(exc: SchemaError) -> str:
if isinstance(exc, SchemaMissingKeyError):
return NO_QUESTIONS_ERROR_MSG.format()
# dict.fromkeys is used for de-duplication with order preservation
errors = {idx: msg for idx, msg in enumerate(exc.errors) if msg is not None}
if not errors:
return str(exc)
error_idx, error_msg_fmt = errors.popitem()
autos = dict.fromkeys(msg for msg in itertools.islice(exc.autos, error_idx) if msg is not None)
keys = [match[1] for msg in autos if (match := KEY_ERROR_MSG_RE.fullmatch(msg)) is not None]
key_count = len(keys)
if key_count == 2:
key = keys[-1]
parent_key = keys[-2]
elif key_count == 1:
key = keys[-1]
# should only happen for messages where this field isn't used
parent_key = "UNKNOWN"
else:
# should only happen for messages where neither of the fields are used
key = parent_key = "UNKNOWN"
return error_msg_fmt.format(key=repr(key), parent_key=repr(parent_key))

View File

@ -4,7 +4,7 @@ import math
import pathlib import pathlib
from collections import Counter from collections import Counter
from typing import Any, Dict, List, Literal, Union from typing import Any, Dict, List, Literal, Union
from schema import Schema, Optional, Or, SchemaError import schema
import io import io
import yaml import yaml
@ -23,26 +23,11 @@ from .checks import trivia_stop_check
from .converters import finite_float from .converters import finite_float
from .log import LOG from .log import LOG
from .session import TriviaSession from .session import TriviaSession
from .schema import TRIVIA_LIST_SCHEMA, format_schema_error
__all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list") __all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list")
UNIQUE_ID = 0xB3C0E453 UNIQUE_ID = 0xB3C0E453
TRIVIA_LIST_SCHEMA = Schema(
{
Optional("AUTHOR"): str,
Optional("CONFIG"): {
Optional("max_score"): int,
Optional("timeout"): Or(int, float),
Optional("delay"): Or(int, float),
Optional("bot_plays"): bool,
Optional("reveal_answer"): bool,
Optional("payout_multiplier"): Or(int, float),
Optional("use_spoilers"): bool,
},
str: [str, int, bool, float],
}
)
_ = Translator("Trivia", __file__) _ = Translator("Trivia", __file__)
@ -120,7 +105,7 @@ class Trivia(commands.Cog):
@triviaset.command(name="maxscore") @triviaset.command(name="maxscore")
async def triviaset_max_score(self, ctx: commands.Context, score: int): async def triviaset_max_score(self, ctx: commands.Context, score: int):
"""Set the total points required to win.""" """Set the total points required to win."""
if score < 0: if score <= 0:
await ctx.send(_("Score must be greater than 0.")) await ctx.send(_("Score must be greater than 0."))
return return
settings = self.config.guild(ctx.guild) settings = self.config.guild(ctx.guild)
@ -293,18 +278,18 @@ class Trivia(commands.Cog):
try: try:
await self._save_trivia_list(ctx=ctx, attachment=parsedfile) await self._save_trivia_list(ctx=ctx, attachment=parsedfile)
except yaml.error.MarkedYAMLError as exc: except yaml.error.MarkedYAMLError as exc:
await ctx.send(_("Invalid syntax: ") + str(exc)) await ctx.send(_("Invalid syntax:\n") + box(str(exc)))
except yaml.error.YAMLError: except yaml.error.YAMLError:
await ctx.send( await ctx.send(
_("There was an error parsing the trivia list. See logs for more info.") _("There was an error parsing the trivia list. See logs for more info.")
) )
LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename) LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename)
except SchemaError as e: except schema.SchemaError as exc:
await ctx.send( await ctx.send(
_( _(
"The custom trivia list was not saved." "The custom trivia list was not saved."
" The file does not follow the proper data format.\n{schema_error}" " The file does not follow the proper data format.\n{schema_error}"
).format(schema_error=box(e)) ).format(schema_error=box(format_schema_error(exc)))
) )
@commands.is_owner() @commands.is_owner()
@ -740,8 +725,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]:
------ ------
InvalidListError InvalidListError
Parsing of list's YAML file failed. Parsing of list's YAML file failed.
SchemaError
The list does not adhere to the schema.
""" """
with path.open(encoding="utf-8") as file: with path.open(encoding="utf-8") as file:
try: try:
@ -751,6 +734,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]:
try: try:
TRIVIA_LIST_SCHEMA.validate(trivia_dict) TRIVIA_LIST_SCHEMA.validate(trivia_dict)
except SchemaError as exc: except schema.SchemaError as exc:
raise InvalidListError("The list does not adhere to the schema.") from exc raise InvalidListError("The list does not adhere to the schema.") from exc
return trivia_dict return trivia_dict

View File

@ -1,7 +1,17 @@
import textwrap import textwrap
from typing import Any
import pytest
import yaml import yaml
from schema import SchemaError from schema import And, Optional, SchemaError
from redbot.cogs.trivia.schema import (
ALWAYS_MATCH,
MATCH_ALL_BUT_STR,
NO_QUESTIONS_ERROR_MSG,
TRIVIA_LIST_SCHEMA,
format_schema_error,
)
def test_trivia_lists(): def test_trivia_lists():
@ -25,3 +35,44 @@ def test_trivia_lists():
f"- {name}:\n{textwrap.indent(error, ' ')}" for name, error in problem_lists f"- {name}:\n{textwrap.indent(error, ' ')}" for name, error in problem_lists
) )
raise TypeError("The following lists contain errors:\n" + msg) raise TypeError("The following lists contain errors:\n" + msg)
def _get_error_message(*keys: Any, key: str = "UNKNOWN", parent_key: str = "UNKNOWN") -> str:
if not keys:
return TRIVIA_LIST_SCHEMA._error
current = TRIVIA_LIST_SCHEMA.schema
for key_name in keys:
if isinstance(current, And):
current = current.args[0]
current = current[key_name]
return str(current._error).format(key=repr(key), parent_key=repr(parent_key))
@pytest.mark.parametrize(
"data,error_msg",
(
("text", _get_error_message()),
({"AUTHOR": 123}, _get_error_message(Optional("AUTHOR"), key="AUTHOR")),
({"CONFIG": 123}, _get_error_message(Optional("CONFIG"), key="CONFIG")),
(
{"CONFIG": {"key": "value"}},
_get_error_message(Optional("CONFIG"), ALWAYS_MATCH, key="key", parent_key="CONFIG"),
),
(
{"CONFIG": {"bot_plays": "wrong type"}},
_get_error_message(
Optional("CONFIG"), Optional("bot_plays"), key="bot_plays", parent_key="CONFIG"
),
),
({"AUTHOR": "Correct type but no questions."}, NO_QUESTIONS_ERROR_MSG),
({"Question": "wrong type"}, _get_error_message(str, key="Question")),
({"Question": [{"wrong": "type"}]}, _get_error_message(str, key="Question")),
({123: "wrong key type"}, _get_error_message(MATCH_ALL_BUT_STR, key="123")),
),
)
def test_trivia_schema_error_messages(data: Any, error_msg: str):
with pytest.raises(SchemaError) as exc:
TRIVIA_LIST_SCHEMA.validate(data)
assert format_schema_error(exc.value) == error_msg