Files
2026-04-17 16:01:50 +03:00

412 lines
15 KiB
Python

"""
Phase 5 — MQTT live data functions backed by Postgres.
device_logs is a partitioned table; heartbeats and commands are plain tables.
All three are accessed via raw SQL (not ORM) because device_logs partitioning
does not play well with SQLAlchemy's declarative ORM.
device_alerts is an ORM model (devices/orm.py) and is handled here via raw SQL
to keep a single consistent interface for callers that used to import from database.core.
"""
import asyncio
import json
import logging
from datetime import date, datetime, timedelta, timezone
from sqlalchemy import text
from config import settings
from database.postgres import AsyncSessionLocal
logger = logging.getLogger("database.pg_mqtt")
# ---------------------------------------------------------------------------
# Insert operations
# ---------------------------------------------------------------------------
async def insert_log(device_serial: str, level: str, message: str,
device_timestamp: int | None = None) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
text("""
INSERT INTO device_logs (device_serial, level, message, device_timestamp, received_at)
VALUES (:serial, :level, :message, :ts, now())
RETURNING id
"""),
{"serial": device_serial, "level": level, "message": message, "ts": device_timestamp},
)
row = result.fetchone()
await session.commit()
return row[0]
async def insert_heartbeat(device_serial: str, device_id: str,
firmware_version: str, ip_address: str,
gateway: str, uptime_ms: int, uptime_display: str) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
text("""
INSERT INTO heartbeats
(device_serial, device_id, firmware_version, ip_address,
gateway, uptime_ms, uptime_display, received_at)
VALUES
(:serial, :device_id, :fw, :ip, :gw, :uptime_ms, :uptime_display, now())
RETURNING id
"""),
{
"serial": device_serial,
"device_id": device_id,
"fw": firmware_version,
"ip": ip_address,
"gw": gateway,
"uptime_ms": uptime_ms,
"uptime_display": uptime_display,
},
)
row = result.fetchone()
await session.commit()
return row[0]
async def insert_command(device_serial: str, command_name: str,
command_payload: dict) -> int:
async with AsyncSessionLocal() as session:
result = await session.execute(
text("""
INSERT INTO commands (device_serial, command_name, command_payload, sent_at)
VALUES (:serial, :name, :payload, now())
RETURNING id
"""),
{
"serial": device_serial,
"name": command_name,
"payload": json.dumps(command_payload),
},
)
row = result.fetchone()
await session.commit()
return row[0]
async def update_command_response(command_id: int, status: str,
response_payload: dict | None = None):
async with AsyncSessionLocal() as session:
await session.execute(
text("""
UPDATE commands
SET status = :status,
response_payload = :payload,
responded_at = now()
WHERE id = :id
"""),
{
"id": command_id,
"status": status,
"payload": json.dumps(response_payload) if response_payload else None,
},
)
await session.commit()
# ---------------------------------------------------------------------------
# Query operations
# ---------------------------------------------------------------------------
async def get_logs(device_serial: str, level: str | None = None,
search: str | None = None,
limit: int = 100, offset: int = 0) -> tuple[list, int]:
where = "device_serial = :serial"
params: dict = {"serial": device_serial, "limit": limit, "offset": offset}
if level:
where += " AND level = :level"
params["level"] = level
if search:
where += " AND message ILIKE :search"
params["search"] = f"%{search}%"
async with AsyncSessionLocal() as session:
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM device_logs WHERE {where}"), params
)
total = count_result.scalar()
rows_result = await session.execute(
text(f"""
SELECT id, device_serial, level, message, device_timestamp,
received_at AT TIME ZONE 'UTC' AS received_at
FROM device_logs
WHERE {where}
ORDER BY received_at DESC
LIMIT :limit OFFSET :offset
"""),
params,
)
rows = rows_result.mappings().all()
return [_row_to_dict(r) for r in rows], total
async def get_heartbeats(device_serial: str, limit: int = 100,
offset: int = 0) -> tuple[list, int]:
async with AsyncSessionLocal() as session:
count_result = await session.execute(
text("SELECT COUNT(*) FROM heartbeats WHERE device_serial = :serial"),
{"serial": device_serial},
)
total = count_result.scalar()
rows_result = await session.execute(
text("""
SELECT id, device_serial, device_id, firmware_version, ip_address,
gateway, uptime_ms, uptime_display,
received_at AT TIME ZONE 'UTC' AS received_at
FROM heartbeats
WHERE device_serial = :serial
ORDER BY received_at DESC
LIMIT :limit OFFSET :offset
"""),
{"serial": device_serial, "limit": limit, "offset": offset},
)
rows = rows_result.mappings().all()
return [_row_to_dict(r) for r in rows], total
async def get_commands(device_serial: str, limit: int = 100,
offset: int = 0) -> tuple[list, int]:
async with AsyncSessionLocal() as session:
count_result = await session.execute(
text("SELECT COUNT(*) FROM commands WHERE device_serial = :serial"),
{"serial": device_serial},
)
total = count_result.scalar()
rows_result = await session.execute(
text("""
SELECT id, device_serial, command_name, command_payload, status,
response_payload,
sent_at AT TIME ZONE 'UTC' AS sent_at,
responded_at AT TIME ZONE 'UTC' AS responded_at
FROM commands
WHERE device_serial = :serial
ORDER BY sent_at DESC
LIMIT :limit OFFSET :offset
"""),
{"serial": device_serial, "limit": limit, "offset": offset},
)
rows = rows_result.mappings().all()
return [_row_to_dict(r) for r in rows], total
async def get_latest_heartbeats() -> list:
async with AsyncSessionLocal() as session:
rows_result = await session.execute(
text("""
SELECT DISTINCT ON (device_serial)
id, device_serial, device_id, firmware_version, ip_address,
gateway, uptime_ms, uptime_display,
received_at AT TIME ZONE 'UTC' AS received_at
FROM heartbeats
ORDER BY device_serial, received_at DESC
""")
)
rows = rows_result.mappings().all()
return [_row_to_dict(r) for r in rows]
async def get_pending_command(device_serial: str) -> dict | None:
async with AsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, device_serial, command_name, command_payload, status,
response_payload,
sent_at AT TIME ZONE 'UTC' AS sent_at,
responded_at AT TIME ZONE 'UTC' AS responded_at
FROM commands
WHERE device_serial = :serial AND status = 'pending'
ORDER BY sent_at DESC
LIMIT 1
"""),
{"serial": device_serial},
)
row = result.mappings().fetchone()
return _row_to_dict(row) if row else None
# ---------------------------------------------------------------------------
# Device alerts
# ---------------------------------------------------------------------------
async def upsert_alert(device_serial: str, subsystem: str, state: str,
message: str | None = None):
async with AsyncSessionLocal() as session:
await session.execute(
text("""
INSERT INTO device_alerts (device_serial, subsystem, state, message, updated_at)
VALUES (:serial, :subsystem, :state, :message, now())
ON CONFLICT (device_serial, subsystem)
DO UPDATE SET
state = EXCLUDED.state,
message = EXCLUDED.message,
updated_at = EXCLUDED.updated_at
"""),
{"serial": device_serial, "subsystem": subsystem, "state": state, "message": message},
)
await session.commit()
async def delete_alert(device_serial: str, subsystem: str):
async with AsyncSessionLocal() as session:
await session.execute(
text("DELETE FROM device_alerts WHERE device_serial = :serial AND subsystem = :subsystem"),
{"serial": device_serial, "subsystem": subsystem},
)
await session.commit()
async def get_alerts(device_serial: str) -> list:
async with AsyncSessionLocal() as session:
result = await session.execute(
text("""
SELECT id, device_serial, subsystem, state, message,
updated_at AT TIME ZONE 'UTC' AS updated_at
FROM device_alerts
WHERE device_serial = :serial
ORDER BY updated_at DESC
"""),
{"serial": device_serial},
)
rows = result.mappings().all()
return [_row_to_dict(r) for r in rows]
# ---------------------------------------------------------------------------
# Partition management
# ---------------------------------------------------------------------------
def _add_months(d: date, months: int) -> date:
month = d.month - 1 + months
year = d.year + month // 12
month = month % 12 + 1
return d.replace(year=year, month=month, day=1)
async def ensure_current_partitions():
"""Create device_logs partitions for the current and next month if missing."""
async with AsyncSessionLocal() as session:
for month_offset in (0, 1):
d = _add_months(date.today().replace(day=1), month_offset)
partition_name = f"device_logs_{d.strftime('%Y_%m')}"
start = d.isoformat()
end = _add_months(d, 1).isoformat()
await session.execute(text(f"""
CREATE TABLE IF NOT EXISTS {partition_name}
PARTITION OF device_logs
FOR VALUES FROM ('{start}') TO ('{end}')
"""))
await session.commit()
logger.info("Partition check complete")
async def drop_old_partitions(keep_months: int = 6):
"""Drop device_logs partitions older than keep_months."""
cutoff = _add_months(date.today().replace(day=1), -keep_months)
async with AsyncSessionLocal() as session:
result = await session.execute(text("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
AND tablename LIKE 'device_logs_%'
"""))
partitions = [r[0] for r in result.fetchall()]
for name in partitions:
# name format: device_logs_YYYY_MM
parts = name.split("_")
if len(parts) != 4:
continue
try:
partition_date = date(int(parts[2]), int(parts[3]), 1)
except ValueError:
continue
if partition_date < cutoff:
async with AsyncSessionLocal() as session:
await session.execute(text(f"DROP TABLE IF EXISTS {name}"))
await session.commit()
logger.info(f"Dropped old partition: {name}")
async def partition_manager_loop():
"""Runs once on startup, then monthly thereafter."""
await ensure_current_partitions()
while True:
# Sleep ~30 days, wake up and ensure next month's partition exists
await asyncio.sleep(30 * 24 * 3600)
try:
await ensure_current_partitions()
await drop_old_partitions()
except Exception as e:
logger.error(f"Partition manager error: {e}")
# ---------------------------------------------------------------------------
# Cleanup (replaces SQLite purge_loop — now a no-op since Postgres uses
# partition drops instead of row-by-row deletes for device_logs; heartbeats
# and commands are still purged by row deletion)
# ---------------------------------------------------------------------------
async def purge_old_data(retention_days: int | None = None):
days = retention_days or settings.mqtt_data_retention_days
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
async with AsyncSessionLocal() as session:
await session.execute(
text("DELETE FROM heartbeats WHERE received_at < :cutoff"),
{"cutoff": cutoff},
)
await session.execute(
text("DELETE FROM commands WHERE sent_at < :cutoff"),
{"cutoff": cutoff},
)
await session.commit()
logger.info(f"Purged heartbeats and commands older than {days} days")
async def purge_loop():
while True:
await asyncio.sleep(86400)
try:
await purge_old_data()
except Exception as e:
logger.error(f"Purge failed: {e}")
# ---------------------------------------------------------------------------
# Stub — no longer needed but kept so nothing that imports init_db/close_db breaks
# ---------------------------------------------------------------------------
async def init_db():
"""No-op: Postgres schema is managed by Alembic, not runtime init."""
logger.info("Postgres MQTT backend active — no SQLite init needed")
async def close_db():
"""No-op: SQLAlchemy engine lifecycle is managed by the process."""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _row_to_dict(row) -> dict:
"""Convert a SQLAlchemy RowMapping to a plain dict with ISO string timestamps."""
d = dict(row)
for key, val in d.items():
if isinstance(val, datetime):
d[key] = val.isoformat()
return d