""" Phase 5 — MQTT live data functions backed by Postgres. device_logs is a partitioned table; heartbeats and commands are plain tables. All three are accessed via raw SQL (not ORM) because device_logs partitioning does not play well with SQLAlchemy's declarative ORM. device_alerts is an ORM model (devices/orm.py) and is handled here via raw SQL to keep a single consistent interface for callers that used to import from database.core. """ import asyncio import json import logging from datetime import date, datetime, timedelta, timezone from dateutil.relativedelta import relativedelta from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from config import settings from database.postgres import AsyncSessionLocal logger = logging.getLogger("database.pg_mqtt") # --------------------------------------------------------------------------- # Internal session helper — used by fire-and-forget insert paths that are # called from the MQTT ingestion thread (not inside a FastAPI request). # --------------------------------------------------------------------------- async def _session() -> AsyncSession: return AsyncSessionLocal() # --------------------------------------------------------------------------- # Insert operations # --------------------------------------------------------------------------- async def insert_log(device_serial: str, level: str, message: str, device_timestamp: int | None = None) -> int: async with AsyncSessionLocal() as session: result = await session.execute( text(""" INSERT INTO device_logs (device_serial, level, message, device_timestamp, received_at) VALUES (:serial, :level, :message, :ts, now()) RETURNING id """), {"serial": device_serial, "level": level, "message": message, "ts": device_timestamp}, ) row = result.fetchone() await session.commit() return row[0] async def insert_heartbeat(device_serial: str, device_id: str, firmware_version: str, ip_address: str, gateway: str, uptime_ms: int, uptime_display: str) -> int: async with AsyncSessionLocal() as session: result = await session.execute( text(""" INSERT INTO heartbeats (device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display, received_at) VALUES (:serial, :device_id, :fw, :ip, :gw, :uptime_ms, :uptime_display, now()) RETURNING id """), { "serial": device_serial, "device_id": device_id, "fw": firmware_version, "ip": ip_address, "gw": gateway, "uptime_ms": uptime_ms, "uptime_display": uptime_display, }, ) row = result.fetchone() await session.commit() return row[0] async def insert_command(device_serial: str, command_name: str, command_payload: dict) -> int: async with AsyncSessionLocal() as session: result = await session.execute( text(""" INSERT INTO commands (device_serial, command_name, command_payload, sent_at) VALUES (:serial, :name, :payload, now()) RETURNING id """), { "serial": device_serial, "name": command_name, "payload": json.dumps(command_payload), }, ) row = result.fetchone() await session.commit() return row[0] async def update_command_response(command_id: int, status: str, response_payload: dict | None = None): async with AsyncSessionLocal() as session: await session.execute( text(""" UPDATE commands SET status = :status, response_payload = :payload, responded_at = now() WHERE id = :id """), { "id": command_id, "status": status, "payload": json.dumps(response_payload) if response_payload else None, }, ) await session.commit() # --------------------------------------------------------------------------- # Query operations # --------------------------------------------------------------------------- async def get_logs(device_serial: str, level: str | None = None, search: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list, int]: where = "device_serial = :serial" params: dict = {"serial": device_serial, "limit": limit, "offset": offset} if level: where += " AND level = :level" params["level"] = level if search: where += " AND message ILIKE :search" params["search"] = f"%{search}%" async with AsyncSessionLocal() as session: count_result = await session.execute( text(f"SELECT COUNT(*) FROM device_logs WHERE {where}"), params ) total = count_result.scalar() rows_result = await session.execute( text(f""" SELECT id, device_serial, level, message, device_timestamp, received_at AT TIME ZONE 'UTC' AS received_at FROM device_logs WHERE {where} ORDER BY received_at DESC LIMIT :limit OFFSET :offset """), params, ) rows = rows_result.mappings().all() return [_row_to_dict(r) for r in rows], total async def get_heartbeats(device_serial: str, limit: int = 100, offset: int = 0) -> tuple[list, int]: async with AsyncSessionLocal() as session: count_result = await session.execute( text("SELECT COUNT(*) FROM heartbeats WHERE device_serial = :serial"), {"serial": device_serial}, ) total = count_result.scalar() rows_result = await session.execute( text(""" SELECT id, device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display, received_at AT TIME ZONE 'UTC' AS received_at FROM heartbeats WHERE device_serial = :serial ORDER BY received_at DESC LIMIT :limit OFFSET :offset """), {"serial": device_serial, "limit": limit, "offset": offset}, ) rows = rows_result.mappings().all() return [_row_to_dict(r) for r in rows], total async def get_commands(device_serial: str, limit: int = 100, offset: int = 0) -> tuple[list, int]: async with AsyncSessionLocal() as session: count_result = await session.execute( text("SELECT COUNT(*) FROM commands WHERE device_serial = :serial"), {"serial": device_serial}, ) total = count_result.scalar() rows_result = await session.execute( text(""" SELECT id, device_serial, command_name, command_payload, status, response_payload, sent_at AT TIME ZONE 'UTC' AS sent_at, responded_at AT TIME ZONE 'UTC' AS responded_at FROM commands WHERE device_serial = :serial ORDER BY sent_at DESC LIMIT :limit OFFSET :offset """), {"serial": device_serial, "limit": limit, "offset": offset}, ) rows = rows_result.mappings().all() return [_row_to_dict(r) for r in rows], total async def get_latest_heartbeats() -> list: async with AsyncSessionLocal() as session: rows_result = await session.execute( text(""" SELECT DISTINCT ON (device_serial) id, device_serial, device_id, firmware_version, ip_address, gateway, uptime_ms, uptime_display, received_at AT TIME ZONE 'UTC' AS received_at FROM heartbeats ORDER BY device_serial, received_at DESC """) ) rows = rows_result.mappings().all() return [_row_to_dict(r) for r in rows] async def get_pending_command(device_serial: str) -> dict | None: async with AsyncSessionLocal() as session: result = await session.execute( text(""" SELECT id, device_serial, command_name, command_payload, status, response_payload, sent_at AT TIME ZONE 'UTC' AS sent_at, responded_at AT TIME ZONE 'UTC' AS responded_at FROM commands WHERE device_serial = :serial AND status = 'pending' ORDER BY sent_at DESC LIMIT 1 """), {"serial": device_serial}, ) row = result.mappings().fetchone() return _row_to_dict(row) if row else None # --------------------------------------------------------------------------- # Device alerts # --------------------------------------------------------------------------- async def upsert_alert(device_serial: str, subsystem: str, state: str, message: str | None = None): async with AsyncSessionLocal() as session: await session.execute( text(""" INSERT INTO device_alerts (device_serial, subsystem, state, message, updated_at) VALUES (:serial, :subsystem, :state, :message, now()) ON CONFLICT (device_serial, subsystem) DO UPDATE SET state = EXCLUDED.state, message = EXCLUDED.message, updated_at = EXCLUDED.updated_at """), {"serial": device_serial, "subsystem": subsystem, "state": state, "message": message}, ) await session.commit() async def delete_alert(device_serial: str, subsystem: str): async with AsyncSessionLocal() as session: await session.execute( text("DELETE FROM device_alerts WHERE device_serial = :serial AND subsystem = :subsystem"), {"serial": device_serial, "subsystem": subsystem}, ) await session.commit() async def get_alerts(device_serial: str) -> list: async with AsyncSessionLocal() as session: result = await session.execute( text(""" SELECT id, device_serial, subsystem, state, message, updated_at AT TIME ZONE 'UTC' AS updated_at FROM device_alerts WHERE device_serial = :serial ORDER BY updated_at DESC """), {"serial": device_serial}, ) rows = result.mappings().all() return [_row_to_dict(r) for r in rows] # --------------------------------------------------------------------------- # Partition management # --------------------------------------------------------------------------- async def ensure_current_partitions(): """Create device_logs partitions for the current and next month if missing.""" async with AsyncSessionLocal() as session: for month_offset in (0, 1): d = date.today().replace(day=1) + relativedelta(months=month_offset) partition_name = f"device_logs_{d.strftime('%Y_%m')}" start = d.isoformat() end = (d + relativedelta(months=1)).isoformat() await session.execute(text(f""" CREATE TABLE IF NOT EXISTS {partition_name} PARTITION OF device_logs FOR VALUES FROM ('{start}') TO ('{end}') """)) await session.commit() logger.info("Partition check complete") async def drop_old_partitions(keep_months: int = 6): """Drop device_logs partitions older than keep_months.""" cutoff = date.today().replace(day=1) - relativedelta(months=keep_months) async with AsyncSessionLocal() as session: result = await session.execute(text(""" SELECT tablename FROM pg_tables WHERE schemaname = 'public' AND tablename LIKE 'device_logs_%' """)) partitions = [r[0] for r in result.fetchall()] for name in partitions: # name format: device_logs_YYYY_MM parts = name.split("_") if len(parts) != 4: continue try: partition_date = date(int(parts[2]), int(parts[3]), 1) except ValueError: continue if partition_date < cutoff: async with AsyncSessionLocal() as session: await session.execute(text(f"DROP TABLE IF EXISTS {name}")) await session.commit() logger.info(f"Dropped old partition: {name}") async def partition_manager_loop(): """Runs once on startup, then monthly thereafter.""" await ensure_current_partitions() while True: # Sleep ~30 days, wake up and ensure next month's partition exists await asyncio.sleep(30 * 24 * 3600) try: await ensure_current_partitions() await drop_old_partitions() except Exception as e: logger.error(f"Partition manager error: {e}") # --------------------------------------------------------------------------- # Cleanup (replaces SQLite purge_loop — now a no-op since Postgres uses # partition drops instead of row-by-row deletes for device_logs; heartbeats # and commands are still purged by row deletion) # --------------------------------------------------------------------------- async def purge_old_data(retention_days: int | None = None): days = retention_days or settings.mqtt_data_retention_days cutoff = datetime.now(timezone.utc) - timedelta(days=days) async with AsyncSessionLocal() as session: await session.execute( text("DELETE FROM heartbeats WHERE received_at < :cutoff"), {"cutoff": cutoff}, ) await session.execute( text("DELETE FROM commands WHERE sent_at < :cutoff"), {"cutoff": cutoff}, ) await session.commit() logger.info(f"Purged heartbeats and commands older than {days} days") async def purge_loop(): while True: await asyncio.sleep(86400) try: await purge_old_data() except Exception as e: logger.error(f"Purge failed: {e}") # --------------------------------------------------------------------------- # Stub — no longer needed but kept so nothing that imports init_db/close_db breaks # --------------------------------------------------------------------------- async def init_db(): """No-op: Postgres schema is managed by Alembic, not runtime init.""" logger.info("Postgres MQTT backend active — no SQLite init needed") async def close_db(): """No-op: SQLAlchemy engine lifecycle is managed by the process.""" # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _row_to_dict(row) -> dict: """Convert a SQLAlchemy RowMapping to a plain dict with ISO string timestamps.""" d = dict(row) for key, val in d.items(): if isinstance(val, datetime): d[key] = val.isoformat() return d