Skip to content

Commit 2bd4879

Browse files
authored
Merge branch 'main' into claude/issue-2915-20251209-1250
2 parents 077d6ec + aa203f1 commit 2bd4879

File tree

142 files changed

+10039
-1919
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

142 files changed

+10039
-1919
lines changed

comments/services/key_factors/suggestions.py

Lines changed: 244 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,60 @@
11
import json
2+
import logging
23
import textwrap
3-
from typing import List, Optional
4+
from typing import Literal, Optional, Union
45

56
from django.conf import settings
7+
from django.core.validators import URLValidator
68
from pydantic import (
9+
field_validator,
10+
ValidationInfo,
11+
ValidationError as PydanticValidationError,
712
BaseModel,
813
Field,
914
model_validator,
10-
ValidationError as PydanticValidationError,
1115
)
1216
from rest_framework.exceptions import ValidationError
1317

14-
from comments.models import KeyFactor, KeyFactorDriver
18+
from comments.models import KeyFactor, KeyFactorDriver, KeyFactorNews, KeyFactorBaseRate
1519
from posts.models import Post
1620
from questions.models import Question
1721
from utils.openai import pydantic_to_openai_json_schema, get_openai_client
1822

19-
# Central constraints
2023
MAX_LENGTH = 50
2124

25+
logger = logging.getLogger(__name__)
26+
27+
28+
def _normalize_impact_fields(data: dict) -> dict:
29+
"""
30+
Normalize impact_direction and certainty fields.
31+
- Coerces values to allowed sets {1, -1} for impact_direction and {-1} for certainty
32+
- Enforces XOR: certainty (-1) overrides impact_direction
33+
"""
34+
if not isinstance(data, dict):
35+
return data
36+
37+
def coerce(value, allowed):
38+
try:
39+
v = int(value)
40+
return v if v in allowed else None
41+
except (TypeError, ValueError):
42+
return None
43+
44+
impact_direction = coerce(data.get("impact_direction"), {1, -1})
45+
certainty = coerce(data.get("certainty"), {-1})
46+
47+
# Enforce XOR preference: certainty (-1) overrides impact_direction
48+
if certainty == -1:
49+
impact_direction = None
2250

23-
# TODO: unit tests!
24-
class KeyFactorResponse(BaseModel):
51+
data.update(impact_direction=impact_direction, certainty=certainty)
52+
53+
return data
54+
55+
56+
class DriverResponse(BaseModel):
57+
type: str = Field("driver", description="Type identifier")
2558
text: str = Field(
2659
..., description="Concise single-sentence key factor (<= 50 chars)"
2760
)
@@ -41,44 +74,102 @@ class KeyFactorResponse(BaseModel):
4174
@model_validator(mode="before")
4275
@classmethod
4376
def normalize_fields(cls, data):
44-
if not isinstance(data, dict):
45-
return data
77+
return _normalize_impact_fields(data)
4678

47-
def coerce(value, allowed):
48-
try:
49-
v = int(value)
50-
return v if v in allowed else None
51-
except (TypeError, ValueError):
52-
return None
5379

54-
impact_direction = coerce(data.get("impact_direction"), {1, -1})
55-
certainty = coerce(data.get("certainty"), {-1})
80+
class NewsResponse(BaseModel):
81+
url: str = Field(
82+
None,
83+
description="URL of the news article extracted from the comment",
84+
)
85+
impact_direction: Optional[int] = Field(
86+
None,
87+
description="Set to 1 or -1 to indicate direction; omit if certainty is set",
88+
)
89+
certainty: Optional[int] = Field(
90+
None,
91+
description="Set to -1 only if the article increases uncertainty; else omit",
92+
)
5693

57-
# Enforce XOR preference: certainty (-1) overrides impact_direction
58-
if certainty == -1:
59-
impact_direction = None
94+
@model_validator(mode="before")
95+
@classmethod
96+
def normalize_fields(cls, data):
97+
return _normalize_impact_fields(data)
6098

61-
data.update(impact_direction=impact_direction, certainty=certainty)
6299

63-
return data
100+
class BaseRateResponse(BaseModel):
101+
type: str = Field("base_rate", description="Type identifier")
102+
base_rate_type: str = Field(
103+
..., description="'frequency' or 'trend' - must be one of these two types"
104+
)
105+
reference_class: str = Field(
106+
..., description="Reference class for the base rate (required)"
107+
)
108+
unit: str = Field(..., description="Unit of measurement (required)")
109+
source_url: str = Field(
110+
..., description="URL of the base rate data source (required)"
111+
)
64112

113+
# Frequency-specific fields
114+
rate_numerator: Optional[int] = Field(
115+
None, description="Numerator for frequency type (required for frequency)"
116+
)
117+
rate_denominator: Optional[int] = Field(
118+
None, description="Denominator for frequency type (required for frequency)"
119+
)
120+
# Trend-specific fields
121+
projected_value: Optional[float] = Field(
122+
None, description="Projected value for trend type (required for trend)"
123+
)
124+
projected_by_year: Optional[int] = Field(
125+
None, description="Year for trend projection (required for trend)"
126+
)
127+
extrapolation: Optional[Literal["linear", "exponential", "other"]] = Field(
128+
None,
129+
description=(
130+
"Extrapolation method for trend type (required for trend); "
131+
"must be one of: linear, exponential, other"
132+
),
133+
)
134+
based_on: Optional[str] = Field(
135+
None, description="What the trend is based on (optional)"
136+
)
65137

