From b33988507db3ad68150916959f33b77db9b75605 Mon Sep 17 00:00:00 2001 From: Martin Steinegger Date: Mon, 18 Nov 2024 00:29:01 +0100 Subject: [PATCH] First try Boltz notebook --- boltz1.ipynb | 371 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 boltz1.ipynb diff --git a/boltz1.ipynb b/boltz1.ipynb new file mode 100644 index 00000000..a3f274e4 --- /dev/null +++ b/boltz1.ipynb @@ -0,0 +1,371 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm", + "gpuType": "A100", + "authorship_tag": "ABX9TyMG+xMrC8CD3APVXWbRvd84", + "include_colab_link": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Input protein sequence(s), then hit `Runtime` -> `Run all`\n", + "from google.colab import files\n", + "import os\n", + "import re\n", + "import hashlib\n", + "import random\n", + "from string import ascii_uppercase\n", + "\n", + "# Function to add a hash to the jobname\n", + "def add_hash(x, y):\n", + " return x + \"_\" + hashlib.sha1(y.encode()).hexdigest()[:5]\n", + "\n", + "# User inputs\n", + "query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK' #@param {type:\"string\"}\n", + "#@markdown - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer\n", + "ligand_input = 'N[C@@H](Cc1ccc(O)cc1)C(=O)O' #@param {type:\"string\"}\n", + "#@markdown - Use `:` to specify multiple ligands as smile strings\n", + "jobname = 'test' #@param {type:\"string\"}\n", + "\n", + "# Clean up the query sequence and jobname\n", + "query_sequence = \"\".join(query_sequence.split())\n", + "basejobname = \"\".join(jobname.split())\n", + "basejobname = re.sub(r'\\W+', '', basejobname)\n", + "jobname = add_hash(basejobname, query_sequence)\n", + "\n", + "# Check if a directory with jobname exists\n", + "def check(folder):\n", + " return not os.path.exists(folder)\n", + "\n", + "if not check(jobname):\n", + " n = 0\n", + " while not check(f\"{jobname}_{n}\"):\n", + " n += 1\n", + " jobname = f\"{jobname}_{n}\"\n", + "\n", + "# Make directory to save results\n", + "os.makedirs(jobname, exist_ok=True)\n", + "\n", + "from string import ascii_uppercase\n", + "\n", + "# Split sequences on chain breaks\n", + "protein_sequences = query_sequence.strip().split(':')\n", + "ligand_sequences = ligand_input.strip().split(':')\n", + "\n", + "# Initialize chain labels starting from 'A'\n", + "chain_labels = iter(ascii_uppercase)\n", + "\n", + "fasta_entries = []\n", + "csv_entries = []\n", + "chain_label_to_seq_id = {}\n", + "\n", + "# Process protein sequences\n", + "for i, seq in enumerate(protein_sequences):\n", + " seq = seq.strip()\n", + " if not seq:\n", + " continue # Skip empty sequences\n", + " chain_label = next(chain_labels)\n", + " seq_id = f\"{jobname}_{i}\"\n", + " chain_label_to_seq_id[chain_label] = seq_id\n", + " # For CSV file (for ColabFold)\n", + " csv_entries.append((seq_id, seq))\n", + " # For FASTA file\n", + " msa_path = os.path.join(jobname, f\"{seq_id}.a3m\")\n", + " header = f\">{chain_label}|protein|{msa_path}\"\n", + " sequence = seq\n", + " fasta_entries.append((header, sequence))\n", + "\n", + "# Process ligand sequences (assumed to be SMILES strings)\n", + "for lig in ligand_sequences:\n", + " lig = lig.strip()\n", + " if not lig:\n", + " continue # Skip empty ligands\n", + " chain_label = next(chain_labels)\n", + " lig_type = 'smiles'\n", + " header = f\">{chain_label}|{lig_type}\"\n", + " sequence = lig\n", + " fasta_entries.append((header, sequence))\n", + "\n", + "# Write the CSV file for ColabFold\n", + "queries_path = os.path.join(jobname, f\"{jobname}.csv\")\n", + "with open(queries_path, \"w\") as text_file:\n", + " text_file.write(\"id,sequence\\n\")\n", + " for seq_id, seq in csv_entries:\n", + " text_file.write(f\"{seq_id},{seq}\\n\")\n", + "\n", + "# Write the FASTA file\n", + "queries_fasta = os.path.join(jobname, f\"{jobname}.fasta\")\n", + "with open(queries_fasta, 'w') as f:\n", + " for header, sequence in fasta_entries:\n", + " f.write(f\"{header}\\n{sequence}\\n\")\n", + "\n", + "# Optionally, print the output for verification\n", + "#print(f\"Generated FASTA file '{queries_fasta}':\\n\")\n", + "#for header, sequence in fasta_entries:\n", + "# print(f\"{header}\\n{sequence}\\n\")\n" + ], + "metadata": { + "cellView": "form", + "id": "AcYvVeDESi2a" + }, + "execution_count": 35, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "cellView": "form", + "id": "4eXNO1JJHYrB", + "outputId": "c9b764cb-37e7-4ec0-8519-b488163d3b40" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "CPU times: user 30 µs, sys: 6 µs, total: 36 µs\n", + "Wall time: 39.1 µs\n" + ] + } + ], + "source": [ + "#@title Install dependencies\n", + "%%time\n", + "import os\n", + "if not os.path.isfile(\"COLABFOLD_READY\"):\n", + " print(\"installing colabfold...\")\n", + " os.system(\"pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'\")\n", + " if os.environ.get('TPU_NAME', False) != False:\n", + " os.system(\"pip uninstall -y jax jaxlib\")\n", + " os.system(\"pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\")\n", + " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold\")\n", + " os.system(\"ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold\")\n", + " os.system(\"touch COLABFOLD_READY\")\n", + "\n", + "if not os.path.isfile(\"BOLZ_READY\"):\n", + " os.system(\"pip install -q --no-warn-conflicts boltz\")\n", + " os.system(\"touch BOLZ_READY\")" + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Generate MSA with ColabFold\n", + "!colabfold_batch \"{queries_path}\" \"{jobname}\" --msa-only" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "cellView": "form", + "id": "4aFDR4IhRe6y", + "outputId": "e32be2a0-2192-4385-9948-fafe3180dea4" + }, + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "2024-11-17 23:21:37,398 Running colabfold 1.5.5 (c21e1768d18e3608e6e6d99c97134317e7e41c75)\n", + "\n", + "WARNING: You are welcome to use the default MSA server, however keep in mind that it's a\n", + "limited shared resource only capable of processing a few thousand MSAs per day. Please\n", + "submit jobs only from a single IP address. We reserve the right to limit access to the\n", + "server case-by-case when usage exceeds fair use. If you require more MSAs: You can \n", + "precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url`\n", + "\n", + "2024-11-17 23:21:39,576 Running on GPU\n", + "2024-11-17 23:21:40,309 Found 4 citations for tools or databases\n", + "2024-11-17 23:21:40,310 Query 1/1: test_a5e17_1_0 (length 59)\n", + "COMPLETE: 100% 150/150 [00:22<00:00, 6.53it/s] \n", + "2024-11-17 23:22:03,313 Saved test_a5e17_1/test_a5e17_1_0.pickle\n", + "2024-11-17 23:22:04,055 Done\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Predict structure using boltz\n", + "!boltz predict --out_dir \"{jobname}\" \"{jobname}/{jobname}.fasta\"" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "cellView": "form", + "id": "bgaBXxXtIAu9", + "outputId": "7cb18790-83f7-42db-d3f1-1043025520b0" + }, + "execution_count": 38, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading data and model to /root/.boltz. You may change this by setting the --cache flag.\n", + "Checking input data.\n", + "Processing input data.\n", + "100% 1/1 [00:00<00:00, 16.14it/s]\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "2024-11-17 23:22:42.568213: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-11-17 23:22:42.619224: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "Predicting DataLoader 0: 100% 1/1 [00:10<00:00, 10.87s/it]Number of failed examples: 0\n", + "Predicting DataLoader 0: 100% 1/1 [00:10<00:00, 10.87s/it]\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "#@title Download results\n", + "# Import necessary modules\n", + "import os\n", + "import zipfile\n", + "from google.colab import files\n", + "import glob\n", + "\n", + "# Ensure 'jobname' variable is defined\n", + "# jobname = 'test_abcde' # Uncomment and set if not already defined\n", + "\n", + "# Name of the zip file\n", + "zip_filename = f\"results_{jobname}.zip\"\n", + "\n", + "# Create a zip file and add the specified files without preserving directory structure\n", + "with zipfile.ZipFile(zip_filename, 'w') as zipf:\n", + " coverage_png_files = glob.glob(os.path.join(jobname, '*_coverage.png'))\n", + " a3m_files = glob.glob(os.path.join(jobname, '*.a3m'))\n", + " for file in coverage_png_files + a3m_files:\n", + " arcname = os.path.basename(file) # Use only the file name\n", + " zipf.write(file, arcname=arcname)\n", + "\n", + " cif_files = glob.glob(os.path.join(jobname, f'boltz_results_{jobname}', 'predictions', jobname, '*.cif'))\n", + " for file in cif_files:\n", + " arcname = os.path.basename(file) # Use only the file name\n", + " zipf.write(file, arcname=arcname)\n", + "\n", + " hparams_file = os.path.join(jobname, f'boltz_results_{jobname}', 'lightning_logs', 'version_0', 'hparams.yaml')\n", + " if os.path.exists(hparams_file):\n", + " arcname = os.path.basename(hparams_file) # Use only the file name\n", + " zipf.write(hparams_file, arcname=arcname)\n", + " else:\n", + " print(f\"Warning: {hparams_file} not found.\")\n", + "\n", + "# Download the zip file\n", + "files.download(zip_filename)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "cellView": "form", + "id": "jdSBSTOpaULF", + "outputId": "921ee9cb-9d73-48b7-de0b-0ec29d73d2bf" + }, + "execution_count": 42, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "download(\"download_bb74d9f7-a559-405a-b95e-4444e730cef3\", \"results_test_a5e17_1.zip\", 1119339)" + ] + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file