Add redbot-setup restore cli command (#6709)

This commit is contained in:
Jakub Kuczys
2026-05-13 21:18:17 +02:00
committed by GitHub
parent edce32364f
commit a234fc1e02
4 changed files with 866 additions and 47 deletions
+106 -22
View File
@@ -9,9 +9,12 @@ import os
import re
import shutil
import tarfile
import time
import warnings
from datetime import datetime
from io import BytesIO
from pathlib import Path
from tarfile import TarInfo
from typing import (
AsyncIterable,
AsyncIterator,
@@ -24,6 +27,7 @@ from typing import (
Optional,
Union,
TypeVar,
TypedDict,
TYPE_CHECKING,
Tuple,
cast,
@@ -33,8 +37,9 @@ import aiohttp
import discord
from packaging.requirements import Requirement
import rapidfuzz
from rich.progress import ProgressColumn
from rich.progress_bar import ProgressBar
import rich.progress
from rich.console import Console
from rich.text import Text
from red_commons.logging import VERBOSE, TRACE
from redbot import VersionInfo
@@ -58,6 +63,8 @@ __all__ = (
"fetch_latest_red_version_info",
"deprecated_removed",
"RichIndefiniteBarColumn",
"RichSpeedColumn",
"detailed_progress",
"cli_level_to_log_level",
)
@@ -216,7 +223,27 @@ async def format_fuzzy_results(
return "Perhaps you wanted one of these? " + box("\n".join(lines), lang="vhdl")
def _tar_addfile_from_string(tar: tarfile.TarFile, name: str, string: str) -> None:
encoded = string.encode("utf-8")
fp = BytesIO(encoded)
# TarInfo needs `mtime` and `size`
# https://stackoverflow.com/q/53306000
tar_info = tarfile.TarInfo(name)
tar_info.mtime = time.time()
tar_info.size = len(encoded)
tar.addfile(tar_info, fp)
class BackupDetails(TypedDict):
backup_version: int
async def create_backup(dest: Path = Path.home()) -> Optional[Path]:
# version of backup
BACKUP_VERSION = 2
data_path = Path(data_manager.core_data_path().parent)
if not data_path.exists():
return None
@@ -226,13 +253,20 @@ async def create_backup(dest: Path = Path.home()) -> Optional[Path]:
backup_fpath = dest / f"redv3_{data_manager.instance_name()}_{timestr}.tar.gz"
to_backup = []
# we need trailing separator to not exclude files and folders that only start with these names
exclusions = [
"__pycache__",
# Lavalink will be downloaded on Audio load
"Lavalink.jar",
os.path.join("Downloader", "lib"),
os.path.join("CogManager", "cogs"),
os.path.join("RepoManager", "repos"),
os.path.join("Audio", "logs"),
# cogs and repos installed through Downloader can be reinstalled using restore command
os.path.join("Downloader", "lib", ""),
os.path.join("CogManager", "cogs", ""),
os.path.join("RepoManager", "repos", ""),
os.path.join("Audio", "logs", ""),
# these files are created during backup so we exclude them from data path backup
os.path.join("RepoManager", "repos.json"),
"instance.json",
"backup_details.json",
]
# Avoiding circular imports
@@ -243,19 +277,42 @@ async def create_backup(dest: Path = Path.home()) -> Optional[Path]:
repo_output = []
for repo in repo_mgr.repos:
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
repos_file = data_path / "cogs" / "RepoManager" / "repos.json"
with repos_file.open("w") as fs:
json.dump(repo_output, fs, indent=4)
instance_file = data_path / "instance.json"
with instance_file.open("w") as fs:
json.dump({data_manager.instance_name(): data_manager.basic_config}, fs, indent=4)
for f in data_path.glob("**/*"):
if not any(ex in str(f) for ex in exclusions) and f.is_file():
to_backup.append(f)
with tarfile.open(str(backup_fpath), "w:gz") as tar:
for f in to_backup:
tar.add(str(f), arcname=str(f.relative_to(data_path)), recursive=False)
with rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
RichIndefiniteBarColumn(),
rich.progress.TextColumn("{task.completed} files processed"),
rich.progress.TimeElapsedColumn(),
) as progress:
for f in progress.track(
data_path.glob("**/*"), description="Preparing files for backup..."
):
if not any(ex in str(f) for ex in exclusions) and f.is_file():
to_backup.append(f)
backup_details: BackupDetails = {
"backup_version": BACKUP_VERSION,
}
with tarfile.open(str(backup_fpath), "w:gz", dereference=True) as tar:
with detailed_progress(unit="files") as progress:
progress_tracker = progress.track(to_backup, description="Compressing data")
for f in progress_tracker:
tar.add(str(f), arcname=str(f.relative_to(data_path)), recursive=False)
# add repos backup
repos_data = json.dumps(repo_output, indent=4)
_tar_addfile_from_string(tar, "cogs/RepoManager/repos.json", repos_data)
# add instance's original data
instance_data = json.dumps(
{data_manager.instance_name(): data_manager.basic_config}, indent=4
)
_tar_addfile_from_string(tar, "instance.json", instance_data)
# add info about backup version
_tar_addfile_from_string(tar, "backup_details.json", json.dumps(backup_details))
return backup_fpath
@@ -367,10 +424,10 @@ def deprecated_removed(
)
class RichIndefiniteBarColumn(ProgressColumn):
def render(self, task):
return ProgressBar(
pulse=task.completed < task.total,
class RichIndefiniteBarColumn(rich.progress.ProgressColumn):
def render(self, task: rich.progress.Task) -> rich.progress.ProgressBar:
return rich.progress.ProgressBar(
pulse=task.completed < task.total if task.total is not None else True,
animation_time=task.get_time(),
width=40,
total=task.total,
@@ -378,6 +435,33 @@ class RichIndefiniteBarColumn(ProgressColumn):
)
class RichSpeedColumn(rich.progress.ProgressColumn):
def __init__(self, *, unit: str) -> None:
self.unit = unit
super().__init__()
def render(self, task: rich.progress.Task) -> Text:
speed = task.finished_speed or task.speed
if speed is None:
return Text("?", style="progress.data.speed")
return Text(f"{int(speed)} {self.unit}/s", style="progress.data.speed")
def detailed_progress(*, unit: str, console: Optional[Console] = None) -> rich.progress.Progress:
return rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
rich.progress.BarColumn(bar_width=None),
RichSpeedColumn(unit=unit),
rich.progress.TaskProgressColumn(),
rich.progress.TextColumn("eta"),
rich.progress.TimeRemainingColumn(),
rich.progress.TextColumn("elapsed"),
rich.progress.TimeElapsedColumn(),
console=console,
)
def cli_level_to_log_level(level: int) -> int:
if level == 0:
log_level = logging.INFO