Skip to content

Commit

Permalink
First try Boltz notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Nov 17, 2024
1 parent c21e176 commit b339885
Showing 1 changed file with 371 additions and 0 deletions.
371 changes: 371 additions & 0 deletions boltz1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100",
"authorship_tag": "ABX9TyMG+xMrC8CD3APVXWbRvd84",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/boltz1.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"#@title Input protein sequence(s), then hit `Runtime` -> `Run all`\n",
"from google.colab import files\n",
"import os\n",
"import re\n",
"import hashlib\n",
"import random\n",
"from string import ascii_uppercase\n",
"\n",
"# Function to add a hash to the jobname\n",
"def add_hash(x, y):\n",
" return x + \"_\" + hashlib.sha1(y.encode()).hexdigest()[:5]\n",
"\n",
"# User inputs\n",
"query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n",
"ligand_input = 'N[C@@H](Cc1ccc(O)cc1)C(=O)O' #@param {type:\"string\"}\n",
"#@markdown - Use `:` to specify multiple ligands as smile strings\n",
"jobname = 'test' #@param {type:\"string\"}\n",
"\n",
"# Clean up the query sequence and jobname\n",
"query_sequence = \"\".join(query_sequence.split())\n",
"basejobname = \"\".join(jobname.split())\n",
"basejobname = re.sub(r'\\W+', '', basejobname)\n",
"jobname = add_hash(basejobname, query_sequence)\n",
"\n",
"# Check if a directory with jobname exists\n",
"def check(folder):\n",
" return not os.path.exists(folder)\n",
"\n",
"if not check(jobname):\n",
" n = 0\n",
" while not check(f\"{jobname}_{n}\"):\n",
" n += 1\n",
" jobname = f\"{jobname}_{n}\"\n",
"\n",
"# Make directory to save results\n",
"os.makedirs(jobname, exist_ok=True)\n",
"\n",
"from string import ascii_uppercase\n",
"\n",
"# Split sequences on chain breaks\n",
"protein_sequences = query_sequence.strip().split(':')\n",
"ligand_sequences = ligand_input.strip().split(':')\n",
"\n",
"# Initialize chain labels starting from 'A'\n",
"chain_labels = iter(ascii_uppercase)\n",
"\n",
"fasta_entries = []\n",
"csv_entries = []\n",
"chain_label_to_seq_id = {}\n",
"\n",
"# Process protein sequences\n",
"for i, seq in enumerate(protein_sequences):\n",
" seq = seq.strip()\n",
" if not seq:\n",
" continue # Skip empty sequences\n",
" chain_label = next(chain_labels)\n",
" seq_id = f\"{jobname}_{i}\"\n",
" chain_label_to_seq_id[chain_label] = seq_id\n",
" # For CSV file (for ColabFold)\n",
" csv_entries.append((seq_id, seq))\n",
" # For FASTA file\n",
" msa_path = os.path.join(jobname, f\"{seq_id}.a3m\")\n",
" header = f\">{chain_label}|protein|{msa_path}\"\n",
" sequence = seq\n",
" fasta_entries.append((header, sequence))\n",
"\n",
"# Process ligand sequences (assumed to be SMILES strings)\n",
"for lig in ligand_sequences:\n",
" lig = lig.strip()\n",
" if not lig:\n",
" continue # Skip empty ligands\n",
" chain_label = next(chain_labels)\n",
" lig_type = 'smiles'\n",
" header = f\">{chain_label}|{lig_type}\"\n",
" sequence = lig\n",
" fasta_entries.append((header, sequence))\n",
"\n",
"# Write the CSV file for ColabFold\n",
"queries_path = os.path.join(jobname, f\"{jobname}.csv\")\n",
"with open(queries_path, \"w\") as text_file:\n",
" text_file.write(\"id,sequence\\n\")\n",
" for seq_id, seq in csv_entries:\n",
" text_file.write(f\"{seq_id},{seq}\\n\")\n",
"\n",
"# Write the FASTA file\n",
"queries_fasta = os.path.join(jobname, f\"{jobname}.fasta\")\n",
"with open(queries_fasta, 'w') as f:\n",
" for header, sequence in fasta_entries:\n",
" f.write(f\"{header}\\n{sequence}\\n\")\n",
"\n",
"# Optionally, print the output for verification\n",
"#print(f\"Generated FASTA file '{queries_fasta}':\\n\")\n",
"#for header, sequence in fasta_entries:\n",
"# print(f\"{header}\\n{sequence}\\n\")\n"
],
"metadata": {
"cellView": "form",
"id": "AcYvVeDESi2a"
},
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "form",
"id": "4eXNO1JJHYrB",
"outputId": "c9b764cb-37e7-4ec0-8519-b488163d3b40"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"CPU times: user 30 µs, sys: 6 µs, total: 36 µs\n",
"Wall time: 39.1 µs\n"
]
}
],
"source": [
"#@title Install dependencies\n",
"%%time\n",
"import os\n",
"if not os.path.isfile(\"COLABFOLD_READY\"):\n",
" print(\"installing colabfold...\")\n",
" os.system(\"pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'\")\n",
" if os.environ.get('TPU_NAME', False) != False:\n",
" os.system(\"pip uninstall -y jax jaxlib\")\n",
" os.system(\"pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\")\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\")\n",
" os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold\")\n",
" os.system(\"touch COLABFOLD_READY\")\n",
"\n",
"if not os.path.isfile(\"BOLZ_READY\"):\n",
" os.system(\"pip install -q --no-warn-conflicts boltz\")\n",
" os.system(\"touch BOLZ_READY\")"
]
},
{
"cell_type": "code",
"source": [
"#@title Generate MSA with ColabFold\n",
"!colabfold_batch \"{queries_path}\" \"{jobname}\" --msa-only"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "form",
"id": "4aFDR4IhRe6y",
"outputId": "e32be2a0-2192-4385-9948-fafe3180dea4"
},
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2024-11-17 23:21:37,398 Running colabfold 1.5.5 (c21e1768d18e3608e6e6d99c97134317e7e41c75)\n",
"\n",
"WARNING: You are welcome to use the default MSA server, however keep in mind that it's a\n",
"limited shared resource only capable of processing a few thousand MSAs per day. Please\n",
"submit jobs only from a single IP address. We reserve the right to limit access to the\n",
"server case-by-case when usage exceeds fair use. If you require more MSAs: You can \n",
"precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url`\n",
"\n",
"2024-11-17 23:21:39,576 Running on GPU\n",
"2024-11-17 23:21:40,309 Found 4 citations for tools or databases\n",
"2024-11-17 23:21:40,310 Query 1/1: test_a5e17_1_0 (length 59)\n",
"COMPLETE: 100% 150/150 [00:22<00:00, 6.53it/s] \n",
"2024-11-17 23:22:03,313 Saved test_a5e17_1/test_a5e17_1_0.pickle\n",
"2024-11-17 23:22:04,055 Done\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title Predict structure using boltz\n",
"!boltz predict --out_dir \"{jobname}\" \"{jobname}/{jobname}.fasta\""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"cellView": "form",
"id": "bgaBXxXtIAu9",
"outputId": "7cb18790-83f7-42db-d3f1-1043025520b0"
},
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data and model to /root/.boltz. You may change this by setting the --cache flag.\n",
"Checking input data.\n",
"Processing input data.\n",
"100% 1/1 [00:00<00:00, 16.14it/s]\n",
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
"2024-11-17 23:22:42.568213: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
"2024-11-17 23:22:42.619224: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
"Predicting DataLoader 0: 100% 1/1 [00:10<00:00, 10.87s/it]Number of failed examples: 0\n",
"Predicting DataLoader 0: 100% 1/1 [00:10<00:00, 10.87s/it]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"#@title Download results\n",
"# Import necessary modules\n",
"import os\n",
"import zipfile\n",
"from google.colab import files\n",
"import glob\n",
"\n",
"# Ensure 'jobname' variable is defined\n",
"# jobname = 'test_abcde' # Uncomment and set if not already defined\n",
"\n",
"# Name of the zip file\n",
"zip_filename = f\"results_{jobname}.zip\"\n",
"\n",
"# Create a zip file and add the specified files without preserving directory structure\n",
"with zipfile.ZipFile(zip_filename, 'w') as zipf:\n",
" coverage_png_files = glob.glob(os.path.join(jobname, '*_coverage.png'))\n",
" a3m_files = glob.glob(os.path.join(jobname, '*.a3m'))\n",
" for file in coverage_png_files + a3m_files:\n",
" arcname = os.path.basename(file) # Use only the file name\n",
" zipf.write(file, arcname=arcname)\n",
"\n",
" cif_files = glob.glob(os.path.join(jobname, f'boltz_results_{jobname}', 'predictions', jobname, '*.cif'))\n",
" for file in cif_files:\n",
" arcname = os.path.basename(file) # Use only the file name\n",
" zipf.write(file, arcname=arcname)\n",
"\n",
" hparams_file = os.path.join(jobname, f'boltz_results_{jobname}', 'lightning_logs', 'version_0', 'hparams.yaml')\n",
" if os.path.exists(hparams_file):\n",
" arcname = os.path.basename(hparams_file) # Use only the file name\n",
" zipf.write(hparams_file, arcname=arcname)\n",
" else:\n",
" print(f\"Warning: {hparams_file} not found.\")\n",
"\n",
"# Download the zip file\n",
"files.download(zip_filename)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"cellView": "form",
"id": "jdSBSTOpaULF",
"outputId": "921ee9cb-9d73-48b7-de0b-0ec29d73d2bf"
},
"execution_count": 42,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<IPython.core.display.Javascript object>"
],
"application/javascript": [
"download(\"download_bb74d9f7-a559-405a-b95e-4444e730cef3\", \"results_test_a5e17_1.zip\", 1119339)"
]
},
"metadata": {}
}
]
}
]
}

0 comments on commit b339885

Please sign in to comment.