import aiosqlite import asyncio import json import logging from datetime import datetime, timedelta, timezone from config import settings logger = logging.getLogger("mqtt.database") _db: aiosqlite.Connection | None = None SCHEMA_STATEMENTS = [ """CREATE TABLE IF NOT EXISTS device_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, device_serial TEXT NOT NULL, level TEXT NOT NULL, message TEXT NOT NULL, device_timestamp INTEGER, received_at TEXT NOT NULL DEFAULT (datetime('now')) )""", """CREATE TABLE IF NOT EXISTS heartbeats ( id INTEGER PRIMARY KEY AUTOINCREMENT, device_serial TEXT NOT NULL, device_id TEXT, firmware_version TEXT, ip_address TEXT, gateway TEXT, uptime_ms INTEGER, uptime_display TEXT, received_at TEXT NOT NULL DEFAULT (datetime('now')) )""", """CREATE TABLE IF NOT EXISTS commands ( id INTEGER PRIMARY KEY AUTOINCREMENT, device_serial TEXT NOT NULL, command_name TEXT NOT NULL, command_payload TEXT, status TEXT NOT NULL DEFAULT 'pending', response_payload TEXT, sent_at TEXT NOT NULL DEFAULT (datetime('now')), responded_at TEXT )""", "CREATE INDEX IF NOT EXISTS idx_logs_serial_time ON device_logs(device_serial, received_at)", "CREATE INDEX IF NOT EXISTS idx_logs_level ON device_logs(level)", "CREATE INDEX IF NOT EXISTS idx_heartbeats_serial_time ON heartbeats(device_serial, received_at)", "CREATE INDEX IF NOT EXISTS idx_commands_serial_time ON commands(device_serial, sent_at)", "CREATE INDEX IF NOT EXISTS idx_commands_status ON commands(status)", # Melody drafts table """CREATE TABLE IF NOT EXISTS melody_drafts ( id TEXT PRIMARY KEY, status TEXT NOT NULL DEFAULT 'draft', data TEXT NOT NULL, created_at TEXT NOT NULL DEFAULT (datetime('now')), updated_at TEXT NOT NULL DEFAULT (datetime('now')) )""", "CREATE INDEX IF NOT EXISTS idx_melody_drafts_status ON melody_drafts(status)", ] async def init_db(): global _db _db = await aiosqlite.connect(settings.sqlite_db_path) _db.row_factory = aiosqlite.Row for stmt in SCHEMA_STATEMENTS: await _db.execute(stmt) await _db.commit() logger.info(f"SQLite database initialized at {settings.sqlite_db_path}") async def close_db(): global _db if _db: await _db.close() _db = None async def get_db() -> aiosqlite.Connection: if _db is None: await init_db() return _db # --- Insert Operations --- async def insert_log(device_serial: str, level: str, message: str, device_timestamp: int | None = None): db = await get_db() cursor = await db.execute( "INSERT INTO device_logs (device_serial, level, message, device_timestamp) VALUES (?, ?, ?, ?)", (device_serial, level, message, device_timestamp) ) await db.commit() return cursor.lastrowid async def insert_heartbeat(device_serial: str, device_id: str, firmware_version: str, ip_address: str, gateway: str, uptime_ms: int, uptime_display: str): db = await get_db() cursor = await db.execute( """INSERT INTO heartbeats (device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display) VALUES (?, ?, ?, ?, ?, ?, ?)""", (device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display) ) await db.commit() return cursor.lastrowid async def insert_command(device_serial: str, command_name: str, command_payload: dict) -> int: db = await get_db() cursor = await db.execute( "INSERT INTO commands (device_serial, command_name, command_payload) VALUES (?, ?, ?)", (device_serial, command_name, json.dumps(command_payload)) ) await db.commit() return cursor.lastrowid async def update_command_response(command_id: int, status: str, response_payload: dict | None = None): db = await get_db() await db.execute( """UPDATE commands SET status = ?, response_payload = ?, responded_at = datetime('now') WHERE id = ?""", (status, json.dumps(response_payload) if response_payload else None, command_id) ) await db.commit() # --- Query Operations --- async def get_logs(device_serial: str, level: str | None = None, search: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list, int]: db = await get_db() where_clauses = ["device_serial = ?"] params: list = [device_serial] if level: where_clauses.append("level = ?") params.append(level) if search: where_clauses.append("message LIKE ?") params.append(f"%{search}%") where = " AND ".join(where_clauses) count_row = await db.execute_fetchall( f"SELECT COUNT(*) as cnt FROM device_logs WHERE {where}", params ) total = count_row[0][0] rows = await db.execute_fetchall( f"SELECT * FROM device_logs WHERE {where} ORDER BY received_at DESC LIMIT ? OFFSET ?", params + [limit, offset] ) return [dict(r) for r in rows], total async def get_heartbeats(device_serial: str, limit: int = 100, offset: int = 0) -> tuple[list, int]: db = await get_db() count_row = await db.execute_fetchall( "SELECT COUNT(*) FROM heartbeats WHERE device_serial = ?", (device_serial,) ) total = count_row[0][0] rows = await db.execute_fetchall( "SELECT * FROM heartbeats WHERE device_serial = ? ORDER BY received_at DESC LIMIT ? OFFSET ?", (device_serial, limit, offset) ) return [dict(r) for r in rows], total async def get_commands(device_serial: str, limit: int = 100, offset: int = 0) -> tuple[list, int]: db = await get_db() count_row = await db.execute_fetchall( "SELECT COUNT(*) FROM commands WHERE device_serial = ?", (device_serial,) ) total = count_row[0][0] rows = await db.execute_fetchall( "SELECT * FROM commands WHERE device_serial = ? ORDER BY sent_at DESC LIMIT ? OFFSET ?", (device_serial, limit, offset) ) return [dict(r) for r in rows], total async def get_latest_heartbeats() -> list: db = await get_db() rows = await db.execute_fetchall(""" SELECT h.* FROM heartbeats h INNER JOIN ( SELECT device_serial, MAX(received_at) as max_time FROM heartbeats GROUP BY device_serial ) latest ON h.device_serial = latest.device_serial AND h.received_at = latest.max_time """) return [dict(r) for r in rows] async def get_pending_command(device_serial: str) -> dict | None: db = await get_db() rows = await db.execute_fetchall( """SELECT * FROM commands WHERE device_serial = ? AND status = 'pending' ORDER BY sent_at DESC LIMIT 1""", (device_serial,) ) return dict(rows[0]) if rows else None # --- Cleanup --- async def purge_old_data(retention_days: int | None = None): days = retention_days or settings.mqtt_data_retention_days cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat() db = await get_db() await db.execute("DELETE FROM device_logs WHERE received_at < ?", (cutoff,)) await db.execute("DELETE FROM heartbeats WHERE received_at < ?", (cutoff,)) await db.execute("DELETE FROM commands WHERE sent_at < ?", (cutoff,)) await db.commit() logger.info(f"Purged MQTT data older than {days} days") async def purge_loop(): while True: await asyncio.sleep(86400) try: await purge_old_data() except Exception as e: logger.error(f"Purge failed: {e}")