diff --git a/RoseTTAFold2.ipynb b/RoseTTAFold2.ipynb
index db6f0729..53361622 100644
--- a/RoseTTAFold2.ipynb
+++ b/RoseTTAFold2.ipynb
@@ -50,20 +50,36 @@
{
"cell_type": "code",
"source": [
- "#@title setup **RoseTTAFold2** (~1m)\n",
"%%time\n",
+ "#@title setup **RoseTTAFold2** (~1m)\n",
+ "params = \"RF2_jan24\" # @param [\"RF2_apr23\",\"RF2_jan24\"]\n",
+ "\n",
"import os, time, sys\n",
- "if not os.path.isfile(\"RF2_apr23.pt\"):\n",
+ "os.environ[\"PYTORCH_CUDA_ALLOC_CONF\"] = \"max_split_size_mb:512\"\n",
+ "\n",
+ "if params == \"RF2_jan24\" and not os.path.isfile(\"RF2_jan24.tgz\"):\n",
+ " # send param download into background\n",
+ " os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_jan24.tgz) &\")\n",
+ "\n",
+ "if params == \"RF2_apr23\" and not os.path.isfile(\"RF2_apr23.tgz\"):\n",
" # send param download into background\n",
- " os.system(\"(apt-get install aria2; aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/RF2_apr23.pt) &\")\n",
+ " os.system(\"(apt-get install aria2; aria2c -q -x 16 https://files.ipd.uw.edu/dimaio/RF2_apr23.tgz) &\")\n",
"\n",
"if not os.path.isdir(\"RoseTTAFold2\"):\n",
" print(\"install RoseTTAFold2\")\n",
- " os.system(\"git clone https://github.com/sokrypton/RoseTTAFold2.git\")\n",
- " os.system(\"pip -q install py3Dmol\")\n",
- " os.system(\"pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html\")\n",
- " os.system(\"cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .\")\n",
- " os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py\")\n",
+ " os.system(\"git clone --branch 'dimaio/gpu_mem_efficiency' https://github.com/uw-ipd/RoseTTAFold2.git\")\n",
+ " os.system(\"pip install py3Dmol\")\n",
+ "\n",
+ " # 17Mar2024: adding --no-dependencies to avoid installing nvidia-cuda-* dependencies\n",
+ " os.system(\"pip install --no-dependencies dgl==2.0.0 -f https://data.dgl.ai/wheels/cu121/repo.html\")\n",
+ " os.system(\"pip install --no-dependencies e3nn==0.3.3 opt_einsum_fx\")\n",
+ " os.system(\"cd RoseTTAFold2/SE3Transformer; pip install .\")\n",
+ "\n",
+ " os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/main/colabfold/colabfold.py -O colabfold_utils.py\")\n",
+ "\n",
+ " #os.system(\"pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html\")\n",
+ " #os.system(\"cd RoseTTAFold2/SE3Transformer; pip -q install --no-cache-dir -r requirements.txt; pip -q install .\")\n",
+ " #os.system(\"wget https://raw.githubusercontent.com/sokrypton/ColabFold/beta/colabfold/mmseqs/api.py\")\n",
"\n",
" # install hhsuite\n",
" print(\"install hhsuite\")\n",
@@ -71,11 +87,14 @@
" os.system(f\"curl -fsSL https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-SSE2-Linux.tar.gz | tar xz -C hhsuite/\")\n",
"\n",
"\n",
- "if os.path.isfile(f\"RF2_apr23.pt.aria2\"):\n",
+ "if os.path.isfile(f\"{params}.tgz.aria2\"):\n",
" print(\"downloading RoseTTAFold2 params\")\n",
- " while os.path.isfile(f\"RF2_apr23.pt.aria2\"):\n",
+ " while os.path.isfile(f\"{params}.tgz.aria2\"):\n",
" time.sleep(5)\n",
"\n",
+ "if not os.path.isfile(f\"{params}.pt\"):\n",
+ " os.system(f\"tar -zxvf {params}.tgz\")\n",
+ "\n",
"if not \"IMPORTED\" in dir():\n",
" if 'RoseTTAFold2/network' not in sys.path:\n",
" os.environ[\"DGLBACKEND\"] = \"pytorch\"\n",
@@ -87,7 +106,8 @@
" from google.colab import files\n",
" import numpy as np\n",
" from parsers import parse_a3m\n",
- " from api import run_mmseqs2\n",
+ " #from api import run_mmseqs2\n",
+ " from colabfold_utils import run_mmseqs2\n",
" import py3Dmol\n",
" import torch\n",
" from string import ascii_uppercase, ascii_lowercase\n",
@@ -103,7 +123,7 @@
"if not \"pred\" in dir():\n",
" from predict import Predictor\n",
" print(\"compile RoseTTAFold2\")\n",
- " model_params = \"RF2_apr23.pt\"\n",
+ " model_params = f\"{params}.pt\"\n",
" if (torch.cuda.is_available()):\n",
" pred = Predictor(model_params, torch.device(\"cuda:0\"))\n",
" else:\n",
@@ -114,6 +134,10 @@
" unique_seqs = list(OrderedDict.fromkeys(seq_list))\n",
" return unique_seqs\n",
"\n",
+ "def run_mmseqs2_wrapper(*args, **kwargs):\n",
+ " kwargs['user_agent'] = \"colabfold/rosettafold2\"\n",
+ " return run_mmseqs2(*args, **kwargs)\n",
+ "\n",
"def get_msa(seq, jobname, cov=50, id=90, max_msa=2048,\n",
" mode=\"unpaired_paired\"):\n",
"\n",
@@ -133,7 +157,7 @@
" os.makedirs(path, exist_ok=True)\n",
" if mode in [\"paired\",\"unpaired_paired\"] and len(u_seqs) > 1:\n",
" print(\"getting paired MSA\")\n",
- " out_paired = run_mmseqs2(u_seqs, f\"{path}/\", use_pairing=True)\n",
+ " out_paired = run_mmseqs2_wrapper(u_seqs, f\"{path}/\", use_pairing=True)\n",
" headers, sequences = [],[]\n",
" for a3m_lines in out_paired:\n",
" n = -1\n",
@@ -163,7 +187,7 @@
"\n",
" if len(msa) < max_msa and (mode in [\"unpaired\",\"unpaired_paired\"] or len(u_seqs) == 1):\n",
" print(\"getting unpaired MSA\")\n",
- " out = run_mmseqs2(u_seqs,f\"{path}/\")\n",
+ " out = run_mmseqs2_wrapper(u_seqs,f\"{path}/\")\n",
" Ls = [len(seq) for seq in u_seqs]\n",
" sub_idx = []\n",
" sub_msa = []\n",
@@ -199,10 +223,24 @@
],
"metadata": {
"id": "ymCHu14w17wF",
- "cellView": "form"
+ "cellView": "form",
+ "outputId": "d273d9af-d478-4d1f-d13b-681133264e04",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "compile RoseTTAFold2\n",
+ "CPU times: user 9.78 s, sys: 981 ms, total: 10.8 s\n",
+ "Wall time: 30.5 s\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -257,7 +295,11 @@
"lengths = [len(s) for s in sequences]\n",
"\n",
"# TODO\n",
- "subcrop = 1000 if sum(lengths) > 1400 else -1\n",
+ "#subcrop = 1000 if sum(lengths) > 1400 else -1\n",
+ "\n",
+ "subcrop = -1\n",
+ "topk = 1536\n",
+ "\n",
"\n",
"sequence = \"/\".join(sequences)\n",
"jobname = jobname+\"_\"+symm+\"_\"+get_hash(sequence)[:5]\n",
@@ -304,6 +346,7 @@
" nseqs=max_msa,\n",
" nseqs_full=max_extra_msa,\n",
" subcrop=subcrop,\n",
+ " topk=topk,\n",
" is_training=use_dropout)\n",
" plddt = np.load(npz)[\"lddt\"].mean()\n",
" if best_plddt is None or plddt > best_plddt:\n",
@@ -312,10 +355,46 @@
],
"metadata": {
"cellView": "form",
- "id": "_oJTZGgdeKkO"
+ "id": "_oJTZGgdeKkO",
+ "outputId": "6d418113-3aea-47d0-b08e-af78a336a7c7",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "jobname: test_X1_a5e17\n",
+ "lengths: [59]\n",
+ "getting unpaired MSA\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "COMPLETE: 100%|██████████| 150/150 [elapsed: 00:01 remaining: 00:00]\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "N=2048 L=59\n",
+ "recycle 0 plddt 0.882 pae 2.725 rmsd 13.466\n",
+ "recycle 1 plddt 0.902 pae 2.400 rmsd 0.578\n",
+ "recycle 2 plddt 0.904 pae 2.387 rmsd 0.214\n",
+ "recycle 3 plddt 0.905 pae 2.375 rmsd 0.140\n",
+ "recycle 4 plddt 0.904 pae 2.418 rmsd 0.097\n",
+ "recycle 5 plddt 0.905 pae 2.393 rmsd 0.066\n",
+ "recycle 6 plddt 0.905 pae 2.400 rmsd 0.052\n",
+ "runtime=14.29 vram=0.76\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -362,10 +441,81 @@
],
"metadata": {
"cellView": "form",
- "id": "53wdd2WX70o_"
+ "id": "53wdd2WX70o_",
+ "outputId": "2b431e52-456b-43cf-cb03-832b31bc12fd",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 948
+ }
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "application/3dmoljs_load.v0": "
\n
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n
\n",
+ "text/html": [
+ "\n",
+ "
3Dmol.js failed to load for some reason. Please check your browser console for error messages.
\n",
+ "
\n",
+ ""
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "