11import json
2+ import logging
23import textwrap
3- from typing import List , Optional
4+ from typing import Literal , Optional , Union
45
56from django .conf import settings
7+ from django .core .validators import URLValidator
68from pydantic import (
9+ field_validator ,
10+ ValidationInfo ,
11+ ValidationError as PydanticValidationError ,
712 BaseModel ,
813 Field ,
914 model_validator ,
10- ValidationError as PydanticValidationError ,
1115)
1216from rest_framework .exceptions import ValidationError
1317
14- from comments .models import KeyFactor , KeyFactorDriver
18+ from comments .models import KeyFactor , KeyFactorDriver , KeyFactorNews , KeyFactorBaseRate
1519from posts .models import Post
1620from questions .models import Question
1721from utils .openai import pydantic_to_openai_json_schema , get_openai_client
1822
19- # Central constraints
2023MAX_LENGTH = 50
2124
25+ logger = logging .getLogger (__name__ )
26+
27+
28+ def _normalize_impact_fields (data : dict ) -> dict :
29+ """
30+ Normalize impact_direction and certainty fields.
31+ - Coerces values to allowed sets {1, -1} for impact_direction and {-1} for certainty
32+ - Enforces XOR: certainty (-1) overrides impact_direction
33+ """
34+ if not isinstance (data , dict ):
35+ return data
36+
37+ def coerce (value , allowed ):
38+ try :
39+ v = int (value )
40+ return v if v in allowed else None
41+ except (TypeError , ValueError ):
42+ return None
43+
44+ impact_direction = coerce (data .get ("impact_direction" ), {1 , - 1 })
45+ certainty = coerce (data .get ("certainty" ), {- 1 })
46+
47+ # Enforce XOR preference: certainty (-1) overrides impact_direction
48+ if certainty == - 1 :
49+ impact_direction = None
2250
23- # TODO: unit tests!
24- class KeyFactorResponse (BaseModel ):
51+ data .update (impact_direction = impact_direction , certainty = certainty )
52+
53+ return data
54+
55+
56+ class DriverResponse (BaseModel ):
57+ type : str = Field ("driver" , description = "Type identifier" )
2558 text : str = Field (
2659 ..., description = "Concise single-sentence key factor (<= 50 chars)"
2760 )
@@ -41,44 +74,102 @@ class KeyFactorResponse(BaseModel):
4174 @model_validator (mode = "before" )
4275 @classmethod
4376 def normalize_fields (cls , data ):
44- if not isinstance (data , dict ):
45- return data
77+ return _normalize_impact_fields (data )
4678
47- def coerce (value , allowed ):
48- try :
49- v = int (value )
50- return v if v in allowed else None
51- except (TypeError , ValueError ):
52- return None
5379
54- impact_direction = coerce (data .get ("impact_direction" ), {1 , - 1 })
55- certainty = coerce (data .get ("certainty" ), {- 1 })
80+ class NewsResponse (BaseModel ):
81+ url : str = Field (
82+ None ,
83+ description = "URL of the news article extracted from the comment" ,
84+ )
85+ impact_direction : Optional [int ] = Field (
86+ None ,
87+ description = "Set to 1 or -1 to indicate direction; omit if certainty is set" ,
88+ )
89+ certainty : Optional [int ] = Field (
90+ None ,
91+ description = "Set to -1 only if the article increases uncertainty; else omit" ,
92+ )
5693
57- # Enforce XOR preference: certainty (-1) overrides impact_direction
58- if certainty == - 1 :
59- impact_direction = None
94+ @model_validator (mode = "before" )
95+ @classmethod
96+ def normalize_fields (cls , data ):
97+ return _normalize_impact_fields (data )
6098
61- data .update (impact_direction = impact_direction , certainty = certainty )
6299
63- return data
100+ class BaseRateResponse (BaseModel ):
101+ type : str = Field ("base_rate" , description = "Type identifier" )
102+ base_rate_type : str = Field (
103+ ..., description = "'frequency' or 'trend' - must be one of these two types"
104+ )
105+ reference_class : str = Field (
106+ ..., description = "Reference class for the base rate (required)"
107+ )
108+ unit : str = Field (..., description = "Unit of measurement (required)" )
109+ source_url : str = Field (
110+ ..., description = "URL of the base rate data source (required)"
111+ )
64112
113+ # Frequency-specific fields
114+ rate_numerator : Optional [int ] = Field (
115+ None , description = "Numerator for frequency type (required for frequency)"
116+ )
117+ rate_denominator : Optional [int ] = Field (
118+ None , description = "Denominator for frequency type (required for frequency)"
119+ )
120+ # Trend-specific fields
121+ projected_value : Optional [float ] = Field (
122+ None , description = "Projected value for trend type (required for trend)"
123+ )
124+ projected_by_year : Optional [int ] = Field (
125+ None , description = "Year for trend projection (required for trend)"
126+ )
127+ extrapolation : Optional [Literal ["linear" , "exponential" , "other" ]] = Field (
128+ None ,
129+ description = (
130+ "Extrapolation method for trend type (required for trend); "
131+ "must be one of: linear, exponential, other"
132+ ),
133+ )
134+ based_on : Optional [str ] = Field (
135+ None , description = "What the trend is based on (optional)"
136+ )
65137
66- class KeyFactorsResponse (BaseModel ):
67- key_factors : List [KeyFactorResponse ]
138+ @field_validator ("source_url" )
139+ @classmethod
140+ def validate_source_url (cls , value : str , info : ValidationInfo ) -> str :
141+ # Validate URL
142+ URLValidator ()(value )
68143
144+ if not info .context :
145+ return value
146+
147+ comment = info .context .get ("comment" )
148+
149+ if not comment :
150+ return value
151+
152+ if value .lower ().strip ("/" ) in comment .lower ():
153+ return value
154+
155+ raise ValueError ("URL must be present in the comment" )
156+
157+
158+ KeyFactorResponseType = Union [DriverResponse , NewsResponse , BaseRateResponse ]
69159
70- def _convert_llm_response_to_key_factor (
71- post : Post , response : KeyFactorResponse
72- ) -> KeyFactor :
73- """
74- Generating and normalizing KeyFactor object (but not saving, just for the structure) from LLM payload
75- """
76160
77- option = response .option .lower () if response .option else None
161+ class KeyFactorsResponse (BaseModel ):
162+ key_factors : list [KeyFactorResponseType ]
163+
164+
165+ def _create_driver_key_factor (post : Post , response : DriverResponse ) -> KeyFactor :
166+ """Create a KeyFactor with a Driver type from LLM response."""
78167 question_id = None
79168 question_option = None
80169
81- if option :
170+ # Resolve option field to question_id and question_option
171+ if response .option :
172+ option = response .option .lower ()
82173 if (
83174 post .question
84175 and post .question .type == Question .QuestionType .MULTIPLE_CHOICE
@@ -91,19 +182,68 @@ def _convert_llm_response_to_key_factor(
91182
92183 if post .group_of_questions :
93184 question_id = next (
94- (q .id for q in post .get_questions () if q .label .lower () == option ), None
185+ (q .id for q in post .get_questions () if q .label .lower () == option ),
186+ None ,
95187 )
96188
97- return KeyFactor (
98- question_id = question_id ,
99- question_option = question_option ,
100- driver = KeyFactorDriver (
101- text = response .text ,
102- certainty = response .certainty ,
103- impact_direction = response .impact_direction ,
104- ),
189+ kf = KeyFactor (question_id = question_id , question_option = question_option )
190+ kf .driver = KeyFactorDriver (
191+ text = response .text ,
192+ certainty = response .certainty ,
193+ impact_direction = response .impact_direction ,
105194 )
106195
196+ return kf
197+
198+
199+ def _create_news_key_factor (response : NewsResponse ) -> KeyFactor :
200+ """
201+ Create a KeyFactor with a News type from LLM response.
202+ """
203+ kf = KeyFactor ()
204+ kf .news = KeyFactorNews (
205+ url = response .url ,
206+ # Other fields will be extracted on the frontend side
207+ # by fetching link-preview endpoint
208+ certainty = response .certainty ,
209+ impact_direction = response .impact_direction ,
210+ )
211+
212+ return kf
213+
214+
215+ def _create_base_rate_key_factor (response : BaseRateResponse ) -> KeyFactor :
216+ """Create a KeyFactor with a BaseRate type from LLM response."""
217+ kf = KeyFactor ()
218+ kf .base_rate = KeyFactorBaseRate (
219+ type = response .base_rate_type ,
220+ reference_class = response .reference_class ,
221+ rate_numerator = response .rate_numerator ,
222+ rate_denominator = response .rate_denominator ,
223+ projected_value = response .projected_value ,
224+ projected_by_year = response .projected_by_year ,
225+ unit = response .unit ,
226+ extrapolation = response .extrapolation ,
227+ based_on = response .based_on or "" ,
228+ source = response .source_url ,
229+ )
230+ return kf
231+
232+
233+ def _convert_llm_response_to_key_factor (
234+ post : Post , response : KeyFactorResponseType
235+ ) -> KeyFactor :
236+ """
237+ Convert LLM response to KeyFactor object (but not saving, just for the structure).
238+ Dispatches to appropriate type-specific converter.
239+ """
240+ if isinstance (response , DriverResponse ):
241+ return _create_driver_key_factor (post , response )
242+ elif isinstance (response , NewsResponse ):
243+ return _create_news_key_factor (response )
244+ elif isinstance (response , BaseRateResponse ):
245+ return _create_base_rate_key_factor (response )
246+
107247
108248def build_post_question_summary (post : Post ) -> tuple [str , Question .QuestionType ]:
109249 """
@@ -177,7 +317,7 @@ def generate_keyfactors(
177317 comment : str ,
178318 existing_key_factors : list [dict ],
179319 type_instructions : str ,
180- ) -> list [KeyFactorResponse ]:
320+ ) -> list [KeyFactorResponseType ]:
181321 """
182322 Generate key factors based on question type and comment.
183323 """
@@ -198,21 +338,33 @@ def generate_keyfactors(
198338 key factors should only be relate to that.
199339 The key factors should be the most important things that the user is trying to say
200340 in the comment and how it might influence the predictions on the question.
201- The key factors text should be single sentences, not longer than { MAX_LENGTH } characters
202- and they should only contain the key factor, no other text (e.g.: do not reference the user).
203341
204- Each key factor should describe something that could influence the forecast for the question.
205- Also specify the direction of impact as described below.
342+ You can generate three types of key factors:
343+ 1. Driver: A factor that drives the outcome (e.g., "X policy change increases likelihood")
344+ - Text should be single sentences under { MAX_LENGTH } characters.
345+ - Should only contain the key factor, no other text (e.g.: do not reference the user).
346+ - Specify impact_direction (1/-1) or certainty (-1)
347+ 2. News: A relevant news article found in the comment.
348+ - url: REQUIRED. Must be extracted from the comment body. If no URL, skip this factor.
349+ - Specify impact_direction (1/-1) or certainty (-1) for how this article affects the forecast
350+ 3. BaseRate: A historical base rate or reference frequency/trend.
351+ - source_url: REQUIRED. Must be a real URL extracted verbatim from <user_comment>.
352+ - CRITICAL: Do not use URLs from <question_summary> or hallucinate URLs.
353+ - If no valid URL is found in the comment, skip this factor entirely.
354+
355+ The key factors should represent the most important things influencing the forecast.
206356
207357 { type_instructions }
208358
209359 Output rules:
210360 - Return valid JSON only, matching the schema.
211- - Each key factor is under { MAX_LENGTH } characters.
212- - Do not include any key factors that are already in the existing key factors list. Read that carefully and make sure you don't have any duplicates.
213- - Be conservative and only include clearly relevant factors.
361+ - Include a "type" field for each factor: "driver", "news", or "base_rate"
362+ - For base_rate: source_url must come from <user_comment> only, never from <question_summary>
363+ - Do not duplicate existing key factors - check the list carefully
364+ - Ensure suggested key factors do not duplicate each other
365+ - Be conservative and only include clearly relevant factors
214366 - Do not include any formatting like quotes, numbering or other punctuation
215- - If the comment provides no meaningful forecasting insight, return the literal string "None" .
367+ - If the comment provides no meaningful forecasting insight, return empty list e.g {{"key_factors":[]}} .
216368
217369 The question details are:
218370 <question_summary>
@@ -255,28 +407,64 @@ def generate_keyfactors(
255407
256408 content = response .choices [0 ].message .content
257409
258- if content is None or content .lower () == "none" :
259- return []
260-
261410 try :
262411 data = json .loads (content )
263- # TODO: replace KeyFactorsResponse with plain list
264- parsed = KeyFactorsResponse (** data )
265- return parsed .key_factors
266- except (json .JSONDecodeError , PydanticValidationError ):
412+ except json .JSONDecodeError :
267413 return []
268414
415+ # Validate each key factor individually
416+ type_map = {
417+ "driver" : DriverResponse ,
418+ "news" : NewsResponse ,
419+ "base_rate" : BaseRateResponse ,
420+ }
421+
422+ validated_key_factors = []
423+ for item in data .get ("key_factors" , []):
424+ try :
425+ kf_type = item .get ("type" )
426+ model_class = type_map .get (kf_type )
427+
428+ if model_class :
429+ model_instance = model_class .model_validate (
430+ item , context = {"comment" : comment }
431+ )
432+ validated_key_factors .append (model_instance )
433+ except PydanticValidationError as e :
434+ logger .debug (f"Validation error for key factor at index: { e } " )
435+
436+ return validated_key_factors
437+
269438
270439def _serialize_key_factor (kf : KeyFactor ):
271440 option = kf .question .label if kf .question else kf .question_option
272441
273442 if kf .driver_id :
274443 return {
444+ "type" : "driver" ,
275445 "text" : kf .driver .text ,
276446 "impact_direction" : kf .driver .impact_direction ,
277447 "certainty" : kf .driver .certainty ,
278448 "option" : option or None ,
279449 }
450+ elif kf .news_id :
451+ return {
452+ "type" : "news" ,
453+ "title" : kf .news .title ,
454+ "url" : kf .news .url ,
455+ }
456+ elif kf .base_rate_id :
457+ return {
458+ "type" : "base_rate" ,
459+ "base_rate_type" : kf .base_rate .type ,
460+ "reference_class" : kf .base_rate .reference_class ,
461+ "rate_numerator" : kf .base_rate .rate_numerator ,
462+ "rate_denominator" : kf .base_rate .rate_denominator ,
463+ "projected_value" : kf .base_rate .projected_value ,
464+ "projected_by_year" : kf .base_rate .projected_by_year ,
465+ "unit" : kf .base_rate .unit ,
466+ "extrapolation" : kf .base_rate .extrapolation ,
467+ }
280468
281469
282470def generate_key_factors_for_comment (
0 commit comments