308 lines
11 KiB
Python
308 lines
11 KiB
Python
# modules/common/boot_notice.py
|
||
import os
|
||
import re
|
||
import base64
|
||
import json
|
||
import time
|
||
from datetime import datetime, timezone, timedelta
|
||
from urllib.parse import urlparse
|
||
|
||
import discord
|
||
import aiohttp
|
||
|
||
from modules.common.settings import cfg
|
||
|
||
# ---------------- Version helpers ----------------
|
||
|
||
_VERSION_RE = re.compile(r'\b\d+\.\d+\.\d+\.\d+(?:\.[A-Za-z0-9]+)?\b')
|
||
|
||
def _extract_version(subject: str) -> str | None:
|
||
if not subject:
|
||
return None
|
||
m = _VERSION_RE.search(subject)
|
||
return m.group(0) if m else None
|
||
|
||
def _split_subject_body(full_message: str) -> tuple[str | None, str | None]:
|
||
if not full_message:
|
||
return None, None
|
||
lines = [ln.rstrip() for ln in full_message.splitlines()]
|
||
# subject = first non-empty line
|
||
subject = None
|
||
i = 0
|
||
while i < len(lines) and subject is None:
|
||
if lines[i].strip():
|
||
subject = lines[i].strip()
|
||
i += 1
|
||
body = '\n'.join(lines[i:]).strip() if i < len(lines) else ''
|
||
return subject or None, (body or None)
|
||
|
||
def _cmp_versions(a: str | None, b: str | None) -> int:
|
||
"""
|
||
Compare your version style: 1.2.3.4.a2 (last segment alnum optional).
|
||
Returns: -1 if a<b, 0 if equal/unknown, +1 if a>b.
|
||
If either is None, treat as equal (0) to avoid false rollback/upgrade.
|
||
"""
|
||
if not a or not b:
|
||
return 0
|
||
pa = a.split('.')
|
||
pb = b.split('.')
|
||
# pad to 5 parts
|
||
while len(pa) < 5: pa.append('0')
|
||
while len(pb) < 5: pb.append('0')
|
||
|
||
def part_key(x: str):
|
||
# numeric if digits; else (numeric_prefix, alpha_suffix)
|
||
if x.isdigit():
|
||
return (int(x), '', 1)
|
||
# split alnum: digits prefix (if any) + rest
|
||
m = re.match(r'(\d+)(.*)', x)
|
||
if m:
|
||
return (int(m.group(1)), m.group(2), 2)
|
||
return (0, x, 3)
|
||
|
||
for xa, xb in zip(pa, pb):
|
||
ka, kb = part_key(xa), part_key(xb)
|
||
if ka[0] != kb[0]:
|
||
return 1 if ka[0] > kb[0] else -1
|
||
if ka[2] != kb[2]:
|
||
return 1 if ka[2] < kb[2] else -1 # prefer pure numeric (1) > num+alpha (2) > alpha (3)
|
||
if ka[1] != kb[1]:
|
||
return 1 if ka[1] > kb[1] else -1
|
||
return 0
|
||
|
||
# ---------------- Gitea helpers ----------------
|
||
|
||
def _parse_repo_url(repo_url: str) -> tuple[str | None, str | None, str | None]:
|
||
"""
|
||
From https://host/owner/repo(.git) -> (api_base, owner, repo)
|
||
api_base = https://host/api/v1
|
||
"""
|
||
try:
|
||
pr = urlparse(repo_url)
|
||
parts = [p for p in pr.path.split('/') if p]
|
||
if len(parts) >= 2:
|
||
owner = parts[0]
|
||
repo = parts[1]
|
||
if repo.endswith('.git'):
|
||
repo = repo[:-4]
|
||
api_base = f"{pr.scheme}://{pr.netloc}/api/v1"
|
||
return api_base, owner, repo
|
||
except Exception:
|
||
pass
|
||
return None, None, None
|
||
|
||
async def _gitea_get_json(url: str, token: str | None, user: str | None, timeout_sec: int = 10):
|
||
headers = {}
|
||
if token and user:
|
||
# Basic auth with user:token
|
||
cred = base64.b64encode(f"{user}:{token}".encode()).decode()
|
||
headers['Authorization'] = f"Basic {cred}"
|
||
elif token:
|
||
headers['Authorization'] = f"token {token}"
|
||
|
||
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
||
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as sess:
|
||
async with sess.get(url) as resp:
|
||
if resp.status != 200:
|
||
text = await resp.text()
|
||
raise RuntimeError(f"Gitea GET {url} -> {resp.status}: {text[:200]}")
|
||
return await resp.json()
|
||
|
||
async def _fetch_latest_commit(api_base: str, owner: str, repo: str, branch: str,
|
||
token: str | None, user: str | None) -> tuple[str | None, str | None, str | None]:
|
||
"""
|
||
Returns (sha, subject, body) for the latest commit on branch.
|
||
"""
|
||
# Fast path: get branch -> commit sha
|
||
branch_url = f"{api_base}/repos/{owner}/{repo}/branches/{branch}"
|
||
try:
|
||
bjson = await _gitea_get_json(branch_url, token, user)
|
||
sha = bjson.get('commit', {}).get('id') or bjson.get('commit', {}).get('sha')
|
||
if not sha:
|
||
raise RuntimeError("No commit sha on branch")
|
||
except Exception as e:
|
||
# Fallback: list commits
|
||
commits_url = f"{api_base}/repos/{owner}/{repo}/commits/{branch}"
|
||
try:
|
||
cjson = await _gitea_get_json(commits_url, token, user)
|
||
if isinstance(cjson, list) and cjson:
|
||
sha = cjson[0].get('sha') or cjson[0].get('id')
|
||
else:
|
||
raise RuntimeError("Empty commits list")
|
||
except Exception as e2:
|
||
raise RuntimeError(f"Failed to get latest commit: {e} / {e2}")
|
||
|
||
# Now fetch full commit message
|
||
# Try git/commits first
|
||
for endpoint in (f"{api_base}/repos/{owner}/{repo}/git/commits/{sha}",
|
||
f"{api_base}/repos/{owner}/{repo}/commits/{sha}"):
|
||
try:
|
||
data = await _gitea_get_json(endpoint, token, user)
|
||
msg = None
|
||
if isinstance(data, dict):
|
||
msg = data.get('message')
|
||
if not msg:
|
||
msg = data.get('commit', {}).get('message')
|
||
subject, body = _split_subject_body(msg or "")
|
||
return sha, (subject or ""), (body or "")
|
||
except Exception:
|
||
continue
|
||
raise RuntimeError("Unable to fetch commit details")
|
||
|
||
# ---------------- Boot reason inference ----------------
|
||
|
||
def _is_near_scheduled(now_utc: datetime, hhmm_utc: str | None, window_min: int = 5) -> bool:
|
||
if not hhmm_utc:
|
||
return False
|
||
try:
|
||
hh, mm = [int(x) for x in hhmm_utc.strip().split(':', 1)]
|
||
except Exception:
|
||
return False
|
||
sched = now_utc.replace(hour=hh, minute=mm, second=0, microsecond=0)
|
||
delta = abs((now_utc - sched).total_seconds())
|
||
return delta <= window_min * 60
|
||
|
||
def _format_status_line(kind: str, old_ver: str | None, new_ver: str | None) -> str:
|
||
if kind == "updated":
|
||
return f"✅ Updated from **{old_ver or 'unknown'}** → **{new_ver or 'unknown'}**"
|
||
if kind == "scheduled":
|
||
return "🕒 Scheduled restart executed"
|
||
if kind == "manual":
|
||
return "🟢 Manual restart detected"
|
||
if kind == "rollback":
|
||
return f"⚠️ Version rollback detected: **{old_ver or 'unknown'}** → **{new_ver or 'unknown'}**"
|
||
return "🟢 Bot started"
|
||
|
||
# ---------------- Main entry ----------------
|
||
|
||
async def post_boot_notice(bot):
|
||
"""
|
||
Always post a boot status to the modlog channel.
|
||
Logic:
|
||
- Wait until bot is ready (guilds/channels cached).
|
||
- Resolve repo from cfg(repo_url/repo_branch); attempt to fetch latest commit (sha, subject, body).
|
||
- Compare to stored boot_state (last_sha/last_version/last_boot_ts):
|
||
* sha/version advanced -> Updated
|
||
* sha same and near scheduled time -> Scheduled restart
|
||
* sha same and not near schedule -> Manual restart
|
||
* version decreased -> Rollback (ping guild owner)
|
||
- Post status line.
|
||
- Post commit message (bold version + md body).
|
||
- Persist new boot_state.
|
||
"""
|
||
try:
|
||
await bot.wait_until_ready()
|
||
except Exception as e:
|
||
print(f"[boot_notice] wait_until_ready failed: {e}")
|
||
|
||
# Resolve modlog channel
|
||
modlog_channel_id = cfg(bot).int('modlog_channel_id', 0)
|
||
if not modlog_channel_id:
|
||
print("[boot_notice] modlog_channel_id not configured; skipping.")
|
||
return
|
||
|
||
ch = bot.get_channel(modlog_channel_id)
|
||
if not ch:
|
||
# fallback: search across guilds
|
||
for g in bot.guilds:
|
||
ch = g.get_channel(modlog_channel_id)
|
||
if ch:
|
||
break
|
||
if not ch:
|
||
print(f"[boot_notice] channel id {modlog_channel_id} not found; skipping.")
|
||
return
|
||
|
||
# Repo info
|
||
r = cfg(bot)
|
||
repo_url = r.get('repo_url', '')
|
||
branch = r.get('repo_branch', 'main')
|
||
api_base = owner = repo = None
|
||
|
||
if repo_url:
|
||
api_base, owner, repo = _parse_repo_url(repo_url)
|
||
|
||
token = os.getenv("SHAI_GITEA_TOKEN", "").strip() or None
|
||
user = os.getenv("SHAI_GITEA_USER", "").strip() or None
|
||
check_time_utc = r.get('check_time_utc', '') # e.g., "03:00"
|
||
now_utc = datetime.now(timezone.utc)
|
||
|
||
# State store
|
||
dm = getattr(bot, "data_manager", None)
|
||
if not dm:
|
||
print("[boot_notice] data_manager missing on bot; cannot persist state.")
|
||
return
|
||
|
||
prev = (dm.get('boot_state') or [{}])[-1] if dm.get('boot_state') else {}
|
||
prev_sha = prev.get('last_sha') or None
|
||
prev_ver = prev.get('last_version') or None
|
||
|
||
# Fetch latest commit (sha, subject, body)
|
||
sha = subject = body = None
|
||
if api_base and owner and repo:
|
||
try:
|
||
sha, subject, body = await _fetch_latest_commit(api_base, owner, repo, branch, token, user)
|
||
except Exception as e:
|
||
print(f"[boot_notice] fetch latest commit failed: {e}")
|
||
|
||
# Derive current version (from subject)
|
||
curr_ver = _extract_version(subject) if subject else None
|
||
|
||
# Decide reason
|
||
reason = "manual"
|
||
mention_owner = False
|
||
|
||
if prev_ver and curr_ver:
|
||
cmpv = _cmp_versions(prev_ver, curr_ver)
|
||
if cmpv < 0:
|
||
reason = "updated"
|
||
elif cmpv > 0:
|
||
reason = "rollback"
|
||
mention_owner = True
|
||
else: # same version
|
||
reason = "scheduled" if _is_near_scheduled(now_utc, check_time_utc) else "manual"
|
||
else:
|
||
# Fall back to sha compare if versions missing
|
||
if prev_sha and sha and prev_sha != sha:
|
||
reason = "updated"
|
||
else:
|
||
reason = "scheduled" if _is_near_scheduled(now_utc, check_time_utc) else "manual"
|
||
|
||
# Post status line
|
||
status_line = _format_status_line(reason, prev_ver, curr_ver)
|
||
try:
|
||
# ping owner only on rollback
|
||
allowed = discord.AllowedMentions(everyone=False, users=True if mention_owner else False, roles=False, replied_user=False)
|
||
if mention_owner and ch.guild and ch.guild.owner_id:
|
||
status_line = f"{status_line}\n<@{ch.guild.owner_id}>"
|
||
await ch.send(status_line, allowed_mentions=allowed)
|
||
except Exception as e:
|
||
print(f"[boot_notice] failed to send status line: {e}")
|
||
return
|
||
|
||
# Post commit message (if we have it)
|
||
# Format: **Version**\n<md body>
|
||
try:
|
||
title = curr_ver or (subject or "Latest commit")
|
||
if title or body:
|
||
# Always post a commit message on start; it’s the core “what’s running now”
|
||
if body:
|
||
commit_msg = f"**{title}**\n{body}"
|
||
else:
|
||
commit_msg = f"**{title}**"
|
||
await ch.send(commit_msg, allowed_mentions=discord.AllowedMentions.none())
|
||
except Exception as e:
|
||
print(f"[boot_notice] failed to send commit message: {e}")
|
||
|
||
# Persist new state
|
||
try:
|
||
new_state = {
|
||
'last_sha': sha,
|
||
'last_version': curr_ver,
|
||
'last_subject': subject,
|
||
'last_boot_ts': time.time(),
|
||
}
|
||
# keep boot_state as list to preserve history
|
||
dm.add('boot_state', new_state)
|
||
except Exception as e:
|
||
print(f"[boot_notice] failed to persist boot_state: {e}")
|