Skip to content

Commit f0edeee

Browse files
authored
Merge pull request #181 from demilade18-git/fix/issues-124-125-126-127
fix: resolve issues #124, #125, #126, #127
2 parents a87c350 + 8d409f9 commit f0edeee

File tree

7 files changed

+335
-8
lines changed

7 files changed

+335
-8
lines changed

src/chat.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
log_info,
1414
log_warning,
1515
)
16+
from src.chat_store import chat_store
1617

1718

1819
class ChatMessage(BaseModel):
@@ -121,6 +122,8 @@ async def send_message(self, message: ChatMessage) -> bool:
121122
self.message_history[message.conversation_id] = []
122123
self.message_history[message.conversation_id].append(message)
123124

125+
chat_store.save_message(message)
126+
124127
if message.conversation_id in self.active_connections:
125128
disconnected: List[WebSocket] = []
126129
for websocket in self.active_connections[message.conversation_id]:
@@ -171,8 +174,27 @@ async def broadcast_event(
171174
def get_message_history(
172175
self, conversation_id: str, limit: int = 50
173176
) -> List[ChatMessage]:
174-
"""Return the most recent messages for a conversation."""
177+
"""Return the most recent messages for a conversation.
178+
179+
Falls back to the DB when the in-memory cache is empty (e.g. after restart).
180+
"""
175181
messages = self.message_history.get(conversation_id, [])
182+
if not messages:
183+
db_rows = chat_store.get_messages(conversation_id, limit=limit)
184+
messages = [
185+
ChatMessage(
186+
id=r["id"],
187+
sender_id=r["sender_id"],
188+
sender_type=r["sender_type"],
189+
content=r["content"],
190+
timestamp=r["timestamp"],
191+
conversation_id=r["conversation_id"],
192+
)
193+
for r in db_rows
194+
]
195+
if messages:
196+
async_lock_safe: List[ChatMessage] = messages
197+
self.message_history[conversation_id] = async_lock_safe
176198
return messages[-limit:] if len(messages) > limit else messages
177199

178200
def get_user_conversations(self, user_id: str) -> List[str]:
@@ -200,6 +222,8 @@ async def escalate_conversation(
200222
self.conversation_assignments[conversation_id] = None
201223
self.conversation_escalated_at[conversation_id] = escalation.timestamp
202224

225+
chat_store.save_escalation(escalation)
226+
203227
if conversation_id in self.active_connections:
204228
escalation_notification = {
205229
"type": "escalation",

src/chat_store.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Persistent storage for chat messages and escalations using SQLAlchemy.
2+
3+
NOTE: Active WebSocket connections are intentionally kept in-memory only
4+
(in ChatManager). Only messages and escalation events are persisted here.
5+
"""
6+
from __future__ import annotations
7+
8+
import logging
9+
from datetime import datetime
10+
from typing import Any, Dict, List, Optional
11+
12+
from sqlalchemy import Column, DateTime, Integer, String, Text, create_engine, text
13+
from sqlalchemy import MetaData, Table
14+
from sqlalchemy.dialects.postgresql import insert as pg_insert
15+
16+
import src.db as _db
17+
18+
logger = logging.getLogger("veritix.chat_store")
19+
20+
_MESSAGES_TABLE = "chat_messages"
21+
_ESCALATIONS_TABLE = "chat_escalations"
22+
23+
24+
def _get_engine():
25+
return _db.get_engine()
26+
27+
28+
def _ensure_tables(engine) -> None:
29+
metadata = MetaData()
30+
Table(
31+
_MESSAGES_TABLE,
32+
metadata,
33+
Column("id", String, primary_key=True),
34+
Column("conversation_id", String, nullable=False),
35+
Column("sender_id", String, nullable=False),
36+
Column("sender_type", String, nullable=False),
37+
Column("content", Text, nullable=False),
38+
Column("timestamp", DateTime, nullable=False),
39+
Column("metadata_json", Text),
40+
)
41+
Table(
42+
_ESCALATIONS_TABLE,
43+
metadata,
44+
Column("id", String, primary_key=True),
45+
Column("conversation_id", String, nullable=False),
46+
Column("reason", String, nullable=False),
47+
Column("timestamp", DateTime, nullable=False),
48+
Column("metadata_json", Text),
49+
)
50+
with engine.begin() as conn:
51+
metadata.create_all(conn) # type: ignore[arg-type]
52+
53+
54+
class ChatStore:
55+
"""Persists chat messages and escalation events to Postgres."""
56+
57+
def __init__(self) -> None:
58+
self._ready = False
59+
60+
def _init(self, engine) -> None:
61+
if not self._ready:
62+
try:
63+
_ensure_tables(engine)
64+
self._ready = True
65+
except Exception as exc:
66+
logger.error("ChatStore: failed to create tables: %s", exc)
67+
68+
# ------------------------------------------------------------------
69+
# Messages
70+
# ------------------------------------------------------------------
71+
72+
def save_message(self, message: Any) -> None:
73+
"""Persist a ChatMessage to the DB (best-effort)."""
74+
engine = _get_engine()
75+
if engine is None:
76+
return
77+
self._init(engine)
78+
import json
79+
try:
80+
with engine.begin() as conn:
81+
conn.execute(
82+
text(
83+
f"INSERT INTO {_MESSAGES_TABLE} " # noqa: S608
84+
"(id, conversation_id, sender_id, sender_type, content, timestamp, metadata_json) "
85+
"VALUES (:id, :conv, :sender, :stype, :content, :ts, :meta) "
86+
"ON CONFLICT (id) DO NOTHING"
87+
),
88+
{
89+
"id": message.id,
90+
"conv": message.conversation_id,
91+
"sender": message.sender_id,
92+
"stype": message.sender_type,
93+
"content": message.content,
94+
"ts": message.timestamp,
95+
"meta": json.dumps(message.metadata or {}),
96+
},
97+
)
98+
except Exception as exc:
99+
logger.error("ChatStore: save_message failed: %s", exc)
100+
101+
def get_messages(self, conversation_id: str, limit: int = 50) -> List[Dict[str, Any]]:
102+
"""Retrieve the most recent messages for a conversation from DB."""
103+
engine = _get_engine()
104+
if engine is None:
105+
return []
106+
self._init(engine)
107+
try:
108+
with engine.connect() as conn:
109+
rows = conn.execute(
110+
text(
111+
f"SELECT id, conversation_id, sender_id, sender_type, content, timestamp, metadata_json " # noqa: S608
112+
f"FROM {_MESSAGES_TABLE} "
113+
"WHERE conversation_id = :conv "
114+
"ORDER BY timestamp DESC "
115+
"LIMIT :lim"
116+
),
117+
{"conv": conversation_id, "lim": limit},
118+
).fetchall()
119+
return [
120+
{
121+
"id": r[0],
122+
"conversation_id": r[1],
123+
"sender_id": r[2],
124+
"sender_type": r[3],
125+
"content": r[4],
126+
"timestamp": r[5],
127+
"metadata": r[6],
128+
}
129+
for r in reversed(rows)
130+
]
131+
except Exception as exc:
132+
logger.error("ChatStore: get_messages failed: %s", exc)
133+
return []
134+
135+
# ------------------------------------------------------------------
136+
# Escalations
137+
# ------------------------------------------------------------------
138+
139+
def save_escalation(self, escalation: Any) -> None:
140+
"""Persist an EscalationEvent to the DB (best-effort)."""
141+
engine = _get_engine()
142+
if engine is None:
143+
return
144+
self._init(engine)
145+
import json
146+
try:
147+
with engine.begin() as conn:
148+
conn.execute(
149+
text(
150+
f"INSERT INTO {_ESCALATIONS_TABLE} " # noqa: S608
151+
"(id, conversation_id, reason, timestamp, metadata_json) "
152+
"VALUES (:id, :conv, :reason, :ts, :meta) "
153+
"ON CONFLICT (id) DO NOTHING"
154+
),
155+
{
156+
"id": escalation.id,
157+
"conv": escalation.conversation_id,
158+
"reason": escalation.reason,
159+
"ts": escalation.timestamp,
160+
"meta": json.dumps(escalation.metadata or {}),
161+
},
162+
)
163+
except Exception as exc:
164+
logger.error("ChatStore: save_escalation failed: %s", exc)
165+
166+
167+
# Singleton
168+
chat_store = ChatStore()

src/etl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,4 +584,5 @@ def run_etl_once() -> None:
584584
status=status,
585585
rejected_count=rejected_count,
586586
)
587+
ETL_JOBS_TOTAL.labels(status=status).inc()
587588
log_info("ETL job completed", {"status": status, "rejected_count": rejected_count})

src/event_store.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Event store: loads events from Postgres event_sales_summary with a 60-second TTL cache."""
2+
import time
3+
from typing import Any, Dict, List, Optional
4+
5+
from sqlalchemy import text
6+
7+
import src.db as _db
8+
from src.logging_config import log_info
9+
10+
_cache: Optional[List[Dict[str, Any]]] = None
11+
_cache_ts: float = 0.0
12+
_CACHE_TTL = 60.0 # seconds
13+
14+
15+
def get_events_from_db() -> List[Dict[str, Any]]:
16+
"""Return events from Postgres event_sales_summary, cached for 60 s.
17+
18+
Falls back to an empty list when the DB is unavailable.
19+
"""
20+
global _cache, _cache_ts
21+
now = time.monotonic()
22+
if _cache is not None and (now - _cache_ts) < _CACHE_TTL:
23+
return _cache
24+
25+
engine = _db.get_engine()
26+
if engine is None:
27+
return _cache or []
28+
29+
try:
30+
with engine.connect() as conn:
31+
rows = conn.execute(
32+
text(
33+
"SELECT event_id, event_name, total_tickets, total_revenue, last_updated "
34+
"FROM event_sales_summary"
35+
)
36+
).fetchall()
37+
events: List[Dict[str, Any]] = [
38+
{
39+
"id": str(row[0]),
40+
"name": str(row[1] or ""),
41+
"description": "",
42+
"event_type": "general",
43+
"location": "",
44+
"date": row[4].isoformat() if row[4] else "",
45+
"price": float(row[3] or 0) / max(int(row[2] or 1), 1),
46+
"capacity": int(row[2] or 0),
47+
}
48+
for row in rows
49+
]
50+
_cache = events
51+
_cache_ts = now
52+
log_info("event_store: loaded events from DB", {"count": len(events)})
53+
return events
54+
except Exception as exc:
55+
from src.logging_config import log_error
56+
log_error("event_store: DB query failed", {"error": str(exc)})
57+
return _cache or []
58+
59+
60+
def invalidate_cache() -> None:
61+
"""Force the next call to re-query the database."""
62+
global _cache, _cache_ts
63+
_cache = None
64+
_cache_ts = 0.0

src/main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@
4242
log_warning,
4343
setup_logging,
4444
)
45-
from src.mock_events import get_mock_events
45+
from src.event_store import get_events_from_db
46+
from src.mock_events import get_mock_events # kept for test usage only
4647
from src.report_service import (
4748
_query_daily_sales,
4849
_query_invalid_scans,
@@ -531,7 +532,7 @@ def search_events(payload: SearchEventsRequest) -> Any:
531532
"""Search for events using natural language keyword extraction."""
532533
try:
533534
keywords = extract_keywords(payload.query)
534-
all_events = get_mock_events()
535+
all_events = get_events_from_db() or get_mock_events()
535536
matching_events = filter_events_by_keywords(
536537
all_events,
537538
keywords,
@@ -570,8 +571,6 @@ def recommend_events(payload: RecommendRequest) -> RecommendResponse:
570571
user_id = payload.user_id
571572
# Prefer DB-sourced history; fall back to mock data when DB is unavailable.
572573
user_events_dict = get_user_events_from_db()
573-
if not user_events_dict:
574-
user_events_dict = mock_user_events
575574
similarity_matrix = build_item_similarity_matrix(user_events_dict)
576575
recommended = get_item_recommendations(
577576
user_id=user_id,

src/revenue_sharing_service.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from decimal import ROUND_HALF_UP, Decimal
55
from typing import Dict, List, Optional, Tuple
66

7+
from fastapi import HTTPException
78
from src.logging_config import log_error, log_info
89
from src.revenue_sharing_models import (
910
EventRevenueInput,
@@ -52,10 +53,19 @@ def calculate_revenue_shares(self, input_data: EventRevenueInput) -> RevenueCalc
5253

5354
# Get stakeholders for this event (in a real implementation, this would come from DB)
5455
stakeholders = self._get_default_stakeholders(input_data.event_id)
55-
56+
5657
# Apply custom rules if provided
5758
rules = input_data.custom_rules or self._get_default_rules()
58-
59+
60+
# Guard: total rule percentages must not exceed 100%
61+
total_pct = sum(rule.percentage for rule in rules)
62+
if total_pct > 100.0:
63+
breakdown = {rule.id: rule.percentage for rule in rules}
64+
raise HTTPException(
65+
status_code=400,
66+
detail=f"Rule percentages sum to {total_pct:.2f}% which exceeds 100%. Breakdown: {breakdown}",
67+
)
68+
5969
# Calculate distributions
6070
distributions, remaining_balance = self._calculate_distributions(
6171
net_revenue, stakeholders, rules

0 commit comments

Comments
 (0)