diff --git a/cogs/utils/dataIO.py b/cogs/utils/dataIO.py index e2d1ece05..998a319e4 100644 --- a/cogs/utils/dataIO.py +++ b/cogs/utils/dataIO.py @@ -1,20 +1,83 @@ import json +import os +import logging +from shutil import copy -def fileIO(filename, IO, data=None): - if IO == "save" and data != None: - with open(filename, encoding='utf-8', mode="w") as f: - f.write(json.dumps(data,indent=4,sort_keys=True,separators=(',',' : '))) - elif IO == "load" and data == None: - with open(filename, encoding='utf-8', mode="r") as f: - return json.loads(f.read()) - elif IO == "check" and data == None: +class InvalidFileIO(Exception): + pass + +class CorruptedJSON(Exception): + pass + +class DataIO(): + def __init__(self): + self.logger = logging.getLogger("red") + + def save_json(self, filename, data): + """Saves and backups json file""" + bak_file = os.path.splitext(filename)[0]+'.bak' + self._save_json(filename, data) + copy(filename, bak_file) # Backup copy + + def load_json(self, filename): + """Loads json file and restores backup copy in case of corrupted file""" try: - with open(filename, encoding='utf-8', mode="r") as f: - return True - except: + return self._read_json(filename) + except json.decoder.JSONDecodeError: + result = self._restore_json(filename) + if result: + return self._read_json(filename) # Which hopefully will work + else: + raise CorruptedJSON("{} is corrupted and no backup copy is" + " available.".format(filename)) + + def is_valid_json(self, filename): + """Returns True if readable json file, False if not existing. + Tries to restore backup copy if corrupted""" + try: + data = self._read_json(filename) + except FileNotFoundError: return False - else: - raise("Invalid fileIO call") + except json.decoder.JSONDecodeError: + result = self._restore_json(filename) + return result # If False, no backup copy, might as well + else: # allow the overwrite + return True + + def _read_json(self, filename): + with open(filename, encoding='utf-8', mode="r") as f: + data = json.loads(f.read()) + return data + + def _save_json(self, filename, data): + with open(filename, encoding='utf-8', mode="w") as f: + f.write(json.dumps(data,indent=4,sort_keys=True, + separators=(',',' : '))) + return data + + def _restore_json(self, filename): + bak_file = os.path.splitext(filename)[0]+'.bak' + if os.path.isfile(bak_file): + copy(bak_file, filename) # Restore last working copy + self.logger.warning("{} was corrupted. Restored " + "backup copy.".format(filename)) + return True + else: + self.logger.critical("{} is corrupted and there is no " + "backup copy available.".format(filename)) + return False + + def _legacy_fileio(self, filename, IO, data=None): + """Old fileIO provided for backwards compatibility""" + if IO == "save" and data != None: + return self.save_json(filename, data) + elif IO == "load" and data == None: + return self.load_json(filename) + elif IO == "check" and data == None: + return self.is_valid_json(filename) + else: + raise InvalidFileIO("FileIO was called with invalid" + " parameters") def get_value(filename, key): with open(filename, encoding='utf-8', mode="r") as f: @@ -25,4 +88,7 @@ def set_value(filename, key, value): data = fileIO(filename, "load") data[key] = value fileIO(filename, "save", data) - return True \ No newline at end of file + return True + +dataIO = DataIO() +fileIO = dataIO._legacy_fileio # backwards compatibility \ No newline at end of file diff --git a/red.py b/red.py index 8e63fbaec..d1cc43121 100644 --- a/red.py +++ b/red.py @@ -1,6 +1,7 @@ from discord.ext import commands import discord from cogs.utils.settings import Settings +from cogs.utils.dataIO import dataIO import json import asyncio import os @@ -209,12 +210,9 @@ def check_configs(): if settings.default_mod == "": settings.default_mod = "Process" - cogs_s_path = "data/red/cogs.json" - cogs = {} - if not os.path.isfile(cogs_s_path): + if not os.path.isfile("data/red/cogs.json"): print("Creating new cogs.json...") - with open(cogs_s_path, "w") as f: - f.write(json.dumps(cogs)) + dataIO.save_json("data/red/cogs.json", {}) def set_logger(): @@ -262,12 +260,9 @@ def get_answer(): def set_cog(cog, value): - with open('data/red/cogs.json', "r") as f: - data = json.load(f) + data = dataIO.load_json("data/red/cogs.json") data[cog] = value - with open('data/red/cogs.json', "w") as f: - f.write(json.dumps(data)) - + dataIO.save_json("data/red/cogs.json", data) def load_cogs(): try: @@ -279,8 +274,7 @@ def load_cogs(): no_prompt = False try: - with open('data/red/cogs.json', "r") as f: - registry = json.load(f) + registry = dataIO.load_json("data/red/cogs.json") except: registry = {} @@ -319,8 +313,7 @@ def load_cogs(): registry[extension] = False if extensions: - with open('data/red/cogs.json', "w") as f: - f.write(json.dumps(registry)) + dataIO.save_json("data/red/cogs.json", registry) if failed: print("\nFailed to load: ", end="")