417 lines
15 KiB
Python
417 lines
15 KiB
Python
"""
|
|
Phase 5 — MQTT live data functions backed by Postgres.
|
|
|
|
device_logs is a partitioned table; heartbeats and commands are plain tables.
|
|
All three are accessed via raw SQL (not ORM) because device_logs partitioning
|
|
does not play well with SQLAlchemy's declarative ORM.
|
|
|
|
device_alerts is an ORM model (devices/orm.py) and is handled here via raw SQL
|
|
to keep a single consistent interface for callers that used to import from database.core.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import date, datetime, timedelta, timezone
|
|
|
|
from dateutil.relativedelta import relativedelta
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from config import settings
|
|
from database.postgres import AsyncSessionLocal
|
|
|
|
logger = logging.getLogger("database.pg_mqtt")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal session helper — used by fire-and-forget insert paths that are
|
|
# called from the MQTT ingestion thread (not inside a FastAPI request).
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _session() -> AsyncSession:
|
|
return AsyncSessionLocal()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Insert operations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def insert_log(device_serial: str, level: str, message: str,
|
|
device_timestamp: int | None = None) -> int:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
text("""
|
|
INSERT INTO device_logs (device_serial, level, message, device_timestamp, received_at)
|
|
VALUES (:serial, :level, :message, :ts, now())
|
|
RETURNING id
|
|
"""),
|
|
{"serial": device_serial, "level": level, "message": message, "ts": device_timestamp},
|
|
)
|
|
row = result.fetchone()
|
|
await session.commit()
|
|
return row[0]
|
|
|
|
|
|
async def insert_heartbeat(device_serial: str, device_id: str,
|
|
firmware_version: str, ip_address: str,
|
|
gateway: str, uptime_ms: int, uptime_display: str) -> int:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
text("""
|
|
INSERT INTO heartbeats
|
|
(device_serial, device_id, firmware_version, ip_address,
|
|
gateway, uptime_ms, uptime_display, received_at)
|
|
VALUES
|
|
(:serial, :device_id, :fw, :ip, :gw, :uptime_ms, :uptime_display, now())
|
|
RETURNING id
|
|
"""),
|
|
{
|
|
"serial": device_serial,
|
|
"device_id": device_id,
|
|
"fw": firmware_version,
|
|
"ip": ip_address,
|
|
"gw": gateway,
|
|
"uptime_ms": uptime_ms,
|
|
"uptime_display": uptime_display,
|
|
},
|
|
)
|
|
row = result.fetchone()
|
|
await session.commit()
|
|
return row[0]
|
|
|
|
|
|
async def insert_command(device_serial: str, command_name: str,
|
|
command_payload: dict) -> int:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
text("""
|
|
INSERT INTO commands (device_serial, command_name, command_payload, sent_at)
|
|
VALUES (:serial, :name, :payload, now())
|
|
RETURNING id
|
|
"""),
|
|
{
|
|
"serial": device_serial,
|
|
"name": command_name,
|
|
"payload": json.dumps(command_payload),
|
|
},
|
|
)
|
|
row = result.fetchone()
|
|
await session.commit()
|
|
return row[0]
|
|
|
|
|
|
async def update_command_response(command_id: int, status: str,
|
|
response_payload: dict | None = None):
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(
|
|
text("""
|
|
UPDATE commands
|
|
SET status = :status,
|
|
response_payload = :payload,
|
|
responded_at = now()
|
|
WHERE id = :id
|
|
"""),
|
|
{
|
|
"id": command_id,
|
|
"status": status,
|
|
"payload": json.dumps(response_payload) if response_payload else None,
|
|
},
|
|
)
|
|
await session.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Query operations
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def get_logs(device_serial: str, level: str | None = None,
|
|
search: str | None = None,
|
|
limit: int = 100, offset: int = 0) -> tuple[list, int]:
|
|
where = "device_serial = :serial"
|
|
params: dict = {"serial": device_serial, "limit": limit, "offset": offset}
|
|
|
|
if level:
|
|
where += " AND level = :level"
|
|
params["level"] = level
|
|
if search:
|
|
where += " AND message ILIKE :search"
|
|
params["search"] = f"%{search}%"
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
count_result = await session.execute(
|
|
text(f"SELECT COUNT(*) FROM device_logs WHERE {where}"), params
|
|
)
|
|
total = count_result.scalar()
|
|
|
|
rows_result = await session.execute(
|
|
text(f"""
|
|
SELECT id, device_serial, level, message, device_timestamp,
|
|
received_at AT TIME ZONE 'UTC' AS received_at
|
|
FROM device_logs
|
|
WHERE {where}
|
|
ORDER BY received_at DESC
|
|
LIMIT :limit OFFSET :offset
|
|
"""),
|
|
params,
|
|
)
|
|
rows = rows_result.mappings().all()
|
|
|
|
return [_row_to_dict(r) for r in rows], total
|
|
|
|
|
|
async def get_heartbeats(device_serial: str, limit: int = 100,
|
|
offset: int = 0) -> tuple[list, int]:
|
|
async with AsyncSessionLocal() as session:
|
|
count_result = await session.execute(
|
|
text("SELECT COUNT(*) FROM heartbeats WHERE device_serial = :serial"),
|
|
{"serial": device_serial},
|
|
)
|
|
total = count_result.scalar()
|
|
|
|
rows_result = await session.execute(
|
|
text("""
|
|
SELECT id, device_serial, device_id, firmware_version, ip_address,
|
|
gateway, uptime_ms, uptime_display,
|
|
received_at AT TIME ZONE 'UTC' AS received_at
|
|
FROM heartbeats
|
|
WHERE device_serial = :serial
|
|
ORDER BY received_at DESC
|
|
LIMIT :limit OFFSET :offset
|
|
"""),
|
|
{"serial": device_serial, "limit": limit, "offset": offset},
|
|
)
|
|
rows = rows_result.mappings().all()
|
|
|
|
return [_row_to_dict(r) for r in rows], total
|
|
|
|
|
|
async def get_commands(device_serial: str, limit: int = 100,
|
|
offset: int = 0) -> tuple[list, int]:
|
|
async with AsyncSessionLocal() as session:
|
|
count_result = await session.execute(
|
|
text("SELECT COUNT(*) FROM commands WHERE device_serial = :serial"),
|
|
{"serial": device_serial},
|
|
)
|
|
total = count_result.scalar()
|
|
|
|
rows_result = await session.execute(
|
|
text("""
|
|
SELECT id, device_serial, command_name, command_payload, status,
|
|
response_payload,
|
|
sent_at AT TIME ZONE 'UTC' AS sent_at,
|
|
responded_at AT TIME ZONE 'UTC' AS responded_at
|
|
FROM commands
|
|
WHERE device_serial = :serial
|
|
ORDER BY sent_at DESC
|
|
LIMIT :limit OFFSET :offset
|
|
"""),
|
|
{"serial": device_serial, "limit": limit, "offset": offset},
|
|
)
|
|
rows = rows_result.mappings().all()
|
|
|
|
return [_row_to_dict(r) for r in rows], total
|
|
|
|
|
|
async def get_latest_heartbeats() -> list:
|
|
async with AsyncSessionLocal() as session:
|
|
rows_result = await session.execute(
|
|
text("""
|
|
SELECT DISTINCT ON (device_serial)
|
|
id, device_serial, device_id, firmware_version, ip_address,
|
|
gateway, uptime_ms, uptime_display,
|
|
received_at AT TIME ZONE 'UTC' AS received_at
|
|
FROM heartbeats
|
|
ORDER BY device_serial, received_at DESC
|
|
""")
|
|
)
|
|
rows = rows_result.mappings().all()
|
|
|
|
return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
async def get_pending_command(device_serial: str) -> dict | None:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT id, device_serial, command_name, command_payload, status,
|
|
response_payload,
|
|
sent_at AT TIME ZONE 'UTC' AS sent_at,
|
|
responded_at AT TIME ZONE 'UTC' AS responded_at
|
|
FROM commands
|
|
WHERE device_serial = :serial AND status = 'pending'
|
|
ORDER BY sent_at DESC
|
|
LIMIT 1
|
|
"""),
|
|
{"serial": device_serial},
|
|
)
|
|
row = result.mappings().fetchone()
|
|
|
|
return _row_to_dict(row) if row else None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Device alerts
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def upsert_alert(device_serial: str, subsystem: str, state: str,
|
|
message: str | None = None):
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(
|
|
text("""
|
|
INSERT INTO device_alerts (device_serial, subsystem, state, message, updated_at)
|
|
VALUES (:serial, :subsystem, :state, :message, now())
|
|
ON CONFLICT (device_serial, subsystem)
|
|
DO UPDATE SET
|
|
state = EXCLUDED.state,
|
|
message = EXCLUDED.message,
|
|
updated_at = EXCLUDED.updated_at
|
|
"""),
|
|
{"serial": device_serial, "subsystem": subsystem, "state": state, "message": message},
|
|
)
|
|
await session.commit()
|
|
|
|
|
|
async def delete_alert(device_serial: str, subsystem: str):
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(
|
|
text("DELETE FROM device_alerts WHERE device_serial = :serial AND subsystem = :subsystem"),
|
|
{"serial": device_serial, "subsystem": subsystem},
|
|
)
|
|
await session.commit()
|
|
|
|
|
|
async def get_alerts(device_serial: str) -> list:
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
text("""
|
|
SELECT id, device_serial, subsystem, state, message,
|
|
updated_at AT TIME ZONE 'UTC' AS updated_at
|
|
FROM device_alerts
|
|
WHERE device_serial = :serial
|
|
ORDER BY updated_at DESC
|
|
"""),
|
|
{"serial": device_serial},
|
|
)
|
|
rows = result.mappings().all()
|
|
|
|
return [_row_to_dict(r) for r in rows]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Partition management
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def ensure_current_partitions():
|
|
"""Create device_logs partitions for the current and next month if missing."""
|
|
async with AsyncSessionLocal() as session:
|
|
for month_offset in (0, 1):
|
|
d = date.today().replace(day=1) + relativedelta(months=month_offset)
|
|
partition_name = f"device_logs_{d.strftime('%Y_%m')}"
|
|
start = d.isoformat()
|
|
end = (d + relativedelta(months=1)).isoformat()
|
|
await session.execute(text(f"""
|
|
CREATE TABLE IF NOT EXISTS {partition_name}
|
|
PARTITION OF device_logs
|
|
FOR VALUES FROM ('{start}') TO ('{end}')
|
|
"""))
|
|
await session.commit()
|
|
logger.info("Partition check complete")
|
|
|
|
|
|
async def drop_old_partitions(keep_months: int = 6):
|
|
"""Drop device_logs partitions older than keep_months."""
|
|
cutoff = date.today().replace(day=1) - relativedelta(months=keep_months)
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(text("""
|
|
SELECT tablename FROM pg_tables
|
|
WHERE schemaname = 'public'
|
|
AND tablename LIKE 'device_logs_%'
|
|
"""))
|
|
partitions = [r[0] for r in result.fetchall()]
|
|
|
|
for name in partitions:
|
|
# name format: device_logs_YYYY_MM
|
|
parts = name.split("_")
|
|
if len(parts) != 4:
|
|
continue
|
|
try:
|
|
partition_date = date(int(parts[2]), int(parts[3]), 1)
|
|
except ValueError:
|
|
continue
|
|
if partition_date < cutoff:
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(text(f"DROP TABLE IF EXISTS {name}"))
|
|
await session.commit()
|
|
logger.info(f"Dropped old partition: {name}")
|
|
|
|
|
|
async def partition_manager_loop():
|
|
"""Runs once on startup, then monthly thereafter."""
|
|
await ensure_current_partitions()
|
|
while True:
|
|
# Sleep ~30 days, wake up and ensure next month's partition exists
|
|
await asyncio.sleep(30 * 24 * 3600)
|
|
try:
|
|
await ensure_current_partitions()
|
|
await drop_old_partitions()
|
|
except Exception as e:
|
|
logger.error(f"Partition manager error: {e}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Cleanup (replaces SQLite purge_loop — now a no-op since Postgres uses
|
|
# partition drops instead of row-by-row deletes for device_logs; heartbeats
|
|
# and commands are still purged by row deletion)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def purge_old_data(retention_days: int | None = None):
|
|
days = retention_days or settings.mqtt_data_retention_days
|
|
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(
|
|
text("DELETE FROM heartbeats WHERE received_at < :cutoff"),
|
|
{"cutoff": cutoff},
|
|
)
|
|
await session.execute(
|
|
text("DELETE FROM commands WHERE sent_at < :cutoff"),
|
|
{"cutoff": cutoff},
|
|
)
|
|
await session.commit()
|
|
logger.info(f"Purged heartbeats and commands older than {days} days")
|
|
|
|
|
|
async def purge_loop():
|
|
while True:
|
|
await asyncio.sleep(86400)
|
|
try:
|
|
await purge_old_data()
|
|
except Exception as e:
|
|
logger.error(f"Purge failed: {e}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Stub — no longer needed but kept so nothing that imports init_db/close_db breaks
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def init_db():
|
|
"""No-op: Postgres schema is managed by Alembic, not runtime init."""
|
|
logger.info("Postgres MQTT backend active — no SQLite init needed")
|
|
|
|
|
|
async def close_db():
|
|
"""No-op: SQLAlchemy engine lifecycle is managed by the process."""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _row_to_dict(row) -> dict:
|
|
"""Convert a SQLAlchemy RowMapping to a plain dict with ISO string timestamps."""
|
|
d = dict(row)
|
|
for key, val in d.items():
|
|
if isinstance(val, datetime):
|
|
d[key] = val.isoformat()
|
|
return d
|