[Trivia] Validate custom trivia file upload using schema (#4659)

* Add custom trivia list schema validation and test

* Address review

* Improve error formatting in trivia list test

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
Grant LeBlanc 2021-08-31 14:44:25 -04:00 committed by GitHub
parent 91ecd6560a
commit 173127e015
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 66 additions and 31 deletions

View File

@ -3,7 +3,8 @@ import asyncio
import math
import pathlib
from collections import Counter
from typing import List, Literal
from typing import Any, Dict, List, Literal
from schema import Schema, Optional, Or, SchemaError
import io
import yaml
@ -23,9 +24,23 @@ from .converters import finite_float
from .log import LOG
from .session import TriviaSession
__all__ = ["Trivia", "UNIQUE_ID", "get_core_lists"]
__all__ = ("Trivia", "UNIQUE_ID", "InvalidListError", "get_core_lists", "get_list")
UNIQUE_ID = 0xB3C0E453
TRIVIA_LIST_SCHEMA = Schema(
{
Optional("AUTHOR"): str,
Optional("CONFIG"): {
Optional("max_score"): int,
Optional("timeout"): Or(int, float),
Optional("delay"): Or(int, float),
Optional("bot_plays"): bool,
Optional("reveal_answer"): bool,
Optional("payout_multiplier"): Or(int, float),
},
str: [str, int, bool, float],
}
)
_ = Translator("Trivia", __file__)
@ -282,6 +297,13 @@ class Trivia(commands.Cog):
_("There was an error parsing the trivia list. See logs for more info.")
)
LOG.exception("Custom Trivia file %s failed to upload", parsedfile.filename)
except SchemaError as e:
await ctx.send(
_(
"The custom trivia list was not saved."
" The file does not follow the proper data format.\n{schema_error}"
).format(schema_error=box(e))
)
@commands.is_owner()
@triviaset_custom.command(name="delete", aliases=["remove"])
@ -615,13 +637,7 @@ class Trivia(commands.Cog):
except StopIteration:
raise FileNotFoundError("Could not find the `{}` category.".format(category))
with path.open(encoding="utf-8") as file:
try:
dict_ = yaml.safe_load(file)
except yaml.error.YAMLError as exc:
raise InvalidListError("YAML parsing failed.") from exc
else:
return dict_
return get_list(path)
async def _save_trivia_list(
self, ctx: commands.Context, attachment: discord.Attachment
@ -683,9 +699,10 @@ class Trivia(commands.Cog):
return
buffer = io.BytesIO(await attachment.read())
yaml.safe_load(buffer)
buffer.seek(0)
trivia_dict = yaml.safe_load(buffer)
TRIVIA_LIST_SCHEMA.validate(trivia_dict)
buffer.seek(0)
with file.open("wb") as fp:
fp.write(buffer.read())
await ctx.send(_("Saved Trivia list as {filename}.").format(filename=filename))
@ -709,3 +726,27 @@ def get_core_lists() -> List[pathlib.Path]:
"""Return a list of paths for all trivia lists packaged with the bot."""
core_lists_path = pathlib.Path(__file__).parent.resolve() / "data/lists"
return list(core_lists_path.glob("*.yaml"))
def get_list(path: pathlib.Path) -> Dict[str, Any]:
"""
Returns a trivia list dictionary from the given path.
Raises
------
InvalidListError
Parsing of list's YAML file failed.
SchemaError
The list does not adhere to the schema.
"""
with path.open(encoding="utf-8") as file:
try:
trivia_dict = yaml.safe_load(file)
except yaml.error.YAMLError as exc:
raise InvalidListError("YAML parsing failed.") from exc
try:
TRIVIA_LIST_SCHEMA.validate(trivia_dict)
except SchemaError as exc:
raise InvalidListError("The list does not adhere to the schema.") from exc
return trivia_dict

View File

@ -1,33 +1,27 @@
import textwrap
import yaml
from schema import SchemaError
def test_trivia_lists():
from redbot.cogs.trivia import get_core_lists
from redbot.cogs.trivia import InvalidListError, get_core_lists, get_list
list_names = get_core_lists()
assert list_names
problem_lists = []
for l in list_names:
with l.open(encoding="utf-8") as f:
try:
dict_ = yaml.safe_load(f)
except yaml.error.YAMLError as e:
problem_lists.append((l.stem, "YAML error:\n{!s}".format(e)))
try:
get_list(l)
except InvalidListError as exc:
e = exc.__cause__
if isinstance(e, SchemaError):
problem_lists.append((l.stem, f"SCHEMA error:\n{e!s}"))
else:
for key in list(dict_.keys()):
if key == "CONFIG":
if not isinstance(dict_[key], dict):
problem_lists.append((l.stem, "CONFIG is not a dict"))
elif key == "AUTHOR":
if not isinstance(dict_[key], str):
problem_lists.append((l.stem, "AUTHOR is not a string"))
else:
if not isinstance(dict_[key], list):
problem_lists.append(
(l.stem, "The answers for '{}' are not a list".format(key))
)
problem_lists.append((l.stem, f"YAML error:\n{e!s}"))
if problem_lists:
msg = ""
for l in problem_lists:
msg += "{}: {}\n".format(l[0], l[1])
for name, error in problem_lists:
msg += f"- {name}:\n{textwrap.indent(error, ' ')}"
raise TypeError("The following lists contain errors:\n" + msg)