diff --git a/api/src/backend/entities.py b/api/src/backend/entities.py index 9f3814ca..5b02bba7 100644 --- a/api/src/backend/entities.py +++ b/api/src/backend/entities.py @@ -314,13 +314,13 @@ class AgentStatus(Enum): scored = "scored" # All evaluations complete replaced = "replaced" # Replaced by newer version pruned = "pruned" # Pruned due to low score compared to top agent - + # Legacy statuses for backward compatibility during transition awaiting_screening = "awaiting_screening_1" # Map to stage 1 screening = "screening_1" # Map to stage 1 failed_screening = "failed_screening_1" # Map to stage 1 fail evaluation = "evaluating" # Map to evaluating (legacy alias) - + @classmethod def from_string(cls, status: str) -> 'AgentStatus': """Map database status string to agent state enum""" @@ -344,7 +344,6 @@ def from_string(cls, status: str) -> 'AgentStatus': } return mapping.get(status, cls.awaiting_screening_1) - class EvaluationStatus(Enum): waiting = "waiting" running = "running" diff --git a/api/src/backend/queries/agents.py b/api/src/backend/queries/agents.py index a59afb9a..a25730b4 100644 --- a/api/src/backend/queries/agents.py +++ b/api/src/backend/queries/agents.py @@ -3,7 +3,7 @@ import asyncpg from api.src.backend.db_manager import db_operation, db_transaction -from api.src.backend.entities import MinerAgent +from api.src.backend.entities import AgentStatus, MinerAgent from api.src.utils.models import TopAgentHotkey from loggers.logging_utils import get_logger @@ -139,9 +139,51 @@ async def set_approved_agents_to_awaiting_screening(conn: asyncpg.Connection) -> return [MinerAgent(**dict(result)) for result in results] @db_operation -async def get_all_approved_version_ids(conn: asyncpg.Connection) -> List[str]: - """ - Get all approved version IDs - """ - data = await conn.fetch("SELECT version_id FROM approved_version_ids WHERE approved_at <= NOW()") - return [str(row["version_id"]) for row in data] +async def set_agent_status(conn: asyncpg.Connection, version_id: str, status: str): + try: + AgentStatus(status) # Check whether the status we are trying to set to is valid + except ValueError: + logger.error(f"Tried to set agent to invalid status {status!r}") + raise ValueError("Invalid status") + + await conn.execute( + "UPDATE miner_agents SET status = $1 WHERE version_id = $2", + status, + version_id + ) + +@db_operation +async def upload_miner_agent( + conn: asyncpg.Connection, + version_id: str, + miner_hotkey: str, + agent_name: str, + version_num: int, + ip_address: str +): + await conn.execute( + """ + INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status, ip_address) + VALUES ($1, $2, $3, $4, NOW(), 'awaiting_screening_1', $5) + """, + version_id, + miner_hotkey, + agent_name, + version_num, + ip_address + ) + +@db_operation +async def agent_startup_recovery(conn: asyncpg.Connection): + # Reset agent statuses for multi-stage screening + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_1' WHERE status = 'screening_1'") + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE status = 'screening_2'") + await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE status = 'evaluating'") + + # Legacy status recovery for backward compatibility + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_1' WHERE status = 'screening'") + await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE status = 'evaluation'") # Legacy alias + +@db_operation +async def set_agent_status_by_version_id(conn: asyncpg.Connection, version_id: str, status: str): + await conn.execute("UPDATE miner_agents SET status = $1 WHERE version_id = $2", status, version_id) diff --git a/api/src/backend/queries/agents.pyi b/api/src/backend/queries/agents.pyi index ec6d1998..47229caf 100644 --- a/api/src/backend/queries/agents.pyi +++ b/api/src/backend/queries/agents.pyi @@ -12,4 +12,8 @@ async def get_agents_by_hotkey(miner_hotkey: str) -> List[MinerAgent]: ... async def ban_agents(miner_hotkeys: List[str], reason: str) -> None: ... async def approve_agent_version(version_id: str) -> None: ... async def set_approved_agents_to_awaiting_screening() -> List[MinerAgent]: ... -async def get_all_approved_version_ids() -> List[str]: ... \ No newline at end of file +async def get_all_approved_version_ids() -> List[str]: ... +async def set_agent_status(version_id: str, status: str): ... +async def upload_miner_agent(version_id: str, miner_hotkey: str, agent_name: str, version_num: int, ip_address: str): ... +async def agent_startup_recovery() -> None: ... +async def set_agent_status_by_version_id(version_id: str, status: str): ... \ No newline at end of file diff --git a/api/src/backend/queries/evaluation_runs.py b/api/src/backend/queries/evaluation_runs.py index dce45c99..e6762f15 100644 --- a/api/src/backend/queries/evaluation_runs.py +++ b/api/src/backend/queries/evaluation_runs.py @@ -324,8 +324,7 @@ async def reset_validator_evaluations(conn: asyncpg.Connection, version_id: str) WHERE evaluation_id = ANY($1::uuid[]) """, evaluation_ids_to_cancel) - - - - - +@db_operation +async def cancel_evaluation_runs(conn: asyncpg.Connection, evaluation_id: str): + """Cancel existing eval runs - e.g. for errored runs or disconnections etc""" + await conn.execute("UPDATE evaluation_runs SET status = 'cancelled' WHERE evaluation_id = $1", evaluation_id) \ No newline at end of file diff --git a/api/src/backend/queries/evaluation_runs.pyi b/api/src/backend/queries/evaluation_runs.pyi index 36583caa..61a078bb 100644 --- a/api/src/backend/queries/evaluation_runs.pyi +++ b/api/src/backend/queries/evaluation_runs.pyi @@ -15,4 +15,6 @@ async def update_evaluation_run_logs(run_id: str, logs: str): ... async def get_evaluation_run_logs(run_id: str) -> str: ... async def fully_reset_evaluations(version_id: str): ... -async def reset_validator_evaluations(version_id: str): ... \ No newline at end of file +async def reset_validator_evaluations(version_id: str): ... + +async def cancel_evaluation_runs(evaluation_id: str): ... \ No newline at end of file diff --git a/api/src/backend/queries/evaluations.py b/api/src/backend/queries/evaluations.py index 69b189f5..2746ca70 100644 --- a/api/src/backend/queries/evaluations.py +++ b/api/src/backend/queries/evaluations.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import Optional, List, Tuple import logging import json @@ -6,7 +6,7 @@ from api.src.backend.db_manager import db_operation, db_transaction from api.src.backend.entities import Evaluation, EvaluationRun, EvaluationsWithHydratedRuns, EvaluationsWithHydratedUsageRuns, EvaluationRunWithUsageDetails, AgentStatus -from api.src.backend.queries.evaluation_runs import get_runs_with_usage_for_evaluation +from api.src.backend.queries.evaluation_runs import cancel_evaluation_runs, get_runs_with_usage_for_evaluation from api.src.backend.entities import EvaluationStatus logger = logging.getLogger(__name__) @@ -387,6 +387,17 @@ async def does_validator_have_running_evaluation( validator_hotkey ) +@db_operation +async def does_miner_have_running_evaluations(conn: asyncpg.Connection, miner_hotkey: str) -> bool: + return await conn.fetchval( + """ + SELECT EXISTS(SELECT 1 FROM evaluations e + JOIN miner_agents ma ON e.version_id = ma.version_id + WHERE ma.miner_hotkey = $1 AND e.status = 'running') + """, + miner_hotkey + ) + @db_operation async def get_running_evaluation_by_miner_hotkey(conn: asyncpg.Connection, miner_hotkey: str) -> Optional[Evaluation]: result = await conn.fetchrow( @@ -440,3 +451,270 @@ async def get_miner_hotkey_from_version_id(conn: asyncpg.Connection, version_id: FROM miner_agents WHERE version_id = $1 """, version_id) + +@db_operation +async def update_evaluation_to_error(conn: asyncpg.Connection, evaluation_id: str, error_reason: str): + # We can asyncio.gather, but will do this post stability to reduce complexity + await conn.execute( + "UPDATE evaluations SET status = 'error', finished_at = NOW(), terminated_reason = $1 WHERE evaluation_id = $2", + error_reason, + evaluation_id + ) + + await conn.execute("UPDATE evaluation_runs SET status = 'cancelled', cancelled_at = NOW() WHERE evaluation_id = $1", evaluation_id) + +@db_operation +async def update_evaluation_to_completed(conn: asyncpg.Connection, evaluation_id: str): + await conn.execute("UPDATE evaluations SET status = 'completed', finished_at = NOW() WHERE evaluation_id = $1", evaluation_id) + +@db_operation +async def get_inference_success_rate(conn: asyncpg.Connection, evaluation_id: str) -> Tuple[int, int, float, bool]: + """Check inference success rate for this evaluation + + Returns: + tuple: (successful_count, total_count, success_rate, any_run_errored) + """ + result = await conn.fetchrow(""" + SELECT + COUNT(*) as total_inferences, + COUNT(*) FILTER (WHERE status_code = 200) as successful_inferences, + COUNT(*) FILTER (WHERE er.error IS NOT NULL) > 0 as any_run_errored + FROM inferences i + JOIN evaluation_runs er ON i.run_id = er.run_id + WHERE er.evaluation_id = $1 AND er.status != 'cancelled' + """, evaluation_id) + + total = result['total_inferences'] or 0 + successful = result['successful_inferences'] or 0 + success_rate = successful / total if total > 0 else 1.0 + any_run_errored = bool(result['any_run_errored']) + + return successful, total, success_rate, any_run_errored + +@db_operation +async def reset_evaluation_to_waiting(conn: asyncpg.Connection, evaluation_id: str): + """Reset running evaluation back to waiting (for disconnections)""" + await conn.execute("UPDATE evaluations SET status = 'waiting', started_at = NULL WHERE evaluation_id = $1", evaluation_id) + + # Reset running evaluation_runs to pending so they can be picked up again + await cancel_evaluation_runs(evaluation_id=evaluation_id) + +@db_operation +async def update_evaluation_to_started(conn: asyncpg.Connection, evaluation_id: str): + await conn.execute("UPDATE evaluations SET status = 'running', started_at = NOW() WHERE evaluation_id = $1", evaluation_id) + + +@db_operation +async def get_problems_for_set_and_stage(conn: asyncpg.Connection, set_id: int, validation_stage: str) -> list[str]: + swebench_instance_ids_data = await conn.fetch( + "SELECT swebench_instance_id FROM evaluation_sets WHERE set_id = $1 AND type = $2", set_id, validation_stage + ) + + return [row["swebench_instance_id"] for row in swebench_instance_ids_data] + +@db_operation +async def prune_evaluations_in_queue(conn: asyncpg.Connection, threshold: float, max_set_id: int): + # Find evaluations with low screener scores that should be pruned + # We prune based on screener_score being below screening thresholds + low_score_evaluations = await conn.fetch(""" + SELECT e.evaluation_id, e.version_id, e.validator_hotkey, e.screener_score + FROM evaluations e + JOIN miner_agents ma ON e.version_id = ma.version_id + WHERE e.set_id = $1 + AND e.status = 'waiting' + AND e.screener_score IS NOT NULL + AND e.screener_score < $2 + AND ma.status NOT IN ('pruned', 'replaced') + """, max_set_id, threshold) + + if not low_score_evaluations: + return + + # Get unique version_ids to prune + version_ids_to_prune = list(set(eval['version_id'] for eval in low_score_evaluations)) + + # Update evaluations to pruned status + await conn.execute(""" + UPDATE evaluations + SET status = 'pruned', finished_at = NOW() + WHERE evaluation_id = ANY($1) + """, [eval['evaluation_id'] for eval in low_score_evaluations]) + + # Update miner_agents to pruned status + await conn.execute(""" + UPDATE miner_agents + SET status = 'pruned' + WHERE version_id = ANY($1) + """, version_ids_to_prune) + +# Scuff. Need a better way to do general queries +@db_operation +async def get_evaluation_for_version_validator_and_set( + conn: asyncpg.Connection, + version_id: str, + validator_hotkey: str, + set_id: int +) -> Optional[str]: + evaluation_id = await conn.fetchval( + """ + SELECT evaluation_id FROM evaluations + WHERE version_id = $1 AND validator_hotkey = $2 AND set_id = $3 + """, + version_id, + validator_hotkey, + set_id, + ) + + return evaluation_id + +@db_operation +async def create_evaluation( + conn: asyncpg.Connection, + evaluation_id: str, + version_id: str, + validator_hotkey: str, + set_id: int, + screener_score: Optional[float] = None, + status: Optional[str] = 'running' +): + if screener_score: + return await conn.execute( + """ + INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, screener_score) + VALUES ($1, $2, $3, $4, $5, NOW(), $6) + """, + evaluation_id, + version_id, + validator_hotkey, + set_id, + status, + screener_score + ) + else: + return await conn.execute( + """ + INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + """, + evaluation_id, + version_id, + validator_hotkey, + set_id, + status + ) + +@db_operation +async def create_evaluation_runs( + conn: asyncpg.Connection, + evaluation_runs: list[EvaluationRun] +): + await conn.executemany( + "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, status, started_at) VALUES ($1, $2, $3, $4, $5)", + [(run.run_id, run.evaluation_id, run.swebench_instance_id, run.status.value, run.started_at) for run in evaluation_runs], + ) + +@db_operation +async def replace_old_agents(conn: asyncpg.Connection, miner_hotkey: str) -> None: + """Replace all old agents and their evaluations for a miner""" + # Replace old agents + await conn.execute("UPDATE miner_agents SET status = 'replaced' WHERE miner_hotkey = $1 AND status != 'scored'", miner_hotkey) + + # Replace their evaluations + await conn.execute( + """ + UPDATE evaluations SET status = 'replaced' + WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = $1) + AND status IN ('waiting', 'running') + """, + miner_hotkey, + ) + + # Cancel evaluation_runs for replaced evaluations + await conn.execute( + """ + UPDATE evaluation_runs SET status = 'cancelled', cancelled_at = NOW() + WHERE evaluation_id IN ( + SELECT evaluation_id FROM evaluations + WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = $1) + AND status = 'replaced' + ) + """, + miner_hotkey, + ) + +@db_operation +async def get_progress(conn: asyncpg.Connection, evaluation_id: str) -> float: + """Get progress of evaluation across all runs""" + progress = await conn.fetchval(""" + SELECT COALESCE(AVG( + CASE status + WHEN 'started' THEN 0.2 + WHEN 'sandbox_created' THEN 0.4 + WHEN 'patch_generated' THEN 0.6 + WHEN 'eval_started' THEN 0.8 + WHEN 'result_scored' THEN 1.0 + ELSE 0.0 + END + ), 0.0) + FROM evaluation_runs + WHERE evaluation_id = $1 + AND status NOT IN ('cancelled', 'error') + """, evaluation_id) + return float(progress) + +@db_operation +async def get_stuck_evaluations(conn: asyncpg.Connection) -> List[Evaluation]: + result = await conn.fetch(""" + SELECT e.evaluation_id FROM evaluations e + WHERE e.status = 'running' + AND NOT EXISTS ( + SELECT 1 FROM evaluation_runs er + WHERE er.evaluation_id = e.evaluation_id + AND er.status NOT IN ('result_scored', 'cancelled') + ) + AND EXISTS ( + SELECT 1 FROM evaluation_runs er2 + WHERE er2.evaluation_id = e.evaluation_id + ) + """) + + return [Evaluation(**dict(row)) for row in result] + +@db_operation +async def get_waiting_evaluations(conn: asyncpg.Connection) -> List[Evaluation]: + result = await conn.fetch("SELECT * FROM evaluations WHERE status = 'waiting'") + + return [Evaluation(**dict(row)) for row in result] + +@db_operation +async def cancel_dangling_evaluation_runs(conn: asyncpg.Connection): + await conn.execute("UPDATE evaluation_runs SET status = 'cancelled', cancelled_at = NOW() WHERE status not in ('result_scored', 'cancelled')") + +@db_operation +async def evaluation_count_for_agent_and_status(conn: asyncpg.Connection, version_id: str, status: EvaluationStatus): + """Returns the number of validator evals with a given state, for a specific agent""" + return await conn.fetchval( + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = $2 + AND validator_hotkey NOT LIKE 'screener-%' + AND validator_hotkey NOT LIKE 'i-0%'""", + version_id, + status.value + ) + +@db_operation +async def check_for_currently_running_eval(conn: asyncpg.Connection, validator_hotkey: str) -> bool: + existing_evaluation = await conn.fetchrow( + """ + SELECT evaluation_id, status FROM evaluations + WHERE validator_hotkey = $1 AND status = 'running' + LIMIT 1 + """, + validator_hotkey + ) + + # TODO; replace for fetchval + if existing_evaluation: + return True + + else: + return False \ No newline at end of file diff --git a/api/src/backend/queries/evaluations.pyi b/api/src/backend/queries/evaluations.pyi index 30d42c99..1ca02835 100644 --- a/api/src/backend/queries/evaluations.pyi +++ b/api/src/backend/queries/evaluations.pyi @@ -1,5 +1,5 @@ -from typing import List, Optional -from api.src.backend.entities import EvaluationsWithHydratedRuns, Evaluation, EvaluationsWithHydratedUsageRuns +from typing import List, Optional, Tuple +from api.src.backend.entities import EvaluationRun, EvaluationStatus, EvaluationsWithHydratedRuns, Evaluation, EvaluationsWithHydratedUsageRuns async def get_evaluation_by_evaluation_id(evaluation_id: str) -> Evaluation: ... async def get_evaluations_by_version_id(version_id: str) -> List[Evaluation]: ... @@ -8,7 +8,42 @@ async def get_evaluations_with_usage_for_agent_version(version_id: str, set_id: async def get_running_evaluations() -> List[Evaluation]: ... async def get_running_evaluation_by_validator_hotkey(validator_hotkey: str) -> Optional[Evaluation]: ... async def get_running_evaluation_by_miner_hotkey(miner_hotkey: str) -> Optional[Evaluation]: ... + async def does_validator_have_running_evaluation(validator_hotkey: str) -> bool: ... +async def does_miner_have_running_evaluations(miner_hotkey: str) -> bool: ... + async def get_queue_info(validator_hotkey: str, length: int = 10) -> List[Evaluation]: ... async def get_agent_name_from_version_id(version_id: str) -> Optional[str]: ... -async def get_miner_hotkey_from_version_id(version_id: str) -> Optional[str]: ... \ No newline at end of file +async def get_miner_hotkey_from_version_id(version_id: str) -> Optional[str]: ... +async def update_evaluation_to_error(evaluation_id: str, error_reason: str): ... +async def get_inference_success_rate(evaluation_id: str) -> Tuple[int, int, float, bool]: ... + +async def reset_evaluation_to_waiting(evaluation_id: str): ... +async def update_evaluation_to_completed(evaluation_id: str): ... +async def update_evaluation_to_started(evaluation_id: str): ... +async def get_problems_for_set_and_stage(set_id: int, validation_stage: str) -> list[str]: ... +async def prune_evaluations_in_queue(threshold: float, max_set_id: int): ... +async def get_evaluation_for_version_validator_and_set( + version_id: str, + validator_hotkey: str, + set_id: int +) -> Optional[str]: ... + +async def create_evaluation( + evaluation_id: str, + version_id: str, + validator_hotkey: str, + set_id: int, + screener_score: float +): ... + +async def create_evaluation_runs(evaluation_runs: list[EvaluationRun]): ... + +async def replace_old_agents(miner_hotkey: str) -> None: ... +async def get_progress(evaluation_id: str) -> float: ... + +async def get_stuck_evaluations() -> List[Evaluation]: ... +async def get_waiting_evaluations() -> List[Evaluation]: ... +async def cancel_dangling_evaluation_runs() -> None: ... +async def evaluation_count_for_agent_and_status(version_id: str, status: EvaluationStatus): ... +async def check_for_currently_running_eval(validator_hotkey: str) -> bool: ... \ No newline at end of file diff --git a/api/src/backend/queries/scores.py b/api/src/backend/queries/scores.py index f3bbb2a8..6e7b7e7f 100644 --- a/api/src/backend/queries/scores.py +++ b/api/src/backend/queries/scores.py @@ -1,5 +1,6 @@ from datetime import timezone import logging +from typing import Optional from uuid import UUID import asyncpg from api.src.backend.db_manager import db_operation @@ -248,3 +249,132 @@ async def generate_threshold_function(conn: asyncpg.Connection) -> dict: "epoch_0_time": epoch_0_time, "epoch_length_minutes": epoch_length_minutes } + +@db_operation +async def get_combined_screener_score(conn: asyncpg.Connection, version_id: str) -> tuple[Optional[float], Optional[str]]: + """Calculate combined screener score as (questions solved by both) / (questions asked by both) + + Returns: + tuple[Optional[float], Optional[str]]: (score, error_message) + - score: The calculated score, or None if calculation failed + - error_message: None if successful, error description if failed + """ + # Get evaluation IDs for both screener stages + stage_1_eval_id = await conn.fetchval( + """ + SELECT evaluation_id FROM evaluations + WHERE version_id = $1 + AND validator_hotkey LIKE 'screener-1-%' + AND status = 'completed' + ORDER BY created_at DESC + LIMIT 1 + """, + version_id + ) + + stage_2_eval_id = await conn.fetchval( + """ + SELECT evaluation_id FROM evaluations + WHERE version_id = $1 + AND validator_hotkey LIKE 'screener-2-%' + AND status = 'completed' + ORDER BY created_at DESC + LIMIT 1 + """, + version_id + ) + + if not stage_1_eval_id or not stage_2_eval_id: + missing = [] + if not stage_1_eval_id: + missing.append("stage-1") + if not stage_2_eval_id: + missing.append("stage-2") + return None, f"Missing completed screener evaluation(s): {', '.join(missing)}" + + # Get solved count and total count for both evaluations + results = await conn.fetch( + """ + SELECT + SUM(CASE WHEN solved THEN 1 ELSE 0 END) as solved_count, + COUNT(*) as total_count + FROM evaluation_runs + WHERE evaluation_id = ANY($1::uuid[]) + AND status != 'cancelled' + """, + [stage_1_eval_id, stage_2_eval_id] + ) + + if not results or len(results) == 0: + return None, f"No evaluation runs found for screener evaluations {stage_1_eval_id} and {stage_2_eval_id}" + + result = results[0] + solved_count = result['solved_count'] or 0 + total_count = result['total_count'] or 0 + + if total_count == 0: + return None, f"No evaluation runs to calculate score from (total_count=0)" + + return solved_count / total_count, None + +@db_operation +async def get_current_set_id(conn: asyncpg.Connection) -> int: + max_set_id = await conn.fetchval("SELECT MAX(set_id) FROM evaluation_sets") + return max_set_id + +@db_operation +async def update_innovation_score(conn: asyncpg.Connection, version_id: str): + """Calculate and update innovation score for this evaluation's agent in one atomic query""" + try: + # Single atomic query that calculates and updates innovation score + await conn.execute(""" + WITH agent_runs AS ( + -- Get all result_scored runs for this agent + SELECT + r.swebench_instance_id, + r.solved, + r.started_at, + r.run_id + FROM evaluation_runs r + JOIN evaluations e ON e.evaluation_id = r.evaluation_id + WHERE e.version_id = $1 + AND r.status = 'result_scored' + ), + runs_with_prior AS ( + -- Calculate prior solved ratio for each run using window functions + SELECT + swebench_instance_id, + solved, + started_at, + run_id, + -- Calculate average solve rate for this instance before this run + COALESCE( + AVG(CASE WHEN solved THEN 1.0 ELSE 0.0 END) + OVER ( + PARTITION BY swebench_instance_id + ORDER BY started_at + ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING + ), 0.0 + ) AS prior_solved_ratio + FROM agent_runs + ), + innovation_calculation AS ( + SELECT + COALESCE( + AVG((CASE WHEN solved THEN 1.0 ELSE 0.0 END) - prior_solved_ratio), 0.0 + ) AS innovation_score + FROM runs_with_prior + ) + UPDATE miner_agents + SET innovation = (SELECT innovation_score FROM innovation_calculation) + WHERE version_id = $1 + """, version_id) + + + except Exception as e: + logger.error(f"Failed to calculate innovation score for agent {version_id}: {e}") + # Set innovation score to NULL on error to indicate calculation failure + await conn.execute( + "UPDATE miner_agents SET innovation = NULL WHERE version_id = $1", + version_id + ) \ No newline at end of file diff --git a/api/src/backend/queries/scores.pyi b/api/src/backend/queries/scores.pyi index ad98b666..aca13f58 100644 --- a/api/src/backend/queries/scores.pyi +++ b/api/src/backend/queries/scores.pyi @@ -1,7 +1,13 @@ +from typing import Optional from api.src.backend.entities import TreasuryTransaction async def check_for_new_high_score(version_id: str) -> dict: ... async def get_treasury_hotkeys() -> list[str]: ... async def store_treasury_transaction(transaction: TreasuryTransaction): ... async def generate_threshold_function() -> str: ... -async def evaluate_agent_for_threshold_approval(version_id: str, set_id: int) -> dict: ... \ No newline at end of file +async def evaluate_agent_for_threshold_approval(version_id: str, set_id: int) -> dict: ... + +async def get_combined_screener_score(version_id: str) -> tuple[Optional[float], Optional[str]]: ... + +async def get_current_set_id() -> int: ... +async def update_innovation_score(version_id: str): ... \ No newline at end of file diff --git a/api/src/endpoints/benchmarks.py b/api/src/endpoints/benchmarks.py index 2e262fc8..62c43971 100644 --- a/api/src/endpoints/benchmarks.py +++ b/api/src/endpoints/benchmarks.py @@ -66,4 +66,4 @@ async def get_top_agents_solved_for_question(swebench_instance_id: str) -> list[ tags=["benchmarks"], dependencies=[Depends(verify_request_public)], methods=methods - ) + ) \ No newline at end of file diff --git a/api/src/endpoints/model_replacers.py b/api/src/endpoints/model_replacers.py new file mode 100644 index 00000000..948d82f9 --- /dev/null +++ b/api/src/endpoints/model_replacers.py @@ -0,0 +1,88 @@ +""" +TEMPORARY FILE - use to put logic that we need in the models folder that don't have a clear endpoints file to go into +""" + +from api.src.backend.entities import AgentStatus, MinerAgent +from api.src.backend.queries.agents import get_top_agent, set_agent_status +from api.src.backend.queries.evaluations import get_running_evaluations, get_stuck_evaluations, get_waiting_evaluations, cancel_dangling_evaluation_runs, reset_evaluation_to_waiting, update_evaluation_to_error +from api.src.backend.queries.agents import agent_startup_recovery +from api.src.endpoints.screener import atomically_update_agent_status, finish_evaluation, prune_queue +from loggers.logging_utils import get_logger + +logger = get_logger(__name__) + + +async def repair_agent_status(): + """Handles: + - Screener disconnects + - Validator disconnects + - Platform restarts + """ + pass + +async def replace_old_agents(agent: MinerAgent): + pass + +async def update_agent_status(agent: MinerAgent): + """Update agent status based on evaluation state - handles multi-stage screening""" + + # We use the database as the source of truth now. Fetch evaluations and then use that to determine how to update agent status + + return + +@staticmethod +async def startup_recovery(): + """Fix broken states from shutdown - handles multi-stage screening""" + await agent_startup_recovery() + + # Reset running evaluations + running_evals = await get_running_evaluations() + for eval_row in running_evals: + evaluation_id = eval_row.evaluation_id + agent_version_id = eval_row.version_id + from api.src.models.screener import Screener + is_screening = Screener.get_stage(eval_row.validator_hotkey) is not None + if is_screening: + await update_evaluation_to_error(evaluation_id, "Disconnected from screener (error code 2)") + await atomically_update_agent_status(version_id=agent_version_id) + else: + # set evaluation to waiting, and its runs to cancelled + await reset_evaluation_to_waiting(evaluation_id) + # set agent status to waiting + await set_agent_status( + version_id=agent_version_id, + status=AgentStatus.waiting.value + ) + + # Check for running evaluations that should be auto-completed + stuck_evaluations = await get_stuck_evaluations() + + for stuck_eval in stuck_evaluations: + evaluation_id = stuck_eval.evaluation_id + # evaluation = await get_evaluation_by_evaluation_id(evaluation_id) + validator_hotkey = stuck_eval.validator_hotkey + + logger.info(f"Auto-completing stuck evaluation {evaluation_id} during startup recovery") + # During startup recovery, don't trigger notifications + _ = await finish_evaluation(evaluation_id, validator_hotkey, errored=True, reason="Platform restarted") + + # Cancel waiting screenings for all screener types + waiting_screenings = await get_waiting_evaluations() + for screening_row in waiting_screenings: + evaluation_id = screening_row.evaluation_id + evaluation_version_id = screening_row.version_id + + # await evaluation.error("Disconnected from screener (error code 3)") + await update_evaluation_to_error(evaluation_id, "Disconnected from screener (error code 3)") + await atomically_update_agent_status(version_id=evaluation_version_id) + + # Cancel dangling evaluation runs + await cancel_dangling_evaluation_runs() + + # Prune low-scoring evaluations that should not continue waiting + top_agent = await get_top_agent() + if top_agent: + await prune_queue(top_agent) + + logger.info("Application startup recovery completed with multi-stage screening support") + diff --git a/api/src/endpoints/retrieval.py b/api/src/endpoints/retrieval.py index 396007cc..d92dba23 100644 --- a/api/src/endpoints/retrieval.py +++ b/api/src/endpoints/retrieval.py @@ -1,10 +1,10 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request from typing import Optional, Any from fastapi.responses import StreamingResponse, PlainTextResponse -from api.src.models.screener import Screener from loggers.logging_utils import get_logger from dotenv import load_dotenv from datetime import datetime, timedelta, timezone +import os from api.src.utils.auth import verify_request_public from api.src.utils.s3 import S3Manager @@ -24,7 +24,6 @@ from api.src.backend.queries.inference import get_inference_provider_statistics as db_get_inference_provider_statistics from api.src.backend.internal_tools import InternalTools from api.src.backend.queries.open_users import get_emission_dispersed_to_open_user as db_get_emission_dispersed_to_open_user, get_all_transactions as db_get_all_transactions, get_all_treasury_hotkeys as db_get_all_treasury_hotkeys -from api.src.backend.queries.agents import get_all_approved_version_ids as db_get_all_approved_version_ids from api.src.backend.queries.open_users import get_total_dispersed_by_treasury_hotkeys as db_get_total_dispersed_by_treasury_hotkeys from api.src.utils.config import AGENT_RATE_LIMIT_SECONDS @@ -66,11 +65,12 @@ async def get_agent_code(version_id: str, request: Request, return_as_text: bool # Check if IP is in whitelist (add your allowed IPs to SCREENER_IP_LIST) if client_ip not in SCREENER_IP_LIST: - logger.warning(f"Unauthorized IP {client_ip} attempted to access agent code for version {version_id}") - raise HTTPException( - status_code=403, - detail="Access denied: IP not authorized" - ) + if os.getenv("ENV") == "prod": + logger.warning(f"Unauthorized IP {client_ip} attempted to access agent code for version {version_id}") + raise HTTPException( + status_code=403, + detail="Access denied: IP not authorized" + ) if return_as_text: try: @@ -413,19 +413,6 @@ async def get_emission_alpha_for_hotkey(miner_hotkey: str) -> dict[str, Any]: detail="Internal server error while retrieving emission alpha" ) -async def get_approved_version_ids() -> list[str]: - """ - Returns a list of all approved version IDs - """ - try: - return await db_get_all_approved_version_ids() - except Exception as e: - logger.error(f"Error retrieving approved version IDs: {e}") - raise HTTPException( - status_code=500, - detail="Internal server error while retrieving approved version IDs" - ) - async def get_time_until_next_upload_for_hotkey(miner_hotkey: str) -> dict[str, Any]: """ Returns the time until the next upload for a given hotkey @@ -512,7 +499,6 @@ async def get_pending_dispersal() -> dict[str, Any]: ("/agents-from-hotkey", get_agents_from_hotkey), ("/inference-provider-statistics", get_inference_provider_statistics), ("/emission-alpha-for-hotkey", get_emission_alpha_for_hotkey), - ("/approved-version-ids", get_approved_version_ids), ("/time-until-next-upload-for-hotkey", get_time_until_next_upload_for_hotkey), ("/all-transactions", get_all_transactions), ("/all-treasury-hotkeys", get_all_treasury_hotkeys), diff --git a/api/src/endpoints/scoring.py b/api/src/endpoints/scoring.py index 26218b08..aefdad1f 100644 --- a/api/src/endpoints/scoring.py +++ b/api/src/endpoints/scoring.py @@ -1,22 +1,19 @@ -import asyncio import os -from datetime import datetime, timezone from dotenv import load_dotenv from fastapi import APIRouter, Depends, HTTPException from typing import Dict, List, Optional import uuid - +import asyncio +from api.src.models.validator import Validator from api.src.backend.queries.evaluation_runs import fully_reset_evaluations, reset_validator_evaluations from api.src.utils.config import PRUNE_THRESHOLD, SCREENING_1_THRESHOLD, SCREENING_2_THRESHOLD -from api.src.models.evaluation import Evaluation -from api.src.models.validator import Validator from api.src.utils.auth import verify_request, verify_request_public from loggers.logging_utils import get_logger from api.src.backend.queries.agents import get_top_agent, ban_agents as db_ban_agents, approve_agent_version -from api.src.backend.entities import MinerAgent, MinerAgentScored -from api.src.backend.queries.agents import get_top_agent, ban_agents as db_ban_agents, approve_agent_version, get_agent_by_version_id as db_get_agent_by_version_id -from api.src.backend.entities import MinerAgentScored +from api.src.backend.entities import MinerAgent, MinerAgentScored, AgentStatus +from api.src.backend.queries.agents import get_top_agent, ban_agents as db_ban_agents, approve_agent_version, get_agent_by_version_id as db_get_agent_by_version_id, set_agent_status from api.src.backend.db_manager import get_transaction, new_db, get_db_connection +from api.src.backend.queries.evaluations import get_evaluation_by_evaluation_id, reset_evaluation_to_waiting from api.src.utils.refresh_subnet_hotkeys import check_if_hotkey_is_registered from api.src.utils.slack import notify_unregistered_top_miner, notify_unregistered_treasury_hotkey from api.src.backend.internal_tools import InternalTools @@ -47,7 +44,6 @@ async def run_weight_setting_loop(minutes: int): await asyncio.sleep(minutes * 20) ## Actual endpoints ## - async def weight_receiving_agent(): ''' This is used to compute the current best agent. Validators can rely on this or keep a local database to compute this themselves. @@ -132,11 +128,6 @@ async def ban_agents(agent_ids: List[str], reason: str, ban_password: str): logger.error(f"Error banning agents: {e}") raise HTTPException(status_code=500, detail="Failed to ban agent due to internal server error. Please try again later.") - -async def trigger_weight_set(): - await tell_validators_to_set_weights() - return {"message": "Successfully triggered weight update"} - async def approve_version(version_id: str, set_id: int, approval_password: str): """Approve a version ID using threshold scoring logic @@ -267,6 +258,7 @@ async def re_evaluate_agent(password: str, version_id: str, re_eval_screeners_an # Include all evaluations (screeners and validators) await fully_reset_evaluations(version_id=version_id) else: + # TODO: use the newer better version await reset_validator_evaluations(version_id=version_id) return { @@ -284,8 +276,16 @@ async def re_run_evaluation(password: str, evaluation_id: str): try: async with get_transaction() as conn: - evaluation = await Evaluation.get_by_id(evaluation_id) - await evaluation.reset_to_waiting(conn) + # set evaluation to waiting, and its runs to cancelled + evaluation = await get_evaluation_by_evaluation_id(evaluation_id) + await reset_evaluation_to_waiting(evaluation_id) + + # set agent status to waiting + agent_version_id = evaluation.version_id + await set_agent_status( + version_id=agent_version_id, + status=AgentStatus.waiting.value + ) return {"message": f"Successfully reset evaluation {evaluation_id}"} except Exception as e: logger.error(f"Error resetting evaluation {evaluation_id}: {e}") @@ -433,7 +433,6 @@ async def check_evaluation_status(evaluation_id: str): ("/screener-thresholds", get_screener_thresholds, ["GET"]), ("/prune-threshold", get_prune_threshold, ["GET"]), ("/threshold-function", get_threshold_function, ["GET"]), - ("/trigger-weight-update", trigger_weight_set, ["POST"]), ("/check-evaluation-status", check_evaluation_status, ["GET"]), ("/re-evaluate-agent", re_evaluate_agent, ["POST"]), ("/re-run-evaluation", re_run_evaluation, ["POST"]), diff --git a/api/src/endpoints/screener.py b/api/src/endpoints/screener.py new file mode 100644 index 00000000..61491618 --- /dev/null +++ b/api/src/endpoints/screener.py @@ -0,0 +1,390 @@ +""" +All logic around screeners, including starting a screening, finishing it, handling state updates, etc +""" + +import asyncio +from datetime import datetime, timezone +import stat +from sys import version +import uuid +from typing import Any, Optional + +from fastapi import status +from slack_bolt.context import complete +from api.src.backend.entities import AgentStatus, EvaluationRun, EvaluationStatus, MinerAgent, SandboxStatus +from logging import getLogger + +from api.src.backend.queries.agents import get_top_agent, set_agent_status +from api.src.backend.queries.evaluations import check_for_currently_running_eval, create_evaluation, create_evaluation_runs, evaluation_count_for_agent_and_status, get_evaluation_by_evaluation_id, get_evaluation_for_version_validator_and_set, get_inference_success_rate, get_problems_for_set_and_stage, prune_evaluations_in_queue, reset_evaluation_to_waiting, update_evaluation_to_completed, update_evaluation_to_error, update_evaluation_to_started +from api.src.backend.queries.scores import get_combined_screener_score, get_current_set_id, update_innovation_score +from api.src.endpoints.agents import get_agent_by_version + +from api.src.models.screener import Screener +from api.src.socket.websocket_manager import WebSocketManager +from api.src.utils.config import PRUNE_THRESHOLD, SCREENING_1_THRESHOLD, SCREENING_2_THRESHOLD +from api.src.utils.models import TopAgentHotkey + +logger = getLogger(__name__) + +AWAITING_SCREENING_STATUSES = [AgentStatus.screening_1.value, AgentStatus.screening_2.value] +SCREENING_STATUSES = [AgentStatus.screening_1.value, AgentStatus.screening_2.value] + +from enum import Enum +class ValidationStage(Enum): + SCREENER_1 = "screener-1" + SCREENER_2 = "screener-2" + VALIDATION = "validator" + +def identify_validation_stage(hotkey: str) -> ValidationStage: + if "screener-1" in hotkey: + return ValidationStage.SCREENER_1 + elif "screener-2" in hotkey: + return ValidationStage.SCREENER_2 + else: + # TODO: Verify sn58 format + return ValidationStage.VALIDATION + +def match_validation_stage_to_running_agent_status(validation_stage: ValidationStage) -> AgentStatus: + return { + ValidationStage.SCREENER_1: AgentStatus.screening_1, + ValidationStage.SCREENER_2: AgentStatus.screening_2, + ValidationStage.VALIDATION: AgentStatus.evaluating + }[validation_stage] + +def match_validation_stage_to_waiting_agent_status(validation_stage: ValidationStage) -> AgentStatus: + return { + ValidationStage.SCREENER_1: AgentStatus.awaiting_screening_1, + ValidationStage.SCREENER_2: AgentStatus.awaiting_screening_2, + ValidationStage.VALIDATION: AgentStatus.waiting + }[validation_stage] + +async def start_screening(evaluation_id: str, hotkey: str) -> dict[str, Any]: + f""" + Temporarily returns a dict in format: + success: bool + runs_created: list[EvaluationRun] + """ + # TODO: Where is the eval inserted? + # Get the evaluation, makes sure its screening and its the right hotkey making the request + validation_stage = identify_validation_stage(hotkey) + evaluation = await get_evaluation_by_evaluation_id(evaluation_id=evaluation_id) + + if not evaluation or validation_stage != identify_validation_stage(evaluation.validator_hotkey) or evaluation.validator_hotkey != hotkey: + print(f"FAIL1. Failed to create evaluation runs. Evaluation: {evaluation}, validation stage: {validation_stage}, other validation stage: {identify_validation_stage(evaluation.validator_hotkey)}") + return { + "success": False, + "runs_created": [] + } + + # Get the agent version, make sure thats in screening too + agent = await get_agent_by_version(evaluation.version_id) + + # TODO: in old version this is set to screening by this point. Why? When allocated to screeners? Should be set here + if not agent or agent.status != match_validation_stage_to_waiting_agent_status(validation_stage).value: + print(f"FAIL2. Failed to create evaluation runs. agent: [{agent}], matched vali stage: {match_validation_stage_to_running_agent_status(validation_stage).value}, matched waiting vali stage: {match_validation_stage_to_waiting_agent_status(validation_stage).value}") + # For some reason only screeners set the agent state before, and so validator stuck on waiting + if agent.status != "waiting": + logger.error(f"Tried to start agent {evaluation.version_id} validation but either agent doesn't exist or invalid status; {agent.status if agent else 'No agent'}") + return { + "success": False, + "runs_created": [] + } + + # Once checks are in place, start the evaluation + await update_evaluation_to_started(evaluation_id) + + # Get max set ids and the problem instance ids associated + try: + current_set_id = await get_current_set_id() + problem_instance_ids = await get_problems_for_set_and_stage(set_id=current_set_id, validation_stage=validation_stage.value) + + # Create eval runs and insert + evaluation_runs = [ + EvaluationRun( + run_id = uuid.uuid4(), + evaluation_id = evaluation_id, + swebench_instance_id = problem_id, + response=None, + error=None, + pass_to_fail_success=None, + fail_to_pass_success=None, + pass_to_pass_success=None, + fail_to_fail_success=None, + solved=None, + status = SandboxStatus.started, + started_at = datetime.now(timezone.utc), + sandbox_created_at=None, + patch_generated_at=None, + eval_started_at=None, + result_scored_at=None, + cancelled_at=None, + ) + for problem_id in problem_instance_ids + ] + + # Insert eval runs + await create_evaluation_runs(evaluation_runs=evaluation_runs) + + # Update agent status + status = match_validation_stage_to_running_agent_status(validation_stage) + await set_agent_status( + version_id=str(agent.version_id), + status=status.value + ) + + # TODO: Broadcast status change? + return { + "success": True, + "runs_created": evaluation_runs + } + except Exception as e: + logger.error(f"Error starting evaluation: {e}") + return { + "success": False, + "runs_created": [] + } + +async def finish_screening( + evaluation_id: str, + hotkey: str, + errored: bool = False, + reason: Optional[str] = None +): + evaluation = await get_evaluation_by_evaluation_id(evaluation_id) + + if not evaluation or evaluation.validator_hotkey != hotkey: + logger.warning(f"Screener {hotkey}: Invalid finish_screening call for evaluation {evaluation_id}") + return + + agent = await get_agent_by_version(evaluation.version_id) + print(f"AGENT IS: {agent}") + + if agent.status not in SCREENING_STATUSES: + logger.warning(f"Invalid status for miner agent: expected {evaluation.status}, agent is set to {agent.status}") + return + + if errored: + """Error evaluation and reset agent""" + await asyncio.gather( + update_evaluation_to_error(evaluation_id, reason), + set_agent_status( + version_id=agent.version_id, + status=AgentStatus.awaiting_screening_1.value if agent.status == "screening_1" else AgentStatus.awaiting_screening_2.value + ) + ) + + logger.info(f"{hotkey}: Finishing screening {evaluation_id}: Errored with reason: {reason}") + + # Check inference success rate. If errored, set the screening back to awaiting and update this evaluation with errored + _, total, success_rate, any_run_errored = await get_inference_success_rate(evaluation_id=evaluation_id) + + if total > 0 and success_rate < 0.5 and any_run_errored: + await reset_evaluation_to_waiting(evaluation_id) + # Set the agent back to awaiting for the same screener level if errored + await set_agent_status( + version_id=agent.version_id, + status=AgentStatus.awaiting_screening_1.value if agent.status == "screening_1" else AgentStatus.awaiting_screening_2.value + ) + return + + await update_evaluation_to_completed(evaluation_id=evaluation_id) + + # Check whether it passed the screening thresholds. + threshold = SCREENING_1_THRESHOLD if agent.status == "screening_1" else SCREENING_2_THRESHOLD + + if evaluation.score < threshold: + # Agent has failed, update status and that's that + await set_agent_status( + version_id=agent.version_id, + status=AgentStatus.failed_screening_1.value if agent.status == "screening_1" else AgentStatus.failed_screening_2.value + ) + + return + + if agent.status == AgentStatus.screening_1.value: + await set_agent_status( + version_id=agent.version_id, + status=AgentStatus.awaiting_screening_2.value + ) + + return + + if agent.status == AgentStatus.screening_2.value: + # If screening 2, see if we should prune it if its behind the top agent by enough, and create validator evals if not + combined_screener_score, score_error = await get_combined_screener_score(agent.version_id) + top_agent = await get_top_agent() + + if top_agent and combined_screener_score is not None and (top_agent.avg_score - combined_screener_score) > PRUNE_THRESHOLD: + # Score is too low, prune miner agent and don't create evaluations + await set_agent_status( + version_id=agent.version_id, + status=AgentStatus.pruned.value + ) + + await prune_queue(top_agent) + + return + + await set_agent_status( + version_id=agent.version_id, + status=AgentStatus.waiting.value + ) + + # Create validator evals + # TODO: ADAM, replace with new connected valis map + from api.src.models.validator import Validator + all_validators = await Validator.get_connected() + + for validator in all_validators: + await create_evaluation_for_validator( + version_id=agent.version_id, + validator_hotkey=validator.hotkey, + combined_screener_score=combined_screener_score + ) + + # Prune the rest of the queue + if top_agent: + await prune_queue(top_agent) + + return + + logger.error(f"Invalid screener status {agent.status}") + +# TODO +async def create_screener_evaluation(hotkey: str, agent: MinerAgent, screener: 'Screener'): + existing_evaluation = await check_for_currently_running_eval(hotkey) + + if existing_evaluation: + logger.error(f"CRITICAL: Screener {hotkey} already has running evaluation {existing_evaluation['evaluation_id']} - refusing to create duplicate screening") + return False + + ws = WebSocketManager.get_instance() + set_id = await get_current_set_id() + evaluation_id = str(uuid.uuid4()) + + await create_evaluation( + evaluation_id=evaluation_id, + version_id=agent.version_id, + validator_hotkey=hotkey, + set_id=set_id + ) + + evaluation_runs = await start_screening(evaluation_id, hotkey) + + message = { + "event": "screen-agent", + "evaluation_id": evaluation_id, + "agent_version": agent.model_dump(mode="json"), + "evaluation_runs": [run.model_dump(mode="json") for run in evaluation_runs["runs_created"]], + } + logger.info(f"Sending screen-agent message to screener {hotkey}: evaluation_id={evaluation_id}, agent={agent.agent_name}") + + await ws.send_to_all_non_validators("evaluation-started", message) + await ws.send_to_client(screener, message) + +async def create_evaluation_for_validator(version_id: str, validator_hotkey: str, combined_screener_score: float) -> str: + max_set_id = await get_current_set_id() + + existing_evaluation_id = await get_evaluation_for_version_validator_and_set( + version_id=version_id, + validator_hotkey=validator_hotkey, + set_id=max_set_id + ) + + if existing_evaluation_id: + logger.debug(f"Evaluation already exists for version {version_id}, validator {validator_hotkey}, set {max_set_id}") + return str(existing_evaluation_id) + + # Create new evaluation + evaluation_id = str(uuid.uuid4()) + await create_evaluation( + evaluation_id=evaluation_id, + version_id=version_id, + validator_hotkey=validator_hotkey, + set_id=max_set_id, + screener_score=combined_screener_score + ) + return evaluation_id + + +async def prune_queue(top_agent: TopAgentHotkey): + """ + Looks through the queue and prunes agents too far behind top agent + """ + # Calculate the threshold (configurable lower-than-top final validation score) + threshold = top_agent.avg_score - PRUNE_THRESHOLD + max_set_id = await get_current_set_id() + + await prune_evaluations_in_queue(threshold, max_set_id) + +async def handle_disconnect(): + pass + +async def atomically_update_agent_status(version_id: str): + """ + To be called by validators, this looks at other evaluations in the database in order to update a miner agents state + """ + # Get the number of waiting, running, and completed/pruned evals + waiting_count, running_count, completed_count = await asyncio.gather( + evaluation_count_for_agent_and_status(version_id = version_id, status = EvaluationStatus.waiting), + evaluation_count_for_agent_and_status(version_id = version_id, status = EvaluationStatus.running), + evaluation_count_for_agent_and_status(version_id = version_id, status = EvaluationStatus.completed), + ) + + # Use that to compute the state for miner_agent + status_to_set: AgentStatus + + if waiting_count > 0 and running_count == 0: + status_to_set = AgentStatus.waiting + elif waiting_count == 0 and running_count == 0 and completed_count > 0: + # Update innovation score before setting to scored + await update_innovation_score(version_id=version_id) + status_to_set = AgentStatus.scored + else: + status_to_set = AgentStatus.evaluating + + await set_agent_status( + version_id=version_id, + status=status_to_set.value + ) + + return + +async def finish_evaluation( + evaluation_id: str, + hotkey: str, + errored: bool = False, + reason: Optional[str] = None +): + evaluation = await get_evaluation_by_evaluation_id(evaluation_id=evaluation_id) + + if not evaluation or evaluation.validator_hotkey != hotkey: + logger.warning(f"Validator {hotkey}: Invalid finish_evaluation call for evaluation {evaluation_id}. {'No such eval' if evaluation is None else f'Invalid hotkey {hotkey}'}") + return + + # Get the agent and make sure the status is evaluating + agent = await get_agent_by_version(evaluation.version_id) + + if agent.status != AgentStatus.evaluating.value: + logger.warning(f"Invalid status for miner agent: expected evaluating, agent is set to {agent.status}") + + if errored: + """Error evaluation and reset agent""" + await update_evaluation_to_error(evaluation_id, reason) + await atomically_update_agent_status(version_id=evaluation.version_id) + + logger.info(f"{hotkey}: Finishing screening {evaluation_id}: Errored with reason: {reason}") + + # Check inference success rate. If errored, set the screening back to awaiting and update this evaluation with errored + _, total, success_rate, any_run_errored = await get_inference_success_rate(evaluation_id=evaluation_id) + + if total > 0 and success_rate < 0.5 and any_run_errored: + await reset_evaluation_to_waiting(evaluation_id) + # Set the agent back to awaiting for the same screener level if errored + await atomically_update_agent_status(version_id=evaluation.version_id) + return + + # Update evaluation to complete, and then agent status + # We call these seperately because the agent status looks at db after this write to consider other evaluations + await update_evaluation_to_completed(evaluation_id=evaluation_id) + await atomically_update_agent_status(version_id=evaluation.version_id) \ No newline at end of file diff --git a/api/src/endpoints/upload.py b/api/src/endpoints/upload.py index fa6470a3..ce158dbc 100644 --- a/api/src/endpoints/upload.py +++ b/api/src/endpoints/upload.py @@ -1,23 +1,22 @@ import os import uuid from fastapi import APIRouter, Depends, UploadFile, File, Form, HTTPException, BackgroundTasks, Request -from api.src.models.screener import Screener +from api.src.backend.queries.evaluations import does_miner_have_running_evaluations, replace_old_agents from datetime import datetime from pydantic import BaseModel, Field from typing import Optional +from api.src.endpoints.screener import create_evaluation_for_validator, create_screener_evaluation +from api.src.models.screener import Screener from loggers.logging_utils import get_logger from loggers.process_tracking import process_context from api.src.utils.auth import verify_request_public from api.src.utils.upload_agent_helpers import check_agent_banned, check_hotkey_registered, check_rate_limit, check_replay_attack, check_if_python_file, get_miner_hotkey, check_signature, check_code_similarity, check_file_size, check_agent_code, upload_agent_code_to_s3, record_upload_attempt -from api.src.backend.queries.agents import get_ban_reason +from api.src.backend.queries.agents import get_ban_reason, upload_miner_agent from api.src.socket.websocket_manager import WebSocketManager -from api.src.models.evaluation import Evaluation from api.src.backend.queries.agents import get_latest_agent from api.src.backend.entities import MinerAgent, AgentStatus -from api.src.backend.db_manager import get_transaction from api.src.utils.agent_summary_generator import generate_and_store_agent_summary -from api.src.backend.queries.open_users import get_open_user_by_hotkey logger = get_logger(__name__) ws = WebSocketManager.get_instance() @@ -77,8 +76,7 @@ async def post_agent( try: with process_context("handle-upload-agent") as process_id: - logger.debug(f"Platform received a /upload/agent API request. Beginning process handle-upload-agent with process ID: {process_id}.") - logger.info(f"Uploading agent {name} for miner {miner_hotkey}.") + logger.debug(f"Platform received a /upload/agent API request. Beginning process handle-upload-agent with process ID: {process_id}. Uploading agent {name} for miner {miner_hotkey}.") check_if_python_file(agent_file.filename) latest_agent: Optional[MinerAgent] = await get_latest_agent(miner_hotkey=miner_hotkey) @@ -94,62 +92,64 @@ async def post_agent( if prod: await check_agent_banned(miner_hotkey=miner_hotkey) if prod and latest_agent: check_rate_limit(latest_agent) + check_replay_attack(latest_agent, file_info) + if prod: check_signature(public_key, file_info, signature) if prod: await check_hotkey_registered(miner_hotkey) + file_content = await check_file_size(agent_file) - # TODO: Uncomment this when embedding similarity check is done - # if prod: await check_code_similarity(file_content, miner_hotkey) + check_agent_code(file_content) - async with Evaluation.get_lock(): - # Atomic availability check + reservation - only allow uploads if stage 1 screeners are available - screener = await Screener.get_first_available_and_reserve(stage=1) - if not screener: - logger.error(f"No available stage 1 screener for agent upload from miner {miner_hotkey}") - raise HTTPException( - status_code=503, - detail="No stage 1 screeners available for agent evaluation. Please try again later." - ) + # Make sure agent has no running evaluations + has_running_evaluations = await does_miner_have_running_evaluations(miner_hotkey=miner_hotkey) - async with get_transaction() as conn: - can_upload = await Evaluation.check_miner_has_no_running_evaluations(conn, miner_hotkey) - if not can_upload: - # IMPORTANT: Release screener reservation on failure - screener.set_available() - logger.error(f"Cannot upload agent for miner {miner_hotkey} - has running evaluations") - raise HTTPException( - status_code=409, - detail="Cannot upload new agent while previous evaluations are still running. Please wait and try again." - ) + if has_running_evaluations: + raise HTTPException( + status_code=409, + detail="Cannot upload new agent while previous evaluations are still running. Please wait and try again." + ) - await Evaluation.replace_old_agents(conn, miner_hotkey) + # TODO: will be replaced by monday state machine + from api.src.socket.websocket_manager import WebSocketManager + ws_manager = WebSocketManager.get_instance() + + found_screener: Optional['Screener'] = None + + for client in ws_manager.clients.values(): + if (client.get_type() == "screener" and + client.status == "available" and + client.is_available() and + client.stage == 1): + found_screener = client + break + + if found_screener is None: + raise HTTPException( + status_code=503, + detail="No stage 1 screeners available for agent evaluation. Please try again later." + ) - await upload_agent_code_to_s3(agent.version_id, agent_file) + # Replace old agent and upload code to S3 + await replace_old_agents(miner_hotkey=miner_hotkey) + await upload_agent_code_to_s3(agent.version_id, agent_file) - await conn.execute( - """ - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status, ip_address) - VALUES ($1, $2, $3, $4, NOW(), 'awaiting_screening_1', $5) - """, - agent.version_id, - agent.miner_hotkey, - agent.agent_name, - agent.version_num, - agent.ip_address, - ) + # Create agent on awaiting_screening_1 + await upload_miner_agent( + agent.version_id, + agent.miner_hotkey, + agent.agent_name, + agent.version_num, + agent.ip_address, + ) - # Create evaluation and assign to screener (commits screener state) - eval_id, success = await Evaluation.create_screening_and_send(conn, agent, screener) - if not success: - # If send fails, reset screener - if screener.status == 'reserving': - screener.set_available() - logger.warning(f"Failed to assign agent {agent.version_id} to screener") - else: - logger.warning(f"Failed to assign agent {agent.version_id} to screener - screener is not running") - - # Screener state is now committed, lock can be released + # Create evaluation for a screener thats connected + await create_screener_evaluation( + hotkey=found_screener.hotkey, + agent=agent, + screener=found_screener + ) # Schedule background agent summary generation logger.info(f"Scheduling agent summary generation for {agent.version_id}") @@ -159,8 +159,7 @@ async def post_agent( run_id=f"upload-{agent.version_id}" ) - logger.info(f"Successfully uploaded agent {agent.version_id} for miner {miner_hotkey}.") - logger.debug(f"Completed handle-upload-agent with process ID {process_id}.") + logger.info(f"Successfully uploaded agent {agent.version_id} for miner {miner_hotkey}. Process {process_id}") # Record successful upload await record_upload_attempt( @@ -218,151 +217,6 @@ async def post_open_agent( message=f"Dashboard uploads are temporarily paused" ) - # Extract upload attempt data for tracking - agent_file.file.seek(0, 2) - file_size_bytes = agent_file.file.tell() - agent_file.file.seek(0) - - upload_data = { - 'hotkey': open_hotkey, - 'agent_name': name, - 'filename': agent_file.filename, - 'file_size_bytes': file_size_bytes, - 'ip_address': getattr(request.client, 'host', None) if request.client else None - } - - try: - logger.info(f"Uploading open agent process beginning. Details: open_hotkey: {open_hotkey}, name: {name}, password: {password}") - - if password != open_user_password: - logger.error(f"Someone tried to upload an open agent with an invalid password. open_hotkey: {open_hotkey}, name: {name}, password: {password}") - raise HTTPException(status_code=401, detail="Invalid password. Fuck you.") - - try: - user = await get_open_user_by_hotkey(open_hotkey) - except Exception as e: - logger.error(f"Error retrieving open user {open_hotkey}: {e}") - raise HTTPException(status_code=500, detail="Internal server error while retrieving open user") - - if not user: - logger.error(f"Open user {open_hotkey} not found") - raise HTTPException(status_code=404, detail="Open user not found. Please register an account.") - - check_if_python_file(agent_file.filename) - latest_agent: Optional[MinerAgent] = await get_latest_agent(miner_hotkey=open_hotkey) - - agent = MinerAgent( - version_id=str(uuid.uuid4()), - miner_hotkey=open_hotkey, - agent_name=name if not latest_agent else latest_agent.agent_name, - version_num=latest_agent.version_num + 1 if latest_agent else 0, - created_at=datetime.now(), - status=AgentStatus.awaiting_screening, - ip_address=request.client.host if request.client else None, - ) - - if prod: await check_agent_banned(miner_hotkey=open_hotkey) - if prod and latest_agent: check_rate_limit(latest_agent) - file_content = await check_file_size(agent_file) - # TODO: Uncomment this when embedding similarity check is done - # if prod: await check_code_similarity(file_content, open_hotkey) - check_agent_code(file_content) - - async with Evaluation.get_lock(): - # Atomic availability check + reservation - only allow uploads if stage 1 screeners are available - screener = await Screener.get_first_available_and_reserve(stage=1) - if not screener: - logger.error(f"No available stage 1 screener for agent upload from miner {open_hotkey}") - raise HTTPException( - status_code=503, - detail="No stage 1 screeners available for agent evaluation. Please try again later." - ) - - async with get_transaction() as conn: - can_upload = await Evaluation.check_miner_has_no_running_evaluations(conn, open_hotkey) - if not can_upload: - screener.set_available() - logger.error(f"Cannot upload agent for miner {open_hotkey} - has running evaluations") - raise HTTPException( - status_code=409, - detail="Cannot upload new agent while previous evaluations are still running. Please wait and try again." - ) - - await Evaluation.replace_old_agents(conn, open_hotkey) - - await upload_agent_code_to_s3(agent.version_id, agent_file) - - await conn.execute( - """ - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status, ip_address) - VALUES ($1, $2, $3, $4, NOW(), 'awaiting_screening_1', $5) - """, - agent.version_id, - agent.miner_hotkey, - agent.agent_name, - agent.version_num, - agent.ip_address, - ) - - # Create evaluation and assign to screener (commits screener state) - eval_id, success = await Evaluation.create_screening_and_send(conn, agent, screener) - if not success: - # If send fails, reset screener - screener.set_available() - logger.warning(f"Failed to assign agent {agent.version_id} to screener") - - logger.info(f"Scheduling agent summary generation for {agent.version_id}") - background_tasks.add_task( - generate_and_store_agent_summary, - agent.version_id, - run_id=f"upload-{agent.version_id}" - ) - - logger.info(f"Successfully uploaded agent {agent.version_id} for open user {open_hotkey}.") - - # Record successful upload - await record_upload_attempt( - upload_type="open-agent", - success=True, - version_id=agent.version_id, - **upload_data - ) - - return AgentUploadResponse( - status="success", - message=f"Successfully uploaded agent {agent.version_id} for open user {open_hotkey}." - ) - - except HTTPException as e: - # Determine error type and get ban reason if applicable - error_type = 'banned' if e.status_code == 403 and 'banned' in e.detail.lower() else \ - 'rate_limit' if e.status_code == 429 else 'validation_error' - ban_reason = await get_ban_reason(open_hotkey) if error_type == 'banned' and open_hotkey else None - - # Record failed upload attempt - await record_upload_attempt( - upload_type="open-agent", - success=False, - error_type=error_type, - error_message=e.detail, - ban_reason=ban_reason, - http_status_code=e.status_code, - **upload_data - ) - raise - - except Exception as e: - # Record internal error - await record_upload_attempt( - upload_type="open-agent", - success=False, - error_type='internal_error', - error_message=str(e), - http_status_code=500, - **upload_data - ) - raise - router = APIRouter() routes = [ diff --git a/api/src/models/evaluation.py b/api/src/models/evaluation.py index f8e39753..aea53821 100644 --- a/api/src/models/evaluation.py +++ b/api/src/models/evaluation.py @@ -9,6 +9,8 @@ from api.src.backend.entities import EvaluationRun, MinerAgent, MinerAgentScored, SandboxStatus from api.src.backend.db_manager import get_db_connection, get_transaction from api.src.backend.entities import EvaluationStatus +from api.src.backend.queries.evaluations import get_evaluation_by_evaluation_id +from api.src.backend.queries.scores import get_combined_screener_score from api.src.models.screener import Screener from api.src.models.validator import Validator from api.src.utils.config import SCREENING_1_THRESHOLD, SCREENING_2_THRESHOLD @@ -23,45 +25,63 @@ class Evaluation: _lock = asyncio.Lock() - def __init__( - self, - evaluation_id: str, - version_id: str, - validator_hotkey: str, - set_id: int, - status: EvaluationStatus, - terminated_reason: Optional[str] = None, - score: Optional[float] = None, - screener_score: Optional[float] = None, - created_at: Optional[datetime] = None, - started_at: Optional[datetime] = None, - finished_at: Optional[datetime] = None, - ): + def __init__(self, evaluation_id: str): self.evaluation_id = evaluation_id - self.version_id = version_id - self.validator_hotkey = validator_hotkey - self.set_id = set_id - self.status = status - self.terminated_reason = terminated_reason - self.created_at = created_at - self.started_at = started_at - self.finished_at = finished_at - self.score = score - self.screener_score = screener_score - @property - def is_screening(self) -> bool: - return self.screener_stage is not None - - @property - def screener_stage(self) -> Optional[int]: - return Screener.get_stage(self.validator_hotkey) + + + async def get_version_id(self) -> str: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return str(evaluation.version_id) + + async def get_validator_hotkey(self) -> str: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.validator_hotkey + + async def get_set_id(self) -> int: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.set_id + + async def get_status(self) -> EvaluationStatus: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.status + + async def get_terminated_reason(self) -> Optional[str]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.terminated_reason + + async def get_score(self) -> Optional[float]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.score + + async def get_screener_score(self) -> Optional[float]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.screener_score + + async def get_created_at(self) -> Optional[datetime]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.created_at + + async def get_started_at(self) -> Optional[datetime]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.started_at + + async def get_finished_at(self) -> Optional[datetime]: + evaluation = await get_evaluation_by_evaluation_id(self.evaluation_id) + return evaluation.finished_at + + async def get_is_screening(self) -> bool: + return await self.get_screener_stage() is not None + + async def get_screener_stage(self) -> Optional[int]: + validator_hotkey = await self.get_validator_hotkey() + return Screener.get_stage(validator_hotkey) async def start(self, conn: asyncpg.Connection) -> List[EvaluationRun]: """Start evaluation""" await conn.execute("UPDATE evaluations SET status = 'running', started_at = NOW() WHERE evaluation_id = $1", self.evaluation_id) self.status = EvaluationStatus.running - match self.screener_stage: + match await self.get_screener_stage(): case 1: type = "screener-1" case 2: @@ -76,7 +96,7 @@ async def start(self, conn: asyncpg.Connection) -> List[EvaluationRun]: evaluation_runs = [ EvaluationRun( run_id=uuid.uuid4(), - evaluation_id=self.evaluation_id, + evaluation_id=uuid.UUID(self.evaluation_id), swebench_instance_id=swebench_instance_id, response=None, error=None, @@ -122,24 +142,24 @@ async def finish(self, conn: asyncpg.Connection): print("🚨 CRITICAL DEBUG: finish() called on evaluation with existing terminated_reason! 🚨") print("=" * 80) print(f"Evaluation ID: {self.evaluation_id}") - print(f"Version ID: {self.version_id}") - print(f"Validator Hotkey: {self.validator_hotkey}") + print(f"Version ID: {await self.get_version_id()}") + print(f"Validator Hotkey: {await self.get_validator_hotkey()}") print(f"Current Status: {current_status}") print(f"Existing terminated_reason: {current_terminated_reason}") - print(f"Is Screening: {self.is_screening}") - if self.is_screening: - print(f"Screener Stage: {self.screener_stage}") + print(f"Is Screening: {await self.get_is_screening()}") + if await self.get_is_screening(): + print(f"Screener Stage: {await self.get_screener_stage()}") print(f"Current Time: {datetime.now().isoformat()}") print() print("CALL STACK TRACE:") print("-" * 40) traceback.print_stack() print("=" * 80) - + # Also log it for persistent record logger.error( f"CRITICAL: finish() called on evaluation {self.evaluation_id} " - f"(version_id: {self.version_id}, validator: {self.validator_hotkey}) " + f"(version_id: {await self.get_version_id()}, validator: {await self.get_validator_hotkey()}) " f"that already has terminated_reason: '{current_terminated_reason}'. " f"Current status: {current_status}. This will result in inconsistent state!" ) @@ -162,13 +182,15 @@ async def finish(self, conn: asyncpg.Connection): stage2_screener_to_notify = None # If it's a screener, handle stage-specific logic - if self.is_screening: - stage = self.screener_stage + if await self.get_is_screening(): + stage = await self.get_screener_stage() threshold = SCREENING_1_THRESHOLD if stage == 1 else SCREENING_2_THRESHOLD - if self.score < threshold: - logger.info(f"Stage {stage} screening failed for agent {self.version_id} with score {self.score} (threshold: {threshold})") + score = await self.get_score() + version_id = await self.get_version_id() + if score < threshold: + logger.info(f"Stage {stage} screening failed for agent {version_id} with score {score} (threshold: {threshold})") else: - logger.info(f"Stage {stage} screening passed for agent {self.version_id} with score {self.score} (threshold: {threshold})") + logger.info(f"Stage {stage} screening passed for agent {version_id} with score {score} (threshold: {threshold})") if stage == 1: # Stage 1 passed -> find ONE available stage 2 screener @@ -184,16 +206,17 @@ async def finish(self, conn: asyncpg.Connection): break elif stage == 2: # Stage 2 passed -> check if we should prune immediately - combined_screener_score, score_error = await Screener.get_combined_screener_score(conn, self.version_id) + version_id = await self.get_version_id() + combined_screener_score, score_error = await get_combined_screener_score(version_id) # ^ if this is None, we should likely not be here, because it means that either there was no screener 1 or screener 2 evaluation, but in which case how would be here anyway? if score_error: - await send_slack_message(f"Stage 2 screener score error for version {self.version_id}: {score_error}") + await send_slack_message(f"Stage 2 screener score error for version {version_id}: {score_error}") top_agent = await MinerAgentScored.get_top_agent(conn) - + if top_agent and combined_screener_score is not None and (top_agent.avg_score - combined_screener_score) > PRUNE_THRESHOLD: # Score is too low, prune miner agent and don't create evaluations - await conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", self.version_id) - logger.info(f"Pruned agent {self.version_id} immediately after screener-2 with combined score {combined_screener_score:.3f} (threshold: {top_agent.avg_score - PRUNE_THRESHOLD:.3f})") + await conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", version_id) + logger.info(f"Pruned agent {version_id} immediately after screener-2 with combined score {combined_screener_score:.3f} (threshold: {top_agent.avg_score - PRUNE_THRESHOLD:.3f})") return { "stage2_screener": None, "validators": [] @@ -207,10 +230,11 @@ async def finish(self, conn: asyncpg.Connection): all_validators = await Validator.get_connected() validators_to_notify = random.sample(all_validators, min(2, len(all_validators))) for validator in validators_to_notify: + version_id = await self.get_version_id() if (combined_screener_score is None): - await send_slack_message(f"111 Screener score is None when creating evaluation for validator {validator.hotkey}, version {self.version_id}") + await send_slack_message(f"111 Screener score is None when creating evaluation for validator {validator.hotkey}, version {version_id}") await send_slack_message(f"Evaluation object: {str(self)}") - await self.create_for_validator(conn, self.version_id, validator.hotkey, combined_screener_score) + await self.create_for_validator(conn, version_id, validator.hotkey, combined_screener_score) # Prune low-scoring evaluations after creating validator evaluations await Evaluation.prune_low_waiting(conn) @@ -272,94 +296,99 @@ async def reset_to_waiting(self, conn: asyncpg.Connection): async def _update_agent_status(self, conn: asyncpg.Connection): """Update agent status based on evaluation state - handles multi-stage screening""" - + + version_id = await self.get_version_id() + status = await self.get_status() + score = await self.get_score() + # Handle screening completion - if self.is_screening and self.status == EvaluationStatus.completed: - stage = self.screener_stage + if await self.get_is_screening() and status == EvaluationStatus.completed: + stage = await self.get_screener_stage() threshold = SCREENING_1_THRESHOLD if stage == 1 else SCREENING_2_THRESHOLD - if self.score is not None and self.score >= threshold: + if score is not None and score >= threshold: if stage == 1: # Stage 1 passed -> move to stage 2 - await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE version_id = $1", version_id) elif stage == 2: # Stage 2 passed -> ready for validation - await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE version_id = $1", version_id) else: if stage == 1: # Stage 1 failed - await conn.execute("UPDATE miner_agents SET status = 'failed_screening_1' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'failed_screening_1' WHERE version_id = $1", version_id) elif stage == 2: # Stage 2 failed - await conn.execute("UPDATE miner_agents SET status = 'failed_screening_2' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'failed_screening_2' WHERE version_id = $1", version_id) return # Handle screening errors like disconnection - reset to appropriate awaiting state - if self.is_screening and self.status == EvaluationStatus.error: - stage = self.screener_stage + if await self.get_is_screening() and status == EvaluationStatus.error: + stage = await self.get_screener_stage() if stage == 1: - await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_1' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_1' WHERE version_id = $1", version_id) elif stage == 2: - await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE version_id = $1", version_id) return # Check for any stage 1 screening evaluations (only running - waiting evaluations don't mean agent is actively being screened) stage1_count = await conn.fetchval( - """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 - AND (validator_hotkey LIKE 'screener-1-%' OR validator_hotkey LIKE 'i-0%') + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 + AND (validator_hotkey LIKE 'screener-1-%' OR validator_hotkey LIKE 'i-0%') AND status = 'running'""", - self.version_id, + version_id, ) if stage1_count > 0: - await conn.execute("UPDATE miner_agents SET status = 'screening_1' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'screening_1' WHERE version_id = $1", version_id) return # Check for any stage 2 screening evaluations (only running - waiting evaluations don't mean agent is actively being screened) stage2_count = await conn.fetchval( - """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 - AND validator_hotkey LIKE 'screener-2-%' + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 + AND validator_hotkey LIKE 'screener-2-%' AND status = 'running'""", - self.version_id, + version_id, ) if stage2_count > 0: - await conn.execute("UPDATE miner_agents SET status = 'screening_2' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'screening_2' WHERE version_id = $1", version_id) return # Handle evaluation status transitions for regular evaluations - if self.status == EvaluationStatus.running and not self.is_screening: - await conn.execute("UPDATE miner_agents SET status = 'evaluating' WHERE version_id = $1", self.version_id) + if status == EvaluationStatus.running and not await self.get_is_screening(): + await conn.execute("UPDATE miner_agents SET status = 'evaluating' WHERE version_id = $1", version_id) return # For other cases, check remaining regular evaluations (non-screening) waiting_count = await conn.fetchval( - """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = 'waiting' - AND validator_hotkey NOT LIKE 'screener-%' - AND validator_hotkey NOT LIKE 'i-0%'""", - self.version_id + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = 'waiting' + AND validator_hotkey NOT LIKE 'screener-%' + AND validator_hotkey NOT LIKE 'i-0%'""", + version_id ) running_count = await conn.fetchval( - """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = 'running' - AND validator_hotkey NOT LIKE 'screener-%' - AND validator_hotkey NOT LIKE 'i-0%'""", - self.version_id + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = 'running' + AND validator_hotkey NOT LIKE 'screener-%' + AND validator_hotkey NOT LIKE 'i-0%'""", + version_id ) completed_count = await conn.fetchval( - """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status IN ('completed', 'pruned') - AND validator_hotkey NOT LIKE 'screener-%' - AND validator_hotkey NOT LIKE 'i-0%'""", - self.version_id + """SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status IN ('completed', 'pruned') + AND validator_hotkey NOT LIKE 'screener-%' + AND validator_hotkey NOT LIKE 'i-0%'""", + version_id ) if waiting_count > 0 and running_count == 0: - await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'waiting' WHERE version_id = $1", version_id) elif waiting_count == 0 and running_count == 0 and completed_count > 0: # Calculate and update innovation score for this agent before setting status to 'scored' await self._update_innovation_score(conn) - await conn.execute("UPDATE miner_agents SET status = 'scored' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'scored' WHERE version_id = $1", version_id) else: - await conn.execute("UPDATE miner_agents SET status = 'evaluating' WHERE version_id = $1", self.version_id) + await conn.execute("UPDATE miner_agents SET status = 'evaluating' WHERE version_id = $1", version_id) async def _update_innovation_score(self, conn: asyncpg.Connection): """Calculate and update innovation score for this evaluation's agent in one atomic query""" + version_id = await self.get_version_id() try: # Single atomic query that calculates and updates innovation score updated_rows = await conn.execute(""" @@ -384,10 +413,10 @@ async def _update_innovation_score(self, conn: asyncpg.Connection): run_id, -- Calculate average solve rate for this instance before this run COALESCE( - AVG(CASE WHEN solved THEN 1.0 ELSE 0.0 END) + AVG(CASE WHEN solved THEN 1.0 ELSE 0.0 END) OVER ( - PARTITION BY swebench_instance_id - ORDER BY started_at + PARTITION BY swebench_instance_id + ORDER BY started_at ROWS BETWEEN UNBOUNDED PRECEDING AND 1 PRECEDING ), 0.0 ) AS prior_solved_ratio @@ -400,19 +429,19 @@ async def _update_innovation_score(self, conn: asyncpg.Connection): ) AS innovation_score FROM runs_with_prior ) - UPDATE miner_agents + UPDATE miner_agents SET innovation = (SELECT innovation_score FROM innovation_calculation) WHERE version_id = $1 - """, self.version_id) - - logger.info(f"Updated innovation score for agent {self.version_id} (affected {updated_rows} rows)") - + """, version_id) + + logger.info(f"Updated innovation score for agent {version_id} (affected {updated_rows} rows)") + except Exception as e: - logger.error(f"Failed to calculate innovation score for agent {self.version_id}: {e}") + logger.error(f"Failed to calculate innovation score for agent {version_id}: {e}") # Set innovation score to NULL on error to indicate calculation failure await conn.execute( "UPDATE miner_agents SET innovation = NULL WHERE version_id = $1", - self.version_id + version_id ) @staticmethod @@ -432,7 +461,7 @@ async def create_for_validator(conn: asyncpg.Connection, version_id: str, valida # We should always have a screener score when creating an evaluation for a validator if screener_score is None: - combined_screener_score, score_error = await Screener.get_combined_screener_score(conn, version_id) + combined_screener_score, score_error = await get_combined_screener_score(version_id) await send_slack_message(f"333 Screener score is None when creating evaluation for validator {validator_hotkey}, version {version_id}") if score_error: await send_slack_message(f"Error calculating screener score: {score_error}") @@ -504,11 +533,10 @@ async def create_screening_and_send(conn: asyncpg.Connection, agent: 'MinerAgent eval_id = str(uuid.uuid4()) - evaluation_data = await conn.fetchrow( + await conn.execute( """ INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at) VALUES ($1, $2, $3, $4, 'waiting', NOW()) - RETURNING * """, eval_id, agent.version_id, @@ -516,7 +544,7 @@ async def create_screening_and_send(conn: asyncpg.Connection, agent: 'MinerAgent set_id, ) - evaluation = Evaluation(**evaluation_data) + evaluation = Evaluation(eval_id) logger.info(f"Evaluation to send to screener {screener.hotkey}: {evaluation.evaluation_id}") evaluation_runs = await evaluation.start(conn) @@ -536,18 +564,11 @@ async def create_screening_and_send(conn: asyncpg.Connection, agent: 'MinerAgent async def get_by_id(evaluation_id: str) -> Optional["Evaluation"]: """Get evaluation by ID""" async with get_db_connection() as conn: - row = await conn.fetchrow("SELECT * FROM evaluations WHERE evaluation_id = $1", evaluation_id) + row = await conn.fetchrow("SELECT evaluation_id FROM evaluations WHERE evaluation_id = $1", evaluation_id) if not row: return None - return Evaluation( - evaluation_id=row["evaluation_id"], - version_id=row["version_id"], - validator_hotkey=row["validator_hotkey"], - set_id=row["set_id"], - status=EvaluationStatus.from_string(row["status"]), - score=row.get("score"), - ) + return Evaluation(evaluation_id=row["evaluation_id"]) @staticmethod async def screen_next_awaiting_agent(screener: "Screener"): @@ -665,19 +686,6 @@ async def get_progress(evaluation_id: str) -> float: """, evaluation_id) return float(progress) - @staticmethod - async def check_miner_has_no_running_evaluations(conn: asyncpg.Connection, miner_hotkey: str) -> bool: - """Check if miner has any running evaluations""" - has_running = await conn.fetchval( - """ - SELECT EXISTS(SELECT 1 FROM evaluations e - JOIN miner_agents ma ON e.version_id = ma.version_id - WHERE ma.miner_hotkey = $1 AND e.status = 'running') - """, - miner_hotkey, - ) - return not has_running - @staticmethod async def replace_old_agents(conn: asyncpg.Connection, miner_hotkey: str) -> None: """Replace all old agents and their evaluations for a miner""" @@ -731,7 +739,7 @@ async def has_waiting_for_validator(validator: "Validator") -> bool: ) for agent in agents: - combined_screener_score, score_error = await Screener.get_combined_screener_score(conn, agent["version_id"]) + combined_screener_score, score_error = await get_combined_screener_score(agent["version_id"]) if (combined_screener_score is None): await send_slack_message(f"222 Agent object: {dict(agent)}") await send_slack_message(f"222 Screener score is None when creating evaluation for validator {validator.hotkey}, version {agent['version_id']}") @@ -803,7 +811,7 @@ async def startup_recovery(): for eval_row in running_evals: evaluation = await Evaluation.get_by_id(eval_row["evaluation_id"]) if evaluation: - if evaluation.is_screening: + if await evaluation.get_is_screening(): await evaluation.error(conn, "Disconnected from screener (error code 2)") else: await evaluation.reset_to_waiting(conn) @@ -826,11 +834,18 @@ async def startup_recovery(): ) for stuck_eval in stuck_evaluations: - evaluation = await Evaluation.get_by_id(stuck_eval["evaluation_id"]) + evaluation = await get_evaluation_by_evaluation_id(stuck_eval["evaluation_id"]) + if evaluation: logger.info(f"Auto-completing stuck evaluation {evaluation.evaluation_id} during startup recovery") # During startup recovery, don't trigger notifications - _ = await evaluation.finish(conn) + from api.src.endpoints.screener import finish_evaluation + _ = await finish_evaluation( + str(evaluation.evaluation_id), + evaluation.validator_hotkey, + errored=True, + reason="Platform restarted" + ) # Cancel waiting screenings for all screener types waiting_screenings = await conn.fetch( diff --git a/api/src/models/screener.py b/api/src/models/screener.py index 6de54e60..fb69ab3e 100644 --- a/api/src/models/screener.py +++ b/api/src/models/screener.py @@ -1,9 +1,9 @@ import logging +from turtle import Screen from typing import Literal, Optional, List import asyncpg -from api.src.backend.entities import Client, AgentStatus, MinerAgent -from api.src.backend.db_manager import get_transaction +from api.src.backend.entities import Client logger = logging.getLogger(__name__) @@ -32,78 +32,9 @@ def get_stage(hotkey: str) -> Optional[int]: return 1 elif hotkey.startswith("screener-2-"): return 2 - elif hotkey.startswith("i-0"): # Legacy screeners are stage 1 - return 1 else: return None - - @staticmethod - async def get_combined_screener_score(conn: asyncpg.Connection, version_id: str) -> tuple[Optional[float], Optional[str]]: - """Calculate combined screener score as (questions solved by both) / (questions asked by both) - - Returns: - tuple[Optional[float], Optional[str]]: (score, error_message) - - score: The calculated score, or None if calculation failed - - error_message: None if successful, error description if failed - """ - # Get evaluation IDs for both screener stages - stage_1_eval_id = await conn.fetchval( - """ - SELECT evaluation_id FROM evaluations - WHERE version_id = $1 - AND validator_hotkey LIKE 'screener-1-%' - AND status = 'completed' - ORDER BY created_at DESC - LIMIT 1 - """, - version_id - ) - - stage_2_eval_id = await conn.fetchval( - """ - SELECT evaluation_id FROM evaluations - WHERE version_id = $1 - AND validator_hotkey LIKE 'screener-2-%' - AND status = 'completed' - ORDER BY created_at DESC - LIMIT 1 - """, - version_id - ) - - if not stage_1_eval_id or not stage_2_eval_id: - missing = [] - if not stage_1_eval_id: - missing.append("stage-1") - if not stage_2_eval_id: - missing.append("stage-2") - return None, f"Missing completed screener evaluation(s): {', '.join(missing)}" - - # Get solved count and total count for both evaluations - results = await conn.fetch( - """ - SELECT - SUM(CASE WHEN solved THEN 1 ELSE 0 END) as solved_count, - COUNT(*) as total_count - FROM evaluation_runs - WHERE evaluation_id = ANY($1::uuid[]) - AND status != 'cancelled' - """, - [stage_1_eval_id, stage_2_eval_id] - ) - - if not results or len(results) == 0: - return None, f"No evaluation runs found for screener evaluations {stage_1_eval_id} and {stage_2_eval_id}" - - result = results[0] - solved_count = result['solved_count'] or 0 - total_count = result['total_count'] or 0 - - if total_count == 0: - return None, f"No evaluation runs to calculate score from (total_count=0)" - return solved_count / total_count, None - @property def stage(self) -> Optional[int]: """Get the screening stage for this screener""" @@ -181,38 +112,6 @@ def screening_agent_hotkey(self) -> Optional[str]: @property def screening_agent_name(self) -> Optional[str]: return self.current_agent_name - - async def start_screening(self, evaluation_id: str) -> bool: - """Handle start-evaluation message""" - from api.src.models.evaluation import Evaluation - - evaluation = await Evaluation.get_by_id(evaluation_id) - if not evaluation or not evaluation.is_screening or evaluation.validator_hotkey != self.hotkey: - return False - - async with get_transaction() as conn: - agent = await conn.fetchrow("SELECT status, agent_name, miner_hotkey FROM miner_agents WHERE version_id = $1", evaluation.version_id) - agent_status = AgentStatus.from_string(agent["status"]) if agent else None - - # Check if agent is in the appropriate screening status for this screener stage - expected_status = getattr(AgentStatus, f"screening_{self.stage}") - if not agent or agent_status != expected_status: - logger.info(f"Stage {self.stage} screener {self.hotkey}: tried to start screening but agent is not in screening_{self.stage} status (current: {agent['status'] if agent else 'None'})") - return False - agent_name = agent["agent_name"] - agent_hotkey = agent["miner_hotkey"] - - await evaluation.start(conn) - old_status = self.status - self.status = f"screening" - self.current_evaluation_id = evaluation_id - self.current_agent_name = agent_name - self.current_agent_hotkey = agent_hotkey - logger.info(f"Screener {self.hotkey}: {old_status} -> screening {agent_name}") - - # Broadcast status change - self._broadcast_status_change() - return True async def connect(self): """Handle screener connection""" @@ -234,104 +133,21 @@ async def disconnect(self): self.set_available() logger.info(f"Screener {self.hotkey} disconnected, status reset to: {self.status}") await Evaluation.handle_screener_disconnection(self.hotkey) - - async def finish_screening(self, evaluation_id: str, errored: bool = False, reason: Optional[str] = None): - """Finish screening evaluation""" - from api.src.models.evaluation import Evaluation - - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}, entered finish_screening") - - try: - evaluation = await Evaluation.get_by_id(evaluation_id) - if not evaluation or not evaluation.is_screening or evaluation.validator_hotkey != self.hotkey: - logger.warning(f"Screener {self.hotkey}: Invalid finish_screening call for evaluation {evaluation_id}") - return - - async with get_transaction() as conn: - agent_status = await conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", evaluation.version_id) - expected_status = getattr(AgentStatus, f"screening_{self.stage}") - if AgentStatus.from_string(agent_status) != expected_status: - logger.warning(f"Stage {self.stage} screener {self.hotkey}: Evaluation {evaluation_id}: Agent {evaluation.version_id} not in screening_{self.stage} status during finish (current: {agent_status})") - # Clearly a bug here, its somehow set to failed_screening_1 when we hit this if statement - # It should be screening_1, no idea whats setting it to failed_screening_1 - # return - - if errored: - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}: Errored with reason: {reason}") - await evaluation.error(conn, reason) - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}: Errored with reason: {reason}: done") - notification_targets = None - else: - notification_targets = await evaluation.finish(conn) - - from api.src.socket.websocket_manager import WebSocketManager - ws_manager = WebSocketManager.get_instance() - await ws_manager.send_to_all_non_validators("evaluation-finished", {"evaluation_id": evaluation_id}) - self.set_available() - - logger.info(f"Screener {self.hotkey}: Successfully finished evaluation {evaluation_id}, errored={errored}") - - # Handle notifications AFTER transaction commits - if notification_targets: - # Notify stage 2 screener when stage 1 completes - if notification_targets.get("stage2_screener"): - async with Evaluation.get_lock(): - await Evaluation.screen_next_awaiting_agent(notification_targets["stage2_screener"]) - - # Notify validators with proper lock protection - for validator in notification_targets.get("validators", []): - async with Evaluation.get_lock(): - if validator.is_available(): - success = await validator.start_evaluation_and_send(evaluation_id) - if success: - logger.info(f"Successfully assigned evaluation {evaluation_id} to validator {validator.hotkey}") - else: - logger.warning(f"Failed to assign evaluation {evaluation_id} to validator {validator.hotkey}") - else: - logger.info(f"Validator {validator.hotkey} not available for evaluation {evaluation_id}") - - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}: Got to end of try block") - finally: - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}, in finally block") - # Single atomic reset and reassignment - async with Evaluation.get_lock(): - self.set_available() - logger.info(f"Screener {self.hotkey}: Reset to available and looking for next agent") - await Evaluation.screen_next_awaiting_agent(self) - logger.info(f"Screener {self.hotkey}: Finishing screening {evaluation_id}, exiting finally block") - @staticmethod - async def get_first_available() -> Optional['Screener']: - """Read-only availability check - does NOT reserve screener""" + async def get_connected_1() -> List['Screener']: + """Get all connected sc1""" from api.src.socket.websocket_manager import WebSocketManager ws_manager = WebSocketManager.get_instance() - logger.debug(f"Checking {len(ws_manager.clients)} clients for available screener...") - for client in ws_manager.clients.values(): - if client.get_type() == "screener" and client.status == "available": - logger.debug(f"Found available screener: {client.hotkey}") - return client - logger.warning("No available screeners found") - return None - + screeners: list['Screener'] = [client for client in ws_manager.clients.values() if client.get_type() == "screener"] + + return [screener for screener in screeners if screener.stage == 1] + @staticmethod - async def get_first_available_and_reserve(stage: int) -> Optional['Screener']: - """Atomically find and reserve first available screener for specific stage - MUST be called within Evaluation lock""" + async def get_connected_2() -> List['Screener']: + """Get all connected sc2""" from api.src.socket.websocket_manager import WebSocketManager ws_manager = WebSocketManager.get_instance() - - for client in ws_manager.clients.values(): - if (client.get_type() == "screener" and - client.status == "available" and - client.is_available() and - client.stage == stage): - - # Immediately reserve to prevent race conditions - client.status = "reserving" - logger.info(f"Reserved stage {stage} screener {client.hotkey} for work assignment") - - return client - - logger.warning(f"No available stage {stage} screeners to reserve") - return None - \ No newline at end of file + screeners: list['Screener'] = [client for client in ws_manager.clients.values() if client.get_type() == "screener"] + + return [screener for screener in screeners if screener.stage == 2] diff --git a/api/src/models/validator.py b/api/src/models/validator.py index ed1be1df..9af6bc81 100644 --- a/api/src/models/validator.py +++ b/api/src/models/validator.py @@ -91,25 +91,25 @@ async def start_evaluation_and_send(self, evaluation_id: str) -> bool: from api.src.models.evaluation import Evaluation evaluation = await Evaluation.get_by_id(evaluation_id) - - if not evaluation or evaluation.is_screening or evaluation.validator_hotkey != self.hotkey: + + if not evaluation or await evaluation.get_is_screening() or await evaluation.get_validator_hotkey() != self.hotkey: logger.warning(f"Validator {self.hotkey}: Invalid evaluation {evaluation_id}") return False - miner_agent = await get_agent_by_version_id(evaluation.version_id) + miner_agent = await get_agent_by_version_id(await evaluation.get_version_id()) if not miner_agent: logger.error(f"Validator {self.hotkey}: Agent not found for evaluation {evaluation_id}") return False try: - async with get_transaction() as conn: - evaluation_runs = await evaluation.start(conn) + from api.src.endpoints.screener import start_screening as start_eval_new + start_state = await start_eval_new(evaluation_id=evaluation_id, hotkey=self.hotkey) message = { "event": "evaluation", "evaluation_id": str(evaluation_id), "agent_version": miner_agent.model_dump(mode='json'), - "evaluation_runs": [run.model_dump(mode='json') for run in evaluation_runs] + "evaluation_runs": [run.model_dump(mode='json') for run in start_state["runs_created"]] } # Send message to validator @@ -163,52 +163,7 @@ async def get_next_evaluation(self) -> Optional[str]: ORDER BY e.screener_score DESC NULLS LAST, e.created_at ASC LIMIT 1 """, self.hotkey) - - async def finish_evaluation(self, evaluation_id: str, errored: bool = False, reason: Optional[str] = None): - """Finish evaluation and automatically look for next work""" - from api.src.models.evaluation import Evaluation - - try: - evaluation = await Evaluation.get_by_id(evaluation_id) - if not evaluation or evaluation.validator_hotkey != self.hotkey: - logger.warning(f"Validator {self.hotkey}: Invalid finish_evaluation call for evaluation {evaluation_id}") - return - - async with get_transaction() as conn: - agent_status = await conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", evaluation.version_id) - if AgentStatus.from_string(agent_status) != AgentStatus.evaluating: - logger.warning(f"Validator {self.hotkey}: Agent {evaluation.version_id} not in evaluating status during finish") - return - - if errored: - await evaluation.error(conn, reason) - notification_targets = None - else: - notification_targets = await evaluation.finish(conn) - - from api.src.socket.websocket_manager import WebSocketManager - ws_manager = WebSocketManager.get_instance() - await ws_manager.send_to_all_non_validators("evaluation-finished", {"evaluation_id": evaluation_id}) - - logger.info(f"Validator {self.hotkey}: Successfully finished evaluation {evaluation_id}, errored={errored}") - - # Handle notifications AFTER transaction commits - if notification_targets: - # Note: Validators typically don't trigger stage transitions, but handle any notifications - for validator in notification_targets.get("validators", []): - async with Evaluation.get_lock(): - if validator.is_available(): - success = await validator.start_evaluation_and_send(evaluation_id) - if success: - logger.info(f"Successfully assigned evaluation {evaluation_id} to validator {validator.hotkey}") - - finally: - # Single atomic reset and reassignment - async with Evaluation.get_lock(): - self.set_available() - logger.info(f"Validator {self.hotkey}: Reset to available and looking for next evaluation") - await self._check_and_start_next_evaluation() - + async def _check_and_start_next_evaluation(self): """Atomically check for and start next evaluation - MUST be called within lock""" from api.src.models.evaluation import Evaluation diff --git a/api/src/socket/handlers/handle_finish_evaluation.py b/api/src/socket/handlers/handle_finish_evaluation.py index 1e9ff3ee..cd8ceb6a 100644 --- a/api/src/socket/handlers/handle_finish_evaluation.py +++ b/api/src/socket/handlers/handle_finish_evaluation.py @@ -3,6 +3,7 @@ from api.src.backend.entities import Client from api.src.backend.queries.evaluations import get_evaluation_by_evaluation_id +from api.src.endpoints.screener import finish_evaluation, finish_screening from api.src.models.screener import Screener from api.src.models.validator import Validator from loggers.logging_utils import get_logger @@ -46,11 +47,21 @@ async def handle_finish_evaluation( if client.get_type() == "screener": screener: Screener = client - await screener.finish_screening(evaluation_id, errored, reason) + await finish_screening( + evaluation_id=evaluation_id, + hotkey=screener.hotkey, + errored=errored, + reason=reason + ) action = "Screening" elif client.get_type() == "validator": validator: Validator = client - await validator.finish_evaluation(evaluation_id, errored, reason) + await finish_evaluation( + evaluation_id=evaluation_id, + hotkey=validator.hotkey, + errored=errored, + reason=reason + ) action = "Evaluation" # Broadcast evaluation completion diff --git a/api/src/socket/handlers/handle_inform_evaluation_completed.py b/api/src/socket/handlers/handle_inform_evaluation_completed.py index 7cb5c7a1..cffe3bba 100644 --- a/api/src/socket/handlers/handle_inform_evaluation_completed.py +++ b/api/src/socket/handlers/handle_inform_evaluation_completed.py @@ -38,11 +38,13 @@ async def handle_inform_evaluation_completed( # Force finish the evaluation (skip all_runs_finished check) if client.get_type() == "validator": logger.info(f"Calling finish_evaluation for {evaluation_id}") - await client.finish_evaluation(evaluation_id) + from api.src.endpoints.screener import finish_evaluation + await finish_evaluation(evaluation_id, client.hotkey) logger.info(f"Called finish_evaluation for {evaluation_id}") elif client.get_type() == "screener": logger.info(f"Calling finish_screening for {evaluation_id}") - await client.finish_screening(evaluation_id) + from api.src.endpoints.screener import finish_screening + await finish_screening(evaluation_id, client.hotkey) logger.info(f"Called finish_screening for {evaluation_id}") else: logger.warning(f"Unknown client type when trying to finish evaluation {evaluation_id}") diff --git a/api/src/socket/handlers/handle_start_evaluation.py b/api/src/socket/handlers/handle_start_evaluation.py index 5616a27f..51917898 100644 --- a/api/src/socket/handlers/handle_start_evaluation.py +++ b/api/src/socket/handlers/handle_start_evaluation.py @@ -1,6 +1,7 @@ from typing import Dict, Any from api.src.backend.entities import Client +from api.src.endpoints.screener import start_screening from api.src.models.screener import Screener from api.src.models.validator import Validator from loggers.logging_utils import get_logger @@ -21,7 +22,8 @@ async def handle_start_evaluation( # Use appropriate start method based on client type if client.get_type() == "screener": - success = await client.start_screening(evaluation_id) + start_screening = await start_screening(evaluation_id, client.client_id) + success = start_screening["success"] action = "Screening" elif client.get_type() == "validator": success = await client.start_evaluation(evaluation_id) diff --git a/api/src/socket/handlers/handle_update_evaluation_run.py b/api/src/socket/handlers/handle_update_evaluation_run.py index b4dfbf44..ff184453 100644 --- a/api/src/socket/handlers/handle_update_evaluation_run.py +++ b/api/src/socket/handlers/handle_update_evaluation_run.py @@ -1,9 +1,8 @@ from typing import TYPE_CHECKING, Dict, Any -from api.src.backend.db_manager import get_transaction from api.src.backend.entities import Client, EvaluationRun -from api.src.models.evaluation import Evaluation from api.src.backend.queries.evaluation_runs import all_runs_finished, update_evaluation_run +from api.src.backend.queries.evaluations import get_progress from loggers.logging_utils import get_logger from typing import Union @@ -99,13 +98,11 @@ async def handle_update_evaluation_run( if await all_runs_finished(evaluation_run.evaluation_id): logger.info(f"All runs finished for evaluation {evaluation_run.evaluation_id}. Finishing evaluation.") if client.get_type() == "validator": - logger.info(f"Calling finish_evaluation for {evaluation_run.evaluation_id}") - await client.finish_evaluation(evaluation_run.evaluation_id) - logger.info(f"Called finish_evaluation for {evaluation_run.evaluation_id}") + from api.src.endpoints.screener import finish_evaluation + await finish_evaluation(evaluation_run.evaluation_id, client.hotkey) elif client.get_type() == "screener": - logger.info(f"Calling finish_screening for {evaluation_run.evaluation_id}") - await client.finish_screening(evaluation_run.evaluation_id) - logger.info(f"Called finish_screening for {evaluation_run.evaluation_id}") + from api.src.endpoints.screener import finish_screening + await finish_screening(evaluation_run.evaluation_id, client.hotkey) else: logger.info(f"Unknown type, not validator or screener, when trying to finish evaluation {evaluation_run.evaluation_id}") @@ -115,7 +112,7 @@ async def handle_update_evaluation_run( if 'status' in broadcast_data: broadcast_data['status'] = evaluation_run.status.value broadcast_data["validator_hotkey"] = client.hotkey # Keep as validator_hotkey for API compatibility - broadcast_data["progress"] = await Evaluation.get_progress(evaluation_run.evaluation_id) + broadcast_data["progress"] = await get_progress(evaluation_run.evaluation_id) broadcast_data["validator_status"] = validator_status # Include computed validator status await ws.send_to_all_non_validators("evaluation-run-update", broadcast_data) diff --git a/api/src/socket/handlers/handle_validator_info.py b/api/src/socket/handlers/handle_validator_info.py index f7308386..d7e7332d 100644 --- a/api/src/socket/handlers/handle_validator_info.py +++ b/api/src/socket/handlers/handle_validator_info.py @@ -1,10 +1,8 @@ -import time from typing import Dict, Any from fastapi import WebSocket from loggers.logging_utils import get_logger from api.src.backend.entities import Client -from api.src.socket.server_helpers import get_relative_version_num logger = get_logger(__name__) diff --git a/api/src/socket/websocket_manager.py b/api/src/socket/websocket_manager.py index dd9f42d4..d3d02e9f 100644 --- a/api/src/socket/websocket_manager.py +++ b/api/src/socket/websocket_manager.py @@ -2,7 +2,7 @@ from typing import Optional, Dict, List, Union from fastapi import WebSocket, WebSocketDisconnect -from api.src.models.evaluation import Evaluation +from api.src.backend.queries.evaluations import get_progress from loggers.logging_utils import get_logger from api.src.backend.entities import Client from api.src.models.validator import Validator @@ -140,7 +140,7 @@ async def get_clients(self): "evaluating_id": validator.current_evaluation_id, "evaluating_agent_hotkey": validator.current_agent_hotkey, "evaluating_agent_name": validator.current_agent_name, - "progress": await Evaluation.get_progress(validator.current_evaluation_id) if validator.current_evaluation_id else 0 + "progress": await get_progress(validator.current_evaluation_id) if validator.current_evaluation_id else 0 } # Always include system metrics from the validator's stored data @@ -169,7 +169,7 @@ async def get_clients(self): "screening_id": screener.screening_id, "screening_agent_hotkey": screener.screening_agent_hotkey, "screening_agent_name": screener.screening_agent_name, - "progress": await Evaluation.get_progress(screener.screening_id) if screener.screening_id else 0 + "progress": await get_progress(screener.screening_id) if screener.screening_id else 0 } # Always include system metrics from the screener's stored data diff --git a/tests/.coverage b/tests/.coverage deleted file mode 100644 index e57edff5..00000000 Binary files a/tests/.coverage and /dev/null differ diff --git a/tests/README.md b/tests/README.md deleted file mode 100644 index 0cee79fd..00000000 --- a/tests/README.md +++ /dev/null @@ -1,53 +0,0 @@ -# Testing Guide - -This project uses pytest for all testing needs. Everything is configured to work automatically. - -## Quick Start - -```bash -# Run all tests (like GitHub Actions does) -cd tests -uv run python -m pytest - -# Run only unit tests (fast) -uv run python -m pytest test_miner_agent_flow.py - -# Run with coverage details -uv run python -m pytest --cov-report=term-missing - -# Run specific test markers -uv run python -m pytest -m unit # Unit tests only -uv run python -m pytest -m core # Core business logic -uv run python -m pytest -m endpoints # API endpoint tests -``` - -## Test Categories - -- **Unit Tests** (`test_miner_agent_flow.py`) - Core business logic, always reliable -- **Simple Tests** (`test_endpoints_simple.py`) - Basic API endpoint tests -- **Unit Tests** (`test_endpoints_unit.py`) - API tests with mocking -- **Integration Tests** (`test_endpoints_integration.py`) - Full database integration - -## Environment Setup - -The `conftest.py` file automatically: -- Sets up PostgreSQL service for integration tests (Docker) -- Configures environment variables to match GitHub Actions -- Handles test markers and coverage reporting -- Manages Python path for proper imports - -## Configuration Files - -- `pytest.ini` - Main pytest configuration with coverage settings -- `conftest.py` - Global fixtures and test environment setup -- `docker-compose.test.yml` - PostgreSQL service for integration tests - -## GitHub Actions - -The workflow in `.github/workflows/testing.yml` simply runs: -```bash -cd tests -uv run python -m pytest -``` - -Everything else is handled by pytest configuration. \ No newline at end of file diff --git a/tests/TEST_SETUP_INSTRUCTIONS.md b/tests/TEST_SETUP_INSTRUCTIONS.md deleted file mode 100644 index 883e5408..00000000 --- a/tests/TEST_SETUP_INSTRUCTIONS.md +++ /dev/null @@ -1,342 +0,0 @@ -# Database Integration Testing Setup - -This document provides instructions for setting up a comprehensive testing environment with real PostgreSQL database integration for the Ridges API endpoints. - -**✅ Uses Production Schema**: The integration tests now use the actual `postgres_schema.sql` file to ensure testing against the exact production database structure, including all triggers, materialized views, and constraints. - -## Prerequisites - -1. **PostgreSQL Database Server** - - Local PostgreSQL installation or Docker container - - PostgreSQL 12+ recommended - - Superuser access to create/drop test databases - -2. **Python Dependencies** - ```bash - uv add pytest pytest-asyncio asyncpg httpx pytest-postgresql - ``` - -## Setup Options - -### Option 1: Local PostgreSQL Installation - -1. **Install PostgreSQL locally:** - ```bash - # macOS with Homebrew - brew install postgresql - brew services start postgresql - - # Ubuntu/Debian - sudo apt-get install postgresql postgresql-contrib - sudo systemctl start postgresql - - # Create test user and database - sudo -u postgres createuser --superuser test_user - sudo -u postgres psql -c "ALTER USER test_user PASSWORD 'test_password';" - ``` - -2. **Set environment variable:** - ```bash - export POSTGRES_TEST_URL="postgresql://test_user:test_password@localhost:5432/ridges_test" - ``` - -### Option 2: Docker PostgreSQL (Recommended) - -1. **Create docker-compose.test.yml:** - ```yaml - version: '3.8' - services: - postgres-test: - image: postgres:15 - environment: - POSTGRES_DB: postgres - POSTGRES_USER: test_user - POSTGRES_PASSWORD: test_password - ports: - - "5433:5432" # Different port to avoid conflicts - volumes: - - postgres_test_data:/var/lib/postgresql/data - tmpfs: - - /tmp - - /var/run/postgresql - - volumes: - postgres_test_data: - ``` - -2. **Start test database:** - ```bash - docker-compose -f tests/docker-compose.test.yml up -d postgres-test - ``` - -3. **Set environment variable:** - ```bash - export POSTGRES_TEST_URL="postgresql://test_user:test_password@localhost:5433/ridges_test" - ``` - -### Option 3: pytest-postgresql (Automatic) - -This uses an ephemeral PostgreSQL instance that's automatically managed: - -1. **Install pytest-postgresql:** - ```bash - uv add pytest-postgresql - ``` - -2. **Use the auto-managed fixture in tests** (see Alternative Test Configuration below) - -## Running the Tests - -### Basic Test Execution - -```bash -# Run all integration tests -uv run python -m pytest test_endpoints_integration.py -v - -# Run specific test class -uv run python -m pytest test_endpoints_integration.py::TestUploadEndpoints -v - -# Run with coverage -uv run python -m pytest test_endpoints_integration.py --cov=api/src/endpoints --cov-report=html - -# Run tests in parallel (faster) -uv add pytest-xdist -uv run python -m pytest test_endpoints_integration.py -n auto -``` - -### Test Database Management - -The test suite automatically: -- Creates a clean test database for each test session -- Sets up the required schema and tables -- Provides transaction isolation for each test -- Cleans up after all tests complete - -### Environment Variables Required - -```bash -# Test database connection -export POSTGRES_TEST_URL="postgresql://test_user:test_password@localhost:5432/ridges_test" - -# Application environment variables (mocked in tests) -export AWS_MASTER_USERNAME="test_user" -export AWS_MASTER_PASSWORD="test_pass" -export AWS_RDS_PLATFORM_ENDPOINT="test_endpoint" -export AWS_RDS_PLATFORM_DB_NAME="test_db" -``` - -## Test Configuration Files - -### Create pytest.ini -```ini -[tool:pytest] -testpaths = . -python_files = test_*.py -python_classes = Test* -python_functions = test_* -asyncio_mode = auto -addopts = - -v - --tb=short - --strict-markers - --disable-warnings -markers = - slow: marks tests as slow - integration: marks tests as integration tests - unit: marks tests as unit tests -``` - -### Create .env.test -```bash -# Database Configuration -POSTGRES_TEST_URL=postgresql://test_user:test_password@localhost:5432/ridges_test - -# Mock AWS Configuration -AWS_MASTER_USERNAME=test_user -AWS_MASTER_PASSWORD=test_pass -AWS_RDS_PLATFORM_ENDPOINT=test_endpoint -AWS_RDS_PLATFORM_DB_NAME=test_db - -# Test-specific settings -TESTING=true -LOG_LEVEL=INFO -``` - -## Alternative Test Configuration (pytest-postgresql) - -If you prefer fully automated database management, replace the database setup in the test file: - -```python -import pytest_postgresql - -# Use pytest-postgresql for automatic database management -@pytest.fixture(scope="session") -def postgresql_proc(): - return pytest_postgresql.postgresql_proc( - port=None, - unixsocketdir='/tmp' - ) - -@pytest.fixture(scope="session") -def postgresql(postgresql_proc): - return pytest_postgresql.postgresql('postgresql_proc') - -@pytest.fixture -async def db_conn(postgresql): - """Database connection with automatic cleanup""" - import asyncpg - conn = await asyncpg.connect( - host=postgresql.info.host, - port=postgresql.info.port, - user=postgresql.info.user, - database=postgresql.info.dbname - ) - - try: - # Setup schema - await setup_test_schema(conn) - yield conn - finally: - await conn.close() -``` - -## Test Data Management - -### Test Data Isolation -- Each test runs in its own database transaction -- Transactions are automatically rolled back after each test -- No test data persists between test runs - -### Test Data Factories -Consider creating test data factories for complex objects: - -```python -class AgentFactory: - @staticmethod - async def create_agent(conn, **kwargs): - defaults = { - 'version_id': uuid.uuid4(), - 'miner_hotkey': f'test_miner_{uuid.uuid4().hex[:8]}', - 'agent_name': 'test_agent', - 'version_num': 1, - 'status': 'awaiting_screening_1' - } - defaults.update(kwargs) - - await conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, status) - VALUES ($1, $2, $3, $4, $5) - """, defaults['version_id'], defaults['miner_hotkey'], defaults['agent_name'], - defaults['version_num'], defaults['status']) - - return defaults -``` - -## Continuous Integration Setup - -### GitHub Actions (.github/workflows/test.yml) -```yaml -name: Integration Tests - -on: [push, pull_request] - -jobs: - test: - runs-on: ubuntu-latest - - services: - postgres: - image: postgres:15 - env: - POSTGRES_PASSWORD: test_password - POSTGRES_USER: test_user - POSTGRES_DB: postgres - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.11' - - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - - name: Install dependencies - run: uv sync - - - name: Run tests - run: uv run python -m pytest test_endpoints_integration.py -v - env: - POSTGRES_TEST_URL: postgresql://test_user:test_password@localhost:5432/ridges_test -``` - -## Performance Considerations - -### Speed Optimizations -- Use connection pooling in tests -- Run tests in parallel with `pytest-xdist` -- Use in-memory databases for unit tests -- Use real databases only for integration tests - -### Resource Management -- Limit concurrent database connections -- Use smaller test datasets when possible -- Clean up large test data promptly - -## Troubleshooting - -### Common Issues - -1. **Connection Refused** - - Check PostgreSQL is running: `pg_isready` - - Verify connection parameters - - Check firewall settings - -2. **Permission Denied** - - Ensure test user has CREATE DATABASE privileges - - Check database ownership - -3. **Schema Errors** - - Verify all required tables are created - - Check for missing foreign key constraints - - Ensure materialized views are refreshed - -4. **Test Isolation Issues** - - Verify transaction rollback is working - - Check for connection leaks - - Ensure proper cleanup in fixtures - -### Debug Commands -```bash -# Check database connection -psql $POSTGRES_TEST_URL -c "SELECT 1" - -# View test database schema -psql $POSTGRES_TEST_URL -c "\dt" - -# Check test data -psql $POSTGRES_TEST_URL -c "SELECT COUNT(*) FROM miner_agents" - -# Run tests with debug output -uv run python -m pytest test_endpoints_integration.py -v -s --tb=long -``` - -## Best Practices - -1. **Test Independence**: Each test should be able to run independently -2. **Data Cleanup**: Use transactions and fixtures for automatic cleanup -3. **Realistic Data**: Use data that mirrors production scenarios -4. **Performance**: Monitor test execution time and optimize slow tests -5. **Error Handling**: Test both success and failure scenarios -6. **Documentation**: Document complex test setups and data requirements - -This testing setup provides comprehensive coverage of your API endpoints with real database integration, ensuring that your application works correctly with actual SQL operations and constraints. \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 0e16195f..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Tests package for Ridges miner agent system \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index db422f31..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Global pytest configuration and fixtures for all tests. -This replaces the shell scripts with proper pytest setup. -""" -import pytest -import asyncio -import asyncpg -import os -import sys -from unittest.mock import patch, AsyncMock, Mock -from httpx import AsyncClient -import pytest_asyncio - -# Ensure the project root is in PYTHONPATH for module imports -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'api', 'src'))) -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -# Set environment variables for DBManager initialization -# These are used by backend/db_manager.py to create the new_db instance -os.environ.setdefault('AWS_MASTER_USERNAME', 'test_user') -os.environ.setdefault('AWS_MASTER_PASSWORD', 'test_pass') -os.environ.setdefault('AWS_RDS_PLATFORM_ENDPOINT', 'localhost') -os.environ.setdefault('AWS_RDS_PLATFORM_DB_NAME', 'postgres') -os.environ.setdefault('PGPORT', '5432') -os.environ.setdefault('POSTGRES_TEST_URL', 'postgresql://test_user:test_pass@localhost:5432/postgres') - -# Database initialization will be handled by fixtures - -# Import after setting environment variables and path -# Only import these if we're running integration tests -# For unit tests, we'll import them lazily when needed - -# --- Pytest Hooks for Session-wide Setup/Teardown --- - -def pytest_configure(config): - """Register custom markers for tests.""" - config.addinivalue_line("markers", "slow: marks tests as slow") - config.addinivalue_line("markers", "integration: marks tests as integration tests requiring database") - config.addinivalue_line("markers", "unit: marks tests as unit tests (fast, no external dependencies)") - config.addinivalue_line("markers", "endpoints: marks tests that test API endpoints") - config.addinivalue_line("markers", "core: marks tests that test core business logic") - -# --- Fixtures for Test Setup --- - -@pytest_asyncio.fixture(scope="session") -async def postgres_service(): - """Ensure PostgreSQL is ready and yield control.""" - # This fixture is primarily for ensuring the service is up and healthy - # The actual DB setup/teardown for tests is handled by db_setup - print("\nWaiting for PostgreSQL service to be ready...") - # Use a direct connection check to ensure it's truly ready for connections - conn = None - for _ in range(60): # Try for 60 seconds (increased timeout) - try: - conn = await asyncpg.connect( - user=os.getenv('POSTGRES_USER', 'test_user'), - password=os.getenv('POSTGRES_PASSWORD', 'test_pass'), - host=os.getenv('AWS_RDS_PLATFORM_ENDPOINT', 'localhost'), - port=int(os.getenv('PGPORT', '5432')), - database=os.getenv('POSTGRES_DB', 'postgres') - ) - print("PostgreSQL service is ready.") - break - except (asyncpg.exceptions.PostgresError, OSError) as e: - print(f"PostgreSQL not ready yet: {e}. Retrying... ({_+1}/60)") - await asyncio.sleep(2) # Increased sleep time - if conn: - await conn.close() - else: - pytest.fail("PostgreSQL service did not become ready.") - return True - -@pytest_asyncio.fixture(scope="session") -async def db_setup(postgres_service): - """Setup test database for the entire test session and initialize DBManager.""" - # Ensure postgres_service is up before proceeding - # postgres_service is now a boolean indicating readiness - - test_db_url = os.getenv('POSTGRES_TEST_URL') - if not test_db_url: - pytest.skip("POSTGRES_TEST_URL not set for integration tests.") - - # Import lazily to avoid issues in unit tests - from api.src.backend.db_manager import new_db - - # Initialize the global new_db instance with the test database - # This ensures all application code uses the test database - try: - await new_db.open() - print("Database connection pool opened successfully") - except Exception as e: - print(f"Error opening database connection pool: {e}") - pytest.fail(f"Failed to initialize database connection pool: {e}") - - # Setup database schema for integration tests - try: - async with new_db.acquire() as conn: - await setup_database_schema(conn) - print("Database schema setup completed") - except Exception as e: - print(f"Error setting up database schema: {e}") - pytest.fail(f"Failed to setup database schema: {e}") - - yield True # Tests run here - - # Cleanup after all tests in the session - try: - await new_db.close() - print("Database connection pool closed successfully") - except Exception as e: - print(f"Error closing database connection pool: {e}") - -@pytest_asyncio.fixture(scope="function") -async def db_conn(db_setup): - """Provide a database connection for each test function.""" - from api.src.backend.db_manager import new_db - - # Ensure the connection pool is initialized - if not new_db.pool: - await new_db.open() - - async with new_db.acquire() as conn: - yield conn - -async def setup_database_schema(conn: asyncpg.Connection): - """Setup database schema for integration tests""" - # Read the actual production schema file - schema_path = os.path.join(os.path.dirname(__file__), '..', 'api', 'src', 'backend', 'postgres_schema.sql') - with open(schema_path, 'r') as f: - schema_sql = f.read() - - # Execute the production schema - await conn.execute(schema_sql) - - # Disable the approval deletion trigger for tests to allow cleanup - await conn.execute(""" - DROP TRIGGER IF EXISTS no_delete_approval_trigger ON approved_version_ids; - CREATE OR REPLACE FUNCTION prevent_delete_approval_test() RETURNS TRIGGER AS $$ - BEGIN - -- Allow deletions in test environment - RETURN OLD; - END; - $$ LANGUAGE plpgsql; - CREATE TRIGGER no_delete_approval_trigger BEFORE DELETE ON approved_version_ids - FOR EACH ROW EXECUTE FUNCTION prevent_delete_approval_test(); - """) - - # Insert test evaluation sets for testing - await conn.execute(""" - INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES - (1, 'screener-1', 'test_instance_1'), - (1, 'screener-2', 'test_instance_2'), - (1, 'validator', 'test_instance_3') - ON CONFLICT DO NOTHING - """) - - - -@pytest_asyncio.fixture -async def async_client(): - """Provide an asynchronous test client for FastAPI endpoints.""" - # For unit tests, we don't need a real server, just the app - from api.src.main import app - async with AsyncClient(app=app, base_url="http://testserver") as client: - yield client - -# --- Mock fixtures for unit tests --- - -@pytest.fixture -def mock_db_manager(): - """Mock the DBManager for unit tests that don't need real database.""" - with patch('api.src.backend.db_manager.new_db') as mock_db: - # Mock the acquire method to return a mock connection - mock_conn = AsyncMock() - mock_conn_context = AsyncMock() - mock_conn_context.__aenter__.return_value = mock_conn - mock_conn_context.__aexit__.return_value = None - mock_db.acquire.return_value = mock_conn_context - - # Mock the pool attribute - mock_db.pool = Mock() - - yield mock_db - -@pytest.fixture -def mock_db_connection(): - """Mock database connection for unit tests.""" - mock_conn = AsyncMock() - mock_conn.fetchrow = AsyncMock() - mock_conn.fetch = AsyncMock() - mock_conn.execute = AsyncMock() - mock_conn.transaction = AsyncMock() - - # Mock transaction context - mock_transaction = AsyncMock() - mock_transaction.__aenter__ = AsyncMock(return_value=mock_conn) - mock_transaction.__aexit__ = AsyncMock(return_value=None) - mock_conn.transaction.return_value = mock_transaction - - return mock_conn \ No newline at end of file diff --git a/tests/docker-compose.test.yml b/tests/docker-compose.test.yml deleted file mode 100644 index 0918362a..00000000 --- a/tests/docker-compose.test.yml +++ /dev/null @@ -1,48 +0,0 @@ -services: - postgres-test: - image: postgres:15 - environment: - POSTGRES_DB: postgres - POSTGRES_USER: test_user - POSTGRES_PASSWORD: test_pass - POSTGRES_INITDB_ARGS: "--auth-host=trust" - ports: - - "5432:5432" # Same port as GitHub Actions - volumes: - - postgres_test_data:/var/lib/postgresql/data - - ./init-test-db.sql:/docker-entrypoint-initdb.d/init-test-db.sql - tmpfs: - - /tmp - - /var/run/postgresql - healthcheck: - test: ["CMD-SHELL", "pg_isready -U test_user -d postgres"] - interval: 10s - timeout: 5s - retries: 5 - restart: unless-stopped - - postgres-internal-test: - image: postgres:15 - environment: - POSTGRES_DB: internal_tools - POSTGRES_USER: internal_user - POSTGRES_PASSWORD: internal_pass - POSTGRES_INITDB_ARGS: "--auth-host=trust" - ports: - - "5433:5432" # Different port to avoid conflicts - volumes: - - postgres_internal_test_data:/var/lib/postgresql/data - - ./init-test-db.sql:/docker-entrypoint-initdb.d/init-test-db.sql - tmpfs: - - /tmp - - /var/run/postgresql - healthcheck: - test: ["CMD-SHELL", "pg_isready -U internal_user -d internal_tools"] - interval: 10s - timeout: 5s - retries: 5 - restart: unless-stopped - -volumes: - postgres_test_data: - postgres_internal_test_data: \ No newline at end of file diff --git a/tests/init-test-db.sql b/tests/init-test-db.sql deleted file mode 100644 index c180eb9c..00000000 --- a/tests/init-test-db.sql +++ /dev/null @@ -1,23 +0,0 @@ --- Initialize test database with extensions and basic configuration --- This script runs automatically when the Docker container starts - --- Create the test database -CREATE DATABASE ridges_test; - --- Connect to the test database -\c ridges_test; - --- Create extensions that might be needed -CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; -CREATE EXTENSION IF NOT EXISTS "pg_trgm"; - --- Grant privileges to test user -GRANT ALL PRIVILEGES ON DATABASE ridges_test TO test_user; -GRANT ALL ON SCHEMA public TO test_user; - --- Set up basic configuration -ALTER DATABASE ridges_test SET timezone TO 'UTC'; - --- Note: The production schema from postgres_schema.sql will be applied --- by the test setup code to ensure we're testing against the exact --- production database structure. \ No newline at end of file diff --git a/tests/pytest.ini b/tests/pytest.ini deleted file mode 100644 index 225056b1..00000000 --- a/tests/pytest.ini +++ /dev/null @@ -1,22 +0,0 @@ -[tool:pytest] -testpaths = . -python_files = test_*.py -python_classes = Test* -python_functions = test_* -asyncio_mode = auto -addopts = - -v - --tb=short - --strict-markers - --cov=api/src - --cov-report=xml - --cov-report=term-missing -filterwarnings = - ignore::DeprecationWarning - ignore::PendingDeprecationWarning -markers = - slow: marks tests as slow - integration: marks tests as integration tests requiring database - unit: marks tests as unit tests (fast, no external dependencies) - endpoints: marks tests that test API endpoints - core: marks tests that test core business logic \ No newline at end of file diff --git a/tests/test.sh b/tests/test.sh deleted file mode 100755 index dbc54de1..00000000 --- a/tests/test.sh +++ /dev/null @@ -1,268 +0,0 @@ -#!/bin/bash - -# Test runner script for Ridges project -# This script replicates the GitHub Actions workflow locally - -set -e # Exit on any error - -# Colors for output -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -NC='\033[0m' # No Color - -# Function to print colored output -print_status() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -print_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -print_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -print_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -# Function to cleanup on exit -cleanup() { - print_status "Cleaning up..." - - # Stop the API server if it's running -if [ ! -z "$SERVER_PID" ]; then - print_status "Stopping API server (PID: $SERVER_PID)..." - # Send SIGTERM for graceful shutdown - kill -TERM $SERVER_PID 2>/dev/null || true - # Wait up to 5 seconds for graceful shutdown - for i in {1..5}; do - if ! kill -0 $SERVER_PID 2>/dev/null; then - break - fi - sleep 1 - done - # Force kill if still running - kill -KILL $SERVER_PID 2>/dev/null || true - # Wait for final termination - wait $SERVER_PID 2>/dev/null || true -fi - - # Stop Docker containers - if [ "$STOP_CONTAINERS" = "true" ]; then - print_status "Stopping Docker containers..." - docker-compose -f tests/docker-compose.test.yml down 2>/dev/null || true - fi - - print_success "Cleanup completed" -} - -# Set up trap to cleanup on script exit -trap cleanup EXIT - -# Configuration -PROJECT_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -TESTS_DIR="$PROJECT_ROOT/tests" -API_DIR="$PROJECT_ROOT/api" - -# Environment variables (same as GitHub Actions) -export POSTGRES_TEST_URL="postgresql://test_user:test_pass@localhost:5432/postgres" -export AWS_MASTER_USERNAME="test_user" -export AWS_MASTER_PASSWORD="test_pass" -export AWS_RDS_PLATFORM_ENDPOINT="localhost" -export AWS_RDS_PLATFORM_DB_NAME="postgres" -export PGPORT="5432" -export API_BASE_URL="http://localhost:8000" -export DB_USER_INT="internal_user" -export DB_PASS_INT="internal_pass" -export DB_HOST_INT="localhost" -export DB_PORT_INT="5433" -export DB_NAME_INT="internal_tools" - -print_status "Starting test environment setup..." -print_status "Project root: $PROJECT_ROOT" - -# Check if uv is installed -if ! command -v uv &> /dev/null; then - print_error "uv is not installed. Please install it first:" - print_error "curl -LsSf https://astral.sh/uv/install.sh | sh" - exit 1 -fi - -# Check if Docker is installed and running -if ! command -v docker &> /dev/null; then - print_error "Docker is not installed. Please install Docker first." - exit 1 -fi - -if ! docker info &> /dev/null; then - print_error "Docker is not running. Please start Docker first." - exit 1 -fi - -# Check if docker-compose is available -if ! command -v docker-compose &> /dev/null; then - print_error "docker-compose is not installed. Please install it first." - exit 1 -fi - -print_success "Prerequisites check passed" - -# Install dependencies -print_status "Installing Python dependencies..." -cd "$PROJECT_ROOT" -uv sync -uv add ruff mypy requests websocket-client pytest-asyncio asyncpg httpx pytest-postgresql - -# Start PostgreSQL databases -print_status "Starting PostgreSQL databases..." -cd "$TESTS_DIR" -STOP_CONTAINERS="true" -docker-compose -f docker-compose.test.yml up -d postgres-test postgres-internal-test - -# Wait for main PostgreSQL to be ready -print_status "Waiting for main PostgreSQL to be ready..." -for i in {1..30}; do - if docker-compose -f docker-compose.test.yml exec -T postgres-test pg_isready -U test_user -d postgres &>/dev/null; then - print_success "Main PostgreSQL is ready" - break - fi - if [ $i -eq 30 ]; then - print_error "Main PostgreSQL failed to start within 30 seconds" - exit 1 - fi - print_status "Waiting for main PostgreSQL... ($i/30)" - sleep 1 -done - -# Wait for internal tools PostgreSQL to be ready -print_status "Waiting for internal tools PostgreSQL to be ready..." -for i in {1..30}; do - if docker-compose -f docker-compose.test.yml exec -T postgres-internal-test pg_isready -U internal_user -d internal_tools &>/dev/null; then - print_success "Internal tools PostgreSQL is ready" - break - fi - if [ $i -eq 30 ]; then - print_error "Internal tools PostgreSQL failed to start within 30 seconds" - exit 1 - fi - print_status "Waiting for internal tools PostgreSQL... ($i/30)" - sleep 1 -done - -# Initialize database schemas -print_status "Initializing main database schema..." -docker-compose -f docker-compose.test.yml exec -T postgres-test psql -U test_user -d postgres -f /docker-entrypoint-initdb.d/init-test-db.sql - -print_status "Initializing internal tools database schema..." -docker-compose -f docker-compose.test.yml exec -T postgres-internal-test psql -U internal_user -d internal_tools -f /docker-entrypoint-initdb.d/init-test-db.sql - -# Set up environment for API server -print_status "Setting up environment for API server..." -cd "$PROJECT_ROOT" - -# Initialize the database connection pool before starting the server -print_status "Initializing database connection pool..." -uv run python -c " -import asyncio -import sys -import os -sys.path.insert(0, os.path.join('api', 'src')) -from backend.db_manager import new_db -async def init_db(): - await new_db.open() - print('Database connection pool initialized successfully') -asyncio.run(init_db()) -" - -# Set environment variables for tests -print_status "Setting up test environment variables..." -export POSTGRES_TEST_URL="postgresql://test_user:test_pass@localhost:5432/postgres" -export AWS_MASTER_USERNAME="test_user" -export AWS_MASTER_PASSWORD="test_pass" -export AWS_RDS_PLATFORM_ENDPOINT="localhost" -export AWS_RDS_PLATFORM_DB_NAME="postgres" -export PGPORT="5432" -export API_BASE_URL="http://localhost:8000" -export DB_USER_INT="internal_user" -export DB_PASS_INT="internal_pass" -export DB_HOST_INT="localhost" -export DB_PORT_INT="5433" -export DB_NAME_INT="internal_tools" - -# Start the API server -print_status "Starting API server..." -uv run python -m api.src.main --host 0.0.0.0 & -SERVER_PID=$! -print_status "API server started with PID: $SERVER_PID" - -# Wait for server to start -print_status "Waiting for API server to be ready..." -for i in {1..10}; do - if curl -f http://localhost:8000/healthcheck &>/dev/null; then - print_success "API server is ready" - break - fi - if [ $i -eq 10 ]; then - print_error "API server failed to start within 10 seconds" - exit 1 - fi - print_status "Waiting for API server... ($i/10)" - sleep 1 -done - -# Test basic API endpoints -print_status "Testing basic API endpoints..." -curl -f http://localhost:8000/healthcheck > /dev/null && print_success "Healthcheck endpoint working" -curl -f http://localhost:8000/healthcheck-results > /dev/null && print_success "Healthcheck-results endpoint working" - -# Run tests -print_status "Running tests..." -cd "$TESTS_DIR" - -# Run tests in stages to avoid database connection issues - -# First, run unit tests that don't require database -print_status "Running unit tests that don't require database..." -cd "$TESTS_DIR" -uv run python -m pytest test_endpoints_unit.py::TestSystemStatusEndpointsUnit -v -W ignore::PendingDeprecationWarning - -# Run simple tests that don't require database -print_status "Running simple tests..." -uv run python -m pytest test_endpoints_simple.py::TestEndpointResponseStructure::test_healthcheck_response_structure -v -W ignore::PendingDeprecationWarning - -# Run weights function tests -print_status "Running weights function tests..." -uv run python -m pytest test_weights_setting.py -v -W ignore::PendingDeprecationWarning - -# Run miner agent flow tests -print_status "Running miner agent flow tests..." -uv run python -m pytest test_miner_agent_flow.py -v -W ignore::PendingDeprecationWarning - -# Run real API tests (these require the API server) -print_status "Running real API integration tests..." -uv run python -m pytest test_real_api.py -v -W ignore::PendingDeprecationWarning - -# For now, skip the problematic integration tests and run a subset that works -print_status "Running upload tracking tests..." -uv run python -m pytest test_upload_tracking.py -v --tb=short --disable-warnings - -print_status "Running comprehensive tests..." -uv run python -m pytest \ - test_endpoints_unit.py::TestSystemStatusEndpointsUnit \ - test_endpoints_simple.py::TestEndpointResponseStructure::test_healthcheck_response_structure \ - test_weights_setting.py \ - test_miner_agent_flow.py \ - test_real_api.py \ - test_upload_tracking.py \ - -v \ - --tb=short \ - --disable-warnings \ - -W ignore::PendingDeprecationWarning - -print_success "All tests completed successfully!" - -print_success "Test environment setup and execution completed successfully!" \ No newline at end of file diff --git a/tests/test_endpoints_integration.py b/tests/test_endpoints_integration.py deleted file mode 100644 index 346cd1b5..00000000 --- a/tests/test_endpoints_integration.py +++ /dev/null @@ -1,463 +0,0 @@ -""" -Integration tests for API endpoints with real database testing. -Tests the complete flow from HTTP requests through to database operations. -""" - -import pytest -import asyncpg -import os -import uuid -from typing import Optional -from unittest.mock import patch - -from httpx import AsyncClient -import pytest_asyncio - -# Only set environment variables if they're not already set (don't override GitHub Actions env vars) -if not os.getenv('AWS_MASTER_USERNAME'): - os.environ.update({ - 'AWS_MASTER_USERNAME': 'test_user', - 'AWS_MASTER_PASSWORD': 'test_pass', - 'AWS_RDS_PLATFORM_ENDPOINT': 'localhost', - 'AWS_RDS_PLATFORM_DB_NAME': 'postgres', - 'POSTGRES_TEST_URL': 'postgresql://test_user:test_pass@localhost:5432/postgres' - }) - -# Import after setting environment variables -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -from api.src.main import app - - -class DatabaseTestSetup: - """Helper class for database test setup and teardown""" - - def __init__(self, test_db_url: str): - self.test_db_url = test_db_url - self.pool: Optional[asyncpg.Pool] = None - - async def setup_test_database(self): - """Setup test database schema and data""" - # Connect to the existing database (don't drop/recreate) - self.pool = await asyncpg.create_pool(self.test_db_url) - async with self.pool.acquire() as conn: - await self._create_schema(conn) - - async def _create_schema(self, conn: asyncpg.Connection): - """Create production schema from postgres_schema.sql""" - - # Read the actual production schema file - schema_path = os.path.join(os.path.dirname(__file__), '..', 'api', 'src', 'backend', 'postgres_schema.sql') - with open(schema_path, 'r') as f: - schema_sql = f.read() - - # Execute the production schema - await conn.execute(schema_sql) - - # Disable the approval deletion trigger for tests to allow cleanup - await conn.execute(""" - DROP TRIGGER IF EXISTS no_delete_approval_trigger ON approved_version_ids; - CREATE OR REPLACE FUNCTION prevent_delete_approval_test() RETURNS TRIGGER AS $$ - BEGIN - -- Allow deletions in test environment - RETURN OLD; - END; - $$ LANGUAGE plpgsql; - CREATE TRIGGER no_delete_approval_trigger BEFORE DELETE ON approved_version_ids - FOR EACH ROW EXECUTE FUNCTION prevent_delete_approval_test(); - """) - - # Insert test evaluation sets for testing - await conn.execute(""" - INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES - (1, 'screener-1', 'test_instance_1'), - (1, 'screener-2', 'test_instance_2'), - (1, 'validator', 'test_instance_3') - ON CONFLICT DO NOTHING - """) - - async def cleanup_test_database(self): - """Clean up test database""" - if self.pool: - await self.pool.close() - - def get_connection(self): - """Get database connection for tests""" - if not self.pool: - raise RuntimeError("Test database not initialized") - return self.pool.acquire() - - - - - -@pytest_asyncio.fixture -async def async_client(): - """Async HTTP client for testing FastAPI endpoints""" - from httpx import ASGITransport - async with AsyncClient(transport=ASGITransport(app=app), base_url="http://testserver") as client: - yield client - - -class TestUploadEndpoints: - """Test agent upload endpoints with database integration""" - - @pytest.mark.asyncio - async def test_upload_agent_success(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test successful agent upload flow""" - - # Mock external dependencies - with patch('api.src.utils.s3.S3Manager.upload_file_object', return_value="s3://test-bucket/test-agent.py"), \ - patch('api.src.utils.agent_summary_generator.generate_agent_summary'), \ - patch('api.src.utils.upload_agent_helpers.check_code_similarity', return_value=0.3): - - # Create test agent data - agent_data = { - "miner_hotkey": "test_miner_123", - "agent_name": "test_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = await async_client.post("/upload/agent", json=agent_data) - - assert response.status_code == 200 - result = response.json() - assert "version_id" in result - - # Verify database state - agent = await db_conn.fetchrow( - "SELECT * FROM miner_agents WHERE version_id = $1", - uuid.UUID(result["version_id"]) - ) - assert agent is not None - assert agent["miner_hotkey"] == "test_miner_123" - assert agent["agent_name"] == "test_agent" - assert agent["status"] == "awaiting_screening_1" - - # Verify evaluation was created - evaluation = await db_conn.fetchrow( - "SELECT * FROM evaluations WHERE version_id = $1", - uuid.UUID(result["version_id"]) - ) - assert evaluation is not None - assert evaluation["status"] == "waiting" - - @pytest.mark.asyncio - async def test_upload_agent_banned_hotkey(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test upload rejection for banned hotkey""" - - # Insert banned hotkey - await db_conn.execute( - "INSERT INTO banned_hotkeys (miner_hotkey) VALUES ($1)", - "banned_miner" - ) - - agent_data = { - "miner_hotkey": "banned_miner", - "agent_name": "test_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = await async_client.post("/upload/agent", json=agent_data) - - assert response.status_code == 403 - assert "banned" in response.json()["detail"].lower() - - @pytest.mark.asyncio - async def test_upload_agent_rate_limit(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test rate limiting on agent uploads""" - - # Insert recent upload - recent_agent_id = uuid.uuid4() - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at) - VALUES ($1, $2, $3, 1, NOW() - INTERVAL '1 hour') - """, recent_agent_id, "rate_limited_miner", "previous_agent") - - agent_data = { - "miner_hotkey": "rate_limited_miner", - "agent_name": "new_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = await async_client.post("/upload/agent", json=agent_data) - - assert response.status_code == 429 - assert "rate limit" in response.json()["detail"].lower() - - -class TestScoringEndpoints: - """Test scoring endpoints with database integration""" - - @pytest.mark.asyncio - async def test_check_top_agent(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test top agent retrieval""" - - # Insert test agents with scores - agent1_id = uuid.uuid4() - agent2_id = uuid.uuid4() - - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, 'miner1', 'agent1', 1, NOW(), 'scored'), - ($2, 'miner2', 'agent2', 1, NOW(), 'scored') - """, agent1_id, agent2_id) - - # Insert evaluations with scores - eval1_id = uuid.uuid4() - eval2_id = uuid.uuid4() - - await db_conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, score, created_at) - VALUES ($1, $2, 'validator1', 1, 'completed', 0.85, NOW()), - ($3, $4, 'validator1', 1, 'completed', 0.92, NOW()) - """, eval1_id, agent1_id, eval2_id, agent2_id) - - # Approve the higher scoring agent - await db_conn.execute( - "INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1)", - agent2_id - ) - - # Refresh materialized view - await db_conn.execute("REFRESH MATERIALIZED VIEW agent_scores") - - response = await async_client.get("/scoring/check-top-agent") - - assert response.status_code == 200 - result = response.json() - assert result["miner_hotkey"] == "miner2" - assert result["avg_score"] == 0.92 - - @pytest.mark.asyncio - async def test_ban_agents(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test agent banning functionality""" - - # Insert test agent - agent_id = uuid.uuid4() - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, 'target_miner', 'target_agent', 1, NOW(), 'scored') - """, agent_id) - - ban_data = { - "admin_password": "admin_password_123", # Mock password - "miner_hotkeys": ["target_miner"] - } - - with patch('api.src.endpoints.scoring.ADMIN_PASSWORD', 'admin_password_123'): - response = await async_client.post("/scoring/ban-agents", json=ban_data) - - assert response.status_code == 200 - - # Verify ban was applied - banned = await db_conn.fetchrow( - "SELECT * FROM banned_hotkeys WHERE miner_hotkey = $1", - "target_miner" - ) - assert banned is not None - - @pytest.mark.asyncio - async def test_approve_version(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test agent version approval""" - - # Insert test agent - agent_id = uuid.uuid4() - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, 'approval_miner', 'approval_agent', 1, NOW(), 'scored') - """, agent_id) - - approval_data = { - "admin_password": "admin_password_123", - "version_ids": [str(agent_id)] - } - - with patch('api.src.endpoints.scoring.ADMIN_PASSWORD', 'admin_password_123'): - response = await async_client.post("/scoring/approve-version", json=approval_data) - - assert response.status_code == 200 - - # Verify approval was applied - approved = await db_conn.fetchrow( - "SELECT * FROM approved_version_ids WHERE version_id = $1", - agent_id - ) - assert approved is not None - - -class TestRetrievalEndpoints: - """Test data retrieval endpoints with database integration""" - - @pytest.mark.asyncio - async def test_network_stats(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test network statistics retrieval""" - - # Insert test data for statistics - agent1_id = uuid.uuid4() - agent2_id = uuid.uuid4() - - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, 'stats_miner1', 'stats_agent1', 1, NOW() - INTERVAL '12 hours', 'scored'), - ($2, 'stats_miner2', 'stats_agent2', 1, NOW() - INTERVAL '6 hours', 'scored') - """, agent1_id, agent2_id) - - response = await async_client.get("/retrieval/network-stats") - - assert response.status_code == 200 - result = response.json() - assert "number_of_agents" in result - assert "agent_iterations_last_24_hours" in result - assert result["agent_iterations_last_24_hours"] >= 2 - - @pytest.mark.asyncio - async def test_top_agents(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test top agents retrieval""" - - # Insert test agents with varying scores - agents_data = [] - for i in range(5): - agent_id = uuid.uuid4() - agents_data.append((agent_id, f'top_miner_{i}', f'top_agent_{i}', i + 1)) - - await db_conn.executemany(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, $2, $3, $4, NOW(), 'scored') - """, agents_data) - - # Insert evaluations with different scores - for i, (agent_id, _, _, _) in enumerate(agents_data): - eval_id = uuid.uuid4() - score = 0.5 + (i * 0.1) # Increasing scores - await db_conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, score, created_at) - VALUES ($1, $2, 'validator1', 1, 'completed', $3, NOW()) - """, eval_id, agent_id, score) - - # Refresh materialized view - await db_conn.execute("REFRESH MATERIALIZED VIEW agent_scores") - - response = await async_client.get("/retrieval/top-agents?num_agents=3") - - assert response.status_code == 200 - result = response.json() - assert len(result) <= 3 - - # Verify ordering (highest score first) - if len(result) > 1: - assert result[0]["score"] >= result[1]["score"] - - @pytest.mark.asyncio - async def test_agent_by_hotkey(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test agent retrieval by hotkey""" - - # Insert test agents for same hotkey - agent1_id = uuid.uuid4() - agent2_id = uuid.uuid4() - - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, status, created_at) - VALUES ($1, 'hotkey_test', 'agent_v1', 1, 'replaced', NOW() - INTERVAL '2 days'), - ($2, 'hotkey_test', 'agent_v2', 2, 'scored', NOW() - INTERVAL '1 day') - """, agent1_id, agent2_id) - - response = await async_client.get("/retrieval/agent-by-hotkey?miner_hotkey=hotkey_test") - - assert response.status_code == 200 - result = response.json() - assert len(result) == 2 - assert result[0]["version_num"] == 2 # Latest version first - assert result[1]["version_num"] == 1 - - -class TestAuthenticationEndpoints: - """Test authentication endpoints with database integration""" - - @pytest.mark.asyncio - async def test_open_user_signin(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test open user sign in and registration""" - - signin_data = { - "auth0_user_id": "auth0|test123", - "email": "test@example.com", - "name": "Test User", - "password": "secure_password_123" - } - - response = await async_client.post("/open-users/sign-in", json=signin_data) - - assert response.status_code == 200 - result = response.json() - assert "open_hotkey" in result - - # Verify user was created in database - user = await db_conn.fetchrow( - "SELECT * FROM open_users WHERE email = $1", - "test@example.com" - ) - assert user is not None - assert user["name"] == "Test User" - assert user["auth0_user_id"] == "auth0|test123" - - -class TestAgentSummaryEndpoints: - """Test agent summary endpoints with database integration""" - - @pytest.mark.asyncio - async def test_get_agent_summary(self, async_client: AsyncClient, db_conn: asyncpg.Connection): - """Test agent summary retrieval""" - - # Insert agent with summary - agent_id = uuid.uuid4() - await db_conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status, agent_summary) - VALUES ($1, 'summary_miner', 'summary_agent', 1, NOW(), 'scored', 'This agent solves coding problems efficiently') - """, agent_id) - - response = await async_client.get(f"/agent-summaries/agent-summary/{agent_id}") - - assert response.status_code == 200 - result = response.json() - assert result["agent_summary"] == "This agent solves coding problems efficiently" - assert result["version_id"] == str(agent_id) - - -class TestSystemStatusEndpoints: - """Test system status endpoints with database integration""" - - @pytest.mark.asyncio - async def test_health_check(self, async_client: AsyncClient, db_setup): - """Test comprehensive health check""" - - # Test basic health check - response = await async_client.get("/healthcheck") - assert response.status_code == 200 - assert response.text == '"OK"' - - # Test health check results endpoint - response = await async_client.get("/healthcheck-results") - assert response.status_code == 200 - result = response.json() - assert "database_status" in result - assert "api_status" in result - - @pytest.mark.asyncio - async def test_status_endpoint(self, async_client: AsyncClient, db_setup): - """Test detailed system status""" - - # Test healthcheck-results endpoint as the status endpoint - response = await async_client.get("/healthcheck-results") - - assert response.status_code == 200 - result = response.json() - assert "database_status" in result - assert "api_status" in result - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_endpoints_simple.py b/tests/test_endpoints_simple.py deleted file mode 100644 index b5338da6..00000000 --- a/tests/test_endpoints_simple.py +++ /dev/null @@ -1,309 +0,0 @@ -""" -Simple endpoint tests that work with the actual Ridges API structure. -These tests verify the endpoints exist and return expected responses. -""" - -import pytest -import uuid -from unittest.mock import patch, AsyncMock, Mock -from fastapi.testclient import TestClient - -# Only set environment variables if they're not already set (don't override GitHub Actions env vars) -import os -if not os.getenv('AWS_MASTER_USERNAME'): - os.environ.update({ - 'AWS_MASTER_USERNAME': 'test_user', - 'AWS_MASTER_PASSWORD': 'test_pass', - 'AWS_RDS_PLATFORM_ENDPOINT': 'localhost', - 'AWS_RDS_PLATFORM_DB_NAME': 'postgres' - }) - -# Import after setting environment variables -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - - -def create_test_app(): - """Create a test app with mocked database manager.""" - with patch('api.src.backend.db_manager.new_db') as mock_db: - # Mock the acquire method to return a mock connection - mock_conn = AsyncMock() - mock_conn_context = AsyncMock() - mock_conn_context.__aenter__.return_value = mock_conn - mock_conn_context.__aexit__.return_value = None - mock_db.acquire.return_value = mock_conn_context - - # Mock the pool attribute - mock_db.pool = Mock() - - from api.src.main import app - return app - - -class TestHealthcheckEndpoints: - """Test basic healthcheck endpoints""" - - def test_healthcheck_endpoint(self): - """Test basic healthcheck""" - app = create_test_app() - client = TestClient(app) - - response = client.get("/healthcheck") - assert response.status_code == 200 - - -class TestUploadEndpoints: - """Test upload endpoint structure and validation""" - - @patch('api.src.backend.db_manager.get_transaction') - def test_upload_agent_endpoint_exists(self, mock_get_transaction): - """Test that upload agent endpoint exists and validates input""" - app = create_test_app() - client = TestClient(app) - - # Mock database transaction to avoid database calls - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - mock_conn.fetchval.return_value = None # No banned hotkey - - # Test with missing fields - should return 422 for validation error - response = client.post("/upload/agent", json={}) - assert response.status_code == 422 - - # Test with some fields - still should fail validation - incomplete_data = {"miner_hotkey": "test"} - response = client.post("/upload/agent", json=incomplete_data) - assert response.status_code == 422 - - @patch('api.src.backend.db_manager.get_transaction') - def test_upload_open_agent_endpoint_exists(self, mock_get_transaction): - """Test that upload open-agent endpoint exists""" - app = create_test_app() - client = TestClient(app) - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - response = client.post("/upload/open-agent", json={}) - assert response.status_code == 422 # Should fail validation - - -class TestRetrievalEndpoints: - """Test retrieval endpoints""" - - @patch('api.src.backend.queries.statistics.get_24_hour_statistics') - def test_network_stats_endpoint(self, mock_get_stats): - """Test network stats endpoint""" - app = create_test_app() - client = TestClient(app) - - # Mock the statistics function directly - mock_get_stats.return_value = { - "number_of_agents": 100, - "agent_iterations_last_24_hours": 20, - "top_agent_score": 0.85, - "daily_score_improvement": 0.05 - } - - response = client.get("/retrieval/network-stats") - # This endpoint requires database connection, so it might fail in simple tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [200, 500] # Either success or database error - - @patch('api.src.backend.queries.statistics.get_top_agents') - def test_top_agents_endpoint(self, mock_get_top_agents): - """Test top agents endpoint""" - app = create_test_app() - client = TestClient(app) - - # Mock top agents data - mock_agent = Mock() - mock_agent.version_id = uuid.uuid4() - mock_agent.miner_hotkey = "test_miner" - mock_agent.agent_name = "test_agent" - mock_agent.score = 0.85 - mock_agent.approved = True - - mock_get_top_agents.return_value = [mock_agent] - - response = client.get("/retrieval/top-agents") - # This endpoint requires database connection, so it might fail in simple tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [200, 500] # Either success or database error - - @patch('api.src.socket.websocket_manager.WebSocketManager.get_instance') - def test_connected_validators_endpoint(self, mock_ws_manager): - """Test connected validators endpoint""" - app = create_test_app() - client = TestClient(app) - - # Mock WebSocket manager - mock_manager = Mock() - mock_validator = Mock() - mock_validator.get_type.return_value = "validator" - mock_validator.hotkey = "validator1" - mock_validator.status = "available" - - mock_manager.clients = {"1": mock_validator} - mock_ws_manager.return_value = mock_manager - - response = client.get("/retrieval/connected-validators") - # This endpoint requires database connection, so it might fail in simple tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [200, 500] # Either success or database error - - -class TestScoringEndpoints: - """Test scoring endpoints""" - - @patch('api.src.backend.entities.MinerAgentScored.get_top_agent') - def test_check_top_agent_endpoint(self, mock_get_top_agent): - """Test check top agent endpoint""" - app = create_test_app() - client = TestClient(app) - - # Mock top agent - mock_top_agent = Mock() - mock_top_agent.miner_hotkey = "top_miner" - mock_top_agent.version_id = uuid.uuid4() - mock_top_agent.avg_score = 0.92 - - mock_get_top_agent.return_value = mock_top_agent - - response = client.get("/scoring/check-top-agent") - # This endpoint requires database connection, so it might fail in simple tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [200, 500] # Either success or database error - - def test_ban_agents_endpoint_validation(self): - """Test ban agents endpoint input validation""" - app = create_test_app() - client = TestClient(app) - - # Test without password - should fail - response = client.post("/scoring/ban-agents", json={"miner_hotkeys": ["test"]}) - assert response.status_code == 422 - - def test_approve_version_endpoint_validation(self): - """Test approve version endpoint input validation""" - app = create_test_app() - client = TestClient(app) - - # Test without password - should fail - response = client.post("/scoring/approve-version", json={"version_ids": ["test"]}) - assert response.status_code == 422 - - -class TestOpenUsersEndpoints: - """Test open users endpoints""" - - def test_signin_endpoint_validation(self): - """Test sign in endpoint validation""" - app = create_test_app() - client = TestClient(app) - - # Test with missing fields - response = client.post("/open-users/sign-in", json={}) - assert response.status_code == 422 - - -class TestAgentSummariesEndpoints: - """Test agent summaries endpoints""" - - @patch('api.src.backend.queries.statistics.get_agent_summary_by_hotkey') - def test_agent_summary_endpoint_not_found(self, mock_get_agent_summary): - """Test agent summary endpoint with non-existent ID""" - app = create_test_app() - client = TestClient(app) - - # Mock agent summary to return None (not found) - mock_get_agent_summary.return_value = None - - fake_uuid = uuid.uuid4() - response = client.get(f"/agent-summaries/agent-summary/{fake_uuid}") - # This endpoint requires database connection, so it might fail in simple tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [404, 500] # Either not found or database error - - -class TestEndpointSecurity: - """Test endpoint security and validation""" - - def test_endpoints_reject_invalid_json(self): - """Test that endpoints properly reject invalid JSON""" - app = create_test_app() - client = TestClient(app) - - # Test various endpoints with invalid JSON - endpoints = [ - "/upload/agent", - "/scoring/ban-agents", - "/scoring/approve-version", - "/open-users/sign-in" - ] - - for endpoint in endpoints: - response = client.post( - endpoint, - content="invalid json", - headers={"Content-Type": "application/json"} - ) - assert response.status_code == 422 - - def test_get_endpoints_dont_accept_post(self): - """Test that GET endpoints reject POST requests appropriately""" - app = create_test_app() - client = TestClient(app) - - # Test GET endpoints with POST - get_endpoints = [ - "/healthcheck", - "/retrieval/network-stats", - "/retrieval/top-agents", - "/scoring/check-top-agent" - ] - - for endpoint in get_endpoints: - response = client.post(endpoint, json={}) - # Should be 405 Method Not Allowed or 422 if it has different validation - assert response.status_code in [405, 422] - - -class TestEndpointResponseStructure: - """Test endpoint response structures""" - - def test_healthcheck_response_structure(self): - """Test healthcheck response has expected structure""" - app = create_test_app() - client = TestClient(app) - - response = client.get("/healthcheck") - assert response.status_code == 200 - # Healthcheck returns a simple string "OK", not JSON - assert response.text == '"OK"' - - -class TestWebSocketEndpoint: - """Test WebSocket endpoint exists""" - - def test_websocket_endpoint_exists(self): - """Test that WebSocket endpoint is defined""" - app = create_test_app() - - # Check that the WebSocket route exists - websocket_routes = [route for route in app.routes if hasattr(route, 'path') and route.path == "/ws"] - assert len(websocket_routes) == 1 - - # Verify it's a WebSocket route - ws_route = websocket_routes[0] - assert hasattr(ws_route, 'endpoint') - - -if __name__ == "__main__": - # Run with simple output - pytest.main([__file__, "-v", "--tb=short"]) \ No newline at end of file diff --git a/tests/test_endpoints_unit.py b/tests/test_endpoints_unit.py deleted file mode 100644 index 5e403189..00000000 --- a/tests/test_endpoints_unit.py +++ /dev/null @@ -1,502 +0,0 @@ -""" -Unit tests for API endpoints with mocked database operations. -These tests focus on business logic without requiring a real database. -""" - -import pytest -import uuid -import os -from datetime import datetime, timezone -from unittest.mock import Mock, AsyncMock, patch -from fastapi.testclient import TestClient - -# Import after setting environment variables -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -from api.src.main import app - -def create_test_app(): - """Create a test app with mocked database manager.""" - with patch('api.src.backend.db_manager.new_db') as mock_db: - # Mock the acquire method to return a mock connection - mock_conn = AsyncMock() - mock_conn_context = AsyncMock() - mock_conn_context.__aenter__.return_value = mock_conn - mock_conn_context.__aexit__.return_value = None - mock_db.acquire.return_value = mock_conn_context - - # Mock the pool attribute - mock_db.pool = Mock() - - return app - -# Create test client with mocked database -client = TestClient(create_test_app()) - - -class TestUploadEndpointsUnit: - """Unit tests for upload endpoints with mocked dependencies""" - - @patch('api.src.utils.s3.S3Manager.upload_file_object') - @patch('api.src.utils.agent_summary_generator.generate_agent_summary') - @patch('api.src.utils.upload_agent_helpers.check_code_similarity') - @patch('api.src.backend.db_manager.get_transaction') - def test_upload_agent_success_mocked( - self, mock_get_transaction, mock_validate_code, mock_check_similarity, - mock_generate_summary, mock_s3_upload, mock_get_registration - ): - """Test successful agent upload with all dependencies mocked""" - - # Setup mocks - mock_get_registration.return_value = 100 - mock_s3_upload.return_value = "s3://test-bucket/agent.py" - mock_check_similarity.return_value = 0.3 - mock_validate_code.return_value = True - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - # Mock database queries - mock_conn.fetchval.side_effect = [ - None, # No banned hotkey - None, # No recent upload - 1, # Version number - None # No available screener - ] - mock_conn.execute.return_value = None - mock_conn.fetchrow.return_value = None - - agent_data = { - "miner_hotkey": "test_miner_123", - "agent_name": "test_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = client.post("/upload/agent", json=agent_data) - - assert response.status_code == 200 - result = response.json() - assert "version_id" in result - - # Verify mocks were called appropriately - mock_validate_code.assert_called_once() - mock_check_similarity.assert_called_once() - mock_s3_upload.assert_called_once() - - @patch('api.src.backend.db_manager.get_transaction') - def test_upload_agent_banned_hotkey_mocked(self, mock_get_transaction): - """Test upload rejection for banned hotkey""" - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - # Mock banned hotkey check - mock_conn.fetchval.return_value = "banned_miner" # Banned hotkey found - - agent_data = { - "miner_hotkey": "banned_miner", - "agent_name": "test_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = client.post("/upload/agent", json=agent_data) - - assert response.status_code == 403 - assert "banned" in response.json()["detail"].lower() - - @patch('api.src.backend.db_manager.get_transaction') - def test_upload_agent_rate_limit_mocked(self, mock_get_transaction): - """Test rate limiting on agent uploads""" - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - # Mock recent upload found - mock_conn.fetchval.side_effect = [ - None, # Not banned - datetime.now(timezone.utc) # Recent upload found - ] - - agent_data = { - "miner_hotkey": "rate_limited_miner", - "agent_name": "new_agent", - "code": "def solve_problem(): return 'solution'", - "signature": "test_signature" - } - - response = client.post("/upload/agent", json=agent_data) - - assert response.status_code == 429 - assert "rate limit" in response.json()["detail"].lower() - - -class TestScoringEndpointsUnit: - """Unit tests for scoring endpoints with mocked dependencies""" - - @patch('api.src.backend.entities.MinerAgentScored.get_top_agent') - def test_check_top_agent_mocked(self, mock_get_top_agent): - """Test top agent retrieval with mocked database""" - - # Mock top agent data - mock_top_agent = Mock() - mock_top_agent.miner_hotkey = "top_miner_123" - mock_top_agent.version_id = uuid.uuid4() - mock_top_agent.avg_score = 0.92 - - mock_get_top_agent.return_value = mock_top_agent - - response = client.get("/scoring/check-top-agent") - - assert response.status_code == 200 - result = response.json() - assert result["miner_hotkey"] == "top_miner_123" - assert result["avg_score"] == 0.92 - - @patch('api.src.backend.entities.MinerAgentScored.get_top_agent') - def test_check_top_agent_none_mocked(self, mock_get_top_agent): - """Test top agent retrieval when no agents exist""" - - mock_get_top_agent.return_value = None - - response = client.get("/scoring/check-top-agent") - - assert response.status_code == 404 - assert "no top agent" in response.json()["detail"].lower() - - @patch('api.src.endpoints.scoring.ADMIN_PASSWORD', 'test_admin_pass') - @patch('api.src.backend.db_manager.get_transaction') - def test_ban_agents_mocked(self, mock_get_transaction): - """Test agent banning with mocked database""" - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - mock_conn.executemany.return_value = None - - ban_data = { - "admin_password": "test_admin_pass", - "miner_hotkeys": ["bad_miner_1", "bad_miner_2"] - } - - response = client.post("/scoring/ban-agents", json=ban_data) - - assert response.status_code == 200 - result = response.json() - assert result["banned_count"] == 2 - - # Verify database call was made - mock_conn.executemany.assert_called_once() - - def test_ban_agents_wrong_password(self): - """Test agent banning with wrong password""" - - ban_data = { - "admin_password": "wrong_password", - "miner_hotkeys": ["some_miner"] - } - - response = client.post("/scoring/ban-agents", json=ban_data) - - assert response.status_code == 401 - assert "unauthorized" in response.json()["detail"].lower() - - @patch('api.src.endpoints.scoring.ADMIN_PASSWORD', 'test_admin_pass') - @patch('api.src.backend.db_manager.get_transaction') - def test_approve_version_mocked(self, mock_get_transaction): - """Test version approval with mocked database""" - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - version_id = uuid.uuid4() - mock_conn.fetchrow.return_value = { - 'version_id': version_id, - 'status': 'scored', - 'miner_hotkey': 'test_miner' - } - mock_conn.execute.return_value = None - - approval_data = { - "admin_password": "test_admin_pass", - "version_ids": [str(version_id)] - } - - response = client.post("/scoring/approve-version", json=approval_data) - - assert response.status_code == 200 - result = response.json() - assert result["approved_count"] == 1 - - -class TestRetrievalEndpointsUnit: - """Unit tests for retrieval endpoints with mocked dependencies""" - - @patch('api.src.backend.entities.MinerAgentScored.get_24_hour_statistics') - def test_network_stats_mocked(self, mock_get_stats): - """Test network statistics with mocked data""" - - mock_stats = { - "number_of_agents": 150, - "agent_iterations_last_24_hours": 25, - "top_agent_score": 0.895, - "daily_score_improvement": 0.023 - } - mock_get_stats.return_value = mock_stats - - response = client.get("/retrieval/network-stats") - - assert response.status_code == 200 - result = response.json() - assert result["number_of_agents"] == 150 - assert result["agent_iterations_last_24_hours"] == 25 - assert result["top_agent_score"] == 0.895 - - @patch('api.src.backend.entities.MinerAgentScored.get_top_agents') - def test_top_agents_mocked(self, mock_get_top_agents): - """Test top agents retrieval with mocked data""" - - # Mock top agents data - mock_agents = [] - for i in range(3): - agent = Mock() - agent.version_id = uuid.uuid4() - agent.miner_hotkey = f"top_miner_{i}" - agent.agent_name = f"top_agent_{i}" - agent.score = 0.9 - (i * 0.05) # Decreasing scores - agent.approved = True - mock_agents.append(agent) - - mock_get_top_agents.return_value = mock_agents - - response = client.get("/retrieval/top-agents?num_agents=3") - - assert response.status_code == 200 - result = response.json() - assert len(result) == 3 - assert result[0]["score"] > result[1]["score"] # Verify ordering - - @patch('api.src.backend.entities.MinerAgentScored.get_agent_summary_by_hotkey') - def test_agent_by_hotkey_mocked(self, mock_get_agent_summary): - """Test agent retrieval by hotkey with mocked data""" - - # Mock agent data - mock_agents = [] - for i in range(2): - agent = Mock() - agent.version_id = uuid.uuid4() - agent.miner_hotkey = "test_hotkey" - agent.agent_name = f"agent_v{i+1}" - agent.version_num = i + 1 - agent.status = "scored" if i == 1 else "replaced" - agent.score = 0.8 + (i * 0.1) - mock_agents.append(agent) - - mock_get_agent_summary.return_value = mock_agents - - response = client.get("/retrieval/agent-by-hotkey?miner_hotkey=test_hotkey") - - assert response.status_code == 200 - result = response.json() - assert len(result) == 2 - assert result[0]["version_num"] == 1 # Check data structure - - @patch('api.src.socket.websocket_manager.WebSocketManager.get_instance') - def test_connected_validators_mocked(self, mock_ws_manager): - """Test connected validators retrieval""" - - # Mock WebSocket manager with connected validators - mock_manager = Mock() - mock_validator1 = Mock() - mock_validator1.get_type.return_value = "validator" - mock_validator1.hotkey = "validator_1" - mock_validator1.status = "available" - mock_validator1.connected_at = datetime.now(timezone.utc) - - mock_validator2 = Mock() - mock_validator2.get_type.return_value = "validator" - mock_validator2.hotkey = "validator_2" - mock_validator2.status = "evaluating" - mock_validator2.connected_at = datetime.now(timezone.utc) - - mock_manager.clients = {"1": mock_validator1, "2": mock_validator2} - mock_ws_manager.return_value = mock_manager - - response = client.get("/retrieval/connected-validators") - - assert response.status_code == 200 - result = response.json() - assert len(result) == 2 - - -class TestAuthenticationEndpointsUnit: - """Unit tests for authentication endpoints with mocked dependencies""" - - @patch('api.src.backend.db_manager.get_transaction') - def test_open_user_signin_success_mocked(self, mock_get_transaction): - """Test successful open user sign in""" - - # Mock database transaction - mock_conn = AsyncMock() - mock_transaction = AsyncMock() - mock_transaction.__aenter__.return_value = mock_conn - mock_get_transaction.return_value = mock_transaction - - # Mock user creation - mock_conn.fetchval.side_effect = [ - None # User doesn't exist yet - ] - mock_conn.execute.return_value = None - - signin_data = { - "auth0_user_id": "auth0|test123", - "email": "test@example.com", - "name": "Test User", - "password": "secure_password_123" - } - - response = client.post("/open-users/sign-in", json=signin_data) - - assert response.status_code == 200 - result = response.json() - assert "open_hotkey" in result - assert result["message"] == "User registered successfully" - - - - -class TestAgentSummaryEndpointsUnit: - """Unit tests for agent summary endpoints""" - - @patch('api.src.backend.db_manager.get_db_connection') - def test_get_agent_summary_mocked(self, mock_get_connection): - """Test agent summary retrieval""" - - # Mock database connection - mock_conn = AsyncMock() - mock_connection_context = AsyncMock() - mock_connection_context.__aenter__.return_value = mock_conn - mock_get_connection.return_value = mock_connection_context - - version_id = uuid.uuid4() - mock_conn.fetchrow.return_value = { - 'version_id': version_id, - 'agent_summary': 'This agent efficiently solves coding problems using advanced algorithms', - 'agent_name': 'test_agent', - 'miner_hotkey': 'test_miner' - } - - response = client.get(f"/agent-summaries/agent-summary/{version_id}") - - assert response.status_code == 200 - result = response.json() - assert result["agent_summary"] == "This agent efficiently solves coding problems using advanced algorithms" - assert result["version_id"] == str(version_id) - - @patch('api.src.backend.db_manager.get_db_connection') - def test_get_agent_summary_not_found_mocked(self, mock_get_connection): - """Test agent summary retrieval for non-existent agent""" - - # Mock database connection - mock_conn = AsyncMock() - mock_connection_context = AsyncMock() - mock_connection_context.__aenter__.return_value = mock_conn - mock_get_connection.return_value = mock_connection_context - - mock_conn.fetchrow.return_value = None # Agent not found - - version_id = uuid.uuid4() - response = client.get(f"/agent-summaries/agent-summary/{version_id}") - - assert response.status_code == 404 - assert "not found" in response.json()["detail"].lower() - - -class TestSystemStatusEndpointsUnit: - """Unit tests for system status endpoints""" - - def test_health_check_basic(self): - """Test basic health check endpoint""" - response = client.get("/healthcheck") - - assert response.status_code == 200 - assert response.text == '"OK"' - - def test_healthcheck_simple(self): - """Test simple healthcheck endpoint""" - response = client.get("/healthcheck") - - assert response.status_code == 200 - assert response.text == '"OK"' - - def test_healthcheck_results_endpoint(self): - """Test healthcheck results endpoint""" - response = client.get("/healthcheck-results") - - # This endpoint requires database connection, so it might fail in unit tests - # We just check that the endpoint exists and returns a proper response - assert response.status_code in [200, 500] # Either success or database error - - -class TestErrorHandling: - """Test error handling across endpoints""" - - def test_invalid_json(self): - """Test handling of invalid JSON in requests""" - response = client.post( - "/upload/agent", - data="invalid json", - headers={"Content-Type": "application/json"} - ) - - assert response.status_code == 422 - - def test_missing_required_fields(self): - """Test handling of missing required fields""" - incomplete_data = { - "miner_hotkey": "test_miner", - # Missing required fields - } - - response = client.post("/upload/agent", json=incomplete_data) - - assert response.status_code == 422 - - @patch('api.src.backend.db_manager.get_transaction') - def test_database_error_handling(self, mock_get_transaction): - """Test handling of database errors""" - - # Mock database error - mock_get_transaction.side_effect = Exception("Database connection failed") - - agent_data = { - "miner_hotkey": "test_miner", - "agent_name": "test_agent", - "code": "def solve(): pass", - "signature": "signature" - } - - response = client.post("/upload/agent", json=agent_data) - - assert response.status_code == 500 - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_miner_agent_flow.py b/tests/test_miner_agent_flow.py deleted file mode 100644 index 42563492..00000000 --- a/tests/test_miner_agent_flow.py +++ /dev/null @@ -1,1811 +0,0 @@ -""" -Comprehensive test suite for miner agent flow covering upload, screening, evaluation, and scoring. -Tests core status transitions and business logic with proper mocking. -""" - -import pytest -import uuid -from datetime import datetime, timezone -from unittest.mock import Mock, AsyncMock, patch -import asyncpg - -# Import the entities and models we're testing -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -from api.src.backend.entities import ( - AgentStatus, EvaluationStatus, SandboxStatus, - MinerAgent, MinerAgentWithScores, MinerAgentScored, - EvaluationRun -) -from api.src.models.screener import Screener -from api.src.models.evaluation import Evaluation - - -class TestAgentStatus: - """Test AgentStatus enum and transitions""" - - def test_agent_status_enum_values(self): - """Test all agent status enum values exist""" - expected_statuses = [ - "awaiting_screening_1", "screening_1", "failed_screening_1", - "awaiting_screening_2", "screening_2", "failed_screening_2", - "waiting", "evaluating", "scored", "replaced" - ] - - for status in expected_statuses: - assert hasattr(AgentStatus, status) - assert AgentStatus[status].value == status - - def test_agent_status_from_string(self): - """Test status string mapping""" - assert AgentStatus.from_string("awaiting_screening_1") == AgentStatus.awaiting_screening_1 - assert AgentStatus.from_string("screening_2") == AgentStatus.screening_2 - assert AgentStatus.from_string("scored") == AgentStatus.scored - - # Test legacy mappings - assert AgentStatus.from_string("awaiting_screening") == AgentStatus.awaiting_screening_1 - assert AgentStatus.from_string("screening") == AgentStatus.screening_1 - assert AgentStatus.from_string("evaluation") == AgentStatus.evaluating - - # Test invalid status defaults to awaiting_screening_1 - assert AgentStatus.from_string("invalid_status") == AgentStatus.awaiting_screening_1 - - def test_status_transitions_stage1_success(self): - """Test valid stage 1 screening success transition""" - # awaiting_screening_1 -> screening_1 -> awaiting_screening_2 - initial = AgentStatus.from_string("awaiting_screening_1") - screening = AgentStatus.from_string("screening_1") - next_stage = AgentStatus.from_string("awaiting_screening_2") - - assert initial == AgentStatus.awaiting_screening_1 - assert screening == AgentStatus.screening_1 - assert next_stage == AgentStatus.awaiting_screening_2 - - def test_status_transitions_stage1_failure(self): - """Test stage 1 screening failure transition""" - # awaiting_screening_1 -> screening_1 -> failed_screening_1 - initial = AgentStatus.from_string("awaiting_screening_1") - screening = AgentStatus.from_string("screening_1") - failed = AgentStatus.from_string("failed_screening_1") - - assert initial == AgentStatus.awaiting_screening_1 - assert screening == AgentStatus.screening_1 - assert failed == AgentStatus.failed_screening_1 - - def test_status_transitions_stage2_success(self): - """Test stage 2 screening success transition""" - # awaiting_screening_2 -> screening_2 -> waiting - initial = AgentStatus.from_string("awaiting_screening_2") - screening = AgentStatus.from_string("screening_2") - waiting = AgentStatus.from_string("waiting") - - assert initial == AgentStatus.awaiting_screening_2 - assert screening == AgentStatus.screening_2 - assert waiting == AgentStatus.waiting - - def test_status_transitions_evaluation_flow(self): - """Test evaluation flow transitions""" - # waiting -> evaluating -> scored - waiting = AgentStatus.from_string("waiting") - evaluating = AgentStatus.from_string("evaluating") - scored = AgentStatus.from_string("scored") - - assert waiting == AgentStatus.waiting - assert evaluating == AgentStatus.evaluating - assert scored == AgentStatus.scored - - -class TestScreener: - """Test Screener model and stage detection""" - - def test_screener_stage_detection(self): - """Test screener stage detection from hotkey""" - assert Screener.get_stage("screener-1-abc123") == 1 - assert Screener.get_stage("screener-2-def456") == 2 - assert Screener.get_stage("i-0123456789abcdef") == 1 # Legacy - assert Screener.get_stage("validator-xyz") is None - assert Screener.get_stage("invalid-hotkey") is None - - def test_screener_initialization(self): - """Test screener object initialization""" - screener = Screener( - hotkey="screener-1-test", - status="available" - ) - - assert screener.hotkey == "screener-1-test" - assert screener.stage == 1 - assert screener.status == "available" - assert screener.is_available() - assert screener.get_type() == "screener" - - def test_screener_state_management(self): - """Test screener availability state changes""" - screener = Screener(hotkey="screener-2-test", status="available") - - # Test initial state - assert screener.is_available() - assert screener.status == "available" - - # Test setting unavailable - screener.status = "screening" - screener.current_evaluation_id = "eval123" - screener.current_agent_name = "test_agent" - screener.current_agent_hotkey = "miner123" - - assert not screener.is_available() - assert screener.screening_id == "eval123" - assert screener.screening_agent_name == "test_agent" - - # Test reset to available - screener.set_available() - assert screener.is_available() - assert screener.current_evaluation_id is None - assert screener.current_agent_name is None - - def test_screener_start_screening_validation_logic(self): - """Test screener start screening validation logic without database""" - screener = Screener(hotkey="screener-1-test", status="available") - - # Test stage detection - assert screener.stage == 1 - - # Test availability check - assert screener.is_available() is True - - # Test state changes - screener.status = "screening" - screener.current_evaluation_id = "eval123" - assert screener.is_available() is False - assert screener.screening_id == "eval123" - - @pytest.mark.asyncio - async def test_screener_get_first_available_and_reserve(self): - """Test atomic screener reservation""" - from api.src.socket.websocket_manager import WebSocketManager - - # Mock WebSocket manager with available screeners - mock_ws_manager = Mock() - mock_screener1 = Mock() - mock_screener1.get_type.return_value = "screener" - mock_screener1.status = "available" - mock_screener1.is_available.return_value = True - mock_screener1.stage = 1 - mock_screener1.hotkey = "screener-1-test" - - mock_screener2 = Mock() - mock_screener2.get_type.return_value = "screener" - mock_screener2.status = "available" - mock_screener2.is_available.return_value = True - mock_screener2.stage = 2 - mock_screener2.hotkey = "screener-2-test" - - mock_ws_manager.clients = {"1": mock_screener1, "2": mock_screener2} - - with patch.object(WebSocketManager, 'get_instance', return_value=mock_ws_manager): - # Test stage 1 reservation - screener = await Screener.get_first_available_and_reserve(1) - assert screener == mock_screener1 - if screener is not None: - assert screener.status == "reserving" - - # Test stage 2 reservation - screener = await Screener.get_first_available_and_reserve(2) - assert screener == mock_screener2 - if screener is not None: - assert screener.status == "reserving" - - # Test no available screeners for stage 3 - screener = await Screener.get_first_available_and_reserve(3) - assert screener is None - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_combined_screener_score_calculation(self): - """Test that get_combined_screener_score calculates the correct score from evaluation runs""" - import os - import uuid - - # Create a direct database connection for this test - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Reset relevant tables - await db_conn.execute("TRUNCATE evaluation_runs, evaluations, miner_agents RESTART IDENTITY CASCADE") - - set_id = 1 - test_version = str(uuid.uuid4()) - - # Add evaluation sets for current set_id - await db_conn.execute( - "INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES ($1, 'screener-1', 'test-instance-1'), ($1, 'screener-2', 'test-instance-2') ON CONFLICT DO NOTHING", - set_id - ) - - # Create test agent - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'test_miner','test_agent',1,NOW(),'awaiting_screening_1')", - test_version, - ) - - # Create stage 1 evaluation with known results - stage1_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at) VALUES ($1,$2,'screener-1-test',$3,'completed',NOW(),NOW())", - stage1_eval_id, test_version, set_id - ) - - # Create stage 1 evaluation runs: 7 out of 10 questions solved - stage1_solved = 7 - stage1_total = 10 - for i in range(stage1_total): - run_id = str(uuid.uuid4()) - solved = i < stage1_solved # First 7 are solved - await db_conn.execute( - "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, solved, status, started_at) VALUES ($1,$2,$3,$4,'result_scored',NOW())", - run_id, stage1_eval_id, f"stage1-test-{i+1}", solved - ) - - # Create stage 2 evaluation with known results - stage2_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at) VALUES ($1,$2,'screener-2-test',$3,'completed',NOW(),NOW())", - stage2_eval_id, test_version, set_id - ) - - # Create stage 2 evaluation runs: 3 out of 5 questions solved - stage2_solved = 3 - stage2_total = 5 - for i in range(stage2_total): - run_id = str(uuid.uuid4()) - solved = i < stage2_solved # First 3 are solved - await db_conn.execute( - "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, solved, status, started_at) VALUES ($1,$2,$3,$4,'result_scored',NOW())", - run_id, stage2_eval_id, f"stage2-test-{i+1}", solved - ) - - # Test the combined screener score calculation - combined_score, score_error = await Screener.get_combined_screener_score(db_conn, test_version) - - # Calculate expected combined score: (7 + 3) / (10 + 5) = 10/15 = 2/3 ≈ 0.6667 - expected_score = (stage1_solved + stage2_solved) / (stage1_total + stage2_total) - - # Verify the calculation is correct - assert combined_score is not None, "Combined score should not be None when both stages are completed" - assert score_error is None, f"Should not have error, but got: {score_error}" - assert abs(combined_score - expected_score) < 0.0001, f"Expected combined score {expected_score}, got {combined_score}" - - # Verify the specific calculation: 10 solved out of 15 total - assert abs(combined_score - (10/15)) < 0.0001, f"Expected 10/15 = {10/15}, got {combined_score}" - - # Test edge case: only stage 1 completed (should return None) - incomplete_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'incomplete_miner','incomplete_agent',1,NOW(),'awaiting_screening_2')", - incomplete_version, - ) - - incomplete_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at) VALUES ($1,$2,'screener-1-test',$3,'completed',NOW(),NOW())", - incomplete_eval_id, incomplete_version, set_id - ) - - # Add some runs for the incomplete case - for i in range(3): - run_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, solved, status, started_at) VALUES ($1,$2,$3,$4,'result_scored',NOW())", - run_id, incomplete_eval_id, f"incomplete-test-{i+1}", True - ) - - incomplete_score, incomplete_error = await Screener.get_combined_screener_score(db_conn, incomplete_version) - assert incomplete_score is None, "Combined score should be None when only one stage is completed" - assert incomplete_error is not None, "Should have error message when incomplete" - - # Test edge case: no evaluations (should return None) - no_eval_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'no_eval_miner','no_eval_agent',1,NOW(),'awaiting_screening_1')", - no_eval_version, - ) - - no_eval_score, no_eval_error = await Screener.get_combined_screener_score(db_conn, no_eval_version) - assert no_eval_score is None, "Combined score should be None when no evaluations exist" - assert no_eval_error is not None, "Should have error message when no evaluations exist" - - finally: - await db_conn.close() - - -class TestEvaluationStatus: - """Test EvaluationStatus enum and transitions""" - - def test_evaluation_status_enum_values(self): - """Test all evaluation status values exist""" - expected_statuses = ["waiting", "running", "replaced", "error", "completed", "cancelled", "pruned"] - - for status in expected_statuses: - assert hasattr(EvaluationStatus, status) - assert EvaluationStatus[status].value == status - - def test_evaluation_status_from_string(self): - """Test evaluation status string mapping""" - assert EvaluationStatus.from_string("waiting") == EvaluationStatus.waiting - assert EvaluationStatus.from_string("running") == EvaluationStatus.running - assert EvaluationStatus.from_string("completed") == EvaluationStatus.completed - assert EvaluationStatus.from_string("invalid") == EvaluationStatus.error - - -class TestMinerAgent: - """Test MinerAgent model and operations""" - - def test_miner_agent_creation(self): - """Test MinerAgent object creation""" - agent_id = uuid.uuid4() - created_at = datetime.now(timezone.utc) - - agent = MinerAgent( - version_id=agent_id, - miner_hotkey="test_hotkey_123", - agent_name="test_agent", - version_num=1, - created_at=created_at, - status="awaiting_screening_1", - agent_summary="Test agent description" - ) - - assert agent.version_id == agent_id - assert agent.miner_hotkey == "test_hotkey_123" - assert agent.agent_name == "test_agent" - assert agent.version_num == 1 - assert agent.status == "awaiting_screening_1" - assert agent.agent_summary == "Test agent description" - - def test_miner_agent_with_scores(self): - """Test MinerAgentWithScores model""" - agent_id = uuid.uuid4() - created_at = datetime.now(timezone.utc) - - agent = MinerAgentWithScores( - version_id=agent_id, - miner_hotkey="test_hotkey", - agent_name="test_agent", - version_num=1, - created_at=created_at, - status="scored", - score=0.85, - set_id=1, - approved=True - ) - - assert agent.score == 0.85 - assert agent.set_id == 1 - assert agent.approved is True - - @pytest.mark.asyncio - async def test_high_score_detection_no_agent(self): - """Test high score detection when agent not found""" - mock_conn = AsyncMock() - mock_conn.fetchrow.return_value = None - - result = await MinerAgentScored.check_for_new_high_score(mock_conn, uuid.uuid4()) - - assert result["high_score_detected"] is False - assert "not found" in result["reason"] - - @pytest.mark.asyncio - async def test_high_score_detection_beats_previous(self): - """Test high score detection when agent beats previous best""" - mock_conn = AsyncMock() - agent_id = uuid.uuid4() - - # Mock current agent score - mock_conn.fetchrow.side_effect = [ - { - 'version_id': agent_id, - 'miner_hotkey': 'test_hotkey', - 'agent_name': 'test_agent', - 'version_num': 1, - 'created_at': datetime.now(timezone.utc), - 'status': 'scored', - 'agent_summary': 'test', - 'set_id': 1, - 'approved': False, - 'validator_count': 3, - 'final_score': 0.95 - }, - {'max_score': 0.85} # Previous max approved score - ] - - result = await MinerAgentScored.check_for_new_high_score(mock_conn, agent_id) - - assert result["high_score_detected"] is True - assert result["new_score"] == 0.95 - assert result["previous_max_score"] == 0.85 - assert result["miner_hotkey"] == "test_hotkey" - - @pytest.mark.asyncio - async def test_high_score_detection_no_previous_approved(self): - """Test high score detection when no previous approved agents""" - mock_conn = AsyncMock() - agent_id = uuid.uuid4() - - mock_conn.fetchrow.side_effect = [ - { - 'version_id': agent_id, - 'miner_hotkey': 'test_hotkey', - 'agent_name': 'test_agent', - 'version_num': 1, - 'created_at': datetime.now(timezone.utc), - 'status': 'scored', - 'agent_summary': 'test', - 'set_id': 1, - 'approved': False, - 'validator_count': 3, - 'final_score': 0.85 - }, - {'max_score': None} # No previous approved agents - ] - - result = await MinerAgentScored.check_for_new_high_score(mock_conn, agent_id) - - assert result["high_score_detected"] is True - assert result["new_score"] == 0.85 - assert result["previous_max_score"] == 0.0 - - @pytest.mark.asyncio - async def test_get_top_agent_with_leadership_rule(self): - """Test top agent selection with 1.5% leadership rule""" - mock_conn = AsyncMock() - - # Mock max set_id - mock_conn.fetchrow.side_effect = [ - {'max_set_id': 5}, # Max set_id - { # Current leader - 'version_id': uuid.uuid4(), - 'miner_hotkey': 'leader_hotkey', - 'final_score': 0.90, - 'created_at': datetime.now(timezone.utc) - }, - { # Challenger that beats by 1.5% - 'version_id': uuid.uuid4(), - 'miner_hotkey': 'challenger_hotkey', - 'final_score': 0.92 # 0.90 * 1.015 = 0.9135, so 0.92 beats this - } - ] - - with patch('api.src.utils.models.TopAgentHotkey') as mock_top_agent: - result = await MinerAgentScored.get_top_agent(mock_conn) - - # Verify challenger was selected (score >= required_score) - mock_top_agent.assert_called_once() - call_args = mock_top_agent.call_args[1] - assert call_args['miner_hotkey'] == 'challenger_hotkey' - assert call_args['avg_score'] == 0.92 - - @pytest.mark.asyncio - async def test_get_top_agent_no_challenger(self): - """Test top agent selection when no challenger beats 1.5% rule""" - mock_conn = AsyncMock() - - mock_conn.fetchrow.side_effect = [ - {'max_set_id': 5}, - { # Current leader - 'version_id': uuid.uuid4(), - 'miner_hotkey': 'leader_hotkey', - 'final_score': 0.90, - 'created_at': datetime.now(timezone.utc) - }, - None # No challenger beats 1.5% rule - ] - - with patch('api.src.utils.models.TopAgentHotkey') as mock_top_agent: - result = await MinerAgentScored.get_top_agent(mock_conn) - - # Verify current leader remains - call_args = mock_top_agent.call_args[1] - assert call_args['miner_hotkey'] == 'leader_hotkey' - assert call_args['avg_score'] == 0.90 - - -class TestEvaluationModel: - """Test Evaluation model functionality""" - - def test_evaluation_initialization(self): - """Test evaluation object creation""" - eval_id = str(uuid.uuid4()) - version_id = str(uuid.uuid4()) - - evaluation = Evaluation( - evaluation_id=eval_id, - version_id=version_id, - validator_hotkey="screener-1-test", - set_id=1, - status=EvaluationStatus.waiting - ) - - assert evaluation.evaluation_id == eval_id - assert evaluation.version_id == version_id - assert evaluation.validator_hotkey == "screener-1-test" - assert evaluation.is_screening is True - assert evaluation.screener_stage == 1 - - def test_evaluation_screening_detection(self): - """Test screening vs validation detection""" - # Screening evaluation - screening_eval = Evaluation( - evaluation_id=str(uuid.uuid4()), - version_id=str(uuid.uuid4()), - validator_hotkey="screener-2-test", - set_id=1, - status=EvaluationStatus.waiting - ) - - assert screening_eval.is_screening is True - assert screening_eval.screener_stage == 2 - - # Validation evaluation - validation_eval = Evaluation( - evaluation_id=str(uuid.uuid4()), - version_id=str(uuid.uuid4()), - validator_hotkey="validator-hotkey", - set_id=1, - status=EvaluationStatus.waiting - ) - - assert validation_eval.is_screening is False - assert validation_eval.screener_stage is None - - @pytest.mark.asyncio - async def test_evaluation_start_screening(self): - """Test evaluation start with screening setup""" - mock_conn = AsyncMock() - - evaluation = Evaluation( - evaluation_id=str(uuid.uuid4()), - version_id=str(uuid.uuid4()), - validator_hotkey="screener-1-test", - set_id=1, - status=EvaluationStatus.waiting - ) - - # Mock database responses - mock_conn.fetchval.return_value = 5 # max_set_id - mock_conn.fetch.return_value = [ - {'swebench_instance_id': 'instance1'}, - {'swebench_instance_id': 'instance2'} - ] - - # Mock _update_agent_status method which doesn't exist on the basic Evaluation class - with patch.object(evaluation, 'start') as mock_start: - mock_start.return_value = [] # Mock return value - runs = await evaluation.start(mock_conn) - - # Verify start was called - mock_start.assert_called_once_with(mock_conn) - - def test_evaluation_properties_and_validation(self): - """Test evaluation properties and validation logic""" - eval_id = str(uuid.uuid4()) - - # Test screener evaluation - screening_eval = Evaluation( - evaluation_id=eval_id, - version_id=str(uuid.uuid4()), - validator_hotkey="screener-1-test", - set_id=1, - status=EvaluationStatus.waiting - ) - - assert screening_eval.evaluation_id == eval_id - assert screening_eval.is_screening is True - assert screening_eval.screener_stage == 1 - assert screening_eval.status == EvaluationStatus.waiting - - # Test validator evaluation - validation_eval = Evaluation( - evaluation_id=str(uuid.uuid4()), - version_id=str(uuid.uuid4()), - validator_hotkey="validator-hotkey", - set_id=1, - status=EvaluationStatus.waiting - ) - - assert validation_eval.is_screening is False - assert validation_eval.screener_stage is None - - @pytest.mark.asyncio - async def test_prune_low_waiting(self): - """Test pruning of low-scoring evaluations""" - mock_conn = AsyncMock() - - # Mock the database calls that get_top_agent would make - mock_conn.fetchrow.side_effect = [ - {'max_set_id': 1}, # max_set_id result - { # current_leader result - 'version_id': 'top_version', - 'miner_hotkey': 'top_hotkey', - 'final_score': 0.9, - 'created_at': '2023-01-01' - } - ] - - # Mock fetchval calls - mock_conn.fetchval.side_effect = [ - 1, # max_set_id (for prune_low_waiting) - ] - - # Mock low final validation score evaluations to be pruned - mock_conn.fetch.return_value = [ - { - 'evaluation_id': 'eval1', - 'version_id': 'version1', - 'validator_hotkey': 'validator1', - 'final_score': 0.6 # Below 0.72 threshold (0.9 * 0.8) - }, - { - 'evaluation_id': 'eval2', - 'version_id': 'version2', - 'validator_hotkey': 'validator2', - 'final_score': 0.5 # Below 0.72 threshold - } - ] - - # Test the core pruning logic directly - from api.src.utils.models import TopAgentHotkey - from uuid import uuid4 - top_agent = TopAgentHotkey( - miner_hotkey='top_hotkey', - version_id=str(uuid4()), - avg_score=0.9 - ) - - # Calculate threshold - threshold = top_agent.avg_score * 0.8 # 0.72 - - # Verify the logic works - assert 0.6 < threshold # eval1 should be pruned - assert 0.5 < threshold # eval2 should be pruned - - # Simulate the pruning - await mock_conn.execute("UPDATE evaluations SET status = 'pruned', finished_at = NOW() WHERE evaluation_id = ANY($1)", ['eval1', 'eval2']) - await mock_conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = ANY($1)", ['version1', 'version2']) - - # Verify the calls were made - calls = mock_conn.execute.call_args_list - assert len(calls) == 2 - - # Find the evaluation update call - eval_call = None - agent_call = None - for call in calls: - args, kwargs = call - if 'evaluation_id' in args[0]: - eval_call = call - elif 'version_id' in args[0]: - agent_call = call - - assert eval_call is not None, "Evaluation update call not found" - assert agent_call is not None, "Agent update call not found" - - # Verify the parameters - eval_args, eval_kwargs = eval_call - agent_args, agent_kwargs = agent_call - - assert 'eval1' in eval_args[1] and 'eval2' in eval_args[1], f"Expected evaluation IDs in {eval_args[1]}" - assert 'version1' in agent_args[1] and 'version2' in agent_args[1], f"Expected version IDs in {agent_args[1]}" - - @pytest.mark.asyncio - async def test_prune_low_waiting_no_evaluations(self): - """Test pruning when no evaluations need to be pruned""" - mock_conn = AsyncMock() - - # Mock the database calls that get_top_agent would make - mock_conn.fetchrow.side_effect = [ - {'max_set_id': 1}, # max_set_id result - { # current_leader result - 'version_id': 'top_version', - 'miner_hotkey': 'top_hotkey', - 'final_score': 0.9, - 'created_at': '2023-01-01' - } - ] - - # Mock fetchval calls - mock_conn.fetchval.side_effect = [ - 1, # max_set_id (for prune_low_waiting) - ] - - # Mock no low score evaluations - mock_conn.fetch.return_value = [] - - # Test the core pruning logic directly - from api.src.utils.models import TopAgentHotkey - from uuid import uuid4 - top_agent = TopAgentHotkey( - miner_hotkey='top_hotkey', - version_id=str(uuid4()), - avg_score=0.9 - ) - - # Calculate threshold - threshold = top_agent.avg_score * 0.8 # 0.72 - - # Verify no evaluations would be pruned - assert len([]) == 0 # No evaluations to prune - - # Verify no pruning queries were called - mock_conn.execute.assert_not_called() - - @pytest.mark.asyncio - async def test_prune_low_waiting_no_top_score(self): - """Test pruning when no completed evaluations with final validation scores exist""" - mock_conn = AsyncMock() - - # Mock no top agent (no evaluation sets) - mock_conn.fetchrow.return_value = None - - # Test the core pruning logic directly - # When no top agent exists, no pruning should occur - assert None is None # No top agent - - # Verify no pruning queries were called - mock_conn.execute.assert_not_called() - - @pytest.mark.asyncio - async def test_screener2_immediate_prune_low_score(self): - """Test immediate pruning when screener-2 score is too low""" - mock_conn = AsyncMock() - - # Create evaluation with screener-2 hotkey to trigger the logic - evaluation = Evaluation( - evaluation_id="eval1", - version_id="version1", - validator_hotkey="screener-2-test", - set_id=1, - status=EvaluationStatus.completed, - score=0.6 # Low score - ) - - # Test the immediate pruning logic directly - from api.src.utils.models import TopAgentHotkey - from uuid import uuid4 - top_agent = TopAgentHotkey( - miner_hotkey='top_hotkey', - version_id=str(uuid4()), - avg_score=0.9 - ) - - # Verify the logic would trigger pruning - assert evaluation.score < top_agent.avg_score * 0.8 # 0.6 < 0.72 - - # Simulate the pruning - await mock_conn.execute("UPDATE evaluations SET status = 'pruned', finished_at = NOW() WHERE evaluation_id = $1", evaluation.evaluation_id) - await mock_conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", evaluation.version_id) - - # Verify the calls were made - mock_conn.execute.assert_any_call( - "UPDATE evaluations SET status = 'pruned', finished_at = NOW() WHERE evaluation_id = $1", - "eval1" - ) - mock_conn.execute.assert_any_call( - "UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", - "version1" - ) - - @pytest.mark.asyncio - async def test_screener2_no_immediate_prune_acceptable_score(self): - """Test no immediate pruning when screener-2 score is acceptable""" - mock_conn = AsyncMock() - - # Create evaluation with screener-2 hotkey to trigger the logic - evaluation = Evaluation( - evaluation_id="eval1", - version_id="version1", - validator_hotkey="screener-2-test", - set_id=1, - status=EvaluationStatus.completed, - score=0.8 # Acceptable score - ) - - # Test the immediate pruning logic directly - from api.src.utils.models import TopAgentHotkey - from uuid import uuid4 - top_agent = TopAgentHotkey( - miner_hotkey='top_hotkey', - version_id=str(uuid4()), - avg_score=0.9 - ) - - # Verify the logic would NOT trigger pruning - assert evaluation.score >= top_agent.avg_score * 0.8 # 0.8 >= 0.72 - - # Verify no pruning calls were made - mock_conn.execute.assert_not_called() - -class TestEvaluationRun: - """Test EvaluationRun model and sandbox statuses""" - - def test_evaluation_run_creation(self): - """Test evaluation run object creation""" - run_id = uuid.uuid4() - eval_id = uuid.uuid4() # Use UUID instead of string - started_at = datetime.now(timezone.utc) - - run = EvaluationRun( - run_id=run_id, - evaluation_id=eval_id, - swebench_instance_id="instance123", - status=SandboxStatus.started, - started_at=started_at - ) - - assert run.run_id == run_id - assert run.evaluation_id == eval_id - assert run.swebench_instance_id == "instance123" - assert run.status == SandboxStatus.started - assert run.started_at == started_at - assert run.response is None - assert run.solved is None - - def test_sandbox_status_progression(self): - """Test sandbox status progression through evaluation""" - run = EvaluationRun( - run_id=uuid.uuid4(), - evaluation_id=uuid.uuid4(), - swebench_instance_id="test", - status=SandboxStatus.started, - started_at=datetime.now(timezone.utc) - ) - - # Test status progression - assert run.status == SandboxStatus.started - - run.status = SandboxStatus.sandbox_created - run.sandbox_created_at = datetime.now(timezone.utc) - assert run.status == SandboxStatus.sandbox_created - assert run.sandbox_created_at is not None - - run.status = SandboxStatus.patch_generated - run.patch_generated_at = datetime.now(timezone.utc) - assert run.status == SandboxStatus.patch_generated - - run.status = SandboxStatus.result_scored - run.result_scored_at = datetime.now(timezone.utc) - run.solved = True - assert run.status == SandboxStatus.result_scored - assert run.solved is True - - -class TestAgentLifecycleFlow: - """Test complete agent lifecycle flows""" - - @pytest.mark.asyncio - async def test_complete_successful_flow(self): - """Test complete successful agent flow from upload to scoring""" - - # 1. Upload - Agent starts as awaiting_screening_1 - agent_id = uuid.uuid4() - agent = MinerAgent( - version_id=agent_id, - miner_hotkey="test_miner", - agent_name="test_agent", - version_num=1, - created_at=datetime.now(timezone.utc), - status="awaiting_screening_1" - ) - - assert AgentStatus.from_string(agent.status) == AgentStatus.awaiting_screening_1 - - # 2. Stage 1 Screening - transitions through screening_1 to awaiting_screening_2 - agent.status = "screening_1" - assert AgentStatus.from_string(agent.status) == AgentStatus.screening_1 - - # Simulate successful screening (score >= 0.6) - agent.status = "awaiting_screening_2" - assert AgentStatus.from_string(agent.status) == AgentStatus.awaiting_screening_2 - - # 3. Stage 2 Screening - transitions through screening_2 to waiting - agent.status = "screening_2" - assert AgentStatus.from_string(agent.status) == AgentStatus.screening_2 - - # Simulate successful screening (score >= 0.2) - agent.status = "waiting" - assert AgentStatus.from_string(agent.status) == AgentStatus.waiting - - # 4. Evaluation - transitions through evaluating to scored - agent.status = "evaluating" - assert AgentStatus.from_string(agent.status) == AgentStatus.evaluating - - agent.status = "scored" - assert AgentStatus.from_string(agent.status) == AgentStatus.scored - - @pytest.mark.asyncio - async def test_stage1_screening_failure_flow(self): - """Test agent flow when stage 1 screening fails""" - - agent_id = uuid.uuid4() - agent = MinerAgent( - version_id=agent_id, - miner_hotkey="test_miner", - agent_name="failing_agent", - version_num=1, - created_at=datetime.now(timezone.utc), - status="awaiting_screening_1" - ) - - # Stage 1 screening starts - agent.status = "screening_1" - assert AgentStatus.from_string(agent.status) == AgentStatus.screening_1 - - # Screening fails (score < 0.6) - agent.status = "failed_screening_1" - assert AgentStatus.from_string(agent.status) == AgentStatus.failed_screening_1 - - # Agent should not proceed to stage 2 - - @pytest.mark.asyncio - async def test_stage2_screening_failure_flow(self): - """Test agent flow when stage 2 screening fails""" - - agent_id = uuid.uuid4() - agent = MinerAgent( - version_id=agent_id, - miner_hotkey="test_miner", - agent_name="stage2_failing_agent", - version_num=1, - created_at=datetime.now(timezone.utc), - status="awaiting_screening_2" # Passed stage 1 - ) - - # Stage 2 screening starts - agent.status = "screening_2" - assert AgentStatus.from_string(agent.status) == AgentStatus.screening_2 - - # Screening fails (score < 0.2) - agent.status = "failed_screening_2" - assert AgentStatus.from_string(agent.status) == AgentStatus.failed_screening_2 - - # Agent should not proceed to evaluation - - @pytest.mark.asyncio - async def test_agent_replacement_flow(self): - """Test agent replacement when newer version uploaded""" - - # Original agent - original_agent = MinerAgent( - version_id=uuid.uuid4(), - miner_hotkey="test_miner", - agent_name="test_agent", - version_num=1, - created_at=datetime.now(timezone.utc), - status="scored" - ) - - # New version uploaded - new_agent = MinerAgent( - version_id=uuid.uuid4(), - miner_hotkey="test_miner", - agent_name="test_agent", - version_num=2, - created_at=datetime.now(timezone.utc), - status="awaiting_screening_1" - ) - - # Original should be marked as replaced - original_agent.status = "replaced" - assert AgentStatus.from_string(original_agent.status) == AgentStatus.replaced - assert AgentStatus.from_string(new_agent.status) == AgentStatus.awaiting_screening_1 - - # --- Integration pruning tests using real database --- - @pytest.mark.asyncio - @pytest.mark.integration - async def test_prune_low_waiting_integration(self): - """Batch pruning sets waiting evaluations and agent to pruned when below threshold.""" - from api.src.models.evaluation import Evaluation - import os - - # Create a direct database connection for this test to avoid event loop conflicts - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Reset relevant tables - await db_conn.execute("TRUNCATE evaluation_runs, evaluations, miner_agents, approved_version_ids, banned_hotkeys, top_agents RESTART IDENTITY CASCADE") - - set_id = 1 - # Top agent (approved) with completed validator evals (score 0.90) - top_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_top','top_agent',1,NOW(),'scored')", - top_version, - ) - await db_conn.execute("INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1) ON CONFLICT DO NOTHING", top_version) - await db_conn.execute( - """ - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1,$2,'validator-1',$3,'completed',NOW(),NOW(),0.90), - ($4,$2,'validator-2',$3,'completed',NOW(),NOW(),0.90) - """, - str(uuid.uuid4()), top_version, set_id, str(uuid.uuid4()) - ) - - # Low agent with completed evals (0.60) and one waiting eval - low_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_low','low_agent',1,NOW(),'waiting')", - low_version, - ) - await db_conn.execute( - """ - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1,$2,'validator-1',$3,'completed',NOW(),NOW(),0.60), - ($4,$2,'validator-2',$3,'completed',NOW(),NOW(),0.60) - """, - str(uuid.uuid4()), low_version, set_id, str(uuid.uuid4()) - ) - waiting_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at) VALUES ($1,$2,'validator-3',$3,'waiting',NOW())", - waiting_eval_id, low_version, set_id - ) - - # Run prune - we need to call the backend function directly since the model method uses global db manager - from api.src.backend.entities import MinerAgentScored - from api.src.utils.config import PRUNE_THRESHOLD - - # Replicate the prune_low_waiting logic with our direct connection - top_agent = await MinerAgentScored.get_top_agent(db_conn) - - if top_agent: - # Calculate the threshold - threshold = top_agent.avg_score - PRUNE_THRESHOLD - - # Get current set_id for the query - for tests, the set_id is 1 - max_set_id = 1 - - # For this test, we need to refresh the materialized view first - await db_conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Find evaluations below threshold - low_score_evaluations = await db_conn.fetch(""" - SELECT e.evaluation_id, e.version_id, e.validator_hotkey, ass.final_score - FROM evaluations e - JOIN miner_agents ma ON e.version_id = ma.version_id - JOIN agent_scores ass ON e.version_id = ass.version_id AND e.set_id = ass.set_id - WHERE e.set_id = $1 - AND e.status = 'waiting' - AND ass.final_score IS NOT NULL - AND ass.final_score < $2 - AND ma.status NOT IN ('pruned', 'replaced') - """, max_set_id, threshold) - - if low_score_evaluations: - # Get unique version_ids to prune - version_ids_to_prune = list(set(eval['version_id'] for eval in low_score_evaluations)) - evaluation_ids_to_prune = [eval['evaluation_id'] for eval in low_score_evaluations] - - # Update evaluations to pruned status - await db_conn.execute( - "UPDATE evaluations SET status = 'pruned', finished_at = NOW() WHERE evaluation_id = ANY($1)", - evaluation_ids_to_prune - ) - - # Update agents to pruned status - await db_conn.execute( - "UPDATE miner_agents SET status = 'pruned' WHERE version_id = ANY($1)", - version_ids_to_prune - ) - else: - # If no top agent, just manually prune the low scoring agent since we know it should be pruned - await db_conn.execute("UPDATE evaluations SET status = 'pruned', finished_at = NOW() WHERE evaluation_id = $1", waiting_eval_id) - await db_conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", low_version) - - status = await db_conn.fetchval("SELECT status FROM evaluations WHERE evaluation_id = $1", waiting_eval_id) - assert status == 'pruned' - agent_status = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", low_version) - assert agent_status == 'pruned' - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_prune_low_waiting_by_screener_score_integration(self): - """Test that prune_low_waiting prunes evaluations with low screener scores.""" - from api.src.models.evaluation import Evaluation - from api.src.utils.config import PRUNE_THRESHOLD - import os - import uuid - - # Create a direct database connection for this test to avoid event loop conflicts - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Reset relevant tables - await db_conn.execute("TRUNCATE evaluation_runs, evaluations, miner_agents, approved_version_ids, banned_hotkeys, top_agents RESTART IDENTITY CASCADE") - - # Add evaluation set for current set_id - set_id = 1 - await db_conn.execute( - "INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES ($1, 'validator', 'test-instance') ON CONFLICT DO NOTHING", - set_id - ) - - # Create top agent with high validation scores for threshold calculation - top_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_top','top_agent',1,NOW(),'scored')", - top_version, - ) - await db_conn.execute("INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1) ON CONFLICT DO NOTHING", top_version) - await db_conn.execute( - """ - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1,$2,'validator-1',$3,'completed',NOW(),NOW(),0.90), - ($4,$2,'validator-2',$3,'completed',NOW(),NOW(),0.90), - ($5,$2,'validator-5',$3,'completed',NOW(),NOW(),0.89) - """, - str(uuid.uuid4()), top_version, set_id, str(uuid.uuid4()), str(uuid.uuid4()) - ) - - # Create agent with good screener score (above threshold) - good_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_good','good_agent',1,NOW(),'waiting')", - good_version, - ) - good_eval_id = str(uuid.uuid4()) - # Calculate good screener score dynamically - should be above threshold - top_agent_score = 0.9 # Match the top agent score from this test - threshold = top_agent_score - PRUNE_THRESHOLD - good_screener_score = threshold + 0.05 # 5% buffer above threshold - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, screener_score) VALUES ($1,$2,'validator-3',$3,'waiting',NOW(),$4)", - good_eval_id, good_version, set_id, good_screener_score - ) - - # Create agent with low screener score (below threshold) - low_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_low','low_agent',1,NOW(),'waiting')", - low_version, - ) - low_eval_id = str(uuid.uuid4()) - # Calculate low screener score dynamically - should be below threshold - low_screener_score = threshold - 0.1 # 10% below threshold - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, screener_score) VALUES ($1,$2,'validator-4',$3,'waiting',NOW(),$4)", - low_eval_id, low_version, set_id, low_screener_score - ) - - # Ensure the top agent is properly set up and refresh materialized view - await db_conn.execute("UPDATE miner_agents SET status = 'scored' WHERE version_id = $1", top_version) - await db_conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call prune_low_waiting - await Evaluation.prune_low_waiting(db_conn) - - # Check that low scoring evaluation was pruned - low_status = await db_conn.fetchval("SELECT status FROM evaluations WHERE evaluation_id = $1", low_eval_id) - assert low_status == 'pruned', f"Expected low scoring evaluation to be pruned, but status is {low_status}" - - # Check that low scoring agent was pruned - low_agent_status = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", low_version) - assert low_agent_status == 'pruned', f"Expected low scoring agent to be pruned, but status is {low_agent_status}" - - # Check that good scoring evaluation remains waiting - good_status = await db_conn.fetchval("SELECT status FROM evaluations WHERE evaluation_id = $1", good_eval_id) - assert good_status == 'waiting', f"Expected good scoring evaluation to remain waiting, but status is {good_status}" - - # Check that good scoring agent remains waiting - good_agent_status = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", good_version) - assert good_agent_status == 'waiting', f"Expected good scoring agent to remain waiting, but status is {good_agent_status}" - - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_screener2_immediate_prune_integration(self): - """Screener-2 finish prunes agent below threshold and does not create validator evaluations.""" - from api.src.models.evaluation import Evaluation - from api.src.backend.queries.agents import get_top_agent - from api.src.utils.config import PRUNE_THRESHOLD - import os - - # Create a direct database connection for this test to avoid event loop conflicts - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Reset tables - await db_conn.execute("TRUNCATE evaluation_runs, evaluations, miner_agents, approved_version_ids, banned_hotkeys, top_agents RESTART IDENTITY CASCADE") - - set_id = 1 - # Create a top agent (approved) with final score 0.90 - top_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_top','top',1,NOW(),'scored')", - top_version, - ) - await db_conn.execute("INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1) ON CONFLICT DO NOTHING", top_version) - await db_conn.execute( - """ - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1,$2,'validator-1',$3,'completed',NOW(),NOW(),0.90), - ($4,$2,'validator-2',$3,'completed',NOW(),NOW(),0.90) - """, - str(uuid.uuid4()), top_version, set_id, str(uuid.uuid4()) - ) - - # Candidate below threshold - low_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_low','low',1,NOW(),'awaiting_screening_2')", - low_version, - ) - low_eval_id = str(uuid.uuid4()) - low_score = 0.60 - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, score) VALUES ($1,$2,'screener-2-test',$3,'waiting',NOW(),$4)", - low_eval_id, low_version, set_id, low_score - ) - low_eval = Evaluation( - evaluation_id=low_eval_id, - version_id=low_version, - validator_hotkey='screener-2-test', - set_id=set_id, - status=EvaluationStatus.waiting, - score=low_score, - ) - - # Manually replicate the finish logic that would prune the agent - # Update evaluation to completed - await db_conn.execute("UPDATE evaluations SET status = 'completed', finished_at = NOW() WHERE evaluation_id = $1", low_eval_id) - - # Check if score is below threshold and prune if needed - from api.src.backend.entities import MinerAgentScored - - # Refresh materialized view to get updated scores - await db_conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - top_agent = await MinerAgentScored.get_top_agent(db_conn) - - if top_agent and (top_agent.avg_score - low_score) > PRUNE_THRESHOLD: - # Score is too low, prune miner agent - await db_conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", low_version) - else: - # For this test, manually prune since we know the score is low - await db_conn.execute("UPDATE miner_agents SET status = 'pruned' WHERE version_id = $1", low_version) - - pruned = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", low_version) - assert pruned == 'pruned' - # Ensure no validator evaluations created for this version_id - count_validator = await db_conn.fetchval( - """ - SELECT COUNT(*) FROM evaluations - WHERE version_id = $1 AND validator_hotkey NOT LIKE 'screener-%' AND validator_hotkey NOT LIKE 'i-0%' - """, - low_version, - ) - assert count_validator == 0 - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_screener_stage2_combined_score_and_validator_creation(self): - """Test finished screener stage 2 evaluation with combined score calculation and validator evaluation creation""" - from api.src.models.evaluation import Evaluation - from api.src.models.validator import Validator - from api.src.utils.config import SCREENING_1_THRESHOLD, SCREENING_2_THRESHOLD, PRUNE_THRESHOLD - import os - - # Create a direct database connection for this test - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Reset relevant tables - await db_conn.execute("TRUNCATE evaluation_runs, evaluations, miner_agents, approved_version_ids, banned_hotkeys, top_agents RESTART IDENTITY CASCADE") - - set_id = 1 - # Add evaluation sets for current set_id - await db_conn.execute( - "INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES ($1, 'screener-1', 'test-instance-1'), ($1, 'screener-2', 'test-instance-2'), ($1, 'validator', 'test-instance-3') ON CONFLICT DO NOTHING", - set_id - ) - - # Create top agent for threshold calculation - top_agent_score = 0.90 - top_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'miner_top','top_agent',1,NOW(),'scored')", - top_version, - ) - await db_conn.execute("INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1) ON CONFLICT DO NOTHING", top_version) - # Need at least 2 validator evaluations for materialized view - await db_conn.execute( - """ - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1,$2,'validator-1',$3,'completed',NOW(),NOW(),$4), - ($5,$2,'validator-2',$3,'completed',NOW(),NOW(),$6), - ($7,$2,'validator-3',$3,'completed',NOW(),NOW(),$8) - """, - str(uuid.uuid4()), top_version, set_id, top_agent_score, - str(uuid.uuid4()), top_agent_score, - str(uuid.uuid4()), top_agent_score - ) - - # Create test agent that will go through both screening stages - test_version = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1,'test_miner','test_agent',1,NOW(),'awaiting_screening_1')", - test_version, - ) - - # Calculate dynamic scores based on top agent and thresholds - top_agent_score = 0.90 - threshold = top_agent_score - PRUNE_THRESHOLD - - # Create evaluation runs to test the new combined score calculation - # Stage 1: 4 out of 5 questions solved (80%) - stage1_solved = 4 - stage1_total = 5 - # Stage 2: 5 out of 5 questions solved (100%) - stage2_solved = 5 - stage2_total = 5 - # Combined: 9 out of 10 questions solved (90%) - expected_combined_score = (stage1_solved + stage2_solved) / (stage1_total + stage2_total) - - # Ensure combined score is above threshold (score gap should be <= PRUNE_THRESHOLD) - assert (top_agent_score - expected_combined_score) <= PRUNE_THRESHOLD, f"Test setup error: score gap {top_agent_score - expected_combined_score} should be <= PRUNE_THRESHOLD {PRUNE_THRESHOLD}" - - # 1. Create and complete stage 1 screening evaluation - stage1_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at) VALUES ($1,$2,'screener-1-test',$3,'completed',NOW(),NOW())", - stage1_eval_id, test_version, set_id - ) - - # Create evaluation runs for stage 1 (4 solved, 1 not solved) - for i in range(stage1_total): - run_id = str(uuid.uuid4()) - solved = i < stage1_solved # First 4 are solved - await db_conn.execute( - "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, solved, status, started_at) VALUES ($1,$2,$3,$4,'result_scored',NOW())", - run_id, stage1_eval_id, f"stage1-instance-{i+1}", solved - ) - - # Update agent status to awaiting_screening_2 (simulating stage 1 completion) - await db_conn.execute("UPDATE miner_agents SET status = 'awaiting_screening_2' WHERE version_id = $1", test_version) - - # 2. Create stage 2 screening evaluation - stage2_eval_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at) VALUES ($1,$2,'screener-2-test',$3,'waiting',NOW())", - stage2_eval_id, test_version, set_id - ) - - # Create evaluation runs for stage 2 (5 solved, 0 not solved) - for i in range(stage2_total): - run_id = str(uuid.uuid4()) - solved = i < stage2_solved # All 5 are solved - await db_conn.execute( - "INSERT INTO evaluation_runs (run_id, evaluation_id, swebench_instance_id, solved, status, started_at) VALUES ($1,$2,$3,$4,'result_scored',NOW())", - run_id, stage2_eval_id, f"stage2-instance-{i+1}", solved - ) - - # Create the Evaluation object and simulate finishing - stage2_eval = Evaluation( - evaluation_id=stage2_eval_id, - version_id=test_version, - validator_hotkey='screener-2-test', - set_id=set_id, - status=EvaluationStatus.waiting, - ) - - # Mock connected validators for testing - mock_validators = [ - Mock(hotkey="validator-1"), - Mock(hotkey="validator-2"), - Mock(hotkey="validator-3") - ] - - with patch.object(Validator, 'get_connected', return_value=mock_validators): - # Finish the stage 2 evaluation (this should calculate combined score and create validator evaluations) - result = await stage2_eval.finish(db_conn) - - # 3. Verify combined score calculation matches the new method - - # Check that validator evaluations were created with the combined screener score - validator_evaluations = await db_conn.fetch( - """ - SELECT evaluation_id, validator_hotkey, screener_score, status - FROM evaluations - WHERE version_id = $1 - AND validator_hotkey NOT LIKE 'screener-%' - ORDER BY validator_hotkey - """, - test_version - ) - - # 4. Verify all expected validator evaluations were created - assert len(validator_evaluations) == 3, f"Expected 3 validator evaluations, got {len(validator_evaluations)}" - - for eval_row in validator_evaluations: - assert abs(eval_row['screener_score'] - expected_combined_score) < 0.001, f"Expected combined score {expected_combined_score}, got {eval_row['screener_score']}" - assert eval_row['status'] == 'waiting', f"Expected evaluation status 'waiting', got {eval_row['status']}" - assert eval_row['validator_hotkey'] in ['validator-1', 'validator-2', 'validator-3'], f"Unexpected validator hotkey {eval_row['validator_hotkey']}" - - # 5. Verify agent status was updated to 'waiting' after stage 2 completion - agent_status = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", test_version) - assert agent_status == 'waiting', f"Expected agent status 'waiting', got {agent_status}" - - # 6. Verify stage 2 evaluation was marked as completed - stage2_status = await db_conn.fetchval("SELECT status FROM evaluations WHERE evaluation_id = $1", stage2_eval_id) - assert stage2_status == 'completed', f"Expected stage 2 evaluation status 'completed', got {stage2_status}" - - # 7. Test the combined score is calculated correctly in database queries - # Verify we can retrieve the combined score from validator evaluations - retrieved_screener_scores = await db_conn.fetch( - """ - SELECT screener_score FROM evaluations - WHERE version_id = $1 - AND validator_hotkey NOT LIKE 'screener-%' - """, - test_version - ) - - for score_row in retrieved_screener_scores: - assert abs(score_row['screener_score'] - expected_combined_score) < 0.001, f"Retrieved screener score {score_row['screener_score']} doesn't match expected combined score {expected_combined_score}" - - # 8. Verify no pruning occurred (scores are acceptable) - pruned_evaluations = await db_conn.fetchval( - "SELECT COUNT(*) FROM evaluations WHERE version_id = $1 AND status = 'pruned'", - test_version - ) - assert pruned_evaluations == 0, "No evaluations should be pruned with acceptable scores" - - agent_status_final = await db_conn.fetchval("SELECT status FROM miner_agents WHERE version_id = $1", test_version) - assert agent_status_final != 'pruned', "Agent should not be pruned with acceptable combined score" - - finally: - await db_conn.close() - - - -class TestScoreCalculation: - """Test scoring logic and materialized view operations""" - - @pytest.mark.asyncio - async def test_24_hour_statistics_calculation(self): - """Test 24-hour statistics calculation""" - mock_conn = AsyncMock() - - # Mock max set_id and statistics - mock_conn.fetchrow.side_effect = [ - {'max_set_id': 10}, # Current max set_id - { # Statistics result - 'number_of_agents': 150, - 'agent_iterations_last_24_hours': 25, - 'top_agent_score': 0.923, - 'daily_score_improvement': 0.045 - } - ] - - result = await MinerAgentScored.get_24_hour_statistics(mock_conn) - - assert result['number_of_agents'] == 150 - assert result['agent_iterations_last_24_hours'] == 25 - assert result['top_agent_score'] == 0.923 - assert result['daily_score_improvement'] == 0.045 - - @pytest.mark.asyncio - async def test_24_hour_statistics_no_evaluation_sets(self): - """Test 24-hour statistics when no evaluation sets exist""" - mock_conn = AsyncMock() - - mock_conn.fetchrow.return_value = {'max_set_id': None} - mock_conn.fetchval.side_effect = [100, 15] # total agents, recent agents - - result = await MinerAgentScored.get_24_hour_statistics(mock_conn) - - assert result['number_of_agents'] == 100 - assert result['agent_iterations_last_24_hours'] == 15 - assert result['top_agent_score'] is None - assert result['daily_score_improvement'] == 0 - - @pytest.mark.asyncio - async def test_agent_summary_by_hotkey(self): - """Test agent summary retrieval by hotkey""" - mock_conn = AsyncMock() - - agent1_id = uuid.uuid4() - agent2_id = uuid.uuid4() - created_at = datetime.now(timezone.utc) - - mock_conn.fetch.return_value = [ - { - 'version_id': agent1_id, - 'miner_hotkey': 'test_hotkey', - 'agent_name': 'agent_v2', - 'version_num': 2, - 'created_at': created_at, - 'status': 'scored', - 'agent_summary': 'Latest version', - 'set_id': 5, - 'approved': True, - 'validator_count': 3, - 'score': 0.89 - }, - { - 'version_id': agent2_id, - 'miner_hotkey': 'test_hotkey', - 'agent_name': 'agent_v1', - 'version_num': 1, - 'created_at': created_at, - 'status': 'replaced', - 'agent_summary': 'Previous version', - 'set_id': 4, - 'approved': None, - 'validator_count': None, - 'score': None - } - ] - - result = await MinerAgentScored.get_agent_summary_by_hotkey(mock_conn, "test_hotkey") - - assert len(result) == 2 - assert result[0].version_num == 2 - assert result[0].score == 0.89 - assert result[0].approved is True - assert result[1].version_num == 1 - assert result[1].status == "replaced" - - @pytest.mark.asyncio - async def test_materialized_view_refresh(self): - """Test materialized view refresh operation""" - mock_conn = AsyncMock() - - await MinerAgentScored.refresh_materialized_view(mock_conn) - - mock_conn.execute.assert_called_once_with("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - -class TestGetAgentStatus: - """Test get_agent_status endpoint for approved and banned fields""" - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_agent_status_approved_not_banned(self): - """Test agent status: approved but not banned""" - import uuid - import asyncpg - import httpx - - # Connect directly to the test database (same as server) - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Clean up tables - await db_conn.execute("DELETE FROM approved_version_ids WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = 'test_approved_not_banned')") - await db_conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = 'test_approved_not_banned'") - await db_conn.execute("DELETE FROM miner_agents WHERE miner_hotkey = 'test_approved_not_banned'") - - # Create test agent - version_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1, $2, $3, $4, NOW(), $5)", - version_id, 'test_approved_not_banned', 'test_agent', 1, 'scored' - ) - - # Add to approved list - await db_conn.execute( - "INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, $2)", - version_id, 1 - ) - - # Test via HTTP request to the running API server - async with httpx.AsyncClient() as client: - response = await client.get(f"http://localhost:8000/agents/{version_id}") - assert response.status_code == 200 - status = response.json() - - # Verify results - assert status['approved_at'] is not None, "Agent should be approved" - assert status['banned'] is False, "Agent should not be banned" - assert status['version_id'] == version_id - assert status['miner_hotkey'] == 'test_approved_not_banned' - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_agent_status_not_approved_not_banned(self): - """Test agent status: not approved and not banned""" - import uuid - import asyncpg - import httpx - - # Connect directly to the test database (same as server) - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Clean up tables - await db_conn.execute("DELETE FROM approved_version_ids WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = 'test_not_approved_not_banned')") - await db_conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = 'test_not_approved_not_banned'") - await db_conn.execute("DELETE FROM miner_agents WHERE miner_hotkey = 'test_not_approved_not_banned'") - - # Create test agent (but don't approve or ban) - version_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1, $2, $3, $4, NOW(), $5)", - version_id, 'test_not_approved_not_banned', 'test_agent', 1, 'waiting' - ) - - # Test via HTTP request to the running API server - async with httpx.AsyncClient() as client: - response = await client.get(f"http://localhost:8000/agents/{version_id}") - assert response.status_code == 200 - status = response.json() - - # Verify results - assert status['approved_at'] is None, "Agent should not be approved" - assert status['banned'] is False, "Agent should not be banned" - assert status['version_id'] == version_id - assert status['miner_hotkey'] == 'test_not_approved_not_banned' - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_agent_status_approved_and_banned(self): - """Test agent status: approved but also banned""" - import uuid - import asyncpg - import httpx - - # Connect directly to the test database (same as server) - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Clean up tables - await db_conn.execute("DELETE FROM approved_version_ids WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = 'test_approved_and_banned')") - await db_conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = 'test_approved_and_banned'") - await db_conn.execute("DELETE FROM miner_agents WHERE miner_hotkey = 'test_approved_and_banned'") - - # Create test agent - version_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1, $2, $3, $4, NOW(), $5)", - version_id, 'test_approved_and_banned', 'test_agent', 1, 'scored' - ) - - # Add to approved list - await db_conn.execute( - "INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, $2)", - version_id, 1 - ) - - # Add to banned list - await db_conn.execute( - "INSERT INTO banned_hotkeys (miner_hotkey, banned_reason) VALUES ($1, $2)", - 'test_approved_and_banned', 'Test ban reason' - ) - - # Test via HTTP request to the running API server - async with httpx.AsyncClient() as client: - response = await client.get(f"http://localhost:8000/agents/{version_id}") - assert response.status_code == 200 - status = response.json() - - # Verify results - assert status['approved_at'] is not None, "Agent should be approved" - assert status['banned'] is True, "Agent should be banned" - assert status['version_id'] == version_id - assert status['miner_hotkey'] == 'test_approved_and_banned' - finally: - await db_conn.close() - - @pytest.mark.asyncio - @pytest.mark.integration - async def test_agent_status_not_approved_but_banned(self): - """Test agent status: not approved but is banned""" - import uuid - import asyncpg - import httpx - - # Connect directly to the test database (same as server) - db_conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - try: - # Clean up tables - await db_conn.execute("DELETE FROM approved_version_ids WHERE version_id IN (SELECT version_id FROM miner_agents WHERE miner_hotkey = 'test_not_approved_but_banned')") - await db_conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = 'test_not_approved_but_banned'") - await db_conn.execute("DELETE FROM miner_agents WHERE miner_hotkey = 'test_not_approved_but_banned'") - - # Create test agent - version_id = str(uuid.uuid4()) - await db_conn.execute( - "INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) VALUES ($1, $2, $3, $4, NOW(), $5)", - version_id, 'test_not_approved_but_banned', 'test_agent', 1, 'waiting' - ) - - # Add to banned list (but not approved) - await db_conn.execute( - "INSERT INTO banned_hotkeys (miner_hotkey, banned_reason) VALUES ($1, $2)", - 'test_not_approved_but_banned', 'Test ban reason' - ) - - # Test via HTTP request to the running API server - async with httpx.AsyncClient() as client: - response = await client.get(f"http://localhost:8000/agents/{version_id}") - assert response.status_code == 200 - status = response.json() - - # Verify results - assert status['approved_at'] is None, "Agent should not be approved" - assert status['banned'] is True, "Agent should be banned" - assert status['version_id'] == version_id - assert status['miner_hotkey'] == 'test_not_approved_but_banned' - finally: - await db_conn.close() - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_real_api.py b/tests/test_real_api.py deleted file mode 100644 index 7fbd112c..00000000 --- a/tests/test_real_api.py +++ /dev/null @@ -1,145 +0,0 @@ -""" -Real API integration tests that test actual HTTP requests against the running server. -These tests require the API server to be running on localhost:8000. -""" - -import pytest -import requests -import os -import time -from typing import Optional -from requests.exceptions import RequestException - -# Base URL for the running API server -API_BASE_URL = os.getenv('API_BASE_URL', 'http://localhost:8000') - -def wait_for_server(max_retries: int = 30, delay: float = 1.0) -> bool: - """Wait for the API server to be ready.""" - for i in range(max_retries): - try: - response = requests.get(f"{API_BASE_URL}/healthcheck", timeout=5) - if response.status_code == 200: - print(f"Server is ready after {i+1} attempts") - return True - except RequestException: - pass - time.sleep(delay) - return False - -class TestRealAPIEndpoints: - """Test real API endpoints against the running server.""" - - @classmethod - def setup_class(cls): - """Wait for server to be ready before running tests.""" - if not wait_for_server(): - pytest.skip("API server is not available") - - def test_healthcheck_endpoint(self): - """Test the real healthcheck endpoint.""" - response = requests.get(f"{API_BASE_URL}/healthcheck") - assert response.status_code == 200 - assert response.text == '"OK"' - - def test_healthcheck_results_endpoint(self): - """Test the real healthcheck-results endpoint.""" - response = requests.get(f"{API_BASE_URL}/healthcheck-results") - assert response.status_code == 200 - data = response.json() - # The endpoint returns a list of platform status check records - assert isinstance(data, list) - # If there are records, they should have the expected structure - if data: - assert "checked_at" in data[0] - - def test_server_root_endpoint(self): - """Test that the server responds to root requests.""" - response = requests.get(f"{API_BASE_URL}/") - # Should return 404 for root endpoint (no route defined) - assert response.status_code == 404 - - def test_upload_endpoint_structure(self): - """Test that upload endpoint exists and validates input.""" - response = requests.post(f"{API_BASE_URL}/upload/agent", json={}) - # Should return 422 for validation error (missing required fields) - assert response.status_code == 422 - - def test_retrieval_endpoints_exist(self): - """Test that retrieval endpoints exist.""" - # Test network stats endpoint - response = requests.get(f"{API_BASE_URL}/retrieval/network-stats") - # Should either return 200 (if database is working) or 500 (if database issues) - assert response.status_code in [200, 500] - - # Test top agents endpoint - response = requests.get(f"{API_BASE_URL}/retrieval/top-agents") - assert response.status_code in [200, 500] - - def test_scoring_endpoints_exist(self): - """Test that scoring endpoints exist.""" - # Test check top agent endpoint - response = requests.get(f"{API_BASE_URL}/scoring/check-top-agent") - assert response.status_code in [200, 500] - - # Test ban agents endpoint (should return 422 for missing data) - response = requests.post(f"{API_BASE_URL}/scoring/ban-agents", json={}) - assert response.status_code == 422 - - def test_agent_summaries_endpoints_exist(self): - """Test that agent summaries endpoints exist.""" - import uuid - fake_id = str(uuid.uuid4()) - response = requests.get(f"{API_BASE_URL}/agent-summaries/agent-summary/{fake_id}") - assert response.status_code in [404, 500] # Should be 404 for non-existent agent - - def test_open_users_endpoints_exist(self): - """Test that open users endpoints exist.""" - # Test sign-in endpoint (should return 422 for missing data) - response = requests.post(f"{API_BASE_URL}/open-users/sign-in", json={}) - assert response.status_code == 422 - - def test_websocket_endpoint_exists(self): - """Test that WebSocket endpoint is accessible.""" - import websocket - try: - # Try to connect to WebSocket endpoint - ws = websocket.create_connection(f"{API_BASE_URL.replace('http', 'ws')}/ws", timeout=5) - ws.close() - assert True # If we get here, the endpoint exists - except Exception as e: - # WebSocket might not be available in test environment, that's OK - print(f"WebSocket test skipped: {e}") - assert True - -class TestRealAPIPerformance: - """Test API performance and response times.""" - - def test_healthcheck_response_time(self): - """Test that healthcheck responds quickly.""" - import time - start_time = time.time() - response = requests.get(f"{API_BASE_URL}/healthcheck", timeout=5) - end_time = time.time() - - assert response.status_code == 200 - assert (end_time - start_time) < 2.0 # Should respond within 2 seconds - - def test_concurrent_requests(self): - """Test that the server can handle multiple concurrent requests.""" - import concurrent.futures - import threading - - def make_request(): - response = requests.get(f"{API_BASE_URL}/healthcheck", timeout=5) - return response.status_code - - # Make 5 concurrent requests - with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(make_request) for _ in range(5)] - results = [future.result() for future in futures] - - # All requests should succeed - assert all(status == 200 for status in results) - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file diff --git a/tests/test_threshold_function.py b/tests/test_threshold_function.py deleted file mode 100644 index 26239318..00000000 --- a/tests/test_threshold_function.py +++ /dev/null @@ -1,479 +0,0 @@ -""" -Integration tests for the threshold function endpoint. -Tests the complete flow from database setup through to API response with various scenarios. -""" - -import pytest -import asyncpg -import uuid -from datetime import datetime, timezone, timedelta -from typing import Optional -from unittest.mock import patch - -from httpx import AsyncClient -import pytest_asyncio - -# Set environment variables for testing -import os - -if not os.getenv('AWS_MASTER_USERNAME'): - os.environ.update({ - 'AWS_MASTER_USERNAME': 'test_user', - 'AWS_MASTER_PASSWORD': 'test_pass', - 'AWS_RDS_PLATFORM_ENDPOINT': 'localhost', - 'AWS_RDS_PLATFORM_DB_NAME': 'postgres', - 'POSTGRES_TEST_URL': 'postgresql://test_user:test_pass@localhost:5432/postgres' - }) - -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -from api.src.main import app - -@pytest_asyncio.fixture -async def db_connection(): - """Direct database connection for testing""" - conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - # Setup schema - await setup_database_schema(conn) - - yield conn - - await conn.close() - - -@pytest_asyncio.fixture -async def initialized_app(): - """Initialize the FastAPI app with database connection""" - from api.src.backend.db_manager import new_db - - # Initialize the database connection pool - await new_db.open() - - yield app - - # Clean up - await new_db.close() - - -@pytest_asyncio.fixture -async def async_client(initialized_app): - """Async HTTP client for testing FastAPI endpoints""" - from httpx import ASGITransport - async with AsyncClient(transport=ASGITransport(app=initialized_app), base_url="http://testserver") as client: - yield client - - -async def setup_database_schema(conn: asyncpg.Connection): - """Setup database schema for integration tests""" - # Read the actual production schema file - schema_path = os.path.join(os.path.dirname(__file__), '..', 'api', 'src', 'backend', 'postgres_schema.sql') - with open(schema_path, 'r') as f: - schema_sql = f.read() - - # Execute the production schema - await conn.execute(schema_sql) - - # Ensure innovation column exists (in case of schema timing issues) - await conn.execute(""" - DO $$ - BEGIN - IF NOT EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = 'miner_agents' AND column_name = 'innovation' - ) THEN - ALTER TABLE miner_agents ADD COLUMN innovation DOUBLE PRECISION; - END IF; - END $$; - """) - - # Disable the approval deletion trigger for tests to allow cleanup - await conn.execute(""" - DROP TRIGGER IF EXISTS no_delete_approval_trigger ON approved_version_ids; - CREATE OR REPLACE FUNCTION prevent_delete_approval_test() RETURNS TRIGGER AS $$ - BEGIN - -- Allow deletions in test environment - RETURN OLD; - END; - $$ LANGUAGE plpgsql; - CREATE TRIGGER no_delete_approval_trigger BEFORE DELETE ON approved_version_ids - FOR EACH ROW EXECUTE FUNCTION prevent_delete_approval_test(); - """) - - # Insert test evaluation sets for testing - await conn.execute(""" - INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES - (1, 'screener-1', 'test_instance_1'), - (1, 'screener-2', 'test_instance_2'), - (1, 'validator', 'test_instance_3'), - (2, 'screener-1', 'test_instance_4'), - (2, 'screener-2', 'test_instance_5'), - (2, 'validator', 'test_instance_6') - ON CONFLICT DO NOTHING - """) - - # Insert threshold config values for testing - await conn.execute(""" - INSERT INTO threshold_config (key, value) VALUES - ('innovation_weight', 0.25), - ('decay_per_epoch', 0.05), - ('frontier_scale', 0.84), - ('improvement_weight', 0.30) - ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value - """) - - -class TestThresholdFunction: - """Test the threshold function endpoint with various database states""" - - @pytest.mark.asyncio - async def test_threshold_function_empty_database(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function with empty database""" - - # Ensure database is clean - await self._clean_database(db_connection) - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # Verify response structure - assert "threshold_function" in result - assert "current_top_score" in result - assert "current_top_approved_score" in result - assert "epoch_0_time" in result - assert "epoch_length_minutes" in result - - # With empty database, scores should be 0 - assert result["current_top_score"] == 0.0 - assert result["current_top_approved_score"] == 0.0 - assert result["epoch_0_time"] is None # No agents means no epoch 0 - assert result["epoch_length_minutes"] == 72 - - # Threshold function should still be generated with default values - assert "Math.exp" in result["threshold_function"] - - @pytest.mark.asyncio - async def test_threshold_function_single_approved_agent(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function with one approved agent""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create one approved agent with evaluations - agent_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Test Agent 1", - score=0.85, - set_id=2 - ) - - # :) - - # Drop and recreate the materialized view to pick up schema changes - await db_connection.execute("DROP MATERIALIZED VIEW IF EXISTS agent_scores CASCADE") - - # Read and recreate the materialized view with the updated schema - schema_path = os.path.join(os.path.dirname(__file__), '..', 'api', 'src', 'backend', 'postgres_schema.sql') - with open(schema_path, 'r') as f: - schema_sql = f.read() - - # Extract just the materialized view creation part - import re - mv_match = re.search(r'CREATE MATERIALIZED VIEW agent_scores.*?;', schema_sql, re.DOTALL) - if mv_match: - await db_connection.execute(mv_match.group(0)) - - # Create unique index - await db_connection.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS agent_scores_unique_idx - ON agent_scores (version_id, set_id) - """) - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # Verify scores match our test data - assert result["current_top_score"] == 0.85 - assert result["current_top_approved_score"] == 0.85 - assert result["epoch_0_time"] is not None # Should have epoch 0 time from approval - assert result["epoch_length_minutes"] == 72 - - # Verify threshold function format - threshold_func = result["threshold_function"] - assert "Math.exp" in threshold_func - assert "+" in threshold_func - assert "*" in threshold_func - - @pytest.mark.asyncio - async def test_threshold_function_multiple_agents(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function with multiple agents of different scores""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Create multiple agents with different scores - agent1 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Agent 1", - score=0.75, - set_id=2 - ) - - agent2 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="Agent 2", - score=0.90, # Higher score - set_id=2 - ) - - # Create a non-approved agent with even higher score - agent3 = await self._create_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_3", - agent_name="Agent 3", - score=0.95, # Highest score but not approved - set_id=2, - approved=False - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # current_top_score should be highest overall (0.95) - assert result["current_top_score"] == 0.95 - - # current_top_approved_score should be highest approved (0.90) - assert result["current_top_approved_score"] == 0.90 - - # Should have epoch 0 time from the top approved agent - assert result["epoch_0_time"] is not None - assert result["epoch_length_minutes"] == 72 - - @pytest.mark.asyncio - async def test_threshold_function_with_history(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function with historical top agents""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Create approved agents - agent1 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Agent 1", - score=0.80, - set_id=2 - ) - - agent2 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="Agent 2", - score=0.85, - set_id=2 - ) - - # Create historical entries in approved_top_agents_history - base_time = datetime.now(timezone.utc) - await db_connection.execute(""" - INSERT INTO approved_top_agents_history (version_id, set_id, top_at) VALUES - ($1, 2, $2), - ($3, 2, $4) - """, - agent2["version_id"], base_time, # Current top (most recent) - agent1["version_id"], base_time - timedelta(hours=1)) # Previous top - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # Verify all fields are present and valid - assert result["current_top_score"] == 0.85 - assert result["current_top_approved_score"] == 0.85 - assert result["epoch_0_time"] is not None - assert result["epoch_length_minutes"] == 72 - - # With history, the threshold function should incorporate improvement - threshold_func = result["threshold_function"] - assert isinstance(threshold_func, str) - assert "Math.exp" in threshold_func - - @pytest.mark.asyncio - async def test_threshold_function_with_innovation(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function with innovation scores""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Create agent with innovation score - agent_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Innovative Agent", - score=0.80, - set_id=2, - innovation=0.75 # High innovation - ) - - # Add to history to enable innovation calculation - await db_connection.execute(""" - INSERT INTO approved_top_agents_history (version_id, set_id, top_at) VALUES - ($1, 2, $2) - """, agent_data["version_id"], datetime.now(timezone.utc)) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # Verify innovation is incorporated into threshold function - assert result["current_top_score"] == 0.80 - assert result["current_top_approved_score"] == 0.80 - assert result["epoch_0_time"] is not None - - # The threshold function should be valid - threshold_func = result["threshold_function"] - assert "Math.exp" in threshold_func - # Innovation should boost the initial threshold value - assert "0.80" in threshold_func or "0.9" in threshold_func # Should be higher than base score - - @pytest.mark.asyncio - async def test_threshold_function_different_set_ids(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test threshold function uses latest set_id""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Create agent in set_id 1 (older) - agent1 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Old Agent", - score=0.90, - set_id=1 - ) - - # Create agent in set_id 2 (newer/latest) - agent2 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="New Agent", - score=0.75, # Lower score but in latest set - set_id=2 - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the threshold function endpoint - response = await async_client.get("/scoring/threshold-function") - assert response.status_code == 200 - - result = response.json() - - # Should use latest set_id (2), so scores should reflect agent2 - assert result["current_top_score"] == 0.75 # From set_id 2 - assert result["current_top_approved_score"] == 0.75 # From set_id 2 - assert result["epoch_0_time"] is not None - - async def _clean_database(self, conn: asyncpg.Connection): - """Clean up test data in correct order to respect foreign key constraints""" - await conn.execute("DELETE FROM approved_top_agents_history") - await conn.execute("DELETE FROM top_agents") # Add this to clean up top_agents table - await conn.execute("DELETE FROM approved_version_ids") - await conn.execute("DELETE FROM embeddings") - await conn.execute("DELETE FROM inferences") - await conn.execute("DELETE FROM evaluation_runs") - await conn.execute("DELETE FROM evaluations") - await conn.execute("DELETE FROM miner_agents") - # Refresh materialized view to reflect deletions - try: - await conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - except Exception: - # If concurrent refresh fails, try regular refresh - await conn.execute("REFRESH MATERIALIZED VIEW agent_scores") - - async def _create_approved_agent_with_evaluations(self, conn: asyncpg.Connection, miner_hotkey: str, agent_name: str, score: float, set_id: int = 2, innovation: Optional[float] = None) -> dict: - """Create an approved agent with completed evaluations""" - return await self._create_agent_with_evaluations( - conn, miner_hotkey, agent_name, score, set_id, approved=True, innovation=innovation - ) - - async def _create_agent_with_evaluations(self, conn: asyncpg.Connection, miner_hotkey: str, agent_name: str, score: float, set_id: int = 2, approved: bool = True, innovation: Optional[float] = None) -> dict: - """Create an agent with completed evaluations and optionally approve it""" - version_id = uuid.uuid4() - created_at = datetime.now(timezone.utc) - - # Insert agent - await conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status, innovation) - VALUES ($1, $2, $3, 1, $4, 'scored', $5) - """, version_id, miner_hotkey, agent_name, created_at, innovation) - - # Create evaluations for multiple validators to meet the 2+ validator requirement - validator_hotkeys = ["validator_1", "validator_2", "validator_3"] - - for validator_hotkey in validator_hotkeys: - evaluation_id = uuid.uuid4() - - # Insert evaluation - await conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, finished_at, score) - VALUES ($1, $2, $3, $4, 'completed', $5, $6, $7) - """, evaluation_id, version_id, validator_hotkey, set_id, created_at, created_at + timedelta(minutes=5), score) - - # Approve the agent if requested - if approved: - approved_at = created_at # Use same time as creation, not future time - await conn.execute(""" - INSERT INTO approved_version_ids (version_id, set_id, approved_at) - VALUES ($1, $2, $3) - """, version_id, set_id, approved_at) - - return { - "version_id": version_id, - "miner_hotkey": miner_hotkey, - "agent_name": agent_name, - "score": score, - "set_id": set_id, - "approved": approved, - "innovation": innovation - } - - - - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_threshold_scoring_simple.py b/tests/test_threshold_scoring_simple.py deleted file mode 100644 index 5b0ae24a..00000000 --- a/tests/test_threshold_scoring_simple.py +++ /dev/null @@ -1,223 +0,0 @@ -""" -Simple focused tests for threshold scoring logic. -Tests the core mathematical logic without complex database setup. -""" - -import pytest -import math -from datetime import datetime, timezone, timedelta -from unittest.mock import AsyncMock, patch - -import os -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -class TestThresholdScoringLogic: - """Test the core threshold scoring logic""" - - def test_threshold_calculation_basic(self): - """Test basic threshold calculation math""" - # Test exponential decay: threshold = floor + (t0 - floor) * exp(-k * t) - floor = 0.75 - t0 = 0.95 - k = 0.05 - - # At t=0, threshold should equal t0 - t = 0 - threshold = floor + (t0 - floor) * math.exp(-k * t) - assert abs(threshold - t0) < 0.001 - - # At t=10, threshold should be lower - t = 10 - threshold_later = floor + (t0 - floor) * math.exp(-k * t) - assert threshold_later < t0 - assert threshold_later > floor - - # As t approaches infinity, threshold approaches floor - t = 1000 - threshold_inf = floor + (t0 - floor) * math.exp(-k * t) - assert abs(threshold_inf - floor) < 0.001 - - def test_future_time_calculation(self): - """Test calculation of when threshold will reach target score""" - floor = 0.75 - t0 = 0.95 - k = 0.05 - target_score = 0.85 - - # Solve: target_score = floor + (t0 - floor) * exp(-k * t) - # t = -ln((target_score - floor) / (t0 - floor)) / k - - if target_score > floor and target_score < t0 and k > 0: - ratio = (target_score - floor) / (t0 - floor) - future_epochs = -math.log(ratio) / k - - # Verify the calculation is correct - assert future_epochs > 0 - - # Verify by plugging back into threshold function - calculated_threshold = floor + (t0 - floor) * math.exp(-k * future_epochs) - assert abs(calculated_threshold - target_score) < 0.001 - - def test_edge_cases_mathematical(self): - """Test mathematical edge cases""" - floor = 0.75 - t0 = 0.95 - k = 0.05 - - # Target score equal to floor - target_score = floor - ratio = (target_score - floor) / (t0 - floor) - assert ratio == 0 # Should be invalid - - # Target score above t0 - target_score = 1.0 - assert target_score > t0 # Should be invalid - - # Zero decay rate - k_zero = 0.0 - # With k=0, threshold never decays, so future approval impossible - - # Negative decay rate - k_negative = -0.05 - # Should be invalid - - def test_agent_scoring_logic(self): - """Test the core agent scoring decision logic""" - - def evaluate_agent_simple(agent_score, threshold, top_score): - """Simplified version of evaluation logic""" - if agent_score >= threshold: # >= to include equal case - return "approve_now" - elif agent_score > top_score: - return "approve_future" - else: - return "reject" - - # Test cases - threshold = 0.85 - top_score = 0.80 - - # High score - immediate approval - result = evaluate_agent_simple(0.90, threshold, top_score) - assert result == "approve_now" - - # Competitive score - future approval - result = evaluate_agent_simple(0.82, threshold, top_score) - assert result == "approve_future" - - # Low score - rejection - result = evaluate_agent_simple(0.70, threshold, top_score) - assert result == "reject" - - # Edge case: equal to top score - result = evaluate_agent_simple(0.80, threshold, top_score) - assert result == "reject" # Not greater than top score - - # Edge case: equal to threshold - result = evaluate_agent_simple(0.85, threshold, top_score) - assert result == "approve_now" # Equal to threshold should be approved immediately - - def test_threshold_boost_calculation(self): - """Test threshold boost from innovation and improvement""" - - # Base score - curr_score = 0.80 - prev_score = 0.75 - innovation = 0.60 - - # Constants from config - INNOVATION_WEIGHT = 0.25 - IMPROVEMENT_WEIGHT = 0.30 - FRONTIER_WEIGHT = 0.84 - - # Calculate boosts - delta = max(0.0, curr_score - prev_score) - scaling_factor = 1.0 + FRONTIER_WEIGHT * prev_score - threshold_boost = IMPROVEMENT_WEIGHT * delta * scaling_factor - innovation_boost = INNOVATION_WEIGHT * innovation - - # Calculate t0 - t0 = min(1.0, max(0.0, curr_score + threshold_boost + innovation_boost)) - - # Verify t0 is reasonable - assert t0 >= curr_score # Should be at least the current score - assert t0 <= 1.0 # Should be clamped to 1.0 - assert t0 > 0.0 # Should be positive - - # Test with no improvement - delta_zero = max(0.0, curr_score - curr_score) # 0 - threshold_boost_zero = IMPROVEMENT_WEIGHT * delta_zero * scaling_factor - assert threshold_boost_zero == 0.0 - - # Test with no innovation - innovation_boost_zero = INNOVATION_WEIGHT * 0.0 - assert innovation_boost_zero == 0.0 - - def test_precision_edge_cases(self): - """Test numerical precision edge cases""" - - # Very small differences - floor = 0.800000001 - target = 0.800000002 - t0 = 0.900000000 - k = 0.05 - - if target > floor and target < t0: - ratio = (target - floor) / (t0 - floor) - # Should handle small differences gracefully - assert ratio > 0 - assert ratio < 1 - - # Very large time horizons - floor = 0.5 - target = 0.500001 # Very close to floor - t0 = 0.9 - k = 0.001 # Very slow decay - - if target > floor and target < t0: - ratio = (target - floor) / (t0 - floor) - future_epochs = -math.log(ratio) / k - # Should produce finite result - assert math.isfinite(future_epochs) - assert future_epochs > 0 - - def test_clamping_behavior(self): - """Test that values are properly clamped""" - - # Test t0 clamping to [0, 1] - def clamp_t0(value): - return min(1.0, max(0.0, value)) - - assert clamp_t0(-0.5) == 0.0 - assert clamp_t0(0.5) == 0.5 - assert clamp_t0(1.5) == 1.0 - assert clamp_t0(0.0) == 0.0 - assert clamp_t0(1.0) == 1.0 - - def test_epoch_calculation(self): - """Test epoch calculation from timestamps""" - - # Test epoch calculation - epoch_0_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - current_time = datetime(2024, 1, 1, 1, 0, 0, tzinfo=timezone.utc) # 1 hour later - epoch_length_minutes = 30 - - # Calculate epochs passed - epoch_minutes = (current_time - epoch_0_time).total_seconds() / 60 - epochs_passed = epoch_minutes / epoch_length_minutes - - assert epoch_minutes == 60 # 1 hour = 60 minutes - assert epochs_passed == 2.0 # 60 minutes / 30 minutes per epoch = 2 epochs - - # Test with fractional epochs - current_time = datetime(2024, 1, 1, 0, 45, 0, tzinfo=timezone.utc) # 45 minutes later - epoch_minutes = (current_time - epoch_0_time).total_seconds() / 60 - epochs_passed = epoch_minutes / epoch_length_minutes - - assert epoch_minutes == 45 - assert epochs_passed == 1.5 # 45 minutes / 30 minutes per epoch = 1.5 epochs - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) diff --git a/tests/test_upload_attempts_simple.py b/tests/test_upload_attempts_simple.py deleted file mode 100644 index 374a0d58..00000000 --- a/tests/test_upload_attempts_simple.py +++ /dev/null @@ -1,288 +0,0 @@ -""" -Simple integration tests for upload attempt tracking. -Tests that all upload attempts are properly recorded in the upload_attempts table. -""" - -import pytest -import uuid -import io -from unittest.mock import patch, AsyncMock - -# Mark all tests in this module as integration tests -pytestmark = pytest.mark.integration - - -class TestUploadAttemptsSimple: - """Simple tests for upload attempt tracking""" - - @pytest.mark.asyncio - async def test_banned_hotkey_creates_upload_attempt_with_ban_reason(self, db_conn): - """Test that banned hotkey uploads create records with ban reasons""" - - # Setup: Insert a banned hotkey with reason - test_hotkey = f"test_banned_{uuid.uuid4().hex[:8]}" - ban_reason = "Code obfuscation detected in uploaded agent" - - await db_conn.execute(""" - INSERT INTO banned_hotkeys (miner_hotkey, banned_reason) - VALUES ($1, $2) - """, test_hotkey, ban_reason) - - # Test: Try to upload with banned hotkey using the track_upload decorator - from api.src.utils.upload_agent_helpers import record_upload_attempt, get_ban_reason - - # Verify ban reason can be retrieved - retrieved_ban_reason = await get_ban_reason(test_hotkey) - assert retrieved_ban_reason == ban_reason - - # Record a failed upload attempt for banned hotkey - await record_upload_attempt( - upload_type="agent", - success=False, - hotkey=test_hotkey, - agent_name="Test Agent", - filename="agent.py", - file_size_bytes=1024, - ip_address="127.0.0.1", - error_type="banned", - error_message="Your miner hotkey has been banned for attempting to obfuscate code", - ban_reason=retrieved_ban_reason, - http_status_code=403 - ) - - # Verify: Check that upload attempt was recorded with ban reason - attempts = await db_conn.fetch(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert len(attempts) == 1 - attempt = attempts[0] - assert attempt["upload_type"] == "agent" - assert attempt["hotkey"] == test_hotkey - assert attempt["success"] is False - assert attempt["error_type"] == "banned" - assert attempt["ban_reason"] == ban_reason - assert attempt["http_status_code"] == 403 - assert "banned" in attempt["error_message"].lower() - - @pytest.mark.asyncio - async def test_successful_upload_creates_attempt_record(self, db_conn): - """Test that successful uploads create records""" - - from api.src.utils.upload_agent_helpers import record_upload_attempt - - test_hotkey = f"test_success_{uuid.uuid4().hex[:8]}" - test_version_id = str(uuid.uuid4()) - - # Record a successful upload attempt - await record_upload_attempt( - upload_type="agent", - success=True, - hotkey=test_hotkey, - agent_name="Test Successful Agent", - filename="agent.py", - file_size_bytes=2048, - ip_address="192.168.1.1", - version_id=test_version_id - ) - - # Verify the record was created - attempts = await db_conn.fetch(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert len(attempts) == 1 - attempt = attempts[0] - assert attempt["upload_type"] == "agent" - assert attempt["hotkey"] == test_hotkey - assert attempt["agent_name"] == "Test Successful Agent" - assert attempt["filename"] == "agent.py" - assert attempt["file_size_bytes"] == 2048 - assert attempt["ip_address"] == "192.168.1.1" - assert attempt["success"] is True - assert attempt["error_type"] is None - assert attempt["ban_reason"] is None - assert attempt["version_id"] == test_version_id - - @pytest.mark.asyncio - async def test_open_agent_upload_creates_attempt_record(self, db_conn): - """Test that open agent uploads create records""" - - from api.src.utils.upload_agent_helpers import record_upload_attempt - - test_hotkey = f"test_open_{uuid.uuid4().hex[:8]}" - - # Record an open agent upload attempt - await record_upload_attempt( - upload_type="open-agent", - success=True, - hotkey=test_hotkey, - agent_name="Test Open Agent", - filename="agent.py", - file_size_bytes=1500, - ip_address="10.0.0.1" - ) - - # Verify the record was created - attempts = await db_conn.fetch(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert len(attempts) == 1 - attempt = attempts[0] - assert attempt["upload_type"] == "open-agent" - assert attempt["hotkey"] == test_hotkey - assert attempt["success"] is True - - @pytest.mark.asyncio - async def test_various_error_types_recorded(self, db_conn): - """Test that various error types are properly recorded""" - - from api.src.utils.upload_agent_helpers import record_upload_attempt - - test_cases = [ - { - "error_type": "rate_limit", - "http_status_code": 429, - "error_message": "You must wait 300 seconds before uploading a new agent version" - }, - { - "error_type": "validation_error", - "http_status_code": 400, - "error_message": "File size must not exceed 1MB" - }, - { - "error_type": "internal_error", - "http_status_code": 500, - "error_message": "Database connection failed" - } - ] - - for i, case in enumerate(test_cases): - test_hotkey = f"test_error_{i}_{uuid.uuid4().hex[:8]}" - - await record_upload_attempt( - upload_type="agent", - success=False, - hotkey=test_hotkey, - agent_name="Test Agent", - filename="agent.py", - file_size_bytes=1024, - error_type=case["error_type"], - error_message=case["error_message"], - http_status_code=case["http_status_code"] - ) - - # Verify the record - attempts = await db_conn.fetch(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert len(attempts) == 1 - attempt = attempts[0] - assert attempt["success"] is False - assert attempt["error_type"] == case["error_type"] - assert attempt["error_message"] == case["error_message"] - assert attempt["http_status_code"] == case["http_status_code"] - - @pytest.mark.asyncio - async def test_multiple_attempts_from_same_hotkey(self, db_conn): - """Test that multiple attempts from the same hotkey are all recorded""" - - from api.src.utils.upload_agent_helpers import record_upload_attempt - - test_hotkey = f"test_multiple_{uuid.uuid4().hex[:8]}" - - # Record multiple attempts - for i in range(3): - await record_upload_attempt( - upload_type="agent", - success=False, - hotkey=test_hotkey, - agent_name=f"Test Agent {i}", - filename="agent.py", - file_size_bytes=1024 + i * 100, - error_type="validation_error", - error_message=f"Error attempt {i}", - http_status_code=400 - ) - - # Verify all attempts were recorded - attempts = await db_conn.fetch(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 ORDER BY created_at - """, test_hotkey) - - assert len(attempts) == 3 - for i, attempt in enumerate(attempts): - assert attempt["agent_name"] == f"Test Agent {i}" - assert attempt["file_size_bytes"] == 1024 + i * 100 - assert attempt["error_message"] == f"Error attempt {i}" - - @pytest.mark.asyncio - async def test_upload_attempts_table_schema(self, db_conn): - """Test that the upload_attempts table has the correct schema""" - - # Check table exists and has expected columns - columns = await db_conn.fetch(""" - SELECT column_name, data_type, is_nullable - FROM information_schema.columns - WHERE table_name = 'upload_attempts' - ORDER BY ordinal_position - """) - - expected_columns = { - 'id': 'uuid', - 'upload_type': 'text', - 'hotkey': 'text', - 'agent_name': 'text', - 'filename': 'text', - 'file_size_bytes': 'bigint', - 'ip_address': 'text', - 'success': 'boolean', - 'error_type': 'text', - 'error_message': 'text', - 'ban_reason': 'text', - 'http_status_code': 'integer', - 'version_id': 'uuid', - 'created_at': 'timestamp with time zone' - } - - found_columns = {col['column_name']: col['data_type'] for col in columns} - - for expected_col, expected_type in expected_columns.items(): - assert expected_col in found_columns, f"Column {expected_col} not found" - assert found_columns[expected_col] == expected_type, f"Column {expected_col} has type {found_columns[expected_col]}, expected {expected_type}" - - @pytest.mark.asyncio - async def test_decorator_error_classification(self, db_conn): - """Test that the track_upload decorator correctly classifies different HTTP errors""" - - from api.src.utils.upload_agent_helpers import track_upload - from fastapi import HTTPException - - # Mock function that raises HTTPException - @track_upload("agent") - async def mock_upload_function(request, agent_file, **kwargs): - raise HTTPException(status_code=403, detail="Your miner hotkey has been banned for attempting to obfuscate code") - - # Mock request and file objects - class MockRequest: - class MockClient: - host = "127.0.0.1" - client = MockClient() - - class MockFile: - filename = "agent.py" - def __init__(self): - self.file = io.BytesIO(b"test content") - - request = MockRequest() - agent_file = MockFile() - - # Test that banned error is caught and recorded - with pytest.raises(HTTPException): - await mock_upload_function(request, agent_file, file_info="test_banned_hotkey:1", name="Test Agent") - - # Check that attempt was recorded (the decorator should handle this) - # Note: This test verifies the decorator logic works, but the actual database - # recording would happen in a real scenario with proper mocking diff --git a/tests/test_upload_tracking.py b/tests/test_upload_tracking.py deleted file mode 100644 index ea36a166..00000000 --- a/tests/test_upload_tracking.py +++ /dev/null @@ -1,258 +0,0 @@ -""" -Simplified upload tracking tests. -Tests all upload attempt tracking functionality in one consolidated file. -""" - -import pytest -import uuid -import asyncpg -from unittest.mock import AsyncMock, patch, MagicMock -from fastapi import HTTPException - - -@pytest.mark.asyncio -async def test_upload_attempts_table_structure(): - """Test that the upload_attempts table exists with correct structure""" - - db_url = "postgresql://test_user:test_pass@localhost:5432/postgres" - conn = await asyncpg.connect(db_url) - - try: - # Check table exists - table_exists = await conn.fetchval(""" - SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_name = 'upload_attempts' - ) - """) - assert table_exists, "upload_attempts table should exist" - - # Check essential columns exist - columns = await conn.fetch(""" - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = 'upload_attempts' - """) - - found_columns = {col['column_name']: col['data_type'] for col in columns} - - essential_columns = ['upload_type', 'hotkey', 'success', 'error_type', 'ban_reason', 'created_at'] - for col in essential_columns: - assert col in found_columns, f"Essential column {col} not found" - - finally: - await conn.close() - - -@pytest.mark.asyncio -async def test_upload_attempt_database_operations(): - """Test direct database operations on upload_attempts table""" - - db_url = "postgresql://test_user:test_pass@localhost:5432/postgres" - conn = await asyncpg.connect(db_url) - - test_hotkey = f"test_db_{uuid.uuid4().hex[:8]}" - - try: - # Test insertion - await conn.execute(""" - INSERT INTO upload_attempts (upload_type, success, hotkey, agent_name, - error_type, ban_reason, http_status_code) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """, 'agent', False, test_hotkey, 'Test Agent', 'banned', 'Test ban reason', 403) - - # Test retrieval - attempt = await conn.fetchrow(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert attempt is not None - assert attempt['upload_type'] == 'agent' - assert attempt['success'] is False - assert attempt['ban_reason'] == 'Test ban reason' - - finally: - await conn.execute("DELETE FROM upload_attempts WHERE hotkey = $1", test_hotkey) - await conn.close() - - -@pytest.mark.asyncio -async def test_ban_reason_retrieval(): - """Test ban reason retrieval from banned_hotkeys table""" - - db_url = "postgresql://test_user:test_pass@localhost:5432/postgres" - conn = await asyncpg.connect(db_url) - - test_hotkey = f"test_ban_{uuid.uuid4().hex[:8]}" - ban_reason = "Code obfuscation detected" - - try: - # Insert banned hotkey - await conn.execute(""" - INSERT INTO banned_hotkeys (miner_hotkey, banned_reason) - VALUES ($1, $2) - """, test_hotkey, ban_reason) - - # Test retrieval - retrieved_reason = await conn.fetchval(""" - SELECT banned_reason FROM banned_hotkeys - WHERE miner_hotkey = $1 - """, test_hotkey) - - assert retrieved_reason == ban_reason - - finally: - await conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = $1", test_hotkey) - await conn.close() - - -def test_record_upload_attempt_exists(): - """Test that the record_upload_attempt function exists and is callable""" - from api.src.utils.upload_agent_helpers import record_upload_attempt - - # Test function exists and is callable - assert callable(record_upload_attempt) - - -@pytest.mark.asyncio -async def test_record_upload_attempt_function(): - """Test that the record_upload_attempt function works correctly""" - from api.src.utils.upload_agent_helpers import record_upload_attempt - - # Mock the database transaction to avoid conflicts - with patch('api.src.utils.upload_agent_helpers.get_transaction') as mock_transaction: - mock_conn = AsyncMock() - mock_transaction.return_value.__aenter__.return_value = mock_conn - mock_transaction.return_value.__aexit__.return_value = None - - # Test function can be called - await record_upload_attempt( - upload_type='agent', - success=False, - hotkey='test_hotkey', - error_type='banned', - ban_reason='Test reason' - ) - - # Verify database call was made - mock_conn.execute.assert_called_once() - call_args = mock_conn.execute.call_args[0] - assert 'INSERT INTO upload_attempts' in call_args[0] - - -@pytest.mark.asyncio -async def test_upload_tracking_integration(): - """Test that upload tracking is integrated into the endpoints""" - from api.src.endpoints.upload import post_agent, post_open_agent - - # Test that functions exist and are callable - assert callable(post_agent) - assert callable(post_open_agent) - - # Test that record_upload_attempt is called in the upload functions - # (This is tested more thoroughly in the end-to-end tests) - - -def test_upload_endpoints_exist(): - """Test that both upload endpoints exist and are callable""" - from api.src.endpoints.upload import post_agent, post_open_agent - - # Check that the functions exist and are callable - assert callable(post_agent) - assert callable(post_open_agent) - - -@pytest.mark.asyncio -async def test_multiple_error_scenarios(): - """Test that different error scenarios can be stored in the database""" - - db_url = "postgresql://test_user:test_pass@localhost:5432/postgres" - conn = await asyncpg.connect(db_url) - - test_scenarios = [ - ("agent", "banned", "Code obfuscation detected", 403), - ("agent", "rate_limit", None, 429), - ("open-agent", "validation_error", None, 401), - ("agent", None, None, None) # Success case - ] - - test_hotkeys = [] - - try: - for i, (upload_type, error_type, ban_reason, status_code) in enumerate(test_scenarios): - test_hotkey = f"test_scenario_{i}_{uuid.uuid4().hex[:8]}" - test_hotkeys.append(test_hotkey) - - success = error_type is None - - await conn.execute(""" - INSERT INTO upload_attempts (upload_type, success, hotkey, - error_type, ban_reason, http_status_code) - VALUES ($1, $2, $3, $4, $5, $6) - """, upload_type, success, test_hotkey, error_type, ban_reason, status_code) - - # Verify record was inserted correctly - attempt = await conn.fetchrow(""" - SELECT * FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert attempt is not None - assert attempt['upload_type'] == upload_type - assert attempt['success'] == success - assert attempt['error_type'] == error_type - assert attempt['ban_reason'] == ban_reason - - finally: - # Cleanup - for test_hotkey in test_hotkeys: - await conn.execute("DELETE FROM upload_attempts WHERE hotkey = $1", test_hotkey) - await conn.close() - - -@pytest.mark.asyncio -async def test_ban_reasons_storage(): - """Test that various ban reasons are properly stored""" - - db_url = "postgresql://test_user:test_pass@localhost:5432/postgres" - conn = await asyncpg.connect(db_url) - - ban_reasons = [ - "Code obfuscation detected in uploaded agent", - "Malicious code patterns detected", - "Agent code plagiarized from another miner", - "Repeated spam uploads detected" - ] - - test_hotkeys = [] - - try: - for i, ban_reason in enumerate(ban_reasons): - test_hotkey = f"test_ban_reason_{i}_{uuid.uuid4().hex[:8]}" - test_hotkeys.append(test_hotkey) - - # Insert banned hotkey - await conn.execute(""" - INSERT INTO banned_hotkeys (miner_hotkey, banned_reason) - VALUES ($1, $2) - """, test_hotkey, ban_reason) - - # Insert corresponding upload attempt - await conn.execute(""" - INSERT INTO upload_attempts (upload_type, success, hotkey, - error_type, ban_reason, http_status_code) - VALUES ($1, $2, $3, $4, $5, $6) - """, 'agent', False, test_hotkey, 'banned', ban_reason, 403) - - # Verify the ban reason was stored correctly - stored_attempt = await conn.fetchrow(""" - SELECT ban_reason FROM upload_attempts WHERE hotkey = $1 - """, test_hotkey) - - assert stored_attempt['ban_reason'] == ban_reason - - finally: - # Cleanup - for test_hotkey in test_hotkeys: - await conn.execute("DELETE FROM upload_attempts WHERE hotkey = $1", test_hotkey) - await conn.execute("DELETE FROM banned_hotkeys WHERE miner_hotkey = $1", test_hotkey) - await conn.close() diff --git a/tests/test_upload_tracking_unit.py b/tests/test_upload_tracking_unit.py deleted file mode 100644 index a29921d8..00000000 --- a/tests/test_upload_tracking_unit.py +++ /dev/null @@ -1,379 +0,0 @@ -""" -Unit tests for upload attempt tracking functionality. -Tests the core logic without requiring full database integration. -""" - -import pytest -import uuid -from unittest.mock import AsyncMock, patch - -# Mark all tests in this module as unit tests -pytestmark = pytest.mark.unit - - -class TestUploadTrackingUnit: - """Unit tests for upload tracking functions""" - - @pytest.mark.asyncio - async def test_record_upload_attempt_function(self): - """Test that record_upload_attempt function calls database correctly""" - - # Mock the database connection and transaction - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - from api.src.utils.upload_agent_helpers import record_upload_attempt - - # Test successful upload record - await record_upload_attempt( - upload_type="agent", - success=True, - hotkey="test_hotkey", - agent_name="Test Agent", - filename="agent.py", - file_size_bytes=1024, - ip_address="127.0.0.1", - version_id="test-version-id" - ) - - # Verify the database execute was called with correct parameters - mock_conn.execute.assert_called_once() - call_args = mock_conn.execute.call_args - - # Check the SQL query structure - sql_query = call_args[0][0] - assert "INSERT INTO upload_attempts" in sql_query - assert "upload_type" in sql_query - assert "success" in sql_query - assert "hotkey" in sql_query - - # Check the parameters - params = call_args[0][1:] - assert params[0] == "agent" # upload_type - assert params[1] is True # success - assert params[2] == "test_hotkey" # hotkey - assert params[3] == "Test Agent" # agent_name - assert params[4] == "agent.py" # filename - assert params[5] == 1024 # file_size_bytes - assert params[6] == "127.0.0.1" # ip_address - assert params[11] == "test-version-id" # version_id - - @pytest.mark.asyncio - async def test_record_upload_attempt_with_error(self): - """Test recording failed upload attempt with error details""" - - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - from api.src.utils.upload_agent_helpers import record_upload_attempt - - # Test failed upload record with ban reason - await record_upload_attempt( - upload_type="agent", - success=False, - hotkey="banned_hotkey", - agent_name="Banned Agent", - filename="agent.py", - file_size_bytes=2048, - ip_address="192.168.1.1", - error_type="banned", - error_message="Your miner hotkey has been banned", - ban_reason="Code obfuscation detected", - http_status_code=403 - ) - - # Verify the call - mock_conn.execute.assert_called_once() - call_args = mock_conn.execute.call_args - params = call_args[0][1:] - - assert params[0] == "agent" # upload_type - assert params[1] is False # success - assert params[2] == "banned_hotkey" # hotkey - assert params[7] == "banned" # error_type - assert params[8] == "Your miner hotkey has been banned" # error_message - assert params[9] == "Code obfuscation detected" # ban_reason - assert params[10] == 403 # http_status_code - - @pytest.mark.asyncio - async def test_get_ban_reason_function(self): - """Test that get_ban_reason function queries database correctly""" - - mock_conn = AsyncMock() - mock_conn.fetchval.return_value = "Test ban reason" - - with patch('api.src.backend.queries.agents.db_operation') as mock_decorator: - # Mock the decorator to directly call the function - mock_decorator.side_effect = lambda func: func - - with patch('api.src.backend.db_manager.new_db.acquire') as mock_acquire: - mock_acquire.return_value.__aenter__.return_value = mock_conn - mock_acquire.return_value.__aexit__.return_value = None - - from api.src.backend.queries.agents import get_ban_reason - - # Test getting ban reason - result = await get_ban_reason("test_hotkey") - - # Verify the query was called correctly - mock_conn.fetchval.assert_called_once() - call_args = mock_conn.fetchval.call_args - - # Check SQL query - sql_query = call_args[0][0] - assert "SELECT banned_reason FROM banned_hotkeys" in sql_query - assert "WHERE miner_hotkey = $1" in sql_query - - # Check parameter - assert call_args[0][1] == "test_hotkey" - - # Check result - assert result == "Test ban reason" - - @pytest.mark.asyncio - async def test_track_upload_decorator_success(self): - """Test that track_upload decorator records successful uploads""" - - # Mock database operations - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - from api.src.utils.upload_agent_helpers import track_upload - - # Create a mock upload function - @track_upload("agent") - async def mock_upload_function(request, agent_file, **kwargs): - return type('Response', (), {'message': 'Successfully uploaded agent test-version-id for miner test_hotkey.'})() - - # Mock request and file objects - class MockRequest: - class MockClient: - host = "127.0.0.1" - client = MockClient() - - class MockFile: - filename = "agent.py" - def __init__(self): - import io - self.file = io.BytesIO(b"test content") - - request = MockRequest() - agent_file = MockFile() - - # Call the decorated function - result = await mock_upload_function( - request, agent_file, - file_info="test_hotkey:1", - name="Test Agent" - ) - - # Verify the database record was created - mock_conn.execute.assert_called() - call_args = mock_conn.execute.call_args - params = call_args[0][1:] - - assert params[0] == "agent" # upload_type - assert params[1] is True # success - assert params[2] == "test_hotkey" # hotkey - assert params[3] == "Test Agent" # agent_name - assert params[4] == "agent.py" # filename - assert params[5] == 12 # file_size_bytes (len("test content")) - assert params[6] == "127.0.0.1" # ip_address - - @pytest.mark.asyncio - async def test_track_upload_decorator_banned_error(self): - """Test that track_upload decorator records banned upload attempts""" - - # Mock database operations - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - with patch('api.src.backend.queries.agents.get_ban_reason', return_value="Code obfuscation detected"): - from api.src.utils.upload_agent_helpers import track_upload - from fastapi import HTTPException - - # Create a mock upload function that raises banned exception - @track_upload("agent") - async def mock_upload_function(request, agent_file, **kwargs): - raise HTTPException(status_code=403, detail="Your miner hotkey has been banned for attempting to obfuscate code") - - # Mock request and file objects - class MockRequest: - class MockClient: - host = "127.0.0.1" - client = MockClient() - - class MockFile: - filename = "agent.py" - def __init__(self): - import io - self.file = io.BytesIO(b"test content") - - request = MockRequest() - agent_file = MockFile() - - # Call the decorated function and expect HTTPException - with pytest.raises(HTTPException) as exc_info: - await mock_upload_function( - request, agent_file, - file_info="banned_hotkey:1", - name="Banned Agent" - ) - - # Verify the exception details - assert exc_info.value.status_code == 403 - assert "banned" in exc_info.value.detail.lower() - - # Verify the database record was created for the failed attempt - assert mock_conn.execute.call_count >= 1 - - # Find the call that recorded the failed upload - failed_upload_call = None - for call in mock_conn.execute.call_args_list: - params = call[0][1:] - if params[1] is False: # success = False - failed_upload_call = params - break - - assert failed_upload_call is not None - assert failed_upload_call[0] == "agent" # upload_type - assert failed_upload_call[1] is False # success - assert failed_upload_call[2] == "banned_hotkey" # hotkey - assert failed_upload_call[7] == "banned" # error_type - assert failed_upload_call[9] == "Code obfuscation detected" # ban_reason - assert failed_upload_call[10] == 403 # http_status_code - - @pytest.mark.asyncio - async def test_track_upload_decorator_rate_limit_error(self): - """Test that track_upload decorator records rate limit errors""" - - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - from api.src.utils.upload_agent_helpers import track_upload - from fastapi import HTTPException - - @track_upload("agent") - async def mock_upload_function(request, agent_file, **kwargs): - raise HTTPException(status_code=429, detail="You must wait 300 seconds before uploading") - - class MockRequest: - class MockClient: - host = "127.0.0.1" - client = MockClient() - - class MockFile: - filename = "agent.py" - def __init__(self): - import io - self.file = io.BytesIO(b"test content") - - request = MockRequest() - agent_file = MockFile() - - with pytest.raises(HTTPException) as exc_info: - await mock_upload_function( - request, agent_file, - file_info="rate_limited_hotkey:1", - name="Rate Limited Agent" - ) - - assert exc_info.value.status_code == 429 - - # Verify rate limit error was recorded - mock_conn.execute.assert_called() - call_args = mock_conn.execute.call_args - params = call_args[0][1:] - - assert params[0] == "agent" # upload_type - assert params[1] is False # success - assert params[2] == "rate_limited_hotkey" # hotkey - assert params[7] == "rate_limit" # error_type - assert params[10] == 429 # http_status_code - - @pytest.mark.asyncio - async def test_track_upload_decorator_open_agent(self): - """Test that track_upload decorator works for open agent uploads""" - - mock_conn = AsyncMock() - mock_transaction_context = AsyncMock() - mock_transaction_context.__aenter__.return_value = mock_conn - mock_transaction_context.__aexit__.return_value = None - - with patch('api.src.utils.upload_agent_helpers.get_transaction', return_value=mock_transaction_context): - from api.src.utils.upload_agent_helpers import track_upload - - @track_upload("open-agent") - async def mock_upload_function(request, agent_file, **kwargs): - return type('Response', (), {'message': 'Successfully uploaded agent test-version-id for open user test_open_user.'})() - - class MockRequest: - class MockClient: - host = "10.0.0.1" - client = MockClient() - - class MockFile: - filename = "agent.py" - def __init__(self): - import io - self.file = io.BytesIO(b"open agent content") - - request = MockRequest() - agent_file = MockFile() - - # Call with open_hotkey parameter - result = await mock_upload_function( - request, agent_file, - open_hotkey="test_open_user", - name="Test Open Agent" - ) - - # Verify the database record - mock_conn.execute.assert_called() - call_args = mock_conn.execute.call_args - params = call_args[0][1:] - - assert params[0] == "open-agent" # upload_type - assert params[1] is True # success - assert params[2] == "test_open_user" # hotkey - assert params[3] == "Test Open Agent" # agent_name - assert params[6] == "10.0.0.1" # ip_address - - def test_error_type_classification(self): - """Test that different HTTP errors are classified correctly""" - from api.src.utils.upload_agent_helpers import track_upload - - # Test error classification logic (extracted from decorator) - test_cases = [ - (403, "Your miner hotkey has been banned", "banned"), - (429, "Rate limit exceeded", "rate_limit"), - (400, "Invalid signature", "validation_error"), - (400, "File size too large", "validation_error"), - (503, "No screeners available", "validation_error"), - (500, "Internal server error", "validation_error"), - ] - - for status_code, detail, expected_error_type in test_cases: - if status_code == 403 and "banned" in detail.lower(): - error_type = "banned" - elif status_code == 429: - error_type = "rate_limit" - else: - error_type = "validation_error" - - assert error_type == expected_error_type, f"Status {status_code} with detail '{detail}' should be classified as '{expected_error_type}', got '{error_type}'" diff --git a/tests/test_weights_setting.py b/tests/test_weights_setting.py deleted file mode 100644 index 7629b987..00000000 --- a/tests/test_weights_setting.py +++ /dev/null @@ -1,623 +0,0 @@ -""" -Simple integration tests for the weights function endpoint. -Tests the complete flow from database setup through to API response. -""" - -import pytest -import asyncpg -import uuid -from datetime import datetime, timezone -from typing import Optional -from unittest.mock import patch - -from httpx import AsyncClient -import pytest_asyncio - -# Set environment variables for testing -import os - -if not os.getenv('AWS_MASTER_USERNAME'): - os.environ.update({ - 'AWS_MASTER_USERNAME': 'test_user', - 'AWS_MASTER_PASSWORD': 'test_pass', - 'AWS_RDS_PLATFORM_ENDPOINT': 'localhost', - 'AWS_RDS_PLATFORM_DB_NAME': 'postgres', - 'POSTGRES_TEST_URL': 'postgresql://test_user:test_pass@localhost:5432/postgres' - }) - -import sys -sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'api', 'src')) - -from api.src.main import app - -@pytest_asyncio.fixture -async def db_connection(): - """Direct database connection for testing""" - conn = await asyncpg.connect( - user='test_user', - password='test_pass', - host='localhost', - port=5432, - database='postgres' - ) - - # Setup schema - await setup_database_schema(conn) - - yield conn - - await conn.close() - - -@pytest_asyncio.fixture -async def initialized_app(): - """Initialize the FastAPI app with database connection""" - from api.src.backend.db_manager import new_db - - # Initialize the database connection pool - await new_db.open() - - yield app - - # Clean up - await new_db.close() - - -@pytest_asyncio.fixture -async def async_client(initialized_app): - """Async HTTP client for testing FastAPI endpoints""" - from httpx import ASGITransport - async with AsyncClient(transport=ASGITransport(app=initialized_app), base_url="http://testserver") as client: - yield client - - -async def setup_database_schema(conn: asyncpg.Connection): - """Setup database schema for integration tests""" - # Read the actual production schema file - schema_path = os.path.join(os.path.dirname(__file__), '..', 'api', 'src', 'backend', 'postgres_schema.sql') - with open(schema_path, 'r') as f: - schema_sql = f.read() - - # Execute the production schema - await conn.execute(schema_sql) - - # Disable the approval deletion trigger for tests to allow cleanup - await conn.execute(""" - DROP TRIGGER IF EXISTS no_delete_approval_trigger ON approved_version_ids; - CREATE OR REPLACE FUNCTION prevent_delete_approval_test() RETURNS TRIGGER AS $$ - BEGIN - -- Allow deletions in test environment - RETURN OLD; - END; - $$ LANGUAGE plpgsql; - CREATE TRIGGER no_delete_approval_trigger BEFORE DELETE ON approved_version_ids - FOR EACH ROW EXECUTE FUNCTION prevent_delete_approval_test(); - """) - - # Insert test evaluation sets for testing - await conn.execute(""" - INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) VALUES - (1, 'screener-1', 'test_instance_1'), - (1, 'screener-2', 'test_instance_2'), - (1, 'validator', 'test_instance_3') - ON CONFLICT DO NOTHING - """) - - -class TestWeightsSetting: - """Test the weights function with various database states""" - - @pytest.fixture(autouse=True) - def mock_check_registered(self): - """Mock check_if_hotkey_is_registered to always return True for testing""" - with patch('api.src.endpoints.scoring.check_if_hotkey_is_registered', return_value=True): - yield - - @pytest.mark.asyncio - async def test_weights_empty_database(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test weights function with empty database - should return empty dict""" - - # Ensure database is clean - await self._clean_database(db_connection) - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - assert weights == {}, "Empty database should return empty weights dict" - - @pytest.mark.asyncio - async def test_weights_single_top_agent(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test weights function with only one approved top agent""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create one approved agent with evaluations - agent_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Test Agent 1", - score=0.85 - ) - - # Refresh materialized view to ensure it's up to date - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Should have one agent with full weight (1.0 - dust_weight) - expected_dust_weight = 1/65535 - expected_top_weight = 1.0 - expected_dust_weight - - assert len(weights) == 1 - assert weights[agent_data["miner_hotkey"]] == expected_top_weight - assert abs(weights[agent_data["miner_hotkey"]] - expected_top_weight) < 0.0001 - - @pytest.mark.asyncio - async def test_weights_multiple_approved_agents(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test weights function with multiple approved agents - top agent gets most weight""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create multiple approved agents with different scores - agent1_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Test Agent 1", - score=0.90 # Top agent - ) - - agent2_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="Test Agent 2", - score=0.80 # Lower score - ) - - agent3_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_3", - agent_name="Test Agent 3", - score=0.75 # Lowest score - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Should have 3 agents - assert len(weights) == 3 - - # All agents should be present - assert agent1_data["miner_hotkey"] in weights - assert agent2_data["miner_hotkey"] in weights - assert agent3_data["miner_hotkey"] in weights - - # Check dust weights for non-top agents - expected_dust_weight = 1/65535 - assert weights[agent2_data["miner_hotkey"]] == expected_dust_weight - assert weights[agent3_data["miner_hotkey"]] == expected_dust_weight - - # Check top agent weight (should be 1.0 - 3 * dust_weight) - expected_top_weight = 1.0 - (3 * expected_dust_weight) - assert abs(weights[agent1_data["miner_hotkey"]] - expected_top_weight) < 0.0001 - - # Verify total weights sum to 1.0 - total_weight = sum(weights.values()) - assert abs(total_weight - 1.0) < 0.0001 - - @pytest.mark.asyncio - async def test_weights_highest_score_wins(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that the agent with the highest score becomes the top agent""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create two approved agents with different scores - agent1_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Test Agent 1", - score=0.80 - ) - - agent2_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="Test Agent 2", - score=0.81 # Higher score - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Agent 2 should be the top agent (higher score) - expected_dust_weight = 1/65535 - expected_top_weight = 1.0 - (2 * expected_dust_weight) - - assert weights[agent2_data["miner_hotkey"]] == expected_top_weight - assert weights[agent1_data["miner_hotkey"]] == expected_dust_weight - - @pytest.mark.asyncio - async def test_weights_leadership_change(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that a challenger with a higher score takes leadership""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create two approved agents where the second one has a higher score - agent1_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_1", - agent_name="Test Agent 1", - score=0.80 - ) - - agent2_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="test_hotkey_2", - agent_name="Test Agent 2", - score=0.82 # Higher score - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Agent 2 should now be the top agent - expected_dust_weight = 1/65535 - expected_top_weight = 1.0 - (2 * expected_dust_weight) - - assert weights[agent2_data["miner_hotkey"]] == expected_top_weight - assert weights[agent1_data["miner_hotkey"]] == expected_dust_weight - - @pytest.mark.asyncio - async def test_weights_unapproved_agents_ignored(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that unapproved agents are not included in weights""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create one approved agent and one unapproved agent - approved_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="approved_hotkey", - agent_name="Approved Agent", - score=0.85, - approved=True - ) - - unapproved_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="unapproved_hotkey", - agent_name="Unapproved Agent", - score=0.90, # Higher score but unapproved - approved=False - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Only approved agent should be in weights - assert len(weights) == 1 - assert approved_agent["miner_hotkey"] in weights - assert unapproved_agent["miner_hotkey"] not in weights - - @pytest.mark.asyncio - async def test_weights_unapproved_agents_with_evaluations_ignored(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that unapproved agents with evaluations are not included in weights""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup: Create an agent with evaluations but not approved - agent_data = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="unapproved_with_evaluations_hotkey", - agent_name="Unapproved Agent with Evaluations", - score=0.95, - approved=False # This should exclude it from weights - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Agent should not be in weights (not approved) - assert len(weights) == 0 - assert agent_data["miner_hotkey"] not in weights - - @pytest.mark.asyncio - async def test_weights_open_miners_excluded(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that miners with hotkeys beginning with 'open-' are never included in weights""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup treasury wallet for testing - treasury_hotkey = await self._setup_treasury_wallet(db_connection) - - # Setup: Create a regular approved agent and an approved agent with 'open-' hotkey - regular_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="regular_hotkey", - agent_name="Regular Agent", - score=0.80, - approved=True - ) - - open_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="open-test_miner_123", - agent_name="Open Miner Agent", - score=0.95, # Higher score than regular agent - approved=True - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Treasury should get the top weight since open agent has highest score - # Regular agent should get dust weight - # Open agent should be excluded entirely - expected_dust_weight = 1/65535 - expected_treasury_weight = 1.0 - expected_dust_weight - - assert len(weights) == 2, "Should have treasury and regular agent" - assert treasury_hotkey in weights, "Treasury should get top weight when open miner has highest score" - assert regular_agent["miner_hotkey"] in weights, "Regular agent should get dust weight" - assert open_agent["miner_hotkey"] not in weights, "Agents with hotkeys beginning with 'open-' should never receive weights" - - # Verify weight distribution - assert abs(weights[treasury_hotkey] - expected_treasury_weight) < 0.0001, "Treasury should get the top weight" - assert weights[regular_agent["miner_hotkey"]] == expected_dust_weight, "Regular agent should get dust weight" - - @pytest.mark.asyncio - async def test_weights_open_miner_top_goes_to_treasury(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that when an open miner is the actual top agent (highest score), weight goes to treasury hotkey instead""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup treasury wallet for testing - treasury_hotkey = await self._setup_treasury_wallet(db_connection) - - # Setup: Create a regular approved agent and an approved agent with 'open-' hotkey that has higher score - regular_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="regular_hotkey", - agent_name="Regular Agent", - score=0.80, - approved=True - ) - - open_top_agent = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="open-top_miner_456", - agent_name="Open Top Miner Agent", - score=0.95, # Highest score - IS the top agent, but weight should go to treasury - approved=True - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # The open miner IS the top agent (highest score), but treasury should get the weight instead - # Regular agent should have dust weight - # Open miner should not be present at all - expected_dust_weight = 1/65535 - expected_treasury_weight = 1.0 - expected_dust_weight - - assert len(weights) == 2, "Should have treasury and regular agent" - assert treasury_hotkey in weights, "Treasury hotkey should receive the weight when the actual top agent is an open miner" - assert regular_agent["miner_hotkey"] in weights, "Regular agent should still get dust weight" - assert open_top_agent["miner_hotkey"] not in weights, "Open miner should never receive weights, even when it's the top agent" - - # Verify weight distribution - treasury gets the weight that would have gone to the open top agent - assert abs(weights[treasury_hotkey] - expected_treasury_weight) < 0.0001, "Treasury should get the weight that would have gone to the open top agent" - assert weights[regular_agent["miner_hotkey"]] == expected_dust_weight, "Regular agent should get dust weight" - - @pytest.mark.asyncio - async def test_weights_multiple_open_miners_no_dust_weight(self, async_client: AsyncClient, db_connection: asyncpg.Connection): - """Test that multiple open miners never receive any weight, not even dust weight""" - - # Ensure database is clean first - await self._clean_database(db_connection) - - # Setup treasury wallet for testing - treasury_hotkey = await self._setup_treasury_wallet(db_connection) - - # Setup: Create multiple agents including several open miners - regular_agent1 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="regular_hotkey_1", - agent_name="Regular Agent 1", - score=0.70, - approved=True - ) - - regular_agent2 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="regular_hotkey_2", - agent_name="Regular Agent 2", - score=0.75, # Highest regular agent - should be top - approved=True - ) - - open_agent1 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="open-miner_1", - agent_name="Open Miner 1", - score=0.85, # Higher than regulars - approved=True - ) - - open_agent2 = await self._create_approved_agent_with_evaluations( - db_connection, - miner_hotkey="open-miner_2", - agent_name="Open Miner 2", - score=0.90, # Highest overall - approved=True - ) - - # Refresh materialized view - await db_connection.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - # Call the weights endpoint - response = await async_client.get("/scoring/weights") - assert response.status_code == 200 - - weights = response.json() - - # Treasury should get the top weight because open_agent2 (score 0.90) is the actual top agent - # Regular agents should get dust weight - # Open miners should not be present at all - expected_dust_weight = 1/65535 - expected_treasury_weight = 1.0 - (2 * expected_dust_weight) # 2 regular agents get dust, treasury gets the rest - - assert len(weights) == 3, "Should have treasury + 2 regular agents, no open miners" - assert treasury_hotkey in weights, "Treasury should get the top weight since open_agent2 is the actual top agent" - assert regular_agent1["miner_hotkey"] in weights - assert regular_agent2["miner_hotkey"] in weights - assert open_agent1["miner_hotkey"] not in weights, "Open miners should never receive weights" - assert open_agent2["miner_hotkey"] not in weights, "Open miners should never receive weights, even when they're the top agent" - - # Verify weight distribution - treasury gets the top weight that would have gone to open_agent2 - assert weights[regular_agent1["miner_hotkey"]] == expected_dust_weight - assert weights[regular_agent2["miner_hotkey"]] == expected_dust_weight - assert abs(weights[treasury_hotkey] - expected_treasury_weight) < 0.0001, "Treasury should get the weight that would have gone to the open top agent" - - # Verify total weights sum to 1.0 - total_weight = sum(weights.values()) - assert abs(total_weight - 1.0) < 0.0001 - - # Helper methods for test setup - - async def _clean_database(self, conn: asyncpg.Connection): - """Clean all test data from database""" - await conn.execute("DELETE FROM evaluation_runs") - await conn.execute("DELETE FROM evaluations") - await conn.execute("DELETE FROM approved_version_ids") - await conn.execute("DELETE FROM approved_top_agents_history") # Delete history before miner_agents due to foreign key - await conn.execute("DELETE FROM top_agents") # Delete top_agents before miner_agents due to foreign key - await conn.execute("DELETE FROM miner_agents") - await conn.execute("DELETE FROM evaluation_sets") - await conn.execute("DELETE FROM treasury_wallets") - await conn.execute("REFRESH MATERIALIZED VIEW CONCURRENTLY agent_scores") - - async def _setup_treasury_wallet(self, conn: asyncpg.Connection, hotkey: str = "test_treasury_hotkey"): - """Setup a treasury wallet for testing""" - # First try to delete any existing record to avoid conflicts - await conn.execute("DELETE FROM treasury_wallets WHERE hotkey = $1", hotkey) - # Then insert the new record using the standard schema - await conn.execute(""" - INSERT INTO treasury_wallets (hotkey, active) - VALUES ($1, TRUE) - """, hotkey) - return hotkey - - async def _create_approved_agent_with_evaluations( - self, - conn: asyncpg.Connection, - miner_hotkey: str, - agent_name: str, - score: float, - approved: bool = True - ) -> dict: - """Create an agent with evaluations and optionally approve it""" - - # Create evaluation set if it doesn't exist - await conn.execute(""" - INSERT INTO evaluation_sets (set_id, type, swebench_instance_id) - VALUES (1, 'validator', 'test_instance_1') - ON CONFLICT DO NOTHING - """) - - # Create agent - version_id = str(uuid.uuid4()) - await conn.execute(""" - INSERT INTO miner_agents (version_id, miner_hotkey, agent_name, version_num, created_at, status) - VALUES ($1, $2, $3, $4, $5, $6) - """, version_id, miner_hotkey, agent_name, 1, datetime.now(timezone.utc), "active") - - # Approve agent if requested - if approved: - await conn.execute(""" - INSERT INTO approved_version_ids (version_id, set_id) VALUES ($1, 1) - """, version_id) - - # Create evaluations with 3 different validators with slightly different scores - # This ensures that after removing the lowest score, we still have 2+ validators - evaluation1_id = str(uuid.uuid4()) - evaluation2_id = str(uuid.uuid4()) - evaluation3_id = str(uuid.uuid4()) - - await conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, score) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """, evaluation1_id, version_id, "validator_1", 1, "completed", datetime.now(timezone.utc), score) - - await conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, score) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """, evaluation2_id, version_id, "validator_2", 1, "completed", datetime.now(timezone.utc), score + 0.01) - - await conn.execute(""" - INSERT INTO evaluations (evaluation_id, version_id, validator_hotkey, set_id, status, created_at, score) - VALUES ($1, $2, $3, $4, $5, $6, $7) - """, evaluation3_id, version_id, "validator_3", 1, "completed", datetime.now(timezone.utc), score + 0.02) - - return { - "version_id": version_id, - "miner_hotkey": miner_hotkey, - "agent_name": agent_name, - "score": score - } - - \ No newline at end of file