66-
class KeyFactorsResponse(BaseModel):
67-
key_factors: List[KeyFactorResponse]
138+
@field_validator("source_url")
139+
@classmethod
140+
def validate_source_url(cls, value: str, info: ValidationInfo) -> str:
141+
# Validate URL
142+
URLValidator()(value)
68143

144+
if not info.context:
145+
return value
146+
147+
comment = info.context.get("comment")
148+
149+
if not comment:
150+
return value
151+
152+
if value.lower().strip("/") in comment.lower():
153+
return value
154+
155+
raise ValueError("URL must be present in the comment")
156+
157+
158+
KeyFactorResponseType = Union[DriverResponse, NewsResponse, BaseRateResponse]
69159

70-
def _convert_llm_response_to_key_factor(
71-
post: Post, response: KeyFactorResponse
72-
) -> KeyFactor:
73-
"""
74-
Generating and normalizing KeyFactor object (but not saving, just for the structure) from LLM payload
75-
"""
76160

77-
option = response.option.lower() if response.option else None
161+
class KeyFactorsResponse(BaseModel):
162+
key_factors: list[KeyFactorResponseType]
163+
164+
165+
def _create_driver_key_factor(post: Post, response: DriverResponse) -> KeyFactor:
166+
"""Create a KeyFactor with a Driver type from LLM response."""
78167
question_id = None
79168
question_option = None
80169

