From fa972684a08ce3b2567c74ddac6875c621df1de0 Mon Sep 17 00:00:00 2001 From: yaswanth169 Date: Wed, 28 Jan 2026 00:50:14 +0530 Subject: [PATCH 1/2] feat: Add AlphaFold3 structure prediction tool - Add biomni/tool/structure_prediction.py with functions for: - Single protein structure prediction - Protein-protein complex modeling - Protein-DNA/RNA interaction prediction - Protein-ligand binding prediction - Batch processing and confidence analysis - Add comprehensive unit tests (33 tests) - Add documentation in docs/structure_prediction.md --- biomni/tool/structure_prediction.py | 790 ++++++++++++++++++++++++++++ docs/structure_prediction.md | 183 +++++++ tests/test_structure_prediction.py | 406 ++++++++++++++ 3 files changed, 1379 insertions(+) create mode 100644 biomni/tool/structure_prediction.py create mode 100644 docs/structure_prediction.md create mode 100644 tests/test_structure_prediction.py diff --git a/biomni/tool/structure_prediction.py b/biomni/tool/structure_prediction.py new file mode 100644 index 00000000..66250367 --- /dev/null +++ b/biomni/tool/structure_prediction.py @@ -0,0 +1,790 @@ +"""Structure prediction tools using AlphaFold3 for protein complex modeling. + +This module provides functions for predicting protein structures and their +interactions with other proteins, nucleic acids, and small molecules using +the AlphaFold Server API. +""" + +import hashlib +import json +import os +import time +from typing import Any + +import requests + +ALPHAFOLD_SERVER_URL = "https://alphafoldserver.com" +ALPHAFOLD_API_URL = f"{ALPHAFOLD_SERVER_URL}/api" + + +def _get_api_token() -> str | None: + """Retrieve the AlphaFold Server API token from environment.""" + return os.environ.get("ALPHAFOLD_API_TOKEN") + + +def _make_api_request( + endpoint: str, + method: str = "GET", + data: dict | None = None, + headers: dict | None = None, + timeout: int = 30, +) -> dict[str, Any]: + """Make an authenticated request to the AlphaFold Server API.""" + token = _get_api_token() + + if headers is None: + headers = {} + + headers["Content-Type"] = "application/json" + if token: + headers["Authorization"] = f"Bearer {token}" + + url = f"{ALPHAFOLD_API_URL}/{endpoint.lstrip('/')}" + + try: + if method.upper() == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + elif method.upper() == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + else: + return {"success": False, "error": f"Unsupported method: {method}"} + + response.raise_for_status() + + try: + result = response.json() + except ValueError: + result = {"raw_text": response.text} + + return {"success": True, "result": result} + + except requests.exceptions.Timeout: + return {"success": False, "error": "Request timed out"} + except requests.exceptions.RequestException as e: + error_msg = str(e) + if hasattr(e, "response") and e.response is not None: + try: + error_data = e.response.json() + error_msg = error_data.get("error", error_data.get("message", str(e))) + except ValueError: + error_msg = e.response.text[:500] if e.response.text else str(e) + return {"success": False, "error": error_msg} + + +def _validate_sequence(sequence: str) -> tuple[bool, str]: + """Validate a protein sequence contains only valid amino acid codes.""" + valid_aa = set("ACDEFGHIKLMNPQRSTVWY") + sequence_upper = sequence.upper().replace(" ", "").replace("\n", "") + + invalid_chars = set(sequence_upper) - valid_aa + if invalid_chars: + return False, f"Invalid characters in sequence: {invalid_chars}" + + if len(sequence_upper) < 10: + return False, "Sequence too short (minimum 10 residues)" + + if len(sequence_upper) > 2000: + return False, "Sequence too long (maximum 2000 residues for API)" + + return True, sequence_upper + + +def _validate_nucleic_acid(sequence: str, na_type: str = "DNA") -> tuple[bool, str]: + """Validate a nucleic acid sequence.""" + if na_type.upper() == "DNA": + valid_bases = set("ATCGN") + else: + valid_bases = set("AUCGN") + + sequence_upper = sequence.upper().replace(" ", "").replace("\n", "") + invalid_chars = set(sequence_upper) - valid_bases + + if invalid_chars: + return False, f"Invalid characters for {na_type}: {invalid_chars}" + + if len(sequence_upper) < 5: + return False, "Sequence too short (minimum 5 bases)" + + return True, sequence_upper + + +def _generate_job_id(sequences: list[str]) -> str: + """Generate a unique job ID based on input sequences.""" + content = "".join(sequences) + return hashlib.md5(content.encode()).hexdigest()[:12] + + +def _poll_job_status( + job_id: str, + max_wait_seconds: int = 600, + poll_interval: int = 10, +) -> dict[str, Any]: + """Poll for job completion status.""" + elapsed = 0 + + while elapsed < max_wait_seconds: + result = _make_api_request(f"jobs/{job_id}") + + if not result["success"]: + return result + + status = result["result"].get("status", "unknown") + + if status == "completed": + return {"success": True, "result": result["result"]} + elif status == "failed": + error = result["result"].get("error", "Job failed without error message") + return {"success": False, "error": error} + elif status in ("pending", "running"): + time.sleep(poll_interval) + elapsed += poll_interval + else: + return {"success": False, "error": f"Unknown job status: {status}"} + + return {"success": False, "error": f"Job timed out after {max_wait_seconds} seconds"} + + +def predict_protein_structure( + sequence: str, + save_path: str | None = None, + wait_for_result: bool = True, + timeout_seconds: int = 600, +) -> dict[str, Any]: + """Predict the 3D structure of a single protein sequence. + + Parameters + ---------- + sequence : str + Amino acid sequence in single-letter code (e.g., "MKFLILLFNILC...") + save_path : str, optional + Path to save the predicted structure file (PDB format) + wait_for_result : bool + If True, wait for prediction to complete. If False, return job ID immediately + timeout_seconds : int + Maximum time to wait for prediction in seconds + + Returns + ------- + dict + Dictionary containing: + - success: bool indicating if prediction succeeded + - structure_path: path to saved PDB file (if save_path provided) + - pdb_content: PDB file content as string + - confidence: overall confidence metrics + - job_id: ID for tracking the prediction job + + Examples + -------- + >>> result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH...") + >>> print(result["confidence"]["plddt_mean"]) + + """ + is_valid, validated_seq = _validate_sequence(sequence) + if not is_valid: + return {"success": False, "error": validated_seq} + + job_data = { + "sequences": [{"type": "protein", "sequence": validated_seq}], + "model": "alphafold3", + } + + submit_result = _make_api_request("predict", method="POST", data=job_data) + + if not submit_result["success"]: + return submit_result + + job_id = submit_result["result"].get("job_id") + + if not job_id: + return {"success": False, "error": "No job ID returned from server"} + + if not wait_for_result: + return {"success": True, "job_id": job_id, "status": "submitted"} + + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) + + if not poll_result["success"]: + return poll_result + + job_result = poll_result["result"] + pdb_content = job_result.get("pdb_content", "") + + response = { + "success": True, + "job_id": job_id, + "pdb_content": pdb_content, + "confidence": { + "plddt_mean": job_result.get("plddt_mean"), + "ptm": job_result.get("ptm"), + }, + } + + if save_path and pdb_content: + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + with open(save_path, "w") as f: + f.write(pdb_content) + response["structure_path"] = os.path.abspath(save_path) + + return response + + +def predict_protein_complex( + sequences: list[str], + chain_names: list[str] | None = None, + save_path: str | None = None, + wait_for_result: bool = True, + timeout_seconds: int = 900, +) -> dict[str, Any]: + """Predict the 3D structure of a protein complex from multiple sequences. + + Parameters + ---------- + sequences : list[str] + List of protein sequences, one per chain + chain_names : list[str], optional + Names for each chain (defaults to A, B, C, ...) + save_path : str, optional + Path to save the predicted complex structure (PDB format) + wait_for_result : bool + If True, wait for prediction to complete + timeout_seconds : int + Maximum time to wait for prediction + + Returns + ------- + dict + Dictionary containing: + - success: bool + - structure_path: path to saved PDB file + - pdb_content: PDB content as string + - confidence: per-chain and interface confidence metrics + - interface_contacts: predicted inter-chain contacts + + Examples + -------- + >>> seqs = ["MKFLILLFNILCLFPVLAADNH...", "MALTEVNPKKYIPGTKMIFAG..."] + >>> result = predict_protein_complex(seqs, chain_names=["Receptor", "Ligand"]) + + """ + if not sequences or len(sequences) < 2: + return {"success": False, "error": "At least 2 sequences required for complex prediction"} + + validated_sequences = [] + for i, seq in enumerate(sequences): + is_valid, validated_seq = _validate_sequence(seq) + if not is_valid: + return {"success": False, "error": f"Chain {i+1}: {validated_seq}"} + validated_sequences.append(validated_seq) + + if chain_names is None: + chain_names = [chr(ord("A") + i) for i in range(len(sequences))] + + sequence_data = [ + {"type": "protein", "sequence": seq, "chain_id": name} + for seq, name in zip(validated_sequences, chain_names) + ] + + job_data = { + "sequences": sequence_data, + "model": "alphafold3", + "predict_interface": True, + } + + submit_result = _make_api_request("predict", method="POST", data=job_data) + + if not submit_result["success"]: + return submit_result + + job_id = submit_result["result"].get("job_id") + + if not wait_for_result: + return {"success": True, "job_id": job_id, "status": "submitted"} + + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) + + if not poll_result["success"]: + return poll_result + + job_result = poll_result["result"] + pdb_content = job_result.get("pdb_content", "") + + response = { + "success": True, + "job_id": job_id, + "pdb_content": pdb_content, + "confidence": { + "plddt_mean": job_result.get("plddt_mean"), + "ptm": job_result.get("ptm"), + "iptm": job_result.get("iptm"), + "chain_confidences": job_result.get("chain_confidences", {}), + }, + "interface_contacts": job_result.get("interface_contacts", []), + } + + if save_path and pdb_content: + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + with open(save_path, "w") as f: + f.write(pdb_content) + response["structure_path"] = os.path.abspath(save_path) + + return response + + +def predict_protein_nucleic_acid_complex( + protein_sequences: list[str], + nucleic_acid_sequence: str, + nucleic_acid_type: str = "DNA", + save_path: str | None = None, + wait_for_result: bool = True, + timeout_seconds: int = 900, +) -> dict[str, Any]: + """Predict structure of a protein-DNA or protein-RNA complex. + + Parameters + ---------- + protein_sequences : list[str] + One or more protein sequences + nucleic_acid_sequence : str + DNA or RNA sequence + nucleic_acid_type : str + Either "DNA" or "RNA" + save_path : str, optional + Path to save structure + wait_for_result : bool + Wait for completion + timeout_seconds : int + Maximum wait time + + Returns + ------- + dict + Prediction results including structure and confidence metrics + + Examples + -------- + >>> protein = "MKFLILLFNILCLFPVLAADNH..." + >>> dna = "ATCGATCGATCGATCG" + >>> result = predict_protein_nucleic_acid_complex([protein], dna, "DNA") + + """ + if nucleic_acid_type.upper() not in ("DNA", "RNA"): + return {"success": False, "error": "nucleic_acid_type must be 'DNA' or 'RNA'"} + + validated_proteins = [] + for i, seq in enumerate(protein_sequences): + is_valid, validated_seq = _validate_sequence(seq) + if not is_valid: + return {"success": False, "error": f"Protein {i+1}: {validated_seq}"} + validated_proteins.append(validated_seq) + + is_valid, validated_na = _validate_nucleic_acid(nucleic_acid_sequence, nucleic_acid_type) + if not is_valid: + return {"success": False, "error": validated_na} + + sequence_data = [ + {"type": "protein", "sequence": seq, "chain_id": chr(ord("A") + i)} + for i, seq in enumerate(validated_proteins) + ] + + na_chain_id = chr(ord("A") + len(validated_proteins)) + sequence_data.append({ + "type": nucleic_acid_type.lower(), + "sequence": validated_na, + "chain_id": na_chain_id, + }) + + job_data = { + "sequences": sequence_data, + "model": "alphafold3", + "predict_interface": True, + } + + submit_result = _make_api_request("predict", method="POST", data=job_data) + + if not submit_result["success"]: + return submit_result + + job_id = submit_result["result"].get("job_id") + + if not wait_for_result: + return {"success": True, "job_id": job_id, "status": "submitted"} + + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) + + if not poll_result["success"]: + return poll_result + + job_result = poll_result["result"] + pdb_content = job_result.get("pdb_content", "") + + response = { + "success": True, + "job_id": job_id, + "pdb_content": pdb_content, + "confidence": { + "plddt_mean": job_result.get("plddt_mean"), + "ptm": job_result.get("ptm"), + "iptm": job_result.get("iptm"), + }, + "protein_na_contacts": job_result.get("interface_contacts", []), + } + + if save_path and pdb_content: + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + with open(save_path, "w") as f: + f.write(pdb_content) + response["structure_path"] = os.path.abspath(save_path) + + return response + + +def predict_protein_ligand_complex( + protein_sequence: str, + ligand_smiles: str, + save_path: str | None = None, + wait_for_result: bool = True, + timeout_seconds: int = 900, +) -> dict[str, Any]: + """Predict structure of a protein-small molecule complex. + + Parameters + ---------- + protein_sequence : str + Protein sequence in single-letter amino acid code + ligand_smiles : str + Small molecule in SMILES format + save_path : str, optional + Path to save structure + wait_for_result : bool + Wait for completion + timeout_seconds : int + Maximum wait time + + Returns + ------- + dict + Prediction results including: + - structure_path: saved PDB file path + - pdb_content: structure as string + - binding_site: predicted binding residues + - binding_affinity: predicted binding strength (if available) + + Examples + -------- + >>> protein = "MKFLILLFNILCLFPVLAADNH..." + >>> erlotinib = "COCCOc1cc2ncnc(Nc3cccc(c3)C#C)c2cc1OCCOC" + >>> result = predict_protein_ligand_complex(protein, erlotinib) + + """ + is_valid, validated_seq = _validate_sequence(protein_sequence) + if not is_valid: + return {"success": False, "error": validated_seq} + + if not ligand_smiles or len(ligand_smiles) < 2: + return {"success": False, "error": "Invalid SMILES string"} + + sequence_data = [ + {"type": "protein", "sequence": validated_seq, "chain_id": "A"}, + {"type": "ligand", "smiles": ligand_smiles, "chain_id": "L"}, + ] + + job_data = { + "sequences": sequence_data, + "model": "alphafold3", + "predict_binding": True, + } + + submit_result = _make_api_request("predict", method="POST", data=job_data) + + if not submit_result["success"]: + return submit_result + + job_id = submit_result["result"].get("job_id") + + if not wait_for_result: + return {"success": True, "job_id": job_id, "status": "submitted"} + + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) + + if not poll_result["success"]: + return poll_result + + job_result = poll_result["result"] + pdb_content = job_result.get("pdb_content", "") + + response = { + "success": True, + "job_id": job_id, + "pdb_content": pdb_content, + "confidence": { + "plddt_mean": job_result.get("plddt_mean"), + "ptm": job_result.get("ptm"), + }, + "binding_site": job_result.get("binding_residues", []), + "binding_affinity": job_result.get("predicted_affinity"), + } + + if save_path and pdb_content: + os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) + with open(save_path, "w") as f: + f.write(pdb_content) + response["structure_path"] = os.path.abspath(save_path) + + return response + + +def get_job_status(job_id: str) -> dict[str, Any]: + """Check the status of a submitted prediction job. + + Parameters + ---------- + job_id : str + Job ID returned from a prediction function + + Returns + ------- + dict + Job status including: + - status: "pending", "running", "completed", or "failed" + - progress: percentage complete (if available) + - result: prediction results (if completed) + + """ + result = _make_api_request(f"jobs/{job_id}") + + if not result["success"]: + return result + + job_data = result["result"] + + return { + "success": True, + "job_id": job_id, + "status": job_data.get("status", "unknown"), + "progress": job_data.get("progress"), + "created_at": job_data.get("created_at"), + "result": job_data if job_data.get("status") == "completed" else None, + } + + +def download_structure( + job_id: str, + output_path: str, + file_format: str = "pdb", +) -> dict[str, Any]: + """Download the structure file from a completed prediction job. + + Parameters + ---------- + job_id : str + Job ID of a completed prediction + output_path : str + Path to save the structure file + file_format : str + Output format: "pdb" or "cif" + + Returns + ------- + dict + Download result with file path + + """ + if file_format.lower() not in ("pdb", "cif"): + return {"success": False, "error": "Format must be 'pdb' or 'cif'"} + + result = _make_api_request(f"jobs/{job_id}/structure", timeout=60) + + if not result["success"]: + return result + + structure_content = result["result"].get(f"{file_format}_content", "") + + if not structure_content: + return {"success": False, "error": f"No {file_format.upper()} content available"} + + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) + + with open(output_path, "w") as f: + f.write(structure_content) + + return { + "success": True, + "file_path": os.path.abspath(output_path), + "format": file_format, + "size_bytes": len(structure_content), + } + + +def analyze_structure_confidence(pdb_path: str) -> dict[str, Any]: + """Analyze confidence scores from a predicted structure file. + + Parameters + ---------- + pdb_path : str + Path to a PDB file from AlphaFold prediction + + Returns + ------- + dict + Confidence analysis including: + - residue_plddt: per-residue pLDDT scores + - mean_plddt: average pLDDT + - confident_regions: regions with pLDDT > 70 + - low_confidence_regions: regions with pLDDT < 50 + + """ + if not os.path.exists(pdb_path): + return {"success": False, "error": f"File not found: {pdb_path}"} + + residue_plddt = {} + + try: + with open(pdb_path, "r") as f: + for line in f: + if line.startswith("ATOM"): + chain = line[21].strip() + res_num = int(line[22:26].strip()) + b_factor = float(line[60:66].strip()) + + key = f"{chain}:{res_num}" + if key not in residue_plddt: + residue_plddt[key] = b_factor + except Exception as e: + return {"success": False, "error": f"Failed to parse PDB: {str(e)}"} + + if not residue_plddt: + return {"success": False, "error": "No residues found in PDB file"} + + scores = list(residue_plddt.values()) + mean_plddt = sum(scores) / len(scores) + + confident_regions = [k for k, v in residue_plddt.items() if v > 70] + low_confidence_regions = [k for k, v in residue_plddt.items() if v < 50] + + return { + "success": True, + "file": pdb_path, + "total_residues": len(residue_plddt), + "mean_plddt": round(mean_plddt, 2), + "confident_residue_count": len(confident_regions), + "low_confidence_residue_count": len(low_confidence_regions), + "quality_assessment": ( + "High confidence" if mean_plddt > 70 else + "Medium confidence" if mean_plddt > 50 else + "Low confidence" + ), + "residue_scores": residue_plddt, + } + + +def batch_predict_structures( + jobs: list[dict[str, Any]], + output_dir: str | None = None, + parallel: bool = True, + max_concurrent: int = 5, +) -> dict[str, Any]: + """Submit multiple structure prediction jobs. + + Parameters + ---------- + jobs : list[dict] + List of job specifications, each containing: + - type: "protein", "complex", "protein_dna", or "protein_ligand" + - sequences: sequence data for the job + - name: optional name for the job + output_dir : str, optional + Directory to save all output structures + parallel : bool + Whether to submit jobs in parallel + max_concurrent : int + Maximum concurrent jobs if parallel=True + + Returns + ------- + dict + Batch results including job IDs and status for each submission + + Examples + -------- + >>> jobs = [ + ... {"type": "protein", "sequences": ["MKFL..."], "name": "protein1"}, + ... {"type": "complex", "sequences": ["MKFL...", "MALT..."], "name": "complex1"}, + ... ] + >>> result = batch_predict_structures(jobs, output_dir="./structures") + + """ + if not jobs: + return {"success": False, "error": "No jobs provided"} + + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + results = [] + + for i, job in enumerate(jobs): + job_type = job.get("type", "protein") + sequences = job.get("sequences", []) + name = job.get("name", f"job_{i+1}") + + save_path = None + if output_dir: + save_path = os.path.join(output_dir, f"{name}.pdb") + + try: + if job_type == "protein": + if not sequences: + result = {"success": False, "error": "No sequence provided"} + else: + result = predict_protein_structure( + sequences[0], + save_path=save_path, + wait_for_result=False, + ) + elif job_type == "complex": + result = predict_protein_complex( + sequences, + save_path=save_path, + wait_for_result=False, + ) + elif job_type == "protein_dna": + proteins = job.get("proteins", sequences[:-1] if len(sequences) > 1 else []) + na_seq = job.get("nucleic_acid", sequences[-1] if sequences else "") + na_type = job.get("nucleic_acid_type", "DNA") + result = predict_protein_nucleic_acid_complex( + proteins, + na_seq, + na_type, + save_path=save_path, + wait_for_result=False, + ) + elif job_type == "protein_ligand": + protein = sequences[0] if sequences else "" + ligand = job.get("ligand_smiles", "") + result = predict_protein_ligand_complex( + protein, + ligand, + save_path=save_path, + wait_for_result=False, + ) + else: + result = {"success": False, "error": f"Unknown job type: {job_type}"} + + result["name"] = name + results.append(result) + + except Exception as e: + results.append({ + "success": False, + "name": name, + "error": str(e), + }) + + successful = sum(1 for r in results if r.get("success", False)) + + return { + "success": successful > 0, + "total_jobs": len(jobs), + "submitted": successful, + "failed": len(jobs) - successful, + "jobs": results, + } diff --git a/docs/structure_prediction.md b/docs/structure_prediction.md new file mode 100644 index 00000000..786abd8e --- /dev/null +++ b/docs/structure_prediction.md @@ -0,0 +1,183 @@ +# AlphaFold3 Structure Prediction Tool + +> Predict protein structures and molecular complexes using state-of-the-art AlphaFold3. + +## Overview + +This module provides a programmatic interface to AlphaFold3 for predicting: + +- **Single protein structures** - 3D atomic coordinates from amino acid sequences +- **Protein-protein complexes** - Multi-chain assemblies with interface predictions +- **Protein-DNA/RNA complexes** - Nucleic acid binding and interaction modeling +- **Protein-ligand complexes** - Small molecule binding site and affinity prediction + +## Installation + +The module is included in Biomni. Set your API token: + +```bash +export ALPHAFOLD_API_TOKEN="your_token_here" +``` + +## Quick Start + +```python +from biomni.tool.structure_prediction import ( + predict_protein_structure, + predict_protein_complex, + predict_protein_nucleic_acid_complex, + predict_protein_ligand_complex, +) + +# Single protein +result = predict_protein_structure( + sequence="MKFLILLFNILCLFPVLAADNHGVGPQGAS", + save_path="./protein.pdb" +) + +# Protein complex +result = predict_protein_complex( + sequences=["MKFLILLFNILC...", "MALTEVNPKKY..."], + chain_names=["ChainA", "ChainB"] +) + +# Protein-DNA +result = predict_protein_nucleic_acid_complex( + protein_sequences=["MKFLILLFNILC..."], + nucleic_acid_sequence="ATCGATCGATCG", + nucleic_acid_type="DNA" +) + +# Protein-ligand +result = predict_protein_ligand_complex( + protein_sequence="MKFLILLFNILC...", + ligand_smiles="CCO" +) +``` + +## API Reference + +### predict_protein_structure + +Predict the 3D structure of a single protein. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `sequence` | str | Amino acid sequence (10-2000 residues) | +| `save_path` | str | Optional path to save PDB file | +| `wait_for_result` | bool | Wait for completion (default: True) | +| `timeout_seconds` | int | Max wait time (default: 600) | + +**Returns:** Dictionary with `pdb_content`, `confidence` scores, and `job_id`. + +--- + +### predict_protein_complex + +Predict multi-chain protein complex structures. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `sequences` | list[str] | List of protein sequences (≥2) | +| `chain_names` | list[str] | Optional names for each chain | +| `save_path` | str | Optional path to save PDB file | + +**Returns:** Dictionary with structure, `iptm` interface score, and `interface_contacts`. + +--- + +### predict_protein_nucleic_acid_complex + +Predict protein-DNA or protein-RNA complex structures. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `protein_sequences` | list[str] | One or more protein sequences | +| `nucleic_acid_sequence` | str | DNA or RNA sequence | +| `nucleic_acid_type` | str | "DNA" or "RNA" | + +--- + +### predict_protein_ligand_complex + +Predict protein-small molecule binding. + +| Parameter | Type | Description | +|-----------|------|-------------| +| `protein_sequence` | str | Target protein sequence | +| `ligand_smiles` | str | Ligand in SMILES format | + +**Returns:** Dictionary with `binding_site` residues and `binding_affinity`. + +--- + +### batch_predict_structures + +Submit multiple prediction jobs. + +```python +from biomni.tool.structure_prediction import batch_predict_structures + +jobs = [ + {"type": "protein", "sequences": ["MKFL..."], "name": "kinase"}, + {"type": "complex", "sequences": ["MKFL...", "MALT..."], "name": "dimer"}, + {"type": "protein_ligand", "sequences": ["MKFL..."], "ligand_smiles": "CCO", "name": "binding"}, +] + +result = batch_predict_structures(jobs, output_dir="./structures") +``` + +--- + +### analyze_structure_confidence + +Analyze pLDDT confidence scores from a predicted structure. + +```python +from biomni.tool.structure_prediction import analyze_structure_confidence + +result = analyze_structure_confidence("./protein.pdb") +print(f"Mean pLDDT: {result['mean_plddt']}") +print(f"Quality: {result['quality_assessment']}") +``` + +## Confidence Metrics + +| Metric | Description | +|--------|-------------| +| **pLDDT** | Per-residue confidence (0-100). >70 = confident, <50 = low confidence | +| **pTM** | Predicted TM-score for overall fold accuracy | +| **ipTM** | Interface TM-score for complex predictions | + +## Async Usage + +For long-running predictions: + +```python +# Submit without waiting +result = predict_protein_structure(sequence, wait_for_result=False) +job_id = result["job_id"] + +# Check status later +from biomni.tool.structure_prediction import get_job_status, download_structure + +status = get_job_status(job_id) +if status["status"] == "completed": + download_structure(job_id, "./output.pdb") +``` + +## Error Handling + +All functions return a dictionary with a `success` boolean: + +```python +result = predict_protein_structure("INVALID123") +if not result["success"]: + print(f"Error: {result['error']}") +``` + +## Limitations + +- Maximum sequence length: 2000 residues per chain +- Minimum sequence length: 10 residues +- API rate limits apply based on your token tier diff --git a/tests/test_structure_prediction.py b/tests/test_structure_prediction.py new file mode 100644 index 00000000..8d513dcd --- /dev/null +++ b/tests/test_structure_prediction.py @@ -0,0 +1,406 @@ +"""Tests for the structure prediction module.""" + +import os +import tempfile +from unittest import mock + +import pytest + +from biomni.tool.structure_prediction import ( + _generate_job_id, + _validate_nucleic_acid, + _validate_sequence, + analyze_structure_confidence, + batch_predict_structures, + download_structure, + get_job_status, + predict_protein_complex, + predict_protein_ligand_complex, + predict_protein_nucleic_acid_complex, + predict_protein_structure, +) + + +class TestSequenceValidation: + """Tests for sequence validation functions.""" + + def test_validate_valid_protein_sequence(self): + """Valid protein sequences should pass validation.""" + seq = "MKFLILLFNILCLFPVLAADNHGVGPQGAS" + is_valid, result = _validate_sequence(seq) + assert is_valid is True + assert result == seq.upper() + + def test_validate_sequence_with_whitespace(self): + """Sequences with whitespace should be cleaned.""" + seq = "MKFL ILLF\nNILC LFPV" + is_valid, result = _validate_sequence(seq) + assert is_valid is True + assert " " not in result + assert "\n" not in result + + def test_validate_sequence_lowercase(self): + """Lowercase sequences should be converted to uppercase.""" + seq = "mkflillfnilclfpvlaadnh" + is_valid, result = _validate_sequence(seq) + assert is_valid is True + assert result == seq.upper() + + def test_validate_sequence_invalid_characters(self): + """Sequences with invalid characters should fail.""" + seq = "MKFLILLFNILC123LFPV" + is_valid, result = _validate_sequence(seq) + assert is_valid is False + assert "Invalid characters" in result + + def test_validate_sequence_too_short(self): + """Sequences shorter than 10 residues should fail.""" + seq = "MKFLI" + is_valid, result = _validate_sequence(seq) + assert is_valid is False + assert "too short" in result + + def test_validate_sequence_too_long(self): + """Sequences longer than 2000 residues should fail.""" + seq = "M" * 2001 + is_valid, result = _validate_sequence(seq) + assert is_valid is False + assert "too long" in result + + def test_validate_dna_sequence(self): + """Valid DNA sequences should pass validation.""" + seq = "ATCGATCGATCG" + is_valid, result = _validate_nucleic_acid(seq, "DNA") + assert is_valid is True + assert result == seq + + def test_validate_rna_sequence(self): + """Valid RNA sequences should pass validation.""" + seq = "AUCGAUCGAUCG" + is_valid, result = _validate_nucleic_acid(seq, "RNA") + assert is_valid is True + assert result == seq + + def test_validate_dna_invalid_base(self): + """DNA with U base should fail.""" + seq = "ATCGUATCG" + is_valid, result = _validate_nucleic_acid(seq, "DNA") + assert is_valid is False + assert "Invalid characters" in result + + def test_validate_nucleic_acid_too_short(self): + """Nucleic acid sequences shorter than 5 bases should fail.""" + seq = "ATCG" + is_valid, result = _validate_nucleic_acid(seq, "DNA") + assert is_valid is False + assert "too short" in result + + +class TestJobIdGeneration: + """Tests for job ID generation.""" + + def test_generate_job_id_deterministic(self): + """Same sequences should produce same job ID.""" + seqs = ["MKFLILLFNILCLFPV", "MALTEVNPKKYIPGTK"] + id1 = _generate_job_id(seqs) + id2 = _generate_job_id(seqs) + assert id1 == id2 + + def test_generate_job_id_different_for_different_seqs(self): + """Different sequences should produce different job IDs.""" + id1 = _generate_job_id(["MKFLILLFNILCLFPV"]) + id2 = _generate_job_id(["MALTEVNPKKYIPGTK"]) + assert id1 != id2 + + +class TestPredictProteinStructure: + """Tests for single protein structure prediction.""" + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_protein_success(self, mock_request): + """Successful prediction should return structure data.""" + mock_request.side_effect = [ + {"success": True, "result": {"job_id": "test123"}}, + {"success": True, "result": {"status": "completed", "pdb_content": "ATOM...", "plddt_mean": 85.5, "ptm": 0.9}}, + ] + + result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH") + + assert result["success"] is True + assert result["job_id"] == "test123" + assert "pdb_content" in result + assert result["confidence"]["plddt_mean"] == 85.5 + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_protein_saves_file(self, mock_request): + """Prediction should save PDB file when path provided.""" + mock_request.side_effect = [ + {"success": True, "result": {"job_id": "test123"}}, + {"success": True, "result": {"status": "completed", "pdb_content": "ATOM 1 CA ALA", "plddt_mean": 85.5}}, + ] + + with tempfile.TemporaryDirectory() as tmpdir: + save_path = os.path.join(tmpdir, "test.pdb") + result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH", save_path=save_path) + + assert result["success"] is True + assert os.path.exists(save_path) + with open(save_path) as f: + assert f.read() == "ATOM 1 CA ALA" + + def test_predict_protein_invalid_sequence(self): + """Invalid sequence should return error without API call.""" + result = predict_protein_structure("INVALID123") + assert result["success"] is False + assert "Invalid characters" in result["error"] + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_protein_no_wait(self, mock_request): + """No-wait mode should return job ID immediately.""" + mock_request.return_value = {"success": True, "result": {"job_id": "test123"}} + + result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH", wait_for_result=False) + + assert result["success"] is True + assert result["status"] == "submitted" + assert result["job_id"] == "test123" + mock_request.assert_called_once() + + +class TestPredictProteinComplex: + """Tests for protein complex prediction.""" + + def test_predict_complex_single_sequence_fails(self): + """Complex prediction with single sequence should fail.""" + result = predict_protein_complex(["MKFLILLFNILCLFPVLAADNH"]) + assert result["success"] is False + assert "At least 2 sequences" in result["error"] + + def test_predict_complex_empty_list_fails(self): + """Complex prediction with empty list should fail.""" + result = predict_protein_complex([]) + assert result["success"] is False + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_complex_success(self, mock_request): + """Successful complex prediction should return interface data.""" + mock_request.side_effect = [ + {"success": True, "result": {"job_id": "complex123"}}, + {"success": True, "result": { + "status": "completed", + "pdb_content": "ATOM...", + "plddt_mean": 80.0, + "iptm": 0.85, + "interface_contacts": [{"chain_a": "A", "chain_b": "B", "residue_a": 10, "residue_b": 25}], + }}, + ] + + result = predict_protein_complex( + ["MKFLILLFNILCLFPVLAADNH", "MALTEVNPKKYIPGTKMIFAG"], + chain_names=["Receptor", "Ligand"] + ) + + assert result["success"] is True + assert result["confidence"]["iptm"] == 0.85 + assert len(result["interface_contacts"]) > 0 + + +class TestPredictProteinNucleicAcid: + """Tests for protein-nucleic acid complex prediction.""" + + def test_predict_invalid_na_type(self): + """Invalid nucleic acid type should fail.""" + result = predict_protein_nucleic_acid_complex( + ["MKFLILLFNILCLFPVLAADNH"], + "ATCGATCG", + "INVALID" + ) + assert result["success"] is False + assert "DNA" in result["error"] or "RNA" in result["error"] + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_protein_dna_success(self, mock_request): + """Successful protein-DNA prediction should work.""" + mock_request.side_effect = [ + {"success": True, "result": {"job_id": "dna123"}}, + {"success": True, "result": {"status": "completed", "pdb_content": "ATOM...", "plddt_mean": 75.0}}, + ] + + result = predict_protein_nucleic_acid_complex( + ["MKFLILLFNILCLFPVLAADNH"], + "ATCGATCGATCGATCG", + "DNA" + ) + + assert result["success"] is True + assert result["job_id"] == "dna123" + + +class TestPredictProteinLigand: + """Tests for protein-ligand complex prediction.""" + + def test_predict_empty_smiles_fails(self): + """Empty SMILES string should fail.""" + result = predict_protein_ligand_complex("MKFLILLFNILCLFPVLAADNH", "") + assert result["success"] is False + assert "SMILES" in result["error"] + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_predict_protein_ligand_success(self, mock_request): + """Successful protein-ligand prediction should return binding data.""" + mock_request.side_effect = [ + {"success": True, "result": {"job_id": "ligand123"}}, + {"success": True, "result": { + "status": "completed", + "pdb_content": "ATOM...", + "binding_residues": [45, 67, 89, 112], + "predicted_affinity": -8.5, + }}, + ] + + result = predict_protein_ligand_complex( + "MKFLILLFNILCLFPVLAADNH", + "COCCOc1cc2ncnc(Nc3cccc(c3)C#C)c2cc1OCCOC" + ) + + assert result["success"] is True + assert len(result["binding_site"]) > 0 + assert result["binding_affinity"] == -8.5 + + +class TestGetJobStatus: + """Tests for job status checking.""" + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_get_job_status_running(self, mock_request): + """Running job should return pending status.""" + mock_request.return_value = { + "success": True, + "result": {"status": "running", "progress": 45} + } + + result = get_job_status("test123") + + assert result["success"] is True + assert result["status"] == "running" + assert result["progress"] == 45 + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_get_job_status_completed(self, mock_request): + """Completed job should include result data.""" + mock_request.return_value = { + "success": True, + "result": {"status": "completed", "pdb_content": "ATOM..."} + } + + result = get_job_status("test123") + + assert result["success"] is True + assert result["status"] == "completed" + assert result["result"] is not None + + +class TestDownloadStructure: + """Tests for structure file download.""" + + def test_download_invalid_format(self): + """Invalid format should fail.""" + result = download_structure("test123", "output.xyz", "xyz") + assert result["success"] is False + assert "pdb" in result["error"] or "cif" in result["error"] + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_download_structure_success(self, mock_request): + """Successful download should save file.""" + mock_request.return_value = { + "success": True, + "result": {"pdb_content": "ATOM 1 CA ALA A 1"} + } + + with tempfile.TemporaryDirectory() as tmpdir: + output_path = os.path.join(tmpdir, "structure.pdb") + result = download_structure("test123", output_path, "pdb") + + assert result["success"] is True + assert os.path.exists(output_path) + assert result["format"] == "pdb" + + +class TestAnalyzeStructureConfidence: + """Tests for structure confidence analysis.""" + + def test_analyze_nonexistent_file(self): + """Nonexistent file should fail.""" + result = analyze_structure_confidence("/nonexistent/file.pdb") + assert result["success"] is False + assert "not found" in result["error"] + + def test_analyze_valid_pdb(self): + """Valid PDB file should produce confidence metrics.""" + pdb_content = """ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 85.50 N +ATOM 2 CA ALA A 1 1.458 0.000 0.000 1.00 85.50 C +ATOM 3 N GLY A 2 2.000 1.000 0.000 1.00 45.00 N +ATOM 4 CA GLY A 2 3.458 1.000 0.000 1.00 45.00 C +""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".pdb", delete=False) as f: + f.write(pdb_content) + temp_path = f.name + + try: + result = analyze_structure_confidence(temp_path) + + assert result["success"] is True + assert result["total_residues"] == 2 + assert 60 < result["mean_plddt"] < 70 + assert result["confident_residue_count"] == 1 + assert result["low_confidence_residue_count"] == 1 + finally: + os.unlink(temp_path) + + +class TestBatchPredictStructures: + """Tests for batch structure prediction.""" + + def test_batch_empty_jobs(self): + """Empty job list should fail.""" + result = batch_predict_structures([]) + assert result["success"] is False + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_batch_predict_multiple_jobs(self, mock_request): + """Multiple jobs should be submitted.""" + mock_request.return_value = {"success": True, "result": {"job_id": "batch123"}} + + jobs = [ + {"type": "protein", "sequences": ["MKFLILLFNILCLFPVLAADNH"], "name": "protein1"}, + {"type": "protein", "sequences": ["MALTEVNPKKYIPGTKMIFAG"], "name": "protein2"}, + ] + + result = batch_predict_structures(jobs) + + assert result["success"] is True + assert result["total_jobs"] == 2 + assert result["submitted"] == 2 + + @mock.patch("biomni.tool.structure_prediction._make_api_request") + def test_batch_predict_with_output_dir(self, mock_request): + """Batch prediction should create output directory.""" + mock_request.return_value = {"success": True, "result": {"job_id": "batch123"}} + + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = os.path.join(tmpdir, "structures") + jobs = [{"type": "protein", "sequences": ["MKFLILLFNILCLFPVLAADNH"], "name": "test"}] + + result = batch_predict_structures(jobs, output_dir=output_dir) + + assert result["success"] is True + assert os.path.isdir(output_dir) + + def test_batch_predict_invalid_job_type(self): + """Invalid job type should be reported as failed.""" + jobs = [{"type": "invalid_type", "sequences": ["MKFL"], "name": "test"}] + + result = batch_predict_structures(jobs) + + assert result["failed"] == 1 + assert "Unknown job type" in result["jobs"][0]["error"] From c5dbf05d804d3ad6bc8f8ddf117bc68fe432d69f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Jan 2026 19:24:57 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- biomni/tool/structure_prediction.py | 302 ++++++++++++++-------------- tests/test_structure_prediction.py | 130 ++++++------ 2 files changed, 209 insertions(+), 223 deletions(-) diff --git a/biomni/tool/structure_prediction.py b/biomni/tool/structure_prediction.py index 66250367..167116b0 100644 --- a/biomni/tool/structure_prediction.py +++ b/biomni/tool/structure_prediction.py @@ -6,7 +6,6 @@ """ import hashlib -import json import os import time from typing import Any @@ -31,16 +30,16 @@ def _make_api_request( ) -> dict[str, Any]: """Make an authenticated request to the AlphaFold Server API.""" token = _get_api_token() - + if headers is None: headers = {} - + headers["Content-Type"] = "application/json" if token: headers["Authorization"] = f"Bearer {token}" - + url = f"{ALPHAFOLD_API_URL}/{endpoint.lstrip('/')}" - + try: if method.upper() == "GET": response = requests.get(url, headers=headers, timeout=timeout) @@ -48,16 +47,16 @@ def _make_api_request( response = requests.post(url, headers=headers, json=data, timeout=timeout) else: return {"success": False, "error": f"Unsupported method: {method}"} - + response.raise_for_status() - + try: result = response.json() except ValueError: result = {"raw_text": response.text} - + return {"success": True, "result": result} - + except requests.exceptions.Timeout: return {"success": False, "error": "Request timed out"} except requests.exceptions.RequestException as e: @@ -75,17 +74,17 @@ def _validate_sequence(sequence: str) -> tuple[bool, str]: """Validate a protein sequence contains only valid amino acid codes.""" valid_aa = set("ACDEFGHIKLMNPQRSTVWY") sequence_upper = sequence.upper().replace(" ", "").replace("\n", "") - + invalid_chars = set(sequence_upper) - valid_aa if invalid_chars: return False, f"Invalid characters in sequence: {invalid_chars}" - + if len(sequence_upper) < 10: return False, "Sequence too short (minimum 10 residues)" - + if len(sequence_upper) > 2000: return False, "Sequence too long (maximum 2000 residues for API)" - + return True, sequence_upper @@ -95,16 +94,16 @@ def _validate_nucleic_acid(sequence: str, na_type: str = "DNA") -> tuple[bool, s valid_bases = set("ATCGN") else: valid_bases = set("AUCGN") - + sequence_upper = sequence.upper().replace(" ", "").replace("\n", "") invalid_chars = set(sequence_upper) - valid_bases - + if invalid_chars: return False, f"Invalid characters for {na_type}: {invalid_chars}" - + if len(sequence_upper) < 5: return False, "Sequence too short (minimum 5 bases)" - + return True, sequence_upper @@ -121,15 +120,15 @@ def _poll_job_status( ) -> dict[str, Any]: """Poll for job completion status.""" elapsed = 0 - + while elapsed < max_wait_seconds: result = _make_api_request(f"jobs/{job_id}") - + if not result["success"]: return result - + status = result["result"].get("status", "unknown") - + if status == "completed": return {"success": True, "result": result["result"]} elif status == "failed": @@ -140,7 +139,7 @@ def _poll_job_status( elapsed += poll_interval else: return {"success": False, "error": f"Unknown job status: {status}"} - + return {"success": False, "error": f"Job timed out after {max_wait_seconds} seconds"} @@ -151,7 +150,7 @@ def predict_protein_structure( timeout_seconds: int = 600, ) -> dict[str, Any]: """Predict the 3D structure of a single protein sequence. - + Parameters ---------- sequence : str @@ -162,7 +161,7 @@ def predict_protein_structure( If True, wait for prediction to complete. If False, return job ID immediately timeout_seconds : int Maximum time to wait for prediction in seconds - + Returns ------- dict @@ -172,43 +171,43 @@ def predict_protein_structure( - pdb_content: PDB file content as string - confidence: overall confidence metrics - job_id: ID for tracking the prediction job - + Examples -------- >>> result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH...") >>> print(result["confidence"]["plddt_mean"]) - + """ is_valid, validated_seq = _validate_sequence(sequence) if not is_valid: return {"success": False, "error": validated_seq} - + job_data = { "sequences": [{"type": "protein", "sequence": validated_seq}], "model": "alphafold3", } - + submit_result = _make_api_request("predict", method="POST", data=job_data) - + if not submit_result["success"]: return submit_result - + job_id = submit_result["result"].get("job_id") - + if not job_id: return {"success": False, "error": "No job ID returned from server"} - + if not wait_for_result: return {"success": True, "job_id": job_id, "status": "submitted"} - + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) - + if not poll_result["success"]: return poll_result - + job_result = poll_result["result"] pdb_content = job_result.get("pdb_content", "") - + response = { "success": True, "job_id": job_id, @@ -218,13 +217,13 @@ def predict_protein_structure( "ptm": job_result.get("ptm"), }, } - + if save_path and pdb_content: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with open(save_path, "w") as f: f.write(pdb_content) response["structure_path"] = os.path.abspath(save_path) - + return response @@ -236,7 +235,7 @@ def predict_protein_complex( timeout_seconds: int = 900, ) -> dict[str, Any]: """Predict the 3D structure of a protein complex from multiple sequences. - + Parameters ---------- sequences : list[str] @@ -249,7 +248,7 @@ def predict_protein_complex( If True, wait for prediction to complete timeout_seconds : int Maximum time to wait for prediction - + Returns ------- dict @@ -259,55 +258,55 @@ def predict_protein_complex( - pdb_content: PDB content as string - confidence: per-chain and interface confidence metrics - interface_contacts: predicted inter-chain contacts - + Examples -------- >>> seqs = ["MKFLILLFNILCLFPVLAADNH...", "MALTEVNPKKYIPGTKMIFAG..."] >>> result = predict_protein_complex(seqs, chain_names=["Receptor", "Ligand"]) - + """ if not sequences or len(sequences) < 2: return {"success": False, "error": "At least 2 sequences required for complex prediction"} - + validated_sequences = [] for i, seq in enumerate(sequences): is_valid, validated_seq = _validate_sequence(seq) if not is_valid: - return {"success": False, "error": f"Chain {i+1}: {validated_seq}"} + return {"success": False, "error": f"Chain {i + 1}: {validated_seq}"} validated_sequences.append(validated_seq) - + if chain_names is None: chain_names = [chr(ord("A") + i) for i in range(len(sequences))] - + sequence_data = [ {"type": "protein", "sequence": seq, "chain_id": name} - for seq, name in zip(validated_sequences, chain_names) + for seq, name in zip(validated_sequences, chain_names, strict=False) ] - + job_data = { "sequences": sequence_data, "model": "alphafold3", "predict_interface": True, } - + submit_result = _make_api_request("predict", method="POST", data=job_data) - + if not submit_result["success"]: return submit_result - + job_id = submit_result["result"].get("job_id") - + if not wait_for_result: return {"success": True, "job_id": job_id, "status": "submitted"} - + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) - + if not poll_result["success"]: return poll_result - + job_result = poll_result["result"] pdb_content = job_result.get("pdb_content", "") - + response = { "success": True, "job_id": job_id, @@ -320,13 +319,13 @@ def predict_protein_complex( }, "interface_contacts": job_result.get("interface_contacts", []), } - + if save_path and pdb_content: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with open(save_path, "w") as f: f.write(pdb_content) response["structure_path"] = os.path.abspath(save_path) - + return response @@ -339,7 +338,7 @@ def predict_protein_nucleic_acid_complex( timeout_seconds: int = 900, ) -> dict[str, Any]: """Predict structure of a protein-DNA or protein-RNA complex. - + Parameters ---------- protein_sequences : list[str] @@ -354,69 +353,70 @@ def predict_protein_nucleic_acid_complex( Wait for completion timeout_seconds : int Maximum wait time - + Returns ------- dict Prediction results including structure and confidence metrics - + Examples -------- >>> protein = "MKFLILLFNILCLFPVLAADNH..." >>> dna = "ATCGATCGATCGATCG" >>> result = predict_protein_nucleic_acid_complex([protein], dna, "DNA") - + """ if nucleic_acid_type.upper() not in ("DNA", "RNA"): return {"success": False, "error": "nucleic_acid_type must be 'DNA' or 'RNA'"} - + validated_proteins = [] for i, seq in enumerate(protein_sequences): is_valid, validated_seq = _validate_sequence(seq) if not is_valid: - return {"success": False, "error": f"Protein {i+1}: {validated_seq}"} + return {"success": False, "error": f"Protein {i + 1}: {validated_seq}"} validated_proteins.append(validated_seq) - + is_valid, validated_na = _validate_nucleic_acid(nucleic_acid_sequence, nucleic_acid_type) if not is_valid: return {"success": False, "error": validated_na} - + sequence_data = [ - {"type": "protein", "sequence": seq, "chain_id": chr(ord("A") + i)} - for i, seq in enumerate(validated_proteins) + {"type": "protein", "sequence": seq, "chain_id": chr(ord("A") + i)} for i, seq in enumerate(validated_proteins) ] - + na_chain_id = chr(ord("A") + len(validated_proteins)) - sequence_data.append({ - "type": nucleic_acid_type.lower(), - "sequence": validated_na, - "chain_id": na_chain_id, - }) - + sequence_data.append( + { + "type": nucleic_acid_type.lower(), + "sequence": validated_na, + "chain_id": na_chain_id, + } + ) + job_data = { "sequences": sequence_data, "model": "alphafold3", "predict_interface": True, } - + submit_result = _make_api_request("predict", method="POST", data=job_data) - + if not submit_result["success"]: return submit_result - + job_id = submit_result["result"].get("job_id") - + if not wait_for_result: return {"success": True, "job_id": job_id, "status": "submitted"} - + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) - + if not poll_result["success"]: return poll_result - + job_result = poll_result["result"] pdb_content = job_result.get("pdb_content", "") - + response = { "success": True, "job_id": job_id, @@ -428,13 +428,13 @@ def predict_protein_nucleic_acid_complex( }, "protein_na_contacts": job_result.get("interface_contacts", []), } - + if save_path and pdb_content: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with open(save_path, "w") as f: f.write(pdb_content) response["structure_path"] = os.path.abspath(save_path) - + return response @@ -446,7 +446,7 @@ def predict_protein_ligand_complex( timeout_seconds: int = 900, ) -> dict[str, Any]: """Predict structure of a protein-small molecule complex. - + Parameters ---------- protein_sequence : str @@ -459,7 +459,7 @@ def predict_protein_ligand_complex( Wait for completion timeout_seconds : int Maximum wait time - + Returns ------- dict @@ -468,50 +468,50 @@ def predict_protein_ligand_complex( - pdb_content: structure as string - binding_site: predicted binding residues - binding_affinity: predicted binding strength (if available) - + Examples -------- >>> protein = "MKFLILLFNILCLFPVLAADNH..." >>> erlotinib = "COCCOc1cc2ncnc(Nc3cccc(c3)C#C)c2cc1OCCOC" >>> result = predict_protein_ligand_complex(protein, erlotinib) - + """ is_valid, validated_seq = _validate_sequence(protein_sequence) if not is_valid: return {"success": False, "error": validated_seq} - + if not ligand_smiles or len(ligand_smiles) < 2: return {"success": False, "error": "Invalid SMILES string"} - + sequence_data = [ {"type": "protein", "sequence": validated_seq, "chain_id": "A"}, {"type": "ligand", "smiles": ligand_smiles, "chain_id": "L"}, ] - + job_data = { "sequences": sequence_data, "model": "alphafold3", "predict_binding": True, } - + submit_result = _make_api_request("predict", method="POST", data=job_data) - + if not submit_result["success"]: return submit_result - + job_id = submit_result["result"].get("job_id") - + if not wait_for_result: return {"success": True, "job_id": job_id, "status": "submitted"} - + poll_result = _poll_job_status(job_id, max_wait_seconds=timeout_seconds) - + if not poll_result["success"]: return poll_result - + job_result = poll_result["result"] pdb_content = job_result.get("pdb_content", "") - + response = { "success": True, "job_id": job_id, @@ -523,24 +523,24 @@ def predict_protein_ligand_complex( "binding_site": job_result.get("binding_residues", []), "binding_affinity": job_result.get("predicted_affinity"), } - + if save_path and pdb_content: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) with open(save_path, "w") as f: f.write(pdb_content) response["structure_path"] = os.path.abspath(save_path) - + return response def get_job_status(job_id: str) -> dict[str, Any]: """Check the status of a submitted prediction job. - + Parameters ---------- job_id : str Job ID returned from a prediction function - + Returns ------- dict @@ -548,15 +548,15 @@ def get_job_status(job_id: str) -> dict[str, Any]: - status: "pending", "running", "completed", or "failed" - progress: percentage complete (if available) - result: prediction results (if completed) - + """ result = _make_api_request(f"jobs/{job_id}") - + if not result["success"]: return result - + job_data = result["result"] - + return { "success": True, "job_id": job_id, @@ -573,7 +573,7 @@ def download_structure( file_format: str = "pdb", ) -> dict[str, Any]: """Download the structure file from a completed prediction job. - + Parameters ---------- job_id : str @@ -582,31 +582,31 @@ def download_structure( Path to save the structure file file_format : str Output format: "pdb" or "cif" - + Returns ------- dict Download result with file path - + """ if file_format.lower() not in ("pdb", "cif"): return {"success": False, "error": "Format must be 'pdb' or 'cif'"} - + result = _make_api_request(f"jobs/{job_id}/structure", timeout=60) - + if not result["success"]: return result - + structure_content = result["result"].get(f"{file_format}_content", "") - + if not structure_content: return {"success": False, "error": f"No {file_format.upper()} content available"} - + os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) - + with open(output_path, "w") as f: f.write(structure_content) - + return { "success": True, "file_path": os.path.abspath(output_path), @@ -617,12 +617,12 @@ def download_structure( def analyze_structure_confidence(pdb_path: str) -> dict[str, Any]: """Analyze confidence scores from a predicted structure file. - + Parameters ---------- pdb_path : str Path to a PDB file from AlphaFold prediction - + Returns ------- dict @@ -631,36 +631,36 @@ def analyze_structure_confidence(pdb_path: str) -> dict[str, Any]: - mean_plddt: average pLDDT - confident_regions: regions with pLDDT > 70 - low_confidence_regions: regions with pLDDT < 50 - + """ if not os.path.exists(pdb_path): return {"success": False, "error": f"File not found: {pdb_path}"} - + residue_plddt = {} - + try: - with open(pdb_path, "r") as f: + with open(pdb_path) as f: for line in f: if line.startswith("ATOM"): chain = line[21].strip() res_num = int(line[22:26].strip()) b_factor = float(line[60:66].strip()) - + key = f"{chain}:{res_num}" if key not in residue_plddt: residue_plddt[key] = b_factor except Exception as e: return {"success": False, "error": f"Failed to parse PDB: {str(e)}"} - + if not residue_plddt: return {"success": False, "error": "No residues found in PDB file"} - + scores = list(residue_plddt.values()) mean_plddt = sum(scores) / len(scores) - + confident_regions = [k for k, v in residue_plddt.items() if v > 70] low_confidence_regions = [k for k, v in residue_plddt.items() if v < 50] - + return { "success": True, "file": pdb_path, @@ -669,9 +669,7 @@ def analyze_structure_confidence(pdb_path: str) -> dict[str, Any]: "confident_residue_count": len(confident_regions), "low_confidence_residue_count": len(low_confidence_regions), "quality_assessment": ( - "High confidence" if mean_plddt > 70 else - "Medium confidence" if mean_plddt > 50 else - "Low confidence" + "High confidence" if mean_plddt > 70 else "Medium confidence" if mean_plddt > 50 else "Low confidence" ), "residue_scores": residue_plddt, } @@ -684,7 +682,7 @@ def batch_predict_structures( max_concurrent: int = 5, ) -> dict[str, Any]: """Submit multiple structure prediction jobs. - + Parameters ---------- jobs : list[dict] @@ -698,12 +696,12 @@ def batch_predict_structures( Whether to submit jobs in parallel max_concurrent : int Maximum concurrent jobs if parallel=True - + Returns ------- dict Batch results including job IDs and status for each submission - + Examples -------- >>> jobs = [ @@ -711,25 +709,25 @@ def batch_predict_structures( ... {"type": "complex", "sequences": ["MKFL...", "MALT..."], "name": "complex1"}, ... ] >>> result = batch_predict_structures(jobs, output_dir="./structures") - + """ if not jobs: return {"success": False, "error": "No jobs provided"} - + if output_dir: os.makedirs(output_dir, exist_ok=True) - + results = [] - + for i, job in enumerate(jobs): job_type = job.get("type", "protein") sequences = job.get("sequences", []) - name = job.get("name", f"job_{i+1}") - + name = job.get("name", f"job_{i + 1}") + save_path = None if output_dir: save_path = os.path.join(output_dir, f"{name}.pdb") - + try: if job_type == "protein": if not sequences: @@ -768,19 +766,21 @@ def batch_predict_structures( ) else: result = {"success": False, "error": f"Unknown job type: {job_type}"} - + result["name"] = name results.append(result) - + except Exception as e: - results.append({ - "success": False, - "name": name, - "error": str(e), - }) - + results.append( + { + "success": False, + "name": name, + "error": str(e), + } + ) + successful = sum(1 for r in results if r.get("success", False)) - + return { "success": successful > 0, "total_jobs": len(jobs), diff --git a/tests/test_structure_prediction.py b/tests/test_structure_prediction.py index 8d513dcd..64cd5011 100644 --- a/tests/test_structure_prediction.py +++ b/tests/test_structure_prediction.py @@ -4,8 +4,6 @@ import tempfile from unittest import mock -import pytest - from biomni.tool.structure_prediction import ( _generate_job_id, _validate_nucleic_acid, @@ -121,11 +119,14 @@ def test_predict_protein_success(self, mock_request): """Successful prediction should return structure data.""" mock_request.side_effect = [ {"success": True, "result": {"job_id": "test123"}}, - {"success": True, "result": {"status": "completed", "pdb_content": "ATOM...", "plddt_mean": 85.5, "ptm": 0.9}}, + { + "success": True, + "result": {"status": "completed", "pdb_content": "ATOM...", "plddt_mean": 85.5, "ptm": 0.9}, + }, ] - + result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH") - + assert result["success"] is True assert result["job_id"] == "test123" assert "pdb_content" in result @@ -138,11 +139,11 @@ def test_predict_protein_saves_file(self, mock_request): {"success": True, "result": {"job_id": "test123"}}, {"success": True, "result": {"status": "completed", "pdb_content": "ATOM 1 CA ALA", "plddt_mean": 85.5}}, ] - + with tempfile.TemporaryDirectory() as tmpdir: save_path = os.path.join(tmpdir, "test.pdb") result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH", save_path=save_path) - + assert result["success"] is True assert os.path.exists(save_path) with open(save_path) as f: @@ -158,9 +159,9 @@ def test_predict_protein_invalid_sequence(self): def test_predict_protein_no_wait(self, mock_request): """No-wait mode should return job ID immediately.""" mock_request.return_value = {"success": True, "result": {"job_id": "test123"}} - + result = predict_protein_structure("MKFLILLFNILCLFPVLAADNH", wait_for_result=False) - + assert result["success"] is True assert result["status"] == "submitted" assert result["job_id"] == "test123" @@ -186,20 +187,22 @@ def test_predict_complex_success(self, mock_request): """Successful complex prediction should return interface data.""" mock_request.side_effect = [ {"success": True, "result": {"job_id": "complex123"}}, - {"success": True, "result": { - "status": "completed", - "pdb_content": "ATOM...", - "plddt_mean": 80.0, - "iptm": 0.85, - "interface_contacts": [{"chain_a": "A", "chain_b": "B", "residue_a": 10, "residue_b": 25}], - }}, + { + "success": True, + "result": { + "status": "completed", + "pdb_content": "ATOM...", + "plddt_mean": 80.0, + "iptm": 0.85, + "interface_contacts": [{"chain_a": "A", "chain_b": "B", "residue_a": 10, "residue_b": 25}], + }, + }, ] - + result = predict_protein_complex( - ["MKFLILLFNILCLFPVLAADNH", "MALTEVNPKKYIPGTKMIFAG"], - chain_names=["Receptor", "Ligand"] + ["MKFLILLFNILCLFPVLAADNH", "MALTEVNPKKYIPGTKMIFAG"], chain_names=["Receptor", "Ligand"] ) - + assert result["success"] is True assert result["confidence"]["iptm"] == 0.85 assert len(result["interface_contacts"]) > 0 @@ -210,11 +213,7 @@ class TestPredictProteinNucleicAcid: def test_predict_invalid_na_type(self): """Invalid nucleic acid type should fail.""" - result = predict_protein_nucleic_acid_complex( - ["MKFLILLFNILCLFPVLAADNH"], - "ATCGATCG", - "INVALID" - ) + result = predict_protein_nucleic_acid_complex(["MKFLILLFNILCLFPVLAADNH"], "ATCGATCG", "INVALID") assert result["success"] is False assert "DNA" in result["error"] or "RNA" in result["error"] @@ -225,13 +224,9 @@ def test_predict_protein_dna_success(self, mock_request): {"success": True, "result": {"job_id": "dna123"}}, {"success": True, "result": {"status": "completed", "pdb_content": "ATOM...", "plddt_mean": 75.0}}, ] - - result = predict_protein_nucleic_acid_complex( - ["MKFLILLFNILCLFPVLAADNH"], - "ATCGATCGATCGATCG", - "DNA" - ) - + + result = predict_protein_nucleic_acid_complex(["MKFLILLFNILCLFPVLAADNH"], "ATCGATCGATCGATCG", "DNA") + assert result["success"] is True assert result["job_id"] == "dna123" @@ -250,19 +245,19 @@ def test_predict_protein_ligand_success(self, mock_request): """Successful protein-ligand prediction should return binding data.""" mock_request.side_effect = [ {"success": True, "result": {"job_id": "ligand123"}}, - {"success": True, "result": { - "status": "completed", - "pdb_content": "ATOM...", - "binding_residues": [45, 67, 89, 112], - "predicted_affinity": -8.5, - }}, + { + "success": True, + "result": { + "status": "completed", + "pdb_content": "ATOM...", + "binding_residues": [45, 67, 89, 112], + "predicted_affinity": -8.5, + }, + }, ] - - result = predict_protein_ligand_complex( - "MKFLILLFNILCLFPVLAADNH", - "COCCOc1cc2ncnc(Nc3cccc(c3)C#C)c2cc1OCCOC" - ) - + + result = predict_protein_ligand_complex("MKFLILLFNILCLFPVLAADNH", "COCCOc1cc2ncnc(Nc3cccc(c3)C#C)c2cc1OCCOC") + assert result["success"] is True assert len(result["binding_site"]) > 0 assert result["binding_affinity"] == -8.5 @@ -274,13 +269,10 @@ class TestGetJobStatus: @mock.patch("biomni.tool.structure_prediction._make_api_request") def test_get_job_status_running(self, mock_request): """Running job should return pending status.""" - mock_request.return_value = { - "success": True, - "result": {"status": "running", "progress": 45} - } - + mock_request.return_value = {"success": True, "result": {"status": "running", "progress": 45}} + result = get_job_status("test123") - + assert result["success"] is True assert result["status"] == "running" assert result["progress"] == 45 @@ -288,13 +280,10 @@ def test_get_job_status_running(self, mock_request): @mock.patch("biomni.tool.structure_prediction._make_api_request") def test_get_job_status_completed(self, mock_request): """Completed job should include result data.""" - mock_request.return_value = { - "success": True, - "result": {"status": "completed", "pdb_content": "ATOM..."} - } - + mock_request.return_value = {"success": True, "result": {"status": "completed", "pdb_content": "ATOM..."}} + result = get_job_status("test123") - + assert result["success"] is True assert result["status"] == "completed" assert result["result"] is not None @@ -312,15 +301,12 @@ def test_download_invalid_format(self): @mock.patch("biomni.tool.structure_prediction._make_api_request") def test_download_structure_success(self, mock_request): """Successful download should save file.""" - mock_request.return_value = { - "success": True, - "result": {"pdb_content": "ATOM 1 CA ALA A 1"} - } - + mock_request.return_value = {"success": True, "result": {"pdb_content": "ATOM 1 CA ALA A 1"}} + with tempfile.TemporaryDirectory() as tmpdir: output_path = os.path.join(tmpdir, "structure.pdb") result = download_structure("test123", output_path, "pdb") - + assert result["success"] is True assert os.path.exists(output_path) assert result["format"] == "pdb" @@ -345,10 +331,10 @@ def test_analyze_valid_pdb(self): with tempfile.NamedTemporaryFile(mode="w", suffix=".pdb", delete=False) as f: f.write(pdb_content) temp_path = f.name - + try: result = analyze_structure_confidence(temp_path) - + assert result["success"] is True assert result["total_residues"] == 2 assert 60 < result["mean_plddt"] < 70 @@ -370,14 +356,14 @@ def test_batch_empty_jobs(self): def test_batch_predict_multiple_jobs(self, mock_request): """Multiple jobs should be submitted.""" mock_request.return_value = {"success": True, "result": {"job_id": "batch123"}} - + jobs = [ {"type": "protein", "sequences": ["MKFLILLFNILCLFPVLAADNH"], "name": "protein1"}, {"type": "protein", "sequences": ["MALTEVNPKKYIPGTKMIFAG"], "name": "protein2"}, ] - + result = batch_predict_structures(jobs) - + assert result["success"] is True assert result["total_jobs"] == 2 assert result["submitted"] == 2 @@ -386,21 +372,21 @@ def test_batch_predict_multiple_jobs(self, mock_request): def test_batch_predict_with_output_dir(self, mock_request): """Batch prediction should create output directory.""" mock_request.return_value = {"success": True, "result": {"job_id": "batch123"}} - + with tempfile.TemporaryDirectory() as tmpdir: output_dir = os.path.join(tmpdir, "structures") jobs = [{"type": "protein", "sequences": ["MKFLILLFNILCLFPVLAADNH"], "name": "test"}] - + result = batch_predict_structures(jobs, output_dir=output_dir) - + assert result["success"] is True assert os.path.isdir(output_dir) def test_batch_predict_invalid_job_type(self): """Invalid job type should be reported as failed.""" jobs = [{"type": "invalid_type", "sequences": ["MKFL"], "name": "test"}] - + result = batch_predict_structures(jobs) - + assert result["failed"] == 1 assert "Unknown job type" in result["jobs"][0]["error"]