- Added a new owner-only `/data [download/upload]` command for datafile backup and restoration *This is required as v0.4.2 requires a rebuild of the stack, and the existing data should be backed up in case of data loss*
238 lines
8.4 KiB
Python
238 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
import io
|
|
import os
|
|
import json
|
|
import time
|
|
import shutil
|
|
import asyncio
|
|
from typing import Optional, Literal
|
|
|
|
import aiohttp
|
|
import discord
|
|
from discord import app_commands
|
|
from discord.ext import commands
|
|
|
|
|
|
MAX_UPLOAD_BYTES = 8 * 1024 * 1024 # 8 MiB hard cap for safety
|
|
OWNER_HINT = "This command is restricted to the **server owner** (or bot owner)."
|
|
|
|
def _now_stamp() -> str:
|
|
return time.strftime("%Y%m%d-%H%M%S", time.gmtime())
|
|
|
|
class DataAdmin(commands.Cog):
|
|
"""
|
|
[ADMIN] Backup/restore the bot data file.
|
|
Owner-only: guild owner or application (bot) owner.
|
|
"""
|
|
def __init__(self, bot: commands.Bot):
|
|
self.bot = bot
|
|
self._app_owner_id: Optional[int] = None
|
|
|
|
# --- permission helper ---
|
|
async def _is_owner(self, interaction: discord.Interaction) -> bool:
|
|
uid = interaction.user.id
|
|
# cache application owner id
|
|
if self._app_owner_id is None:
|
|
try:
|
|
info = await self.bot.application_info()
|
|
if info and info.owner:
|
|
self._app_owner_id = info.owner.id
|
|
except Exception:
|
|
self._app_owner_id = None
|
|
|
|
guild_owner_id = getattr(getattr(interaction, "guild", None), "owner_id", None)
|
|
if guild_owner_id and uid == guild_owner_id:
|
|
return True
|
|
if self._app_owner_id and uid == self._app_owner_id:
|
|
return True
|
|
return False
|
|
|
|
# --- helpers ---
|
|
def _dm_path(self) -> str:
|
|
dm = getattr(self.bot, "data_manager", None)
|
|
if not dm or not getattr(dm, "json_path", None):
|
|
raise RuntimeError("DataManager/json_path unavailable")
|
|
return dm.json_path
|
|
|
|
def _merge_with_defaults(self, incoming: dict) -> dict:
|
|
"""
|
|
Ensure required keys exist; preserve unknown keys.
|
|
"""
|
|
dm = getattr(self.bot, "data_manager", None)
|
|
if not dm:
|
|
raise RuntimeError("DataManager unavailable")
|
|
|
|
# Create a minimal default schema by calling _default_payload if present,
|
|
# otherwise fall back to a thin set.
|
|
try:
|
|
defaults = dm._default_payload() # type: ignore[attr-defined]
|
|
except Exception:
|
|
defaults = {
|
|
"_counters": {},
|
|
"_events_seen": {},
|
|
"_counter_last_ts": {},
|
|
}
|
|
|
|
merged = dict(incoming)
|
|
for k, v in defaults.items():
|
|
merged.setdefault(k, v if not isinstance(v, list) else list(v))
|
|
return merged
|
|
|
|
async def _download_attachment_bytes(self, att: discord.Attachment) -> bytes:
|
|
if att.size > MAX_UPLOAD_BYTES:
|
|
raise ValueError(f"Attachment too large ({att.size} bytes)")
|
|
return await att.read()
|
|
|
|
async def _download_url_bytes(self, url: str) -> bytes:
|
|
timeout = aiohttp.ClientTimeout(total=25, sock_connect=10, sock_read=15)
|
|
headers = {
|
|
"User-Agent": "ShaiWatcher/backup-restore (+https://example.invalid)"
|
|
}
|
|
async with aiohttp.ClientSession(timeout=timeout) as sess:
|
|
async with sess.get(url, headers=headers, allow_redirects=True) as resp:
|
|
if resp.status >= 400:
|
|
raise RuntimeError(f"HTTP {resp.status}")
|
|
data = await resp.read()
|
|
if len(data) > MAX_UPLOAD_BYTES:
|
|
raise ValueError(f"Downloaded file too large ({len(data)} bytes)")
|
|
return data
|
|
|
|
def _atomic_replace(self, new_payload: dict) -> None:
|
|
"""
|
|
Replace DataManager payload atomically, with a timestamped backup.
|
|
"""
|
|
dm = getattr(self.bot, "data_manager", None)
|
|
if not dm:
|
|
raise RuntimeError("DataManager unavailable")
|
|
|
|
src_path = self._dm_path()
|
|
bak_path = f"{src_path}.manual.{_now_stamp()}.bak"
|
|
|
|
with dm.lock:
|
|
# backup current file if exists
|
|
try:
|
|
if os.path.exists(src_path):
|
|
shutil.copy2(src_path, bak_path)
|
|
except Exception:
|
|
pass
|
|
|
|
# write new file and update in-memory view
|
|
dm._data = self._merge_with_defaults(new_payload) # type: ignore[attr-defined]
|
|
dm._save(dm._data) # type: ignore[attr-defined]
|
|
|
|
# --- slash command ---
|
|
@app_commands.command(
|
|
name="data",
|
|
description="[ADMIN] Download or upload the bot data file (owner-only)"
|
|
)
|
|
@app_commands.describe(
|
|
action="Choose 'download' to get the current file, or 'upload' to restore from JSON",
|
|
attachment="Optional JSON attachment (used for 'upload')",
|
|
url="Optional direct URL to a JSON file (used for 'upload')"
|
|
)
|
|
async def data_cmd(
|
|
self,
|
|
interaction: discord.Interaction,
|
|
action: Literal["download", "upload"],
|
|
attachment: Optional[discord.Attachment] = None,
|
|
url: Optional[str] = None,
|
|
):
|
|
# perms
|
|
if not await self._is_owner(interaction):
|
|
return await interaction.response.send_message(OWNER_HINT, ephemeral=True)
|
|
|
|
# ensure dm available
|
|
try:
|
|
dm_path = self._dm_path()
|
|
except Exception as e:
|
|
return await interaction.response.send_message(
|
|
f"DataManager unavailable: {e}", ephemeral=True
|
|
)
|
|
|
|
# dispatch
|
|
if action == "download":
|
|
await interaction.response.defer(ephemeral=True, thinking=False)
|
|
try:
|
|
# Read raw file bytes to guarantee exact copy
|
|
with open(dm_path, "rb") as f:
|
|
data = f.read()
|
|
file = discord.File(io.BytesIO(data), filename="data.json")
|
|
await interaction.followup.send(
|
|
content="Here is the current data file.",
|
|
file=file,
|
|
ephemeral=True,
|
|
)
|
|
except Exception as e:
|
|
await interaction.followup.send(
|
|
f"Failed to read data file: {e}", ephemeral=True
|
|
)
|
|
return
|
|
|
|
# upload
|
|
# must provide exactly one source
|
|
sources = [s for s in (attachment, url) if s]
|
|
if len(sources) != 1:
|
|
return await interaction.response.send_message(
|
|
"For `upload`, provide **exactly one** of: `attachment` **or** `url`.",
|
|
ephemeral=True,
|
|
)
|
|
|
|
await interaction.response.defer(ephemeral=True, thinking=True)
|
|
|
|
try:
|
|
if attachment:
|
|
raw = await self._download_attachment_bytes(attachment)
|
|
else:
|
|
assert url is not None
|
|
raw = await self._download_url_bytes(url)
|
|
|
|
# decode → JSON
|
|
try:
|
|
text = raw.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
return await interaction.followup.send(
|
|
"The file/URL is not valid UTF-8 text.", ephemeral=True
|
|
)
|
|
|
|
try:
|
|
payload = json.loads(text)
|
|
except json.JSONDecodeError as e:
|
|
return await interaction.followup.send(
|
|
f"Invalid JSON: {e}", ephemeral=True
|
|
)
|
|
|
|
if not isinstance(payload, dict):
|
|
return await interaction.followup.send(
|
|
"Top-level JSON must be an **object** (not an array/string).",
|
|
ephemeral=True,
|
|
)
|
|
|
|
# final size sanity (after parse)
|
|
encoded_size = len(json.dumps(payload, ensure_ascii=False).encode("utf-8"))
|
|
if encoded_size > MAX_UPLOAD_BYTES:
|
|
return await interaction.followup.send(
|
|
f"Refusing to import unusually large JSON ({encoded_size} bytes).",
|
|
ephemeral=True,
|
|
)
|
|
|
|
# write & backup
|
|
self._atomic_replace(payload)
|
|
|
|
# tiny summary
|
|
top_keys = sorted(list(payload.keys()))
|
|
shown = ", ".join(top_keys[:12]) + ("…" if len(top_keys) > 12 else "")
|
|
await interaction.followup.send(
|
|
f"✅ Imported data and wrote a timestamped backup of the previous file.\n"
|
|
f"Path: `{dm_path}`\n"
|
|
f"Top-level keys ({len(top_keys)}): {shown}",
|
|
ephemeral=True,
|
|
)
|
|
|
|
except Exception as e:
|
|
await interaction.followup.send(f"Import failed: {e}", ephemeral=True)
|
|
|
|
|
|
async def setup(bot: commands.Bot):
|
|
await bot.add_cog(DataAdmin(bot))
|