81-
if option:
170+
# Resolve option field to question_id and question_option
171+
if response.option:
172+
option = response.option.lower()
82173
if (
83174
post.question
84175
and post.question.type == Question.QuestionType.MULTIPLE_CHOICE
@@ -91,19 +182,68 @@ def _convert_llm_response_to_key_factor(
91182

92183
if post.group_of_questions:
93184
question_id = next(
94-
(q.id for q in post.get_questions() if q.label.lower() == option), None
185+
(q.id for q in post.get_questions() if q.label.lower() == option),
186+
None,
95187
)
96188

97-
return KeyFactor(
98-
question_id=question_id,
99-
question_option=question_option,
100-
driver=KeyFactorDriver(
101-
text=response.text,
102-
certainty=response.certainty,
103-
impact_direction=response.impact_direction,
104-
),
189+
kf = KeyFactor(question_id=question_id, question_option=question_option)
190+
kf.driver = KeyFactorDriver(
191+
text=response.text,
192+
certainty=response.certainty,
193+
impact_direction=response.impact_direction,
105194
)
106195

196+
return kf
197+
198+
199+
def _create_news_key_factor(response: NewsResponse) -> KeyFactor:
200+
"""
201+
Create a KeyFactor with a News type from LLM response.
202+
"""
203+
kf = KeyFactor()
204+
kf.news = KeyFactorNews(
205+
url=response.url,
206+
# Other fields will be extracted on the frontend side
207+
# by fetching link-preview endpoint
208+
certainty=response.certainty,
209+
impact_direction=response.impact_direction,
210+
)
211+
212+
return kf
213+
214+
215+
def _create_base_rate_key_factor(response: BaseRateResponse) -> KeyFactor:
216+
"""Create a KeyFactor with a BaseRate type from LLM response."""
217+
kf = KeyFactor()
218+
kf.base_rate = KeyFactorBaseRate(
219+
type=response.base_rate_type,
220+
reference_class=response.reference_class,
221+
rate_numerator=response.rate_numerator,
222+
rate_denominator=response.rate_denominator,
223+
projected_value=response.projected_value,
224+
projected_by_year=response.projected_by_year,
225+
unit=response.unit,
226+
extrapolation=response.extrapolation,
227+
based_on=response.based_on or "",
228+
source=response.source_url,
229+
)
230+
return kf
231+
232+
233+
def _convert_llm_response_to_key_factor(
234+
post: Post, response: KeyFactorResponseType
235+
) -> KeyFactor:
236+
"""
237+
Convert LLM response to KeyFactor object (but not saving, just for the structure).
238+
Dispatches to appropriate type-specific converter.
239+
"""
240+
if isinstance(response, DriverResponse):
241+
return _create_driver_key_factor(post, response)
242+
elif isinstance(response, NewsResponse):
243+
return _create_news_key_factor(response)
244+
elif isinstance(response, BaseRateResponse):
245+
return _create_base_rate_key_factor(response)
246+
107247

108248
def build_post_question_summary(post: Post) -> tuple[str, Question.QuestionType]:
109249
"""
@@ -177,7 +317,7 @@ def generate_keyfactors(
177317
comment: str,
178318
existing_key_factors: list[dict],
179319
type_instructions: str,
180-
) -> list[KeyFactorResponse]:
320+
) -> list[KeyFactorResponseType]:
181321
"""
182322
Generate key factors based on question type and comment.
183323
"""
@@ -198,21 +338,33 @@ def generate_keyfactors(
198338
key factors should only be relate to that.
199339
The key factors should be the most important things that the user is trying to say
200340
in the comment and how it might influence the predictions on the question.
201-
The key factors text should be single sentences, not longer than {MAX_LENGTH} characters
202-
and they should only contain the key factor, no other text (e.g.: do not reference the user).
203341
204-
Each key factor should describe something that could influence the forecast for the question.
205-
Also specify the direction of impact as described below.
342+
You can generate three types of key factors:
343+
1. Driver: A factor that drives the outcome (e.g., "X policy change increases likelihood")
344+
- Text should be single sentences under {MAX_LENGTH} characters.
345+
- Should only contain the key factor, no other text (e.g.: do not reference the user).
346+
- Specify impact_direction (1/-1) or certainty (-1)
347+
2. News: A relevant news article found in the comment.
348+
- url: REQUIRED. Must be extracted from the comment body. If no URL, skip this factor.
349+
- Specify impact_direction (1/-1) or certainty (-1) for how this article affects the forecast
350+
3. BaseRate: A historical base rate or reference frequency/trend.
351+
- source_url: REQUIRED. Must be a real URL extracted verbatim from <user_comment>.
352+
- CRITICAL: Do not use URLs from <question_summary> or hallucinate URLs.
353+
- If no valid URL is found in the comment, skip this factor entirely.
354+
355+
The key factors should represent the most important things influencing the forecast.
206356
207357
{type_instructions}
208358
209359
Output rules:
210360
- Return valid JSON only, matching the schema.
211-
- Each key factor is under {MAX_LENGTH} characters.
212-
- Do not include any key factors that are already in the existing key factors list. Read that carefully and make sure you don't have any duplicates.
213-
- Be conservative and only include clearly relevant factors.
361+
- Include a "type" field for each factor: "driver", "news", or "base_rate"
362+
- For base_rate: source_url must come from <user_comment> only, never from <question_summary>
363+
- Do not duplicate existing key factors - check the list carefully
364+
- Ensure suggested key factors do not duplicate each other
365+
- Be conservative and only include clearly relevant factors
214366
- Do not include any formatting like quotes, numbering or other punctuation
215-
- If the comment provides no meaningful forecasting insight, return the literal string "None".
367+
- If the comment provides no meaningful forecasting insight, return empty list e.g {{"key_factors":[]}}.
216368
217369
The question details are:
218370
<question_summary>
@@ -255,28 +407,64 @@ def generate_keyfactors(
255407

256408
content = response.choices[0].message.content
257409

258-
if content is None or content.lower() == "none":
259-
return []
260-
261410
try:
262411
data = json.loads(content)
263-
# TODO: replace KeyFactorsResponse with plain list
264-
parsed = KeyFactorsResponse(**data)
265-
return parsed.key_factors
266-
except (json.JSONDecodeError, PydanticValidationError):
412+
except json.JSONDecodeError:
267413
return []
268414

415+
# Validate each key factor individually
416+
type_map = {
417+
"driver": DriverResponse,
418+
"news": NewsResponse,
419+
"base_rate": BaseRateResponse,
420+
}
421+
422+
validated_key_factors = []
423+
for item in data.get("key_factors", []):
424+
try:
425+
kf_type = item.get("type")
426+
model_class = type_map.get(kf_type)
427+
428+
if model_class:
429+
model_instance = model_class.model_validate(
430+
item, context={"comment": comment}
431+
)
432+
validated_key_factors.append(model_instance)
433+
except PydanticValidationError as e:
434+
logger.debug(f"Validation error for key factor at index: {e}")
435+
436+
return validated_key_factors
437+
269438

270439
def _serialize_key_factor(kf: KeyFactor):
271440
option = kf.question.label if kf.question else kf.question_option
272441

273442
if kf.driver_id:
274443
return {
444+
"type": "driver",
275445
"text": kf.driver.text,
276446
"impact_direction": kf.driver.impact_direction,
277447
"certainty": kf.driver.certainty,
278448
"option": option or None,
279449
}
450+
elif kf.news_id:
451+
return {
452+
"type": "news",
453+
"title": kf.news.title,
454+
"url": kf.news.url,
455+
}
456+
elif kf.base_rate_id:
457+
return {
458+
"type": "base_rate",
459+
"base_rate_type": kf.base_rate.type,
460+
"reference_class": kf.base_rate.reference_class,
461+
"rate_numerator": kf.base_rate.rate_numerator,
462+
"rate_denominator": kf.base_rate.rate_denominator,
463+
"projected_value": kf.base_rate.projected_value,
464+
"projected_by_year": kf.base_rate.projected_by_year,
465+
"unit": kf.base_rate.unit,
466+
"extrapolation": kf.base_rate.extrapolation,
467+
}
280468

281469

282470
def generate_key_factors_for_comment(

0 commit comments

Comments
 (0)