mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
Improve validation in trivia (#5947)
Co-authored-by: Jakub Kuczys <me@jacken.men>
This commit is contained in:
parent
7db635a05b
commit
b493103dcb
121
redbot/cogs/trivia/schema.py
Normal file
121
redbot/cogs/trivia/schema.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
from typing import Any, NoReturn
|
||||||
|
|
||||||
|
from schema import And, Const, Optional, Schema, SchemaError, SchemaMissingKeyError, Use
|
||||||
|
|
||||||
|
from redbot.core.i18n import Translator
|
||||||
|
|
||||||
|
__all__ = ("TRIVIA_LIST_SCHEMA", "format_schema_error")
|
||||||
|
|
||||||
|
T_ = Translator("Trivia", __file__)
|
||||||
|
KEY_ERROR_MSG_RE = re.compile(r"Key '(.+)' error:")
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaErrorMessage(str):
|
||||||
|
def format(self, *args: Any, **kwargs: Any) -> str:
|
||||||
|
return T_(str(self))
|
||||||
|
|
||||||
|
|
||||||
|
def int_or_float(value: Any) -> float:
|
||||||
|
if not isinstance(value, (float, int)):
|
||||||
|
raise TypeError("Value needs to be an integer or a float.")
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def not_str(value: Any) -> float:
|
||||||
|
if isinstance(value, str):
|
||||||
|
raise TypeError("Value needs to not be a string.")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
_ = SchemaErrorMessage
|
||||||
|
NO_QUESTIONS_ERROR_MSG = _("The trivia list does not contain any questions.")
|
||||||
|
ALWAYS_MATCH = Optional(Use(lambda x: x))
|
||||||
|
MATCH_ALL_BUT_STR = Optional(Use(not_str))
|
||||||
|
TRIVIA_LIST_SCHEMA = Schema(
|
||||||
|
{
|
||||||
|
Optional("AUTHOR"): And(str, error=_("{key} key must be a text value.")),
|
||||||
|
Optional("CONFIG"): And(
|
||||||
|
{
|
||||||
|
Optional("max_score"): And(
|
||||||
|
int,
|
||||||
|
lambda n: n >= 1,
|
||||||
|
error=_("{key} key in {parent_key} must be a positive integer."),
|
||||||
|
),
|
||||||
|
Optional("timeout"): And(
|
||||||
|
Use(int_or_float),
|
||||||
|
lambda n: n > 0.0,
|
||||||
|
error=_("{key} key in {parent_key} must be a positive number."),
|
||||||
|
),
|
||||||
|
Optional("delay"): And(
|
||||||
|
Use(int_or_float),
|
||||||
|
lambda n: n >= 4.0,
|
||||||
|
error=_(
|
||||||
|
"{key} key in {parent_key} must be a positive number"
|
||||||
|
" greater than or equal to 4."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Optional("bot_plays"): Const(
|
||||||
|
bool, error=_("{key} key in {parent_key} must be either true or false.")
|
||||||
|
),
|
||||||
|
Optional("reveal_answer"): Const(
|
||||||
|
bool, error=_("{key} key in {parent_key} must be either true or false.")
|
||||||
|
),
|
||||||
|
Optional("payout_multiplier"): And(
|
||||||
|
Use(int_or_float),
|
||||||
|
lambda n: n >= 0.0,
|
||||||
|
error=_("{key} key in {parent_key} must be a non-negative number."),
|
||||||
|
),
|
||||||
|
Optional("use_spoilers"): Const(
|
||||||
|
bool, error=_("{key} key in {parent_key} must be either true or false.")
|
||||||
|
),
|
||||||
|
# This matches any extra key and always fails validation
|
||||||
|
# for the purpose of better error messages.
|
||||||
|
ALWAYS_MATCH: And(
|
||||||
|
lambda __: False,
|
||||||
|
error=_("{key} is not a key that can be specified in {parent_key}."),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
error=_("{key} should be a 'key: value' mapping."),
|
||||||
|
),
|
||||||
|
str: And(
|
||||||
|
[str, int, bool, float],
|
||||||
|
error=_("Value of question {key} is not a list of text values (answers)."),
|
||||||
|
),
|
||||||
|
# This matches any extra key and always fails validation
|
||||||
|
# for the purpose of better error messages.
|
||||||
|
MATCH_ALL_BUT_STR: And(
|
||||||
|
lambda __: False,
|
||||||
|
error=_("A key of question {key} is not a text value."),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
error=_("A trivia list should be a 'key: value' mapping."),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_schema_error(exc: SchemaError) -> str:
|
||||||
|
if isinstance(exc, SchemaMissingKeyError):
|
||||||
|
return NO_QUESTIONS_ERROR_MSG.format()
|
||||||
|
|
||||||
|
# dict.fromkeys is used for de-duplication with order preservation
|
||||||
|
errors = {idx: msg for idx, msg in enumerate(exc.errors) if msg is not None}
|
||||||
|
if not errors:
|
||||||
|
return str(exc)
|
||||||
|
error_idx, error_msg_fmt = errors.popitem()
|
||||||
|
|
||||||
|
autos = dict.fromkeys(msg for msg in itertools.islice(exc.autos, error_idx) if msg is not None)
|
||||||
|
keys = [match[1] for msg in autos if (match := KEY_ERROR_MSG_RE.fullmatch(msg)) is not None]
|
||||||
|
key_count = len(keys)
|
||||||
|
if key_count == 2:
|
||||||
|
key = keys[-1]
|
||||||
|
parent_key = keys[-2]
|
||||||
|
elif key_count == 1:
|
||||||
|
key = keys[-1]
|
||||||
|
# should only happen for messages where this field isn't used
|
||||||
|
parent_key = "UNKNOWN"
|
||||||
|
else:
|
||||||
|
# should only happen for messages where neither of the fields are used
|
||||||
|
key = parent_key = "UNKNOWN"
|
||||||
|
|
||||||
|
return error_msg_fmt.format(key=repr(key), parent_key=repr(parent_key))
|
||||||
@ -4,7 +4,7 @@ import math
|
|||||||
import pathlib
|
import pathlib
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any, Dict, List, Literal, Union
|
from typing import Any, Dict, List, Literal, Union
|
||||||
from schema import Schema, Optional, Or, SchemaError
|
import schema
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import yaml
|
import yaml
|
||||||
@ -23,26 +23,11 @@ from .checks import trivia_stop_check
|
|||||||
from .converters import finite_float
|
from .converters import finite_float
|
||||||
from .log import LOG
|
from .log import LOG
|
||||||
from .session import TriviaSession
|
from .session import TriviaSession
|
||||||
|
from .schema import TRIVIA_LIST_SCHEMA, format_schema_error
|
||||||
|
|
||||||
__all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list")
|
__all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list")
|
||||||
|
|
||||||
UNIQUE_ID = 0xB3C0E453
|
UNIQUE_ID = 0xB3C0E453
|
||||||
TRIVIA_LIST_SCHEMA = Schema(
|
|
||||||
{
|
|
||||||
Optional("AUTHOR"): str,
|
|
||||||
Optional("CONFIG"): {
|
|
||||||
Optional("max_score"): int,
|
|
||||||
Optional("timeout"): Or(int, float),
|
|
||||||
Optional("delay"): Or(int, float),
|
|
||||||
Optional("bot_plays"): bool,
|
|
||||||
Optional("reveal_answer"): bool,
|
|
||||||
Optional("payout_multiplier"): Or(int, float),
|
|
||||||
Optional("use_spoilers"): bool,
|
|
||||||
},
|
|
||||||
str: [str, int, bool, float],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_ = Translator("Trivia", __file__)
|
_ = Translator("Trivia", __file__)
|
||||||
|
|
||||||
|
|
||||||
@ -120,7 +105,7 @@ class Trivia(commands.Cog):
|
|||||||
@triviaset.command(name="maxscore")
|
@triviaset.command(name="maxscore")
|
||||||
async def triviaset_max_score(self, ctx: commands.Context, score: int):
|
async def triviaset_max_score(self, ctx: commands.Context, score: int):
|
||||||
"""Set the total points required to win."""
|
"""Set the total points required to win."""
|
||||||
if score < 0:
|
if score <= 0:
|
||||||
await ctx.send(_("Score must be greater than 0."))
|
await ctx.send(_("Score must be greater than 0."))
|
||||||
return
|
return
|
||||||
settings = self.config.guild(ctx.guild)
|
settings = self.config.guild(ctx.guild)
|
||||||
@ -293,18 +278,18 @@ class Trivia(commands.Cog):
|
|||||||
try:
|
try:
|
||||||
await self._save_trivia_list(ctx=ctx, attachment=parsedfile)
|
await self._save_trivia_list(ctx=ctx, attachment=parsedfile)
|
||||||
except yaml.error.MarkedYAMLError as exc:
|
except yaml.error.MarkedYAMLError as exc:
|
||||||
await ctx.send(_("Invalid syntax: ") + str(exc))
|
await ctx.send(_("Invalid syntax:\n") + box(str(exc)))
|
||||||
except yaml.error.YAMLError:
|
except yaml.error.YAMLError:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
_("There was an error parsing the trivia list. See logs for more info.")
|
_("There was an error parsing the trivia list. See logs for more info.")
|
||||||
)
|
)
|
||||||
LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename)
|
LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename)
|
||||||
except SchemaError as e:
|
except schema.SchemaError as exc:
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
_(
|
_(
|
||||||
"The custom trivia list was not saved."
|
"The custom trivia list was not saved."
|
||||||
" The file does not follow the proper data format.\n{schema_error}"
|
" The file does not follow the proper data format.\n{schema_error}"
|
||||||
).format(schema_error=box(e))
|
).format(schema_error=box(format_schema_error(exc)))
|
||||||
)
|
)
|
||||||
|
|
||||||
@commands.is_owner()
|
@commands.is_owner()
|
||||||
@ -740,8 +725,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]:
|
|||||||
------
|
------
|
||||||
InvalidListError
|
InvalidListError
|
||||||
Parsing of list's YAML file failed.
|
Parsing of list's YAML file failed.
|
||||||
SchemaError
|
|
||||||
The list does not adhere to the schema.
|
|
||||||
"""
|
"""
|
||||||
with path.open(encoding="utf-8") as file:
|
with path.open(encoding="utf-8") as file:
|
||||||
try:
|
try:
|
||||||
@ -751,6 +734,6 @@ def get_list(path: pathlib.Path) -> Dict[str, Any]:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
TRIVIA_LIST_SCHEMA.validate(trivia_dict)
|
TRIVIA_LIST_SCHEMA.validate(trivia_dict)
|
||||||
except SchemaError as exc:
|
except schema.SchemaError as exc:
|
||||||
raise InvalidListError("The list does not adhere to the schema.") from exc
|
raise InvalidListError("The list does not adhere to the schema.") from exc
|
||||||
return trivia_dict
|
return trivia_dict
|
||||||
|
|||||||
@ -1,7 +1,17 @@
|
|||||||
import textwrap
|
import textwrap
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from schema import SchemaError
|
from schema import And, Optional, SchemaError
|
||||||
|
|
||||||
|
from redbot.cogs.trivia.schema import (
|
||||||
|
ALWAYS_MATCH,
|
||||||
|
MATCH_ALL_BUT_STR,
|
||||||
|
NO_QUESTIONS_ERROR_MSG,
|
||||||
|
TRIVIA_LIST_SCHEMA,
|
||||||
|
format_schema_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_trivia_lists():
|
def test_trivia_lists():
|
||||||
@ -25,3 +35,44 @@ def test_trivia_lists():
|
|||||||
f"- {name}:\n{textwrap.indent(error, ' ')}" for name, error in problem_lists
|
f"- {name}:\n{textwrap.indent(error, ' ')}" for name, error in problem_lists
|
||||||
)
|
)
|
||||||
raise TypeError("The following lists contain errors:\n" + msg)
|
raise TypeError("The following lists contain errors:\n" + msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_error_message(*keys: Any, key: str = "UNKNOWN", parent_key: str = "UNKNOWN") -> str:
|
||||||
|
if not keys:
|
||||||
|
return TRIVIA_LIST_SCHEMA._error
|
||||||
|
|
||||||
|
current = TRIVIA_LIST_SCHEMA.schema
|
||||||
|
for key_name in keys:
|
||||||
|
if isinstance(current, And):
|
||||||
|
current = current.args[0]
|
||||||
|
current = current[key_name]
|
||||||
|
return str(current._error).format(key=repr(key), parent_key=repr(parent_key))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"data,error_msg",
|
||||||
|
(
|
||||||
|
("text", _get_error_message()),
|
||||||
|
({"AUTHOR": 123}, _get_error_message(Optional("AUTHOR"), key="AUTHOR")),
|
||||||
|
({"CONFIG": 123}, _get_error_message(Optional("CONFIG"), key="CONFIG")),
|
||||||
|
(
|
||||||
|
{"CONFIG": {"key": "value"}},
|
||||||
|
_get_error_message(Optional("CONFIG"), ALWAYS_MATCH, key="key", parent_key="CONFIG"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
{"CONFIG": {"bot_plays": "wrong type"}},
|
||||||
|
_get_error_message(
|
||||||
|
Optional("CONFIG"), Optional("bot_plays"), key="bot_plays", parent_key="CONFIG"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
({"AUTHOR": "Correct type but no questions."}, NO_QUESTIONS_ERROR_MSG),
|
||||||
|
({"Question": "wrong type"}, _get_error_message(str, key="Question")),
|
||||||
|
({"Question": [{"wrong": "type"}]}, _get_error_message(str, key="Question")),
|
||||||
|
({123: "wrong key type"}, _get_error_message(MATCH_ALL_BUT_STR, key="123")),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def test_trivia_schema_error_messages(data: Any, error_msg: str):
|
||||||
|
with pytest.raises(SchemaError) as exc:
|
||||||
|
TRIVIA_LIST_SCHEMA.validate(data)
|
||||||
|
|
||||||
|
assert format_schema_error(exc.value) == error_msg
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user