From 6b69a9e218f2ca4e071056c73705c6f91b035893 Mon Sep 17 00:00:00 2001 From: Emily Miller Date: Mon, 25 Oct 2021 16:44:06 -0700 Subject: [PATCH] 496 docs review (#150) * add denspose and update save dir * commit two autogenerated files * changelog edits * contribute edits * WIP edits for models * finish edits for models and separate out densepose * full list of parks * config options * vscode format * move config guide into tutorials * extra options * windows not tested * tutorial edits * remove extra nb * put save path back * quickstart edits * vscode formatting * finish quickstart edits * train tutorial * add template section * logging * ffpmeg install in readme * finetuning * capitalization * remove ffmpmeg * date * add densepose * tweak * simplify * fix densepose video link * caps and tensorboard * add help * remove pythong piece since this is focused on yaml files * more tweaks * edit history not change log * copy edits * fix changelog * table * table * typo * alphabetize * table bug * table edits * Simplify save_dir and some directory -> dir renames (#151) * wip renames * renames in docs * readme * data dir renamme in docs * rename in code from data_directory to data_dir * maintaining update * fix capitalization * further updates * tweak * do not overwrite * add overwrite save dir * add overwrite save dir to config * update configs with all info * use full train configuration * only upload if does not exist * tests for save * overwrite param * better set up and test for overwrite * docs * update docs with overwrite * from overwrite_save_dir to overwrite * missed rename * remove machine specific from vlc * unindent so test actually runs * check for local and cached checkpoints * should be and * write out predict config before preds start like we do for train config * update all configs and use only first 10 digits of hash * dry run check after save is configured; more robust test * reorder * show save directory * copy edits * update template * fix test * lower case for consistency * fix test --- .github/MAINTAINING.md | 2 +- HISTORY.md | 14 +- Makefile | 2 +- README.md | 15 +- docs/docs/api-reference/densepose_config.md | 3 + docs/docs/api-reference/densepose_manager.md | 3 + docs/docs/changelog/index.md | 0 docs/docs/configurations.md | 120 ++--- docs/docs/contribute/index.md | 16 +- docs/docs/debugging.md | 14 +- docs/docs/extra-options.md | 40 +- docs/docs/install.md | 10 +- .../zamba_config_diagram.png} | Bin docs/docs/models/denspose.md | 101 +++++ docs/docs/models/index.md | 334 -------------- docs/docs/models/species-detection.md | 318 +++++++++++++ docs/docs/predict-tutorial.md | 41 +- docs/docs/python-tutorial.ipynb | 423 ------------------ docs/docs/quickstart.md | 51 ++- docs/docs/train-tutorial.md | 83 ++-- docs/docs/yaml-config.md | 116 +---- docs/mkdocs.yml | 24 +- templates/european.yaml | 6 +- templates/slowfast.yaml | 4 +- templates/time_distributed.yaml | 12 +- tests/assets/sample_predict_config.yaml | 2 +- tests/assets/sample_train_config.yaml | 2 +- tests/conftest.py | 4 +- tests/test_cli.py | 23 +- tests/test_config.py | 105 +++-- tests/test_densepose.py | 10 +- tests/test_model_manager.py | 20 +- zamba/cli.py | 47 +- zamba/models/config.py | 157 +++---- zamba/models/densepose/config.py | 20 +- zamba/models/model_manager.py | 35 +- .../official_models/european/config.yaml | 29 +- .../european/predict_configuration.yaml | 12 +- .../european/train_configuration.yaml | 10 +- .../official_models/slowfast/config.yaml | 42 +- .../slowfast/predict_configuration.yaml | 6 +- .../slowfast/train_configuration.yaml | 6 +- .../time_distributed/config.yaml | 42 +- .../predict_configuration.yaml | 12 +- .../time_distributed/train_configuration.yaml | 10 +- zamba/models/publish_models.py | 62 ++- zamba/models/utils.py | 7 +- 47 files changed, 1071 insertions(+), 1344 deletions(-) create mode 100644 docs/docs/api-reference/densepose_config.md create mode 100644 docs/docs/api-reference/densepose_manager.md delete mode 100644 docs/docs/changelog/index.md rename docs/docs/{config_diagram.png => media/zamba_config_diagram.png} (100%) create mode 100644 docs/docs/models/denspose.md delete mode 100644 docs/docs/models/index.md create mode 100644 docs/docs/models/species-detection.md delete mode 100644 docs/docs/python-tutorial.ipynb diff --git a/.github/MAINTAINING.md b/.github/MAINTAINING.md index 88b2bdda..657f6f8c 100644 --- a/.github/MAINTAINING.md +++ b/.github/MAINTAINING.md @@ -113,7 +113,7 @@ make publish_models This will generate a public file name for each model based on the config hash and upload the model weights to the three DrivenData public s3 buckets. This will generate a folder in `zamba/models/official_models/{your_name_name}` that contains the official config as well as reference yaml and json files. You should PR everything in this folder. -Lastly, you need to update the template in `templates`. The template should contain all the same info as the model's `config.yaml`, plus placeholders for `data_directory` and `labels` in `train_config`, and `data_directory`, `filepaths`, and `checkpoint` in `predict_config`. +Lastly, you need to update the template in `templates`. The template should contain all the same info as the model's `config.yaml`, plus placeholders for `data_dir` and `labels` in `train_config`, and `data_dir`, `filepaths`, and `checkpoint` in `predict_config`. ### New model checklist diff --git a/HISTORY.md b/HISTORY.md index 218bb272..9912d643 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,8 +1,8 @@ -# zamba Changelog +# `zamba` changelog -## v2 139\u001b[0;31m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/boto3/resources/factory.py\u001b[0m in \u001b[0;36mdo_action\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdo_action\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 505\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 506\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeta\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/boto3/resources/action.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, parent, *args, **kwargs)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmeta\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/client.py\u001b[0m in \u001b[0;36m_api_call\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 387\u001b[0m \u001b[0;31m# The \"self\" in this scope is referring to the BaseClient.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 388\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_api_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperation_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/client.py\u001b[0m in \u001b[0;36m_make_api_call\u001b[0;34m(self, operation_name, api_params)\u001b[0m\n\u001b[1;32m 707\u001b[0m \u001b[0merror_class\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_code\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 708\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merror_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparsed_response\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 709\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mClientError\u001b[0m: An error occurred (403) when calling the HeadObject operation: Forbidden", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mClientError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/var/folders/hh/_nbd8pkn08x5ty_mx954bjvm0000gn/T/ipykernel_11531/1464135739.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpredict_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpredict_config\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpredict_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/Repos/zamba/zamba/models/model_manager.py\u001b[0m in \u001b[0;36mpredict_model\u001b[0;34m(predict_config, video_loader_config)\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 330\u001b[0m \u001b[0;31m# set up model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 331\u001b[0;31m model = instantiate_model(\n\u001b[0m\u001b[1;32m 332\u001b[0m \u001b[0mcheckpoint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpredict_config\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0mweight_download_region\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpredict_config\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight_download_region\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Repos/zamba/zamba/models/model_manager.py\u001b[0m in \u001b[0;36minstantiate_model\u001b[0;34m(checkpoint, weight_download_region, scheduler_config, cache_dir, labels, from_scratch, predict_all_zamba_species)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Downloading weights for model.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m checkpoint = download_weights(\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0mfilename\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mweight_region\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mweight_download_region\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/Repos/zamba/zamba/models/utils.py\u001b[0m in \u001b[0;36mdownload_weights\u001b[0;34m(filename, destination_dir, weight_region)\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0mclient\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mS3Client\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlocal_cache_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdestination_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mno_sign_request\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m )\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0ms3p\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfspath\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/cloudpath.py\u001b[0m in \u001b[0;36mfspath\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 304\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mproperty\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mfspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 306\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__fspath__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 307\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mglob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpattern\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"CloudPath\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/cloudpath.py\u001b[0m in \u001b[0;36m__fspath__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 208\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 209\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__fspath__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 210\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 211\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_refresh_cache\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mforce_overwrite_from_cloud\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 212\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_local\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/s3/s3path.py\u001b[0m in \u001b[0;36mis_file\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_file\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 39\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclient\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_file_or_dir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"file\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 40\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmkdir\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparents\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/s3/s3client.py\u001b[0m in \u001b[0;36m_is_file_or_dir\u001b[0;34m(self, cloud_path)\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;31m# get first item by listing at least one key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 121\u001b[0;31m \u001b[0ms3_obj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_s3_file_query\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcloud_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 122\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0ms3_obj\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/s3/s3client.py\u001b[0m in \u001b[0;36m_s3_file_query\u001b[0;34m(self, cloud_path)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;31m# else, confirm it is a dir by filtering to the first item under the prefix\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mClientError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 144\u001b[0;31m return next(\n\u001b[0m\u001b[1;32m 145\u001b[0m (\n\u001b[1;32m 146\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/cloudpathlib/s3/s3client.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mClientError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m return next(\n\u001b[0;32m--> 145\u001b[0;31m (\n\u001b[0m\u001b[1;32m 146\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m for obj in (\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/boto3/resources/collection.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 83\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mpage\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpages\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 84\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpage\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/boto3/resources/collection.py\u001b[0m in \u001b[0;36mpages\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;31m# we start processing and yielding individual items.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mpage\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mpages\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0mpage_items\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/paginate.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_inject_starting_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcurrent_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_request\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcurrent_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0mparsed\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_extract_parsed_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mfirst_request\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/paginate.py\u001b[0m in \u001b[0;36m_make_request\u001b[0;34m(self, current_kwargs)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_make_request\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mcurrent_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_extract_parsed_response\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresponse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/client.py\u001b[0m in \u001b[0;36m_api_call\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 386\u001b[0m \"%s() only accepts keyword arguments.\" % py_operation_name)\n\u001b[1;32m 387\u001b[0m \u001b[0;31m# The \"self\" in this scope is referring to the BaseClient.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 388\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_api_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperation_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 390\u001b[0m \u001b[0m_api_call\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpy_operation_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/opt/anaconda3/envs/zamba/lib/python3.8/site-packages/botocore/client.py\u001b[0m in \u001b[0;36m_make_api_call\u001b[0;34m(self, operation_name, api_params)\u001b[0m\n\u001b[1;32m 706\u001b[0m \u001b[0merror_code\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparsed_response\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Error\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Code\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 707\u001b[0m \u001b[0merror_class\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_code\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0merror_code\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 708\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0merror_class\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparsed_response\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 709\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 710\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparsed_response\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mClientError\u001b[0m: An error occurred (AccessDenied) when calling the ListObjects operation: Access Denied" - ] - } - ], - "source": [ - "predict_model(predict_config=predict_config)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "outline\n", - "\n", - "- requirements\n", - "- generating predictions for unlabeled videos\n", - "- training a model on labeled videos" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "7c5f4cd3bd54a8e2ae4ec218923447b01b0f54366b28f342e656c452f99042eb" - }, - "kernelspec": { - "display_name": "Python 3.8.12 64-bit ('zamba': conda)", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.12" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/docs/docs/quickstart.md b/docs/docs/quickstart.md index c861a049..5fce19d0 100644 --- a/docs/docs/quickstart.md +++ b/docs/docs/quickstart.md @@ -18,7 +18,7 @@ macOS, this can be done in the terminal (⌘+space, "Terminal"). On Windows, thi ## How do I organize my videos for `zamba`? -You can input the path to a directory of videos or specify a list of file paths. `zamba` supports the same video formats as FFmpeg, [which are listed here](https://www.ffmpeg.org/general.html#Supported-File-Formats_002c-Codecs-or-Features). Any videos that fail a set of FFmpeg checks will be skipped during inference or training. +You can specify the path to a directory of videos or specify a list of filepaths in a `.csv` file. `zamba` supports the same video formats as FFmpeg, [which are listed here](https://www.ffmpeg.org/general.html#Supported-File-Formats_002c-Codecs-or-Features). Any videos that fail a set of FFmpeg checks will be skipped during inference or training. For example, say we have a directory of videos called `example_vids` that we want to generate predictions for using `zamba`. Let's list the videos: @@ -64,8 +64,9 @@ To generate and save predictions for your videos using the default settings, run $ zamba predict --data-dir example_vids/ ``` -`zamba` will output a `.csv` file with rows labeled by each video filename and columns for each class (ie. species). The default prediction will store all class probabilities, so that cell (i,j) is *the probability that animal j is present in video i.* Comprehensive predictions are helpful when a single video contains multiple species. -Predictions will be saved to `zamba_predictions.csv` in the current working directory by default. You can save out predictions under a different name or in a different folder using the `--save-path` argument. +`zamba` will output a `.csv` file with rows labeled by each video filename and columns for each class (ie. species). The default prediction will store all class probabilities, so that cell `(i,j)` is *the probability that animal `j` is present in video `i`.* Comprehensive predictions are helpful when a single video contains multiple species. + +Predictions will be saved to `zamba_predictions.csv` in the current working directory by default. You can save out predictions to a different folder using the `--save-dir` argument. Adding the argument `--output-class-names` will simplify the predictions to return only the *most likely* animal in each video: @@ -88,8 +89,8 @@ $ zamba predict --data-dir example_vids/ --model slowfast You can continue training one of the [models](models/index.md) that ships with `zamba` by either: -* Fine-tuning with additional labeled videos where the species are included in the list of [`zamba` class labels](models/index.md#species-classes) -* Fine-tuning with labeled videos that include new species +* Finetuning with additional labeled videos where the species are included in the list of [`zamba` class labels](models/index.md#species-classes) +* Finetuning with labeled videos that include new species In either case, the commands for training are the same. Say that we have labels for the videos in the `example_vids` folder saved in `example_labels.csv`. To train a model, run: @@ -108,31 +109,32 @@ eleph.MP4,elephant leopard.MP4,leopard ``` -By default, the trained model and additional training output will be saved to a folder in the current working directory called `zamba_{model_name}`. For example, a model finetuned from the provided `time_distributed` model will be saved in `zamba_time_distributed`. +By default, the trained model and additional training output will be saved to a `version_n` folder in the current working directory. For example, ```console $ zamba train --data-dir example_vids/ --labels example_labels.csv -$ ls zamba_time_distributed -time_distributed.ckpt +$ ls version_0/ +hparams.yaml +time_distributed.ckpt +train_configuration.yamml +val_metrics.json ... ``` ## Downloading model weights -**`zamba` needs to download the "weights" files for the neural networks that it uses to make predictions. On first run it will download ~200-500 MB of files with these weights depending which model you choose.** -Once a model's weights are downloaded, the tool will use the local version and will not need to perform this download again. If you are not in the US, we recommend running the above command with the additional flag either `--weight_download_region eu` or `--weight_download_region asia` depending on your location. The closer you are to the server the faster the downloads will be. +**`zamba` needs to download the "weights" files for the models it uses to make predictions. On first run, it will download ~200-500 MB of files with these weights depending which model you choose.** +Once a model's weights are downloaded, `zamba` will use the local version and will not need to perform this download again. If you are not in the United States, we recommend running the above command with the additional flag either `--weight_download_region eu` or `--weight_download_region asia` depending on your location. The closer you are to the server, the faster the downloads will be. ## Getting help -Once zamba is installed, you can see more details of each function with `--help`. +Once zamba is installed, you can see more details of each function with `--help`. To get help with `zamba predict`: ```console -$ zamba predict --help - Usage: zamba predict [OPTIONS] Identify species in a video. @@ -159,11 +161,13 @@ Options: specifiied, will use all GPUs found on machine. --batch-size INTEGER Batch size to use for training. - --save / --no-save Whether to save out predictions to a csv - file. If you want to specify the location of - the csv, use save_path instead. - --save-path PATH Full path for prediction CSV file. Any - needed parent directories will be created. + --save / --no-save Whether to save out predictions. If you want + to specify the output directory, use + save_dir instead. + --save-dir PATH An optional directory in which to save the + model predictions and configuration yaml. + Defaults to the current working directory if + save is True. --dry-run / --no-dry-run Runs one batch of inference to check for bugs. --config PATH Specify options using yaml configuration @@ -190,6 +194,8 @@ Options: loaded prior to inference. Only use if you're very confident all your videos can be loaded. + -o, --overwrite Overwrite outputs in the save directory if + they exist. -y, --yes Skip confirmation of configuration and proceed right to prediction. --help Show this message and exit. @@ -225,11 +231,10 @@ Options: machine. --dry-run / --no-dry-run Runs one batch of train and validation to check for bugs. - --save-dir PATH Directory in which to save model checkpoint - and configuration file. If not specified, - will save to a folder called - 'zamba_{model_name}' in your working - directory. + --save-dir PATH An optional directory in which to save the + model checkpoint and configuration file. If + not specified, will save to a `version_n` + folder in your working directory. --num-workers INTEGER Number of subprocesses to use for data loading. --weight-download-region [us|eu|asia] diff --git a/docs/docs/train-tutorial.md b/docs/docs/train-tutorial.md index d13977e2..dfdd0dbd 100644 --- a/docs/docs/train-tutorial.md +++ b/docs/docs/train-tutorial.md @@ -1,4 +1,4 @@ -# User Tutorial: Training a Model on Labaled Videos +# User tutorial: Training a model on labaled videos This section walks through how to train a model using `zamba`. If you are new to `zamba` and just want to classify some videos as soon as possible, see the [Quickstart](quickstart.md) guide. @@ -9,14 +9,14 @@ This tutorial goes over the steps for using `zamba` if: `zamba` can run two types of model training: -* Fine-tuning a model with labels that are a subset of the possible [zamba labels](models/index.md#species-classes) -* Fine-tuning a model to predict an entirely new set of labels +* Finetuning a model with labels that are a subset of the possible [zamba labels](models/index.md#species-classes) +* Finetuning a model to predict an entirely new set of labels The process is the same for both cases. ## Basic usage: command line interface -Say that we want to finetune the `time_distributed` model based on the videos in `example_vids` and the labels in `example_labels.csv`. +Say that we want to finetune the `time_distributed` model based on the videos in `example_vids` and the labels in `example_labels.csv`. Minimum example for training in the command line: @@ -43,7 +43,7 @@ leopard.MP4,leopard ## Basic usage: Python package -Say that we want to finetune the `time_distributed` model based on the videos in `example_vids` and the labels in `example_labels.csv`. +Say that we want to finetune the `time_distributed` model based on the videos in `example_vids` and the labels in `example_labels.csv`. Minimum example for training using the Python package: @@ -52,66 +52,51 @@ from zamba.models.model_manager import train_model from zamba.models.config import TrainConfig train_config = TrainConfig( - data_directory="example_vids/", labels="example_labels.csv" + data_dir="example_vids/", labels="example_labels.csv" ) train_model(train_config=train_config) ``` -The only two arguments that can be passed to `train_model` are `train_config` and (optionally) `video_loader_config`. The first step is to instantiate [`TrainConfig`](configurations.md#training-arguments). Optionally, you can also specify video loading arguments by instantiating and passing in [`VideoLoaderConfig`](configurations.md#video-loading-arguments). +The only two arguments that can be passed to `train_model` are `train_config` and (optionally) `video_loader_config`. The first step is to instantiate [`TrainConfig`](configurations.md#training-arguments). Optionally, you can also specify video loading arguments by instantiating and passing in [`VideoLoaderConfig`](configurations.md#video-loading-arguments). ### Required arguments -To run `train_model` in Python, you must specify both `data_directory` and `labels` when `TrainConfig` is instantiated. +To run `train_model` in Python, you must specify both `data_dir` and `labels` when `TrainConfig` is instantiated. -* **`data_directory (DirectoryPath)`:** Path to the folder containing your videos. +* **`data_dir (DirectoryPath)`:** Path to the folder containing your videos. * **`labels (FilePath or pd.DataFrame)`:** Either the path to a CSV file with labels for training, or a dataframe of the training labels. There must be columns for `filename` and `label`. -For detailed explanations of all possible configuration arguments, see [All Optional Arguments](configurations.md). +For detailed explanations of all possible configuration arguments, see [All Configuration Options](configurations.md). ## Default behavior -By default, the [`time_distributed`](models/index.md#time-distributed) model will be used as a starting point. The newly trained model will be saved to a folder in the current working directory called `zamba_{model_name}`. For example, a model finetuned from the provided `time_distributed` model (the default) will be saved in `zamba_time_distributed`. - -`zamba_time_distributed` contains: - -* `train_configuration.yaml`: The full model configuration used to generate the given model, including `video_loader_config` and `train_config`. To continue training using the same configuration, or to train another model using the same configuration, you can pass in `train_configurations.yaml` (see [Specifying Model Configurations with a YAML File](yaml-config.md)). -* `hparams.yaml`: Model hyperparameters. For example, the YAML file below tells us that the model was trained with a learning rate (`lr`) of 0.001: - ```yaml - $ cat zamba_time_distributed/hparams.yaml - - lr: 0.001 - model_class: TimeDistributedEfficientNetMultiLayerHead - num_frames: 16 - scheduler: MultiStepLR - scheduler_params: - gamma: 0.5 - milestones: - - 3 - verbose: true - species: - - species_blank - - species_chimpanzee_bonobo - - species_elephant - - species_leopard +By default, the [`time_distributed`](models/index.md#time-distributed) model will be used as a starting point. You can specify where the outputs should be saved with `--save-dir`. If no save directory is specified, `zamba` will write out incremental `version_n` folders to your current working directory. For example, a model finetuned from the provided `time_distributed` model (the default) will be saved in `version_0`. + +`version_0` contains: + +* `train_configuration.yaml`: The full model configuration used to generate the given model, including `video_loader_config` and `train_config`. To continue training using the same configuration, or to train another model using the same configuration, you can pass in `train_configurations.yaml` (see [Specifying Model Configurations with a YAML File](yaml-config.md)) along with the `labels` filepath. +* `hparams.yaml`: Model hyperparameters. These are included in the checkpoint file as well. +* `time_distributed.ckpt`: Model checkpoint. You can continue training from this checkpoint by passing it to `zamba train` with the `--checkpoint` flag: + ```console + $ zamba train --checkpoint version_0/time_distributed.ckpt --data-dir example_vids/ --labels example_labels.csv ``` -* `time_distributed.ckpt`: Model checkpoint. The model checkpoint also includes both the model configuration in `train_configuration.yaml` and the model hyperparameters in `hparams.yaml`. You can continue training from this checkpoint by passing it to `zamba train` with the `--checkpoint` flag: +* `events.out.tfevents.1632250686.ip-172-31-15-179.14229.0`: [TensorBoard](https://www.tensorflow.org/tensorboard/get_started) logs. You can view these with tensorboard: ```console - $ zamba train --checkpoint time_distributed.ckpt --data-dir example_vids/ --labels example_labels.csv + $ tensorboard --logdir version_0/ ``` -* `events.out.tfevents.1632250686.ip-172-31-15-179.14229.0`: [TensorBoard](https://www.tensorflow.org/tensorboard/get_started) logs -* `test_metrics.json`: The model's performance on the test subset * `val_metrics.json`: The model's performance on the validation subset +* `test_metrics.json`: The model's performance on the test (holdout) subset * `splits.csv`: Which files were used for training, validation, and as a holdout set. If split is specified in the labels file passed to training, `splits.csv` will not be saved out. ## Step-by-step tutorial -### 1. Specify the path to your videos +### 1. Specify the path to your videos -Save all of your videos within one folder. +Save all of your videos in a folder. * They can be in nested directories within the folder. -* Your videos should all be saved in formats that are suppored by FFmpeg, [which are listed here](https://www.ffmpeg.org/general.html#Supported-File-Formats_002c-Codecs-or-Features). Any videos that fail a set of FFmpeg checks will be skipped during inference or training. +* Your videos should all be saved in formats that are suppored by FFmpeg, [which are listed here](https://www.ffmpeg.org/general.html#Supported-File-Formats_002c-Codecs-or-Features). Any videos that fail a set of FFmpeg checks will be skipped during inference or training. By default, `zamba` will look for files with the following suffixes: `.avi`, `.mp4`, `.asf`. To use other video suffixes that are supported by FFmpeg, set your `VIDEO_SUFFIXES` environment variable. Add the path to your video folder with `--data-dir`. For example, if your videos are in a folder called `example_vids`, add `--data-dir example_vids/` to your command. @@ -122,10 +107,10 @@ Add the path to your video folder with `--data-dir`. For example, if your videos === "Python" ```python - from zamba.models.model_manager import train_model from zamba.models.config import TrainConfig + from zamba.models.model_manager import train_model - train_config = TrainConfig(data_directory='example_vids/') + train_config = TrainConfig(data_dir='example_vids/') train_model(train_config=train_config) ``` Note that the above will not run yet because labels are not specified. @@ -156,7 +141,7 @@ Add the path to your labels with `--labels`. For example, if your videos are in ```python labels_dataframe = pd.read_csv('example_labels.csv', index_col='filepath') train_config = TrainConfig( - data_directory='example_vids/', labels=labels_dataframe + data_dir='example_vids/', labels=labels_dataframe ) train_model(train_config=train_config) ``` @@ -167,11 +152,13 @@ Your labels may be included in the list of [`zamba` class labels](models/index.m #### Completely new labels -You can also train a model to predict completely new labels - the world is your oyster! (We'd love to see a model trained to predict oysters.) If this is the case, the model architecture will replace the final [neural network](https://www.youtube.com/watch?v=aircAruvnKk&t=995s) layer with a new head that predicts *your* labels instead of those that ship with `zamba`. [Backpropogation](https://www.youtube.com/watch?v=Ilg3gGewQ5U) will continue from that point with the new head. This process is called [transfer learning](https://keras.io/guides/transfer_learning/). +You can also train a model to predict completely new labels - the world is your oyster! (We'd love to see a model trained to predict oysters.) If this is the case, the model architecture will replace the final [neural network](https://www.youtube.com/watch?v=aircAruvnKk&t=995s) layer with a new head that predicts *your* labels instead of those that ship with `zamba`. ### 3. Choose a model for training -If your videos contain species common to central or west Africa, use the [`time_distributed` model](models/index.md#time-distributed). If they contain species common to western Europe, use the [`european` model](models/index.md#european). We do not recommend using the [`slowfast` model](models/index.md#slowfast) for training because it is much more computationally intensive and slower to run. +Any of the models that ship with `zamba` can be trained. If you're training on entirely new species or new ecologies, we recommend starting with the [`time_distributed` model](models/index.md#time-distributed) as this model is less computationally intensive than the [`slowfast` model](models/index.md#slowfast). + +However, if you're tuning a model to a subset of species (e.g. a `european_beaver` or `blank` model), use the model that was trained on data that is most similar to your new data. Add the model name to your command with `--model`. The `time_distributed` model will be used if no model is specified. For example, if you want to continue training the `european` model based on the videos in `example_euro_vids` and the labels in `example_euro_labels.csv`: @@ -182,7 +169,7 @@ Add the model name to your command with `--model`. The `time_distributed` model === "Python" ```python train_config = TrainConfig( - data_directory="example_euro_vids/", + data_dir="example_euro_vids/", labels="example_euro_labels.csv", model_name="european", ) @@ -191,8 +178,8 @@ Add the model name to your command with `--model`. The `time_distributed` model ### 4. Specify any additional parameters -And there's so much more! You can also do things like specify your region for faster model download (`--weight-download-region`), start training from a saved model checkpoint (`--checkpoint`), or specify a different path where your model should be saved (`--save-directory`). To read about a few common considerations, see the [Guide to Common Optional Parameters](extra-options.md) page. +And there's so much more! You can also do things like specify your region for faster model download (`--weight-download-region`), start training from a saved model checkpoint (`--checkpoint`), or specify a different path where your model should be saved (`--save-dir`). To read about a few common considerations, see the [Guide to Common Optional Parameters](extra-options.md) page. ### 5. Test your configuration with a dry run -Before kicking off the full model training, we recommend testing your code with a "dry run". This will run one training and validation batch for one epoch to quickly detect any bugs. See the [Debugging](debugging.md) page for details. \ No newline at end of file +Before kicking off the full model training, we recommend testing your code with a "dry run". This will run one training and validation batch for one epoch to quickly detect any bugs. See the [Debugging](debugging.md) page for details. diff --git a/docs/docs/yaml-config.md b/docs/docs/yaml-config.md index f8273daa..f965073d 100644 --- a/docs/docs/yaml-config.md +++ b/docs/docs/yaml-config.md @@ -1,4 +1,4 @@ -# Using YAML Configuration Files +# Using YAML configuration files In both the command line and the Python module, options for video loading, training, and prediction can be set by passing a YAML file instead of passing arguments directly. YAML files (`.yml` or `.yaml`) are commonly used to serialize data in an easily readable way. @@ -12,15 +12,15 @@ video_loader_config: total_frames: 16 # other video loading parameters -predict_config: +train_config: model_name: time_distributed - data_directoty: example_vids/ + data_dir: example_vids/ + labels: example_labels.csv # other training parameters, eg. batch_size -train_config: +predict_config: model_name: time_distributed - data_directory: example_vids/ - labels: example_labels.csv + data_directoty: example_vids/ # other training parameters, eg. batch_size ``` @@ -39,15 +39,13 @@ predict_config: ## Required arguments -Either `predict_config` or `train_config` is required, based on whether you will be running inference or training a model. See [All Optional Arguments](configurations.md) for a full list of what can be specified under each class. To run inference, either `data_directory` or `filepaths` must be specified. To train a model, both `data_directory` and `labels` must be specified. +Either `predict_config` or `train_config` is required, based on whether you will be running inference or training a model. See [All Configuration Options](configurations.md) for a full list of what can be specified under each class. To run inference, `data_dir`and/or `filepaths` must be specified. To train a model, `labels` must be specified. -In `video_loader_config`, you must specify at least `model_input_height`, `model_input_width`, and `total_frames`. +In `video_loader_config`, you must specify at least `model_input_height`, `model_input_width`, and `total_frames`. While this is the minimum required, we strongly recommend being intentional in your choice of frame selection method. `total_frames` by itself will just take the first `n` frames. For a full list of frame selection methods, see the section on [Video loading arguments](configurations.md#video-loading-arguments). * For `time_distributed` or `european`, `total_frames` must be 16 * For `slowfast`, `total_frames` must be 32 -See the [Available Models](models/index.md) page for more details on each model's requirements. - ## Command line interface A YAML configuration file can be passed to the command line interface with the `--config` argument. For example, say the example configuration above is saved as `example_config.yaml`. To run prediction: @@ -66,36 +64,11 @@ The main API for zamba is the [`ModelManager` class](api-reference/models-model_ from zamba.models.manager import ModelManager ``` -The `ModelManager` class is used by `zamba`’s command line interface to handle preprocessing the filenames, loading the videos, serving them to the model, and saving predictions. Therefore any functionality available to the command line interface is accessible via the `ModelManager` class. +The `ModelManager` class is used by `zamba`’s command line interface to handle preprocessing the filenames, loading the videos, training the model, performing inference, and saving predictions. Therefore any functionality available to the command line interface is accessible via the `ModelManager` class. To instantiate the `ModelManager` based on a configuration file saved at `test_config.yaml`: ```python >>> manager = ModelManager.from_yaml('test_config.yaml') ->>> manager.config - -ModelConfig( - video_loader_config=VideoLoaderConfig(crop_bottom_pixels=None, i_frames=False, - scene_threshold=None, megadetector_lite_config=None, - model_input_height=240, model_input_width=426, - total_frames=16, ensure_total_frames=True, - fps=None, early_bias=False, frame_indices=None, - evenly_sample_total_frames=False, pix_fmt='rgb24' - ), - train_config=None, - predict_config=PredictConfig(data_directory=PosixPath('vids'), - filepaths= filepath - 0 /home/ubuntu/vids/eleph.MP4 - 1 /home/ubuntu/vids/leopard.MP4 - 2 /home/ubuntu/vids/blank.MP4 - 3 /home/ubuntu/vids/chimp.MP4, - checkpoint='zamba_time_distributed.ckpt', - model_params=ModelParams(scheduler=None, scheduler_params=None), - model_name='time_distributed', species=None, - gpus=1, num_workers=3, batch_size=8, - save=True, dry_run=False, proba_threshold=None, - output_class_names=False, weight_download_region='us', - cache_dir=None, skip_load_validation=False) - ) ``` We can now run inference or model training without specifying any additional parameters, because they are already associated with our instance of the `ModelManager` class. To run inference or training: @@ -116,16 +89,6 @@ For example, the default configuration for the [`time_distributed` model](models ```yaml train_config: - model_name: time_distributed - backbone_finetune_config: - backbone_initial_ratio_lr: 0.01 - multiplier: 1 - pre_train_bn: True - train_bn: False - unfreeze_backbone_at_epoch: 3 - verbose: True - early_stopping_config: - patience: 5 scheduler_config: scheduler: MultiStepLR scheduler_params: @@ -133,65 +96,32 @@ train_config: milestones: - 3 verbose: true - + model_name: time_distributed + backbone_finetune_config: + backbone_initial_ratio_lr: 0.01 + multiplier: 1 + pre_train_bn: true + train_bn: false + unfreeze_backbone_at_epoch: 3 + verbose: true + early_stopping_config: + patience: 5 video_loader_config: model_input_height: 240 model_input_width: 426 crop_bottom_pixels: 50 fps: 4 total_frames: 16 - ensure_total_frames: True + ensure_total_frames: true megadetector_lite_config: confidence: 0.25 fill_mode: score_sorted n_frames: 16 - predict_config: model_name: time_distributed +public_checkpoint: time_distributed_9e710aa8c92d25190a64b3b04b9122bdcb456982.ckpt ``` -For reference, the below shows how to specify the same video loading and training parameters using only the Python package: +## Templates -```python -from zamba.data.video import VideoLoaderConfig -from zamba.models.config import TrainConfig -from zamba.models.model_manager import train_model - -video_loader_config = VideoLoaderConfig( - model_input_height=240, - model_input_width=426, - crop_bottom_pixels=50, - fps=4, - total_frames=16, - ensure_total_frames=True, - megadetector_lite_config={ - "confidence": 0.25, - "fill_mode": "score_sorted", - "n_frames": 16, - }, -) - -train_config = TrainConfig( - # data_directory=YOUR_DATA_DIRECTORY_HERE, - # labels=YOUR_LABELS_CSV_HERE, - model_name="time_distributed", - backbone_finetune_config={ - "backbone_initial_ratio_lr": 0.01, - "unfreeze_backbone_at_epoch": 3, - "verbose": True, - "pre_train_bn": True, - "train_bn": False, - "multiplier": 1, - }, - early_stopping_config={"patience": 5}, - scheduler_config={ - "scheduler": "MultiStepLR", - "scheduler_params": {"gamma": 0.5, "milestones": 3, "verbose": True,}, - }, -) - -train_model( - train_config=train_config, - video_loader_config=video_loader_config, -) -``` \ No newline at end of file +To make modifying existing mod easier, we've set up the official models as templates in the [`templates` folder](https://github.com/drivendataorg/zamba/tree/master/templates). Just fill in your data directory and labels, make any desired tweaks to the model config, and then kick off some [training](train_tutorial.md). Happy modeling! diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 319c8a9f..61fcf87b 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -14,35 +14,39 @@ extra_css: nav: - Home: "index.md" - - "Getting Started": - - Installing Zamba: "install.md" + - "Getting started": + - Installing zamba: "install.md" - Quickstart: "quickstart.md" - "User Tutorials": - - Classifying Unlabeled Videos: "predict-tutorial.md" - - Training a Model on Labeled Videos: "train-tutorial.md" + - Classifying unlabeled videos: "predict-tutorial.md" + - Training a model on labeled videos: "train-tutorial.md" - Debugging: "debugging.md" + - Guide to common optional parameters: "extra-options.md" - "Available Models": - - "models/index.md" + - Species detection: "models/species-detection.md" + - DensePose: "models/denspose.md" - "Advanced Options": - - All Optional Arguments: "configurations.md" + - All configuration options: "configurations.md" - Using YAML configuration files: "yaml-config.md" - - Guide to Common Optional Parameters: "extra-options.md" - - "Contribute to zamba": + - "Contribute to zamba": - "contribute/index.md" - "Changelog": - - "changelog/index.md" + - "changelog.md" - API Reference: - zamba.data: - zamba.data.metadata: "api-reference/data-metadata.md" - zamba.data.video: "api-reference/data-video.md" - zamba.models: - - zamba.models.yolox_models: "api-reference/models-yolox_models.md" - zamba.models.config: "api-reference/models-config.md" + - zamba.models.densepose: + - zamba.models.densepose.config: "api-reference/densepose_config.md" + - zamba.models.densepose.densepose_manager: "api-reference/densepose_manager.md" - zamba.models.efficientnet_models: "api-reference/models-efficientnet_models.md" - zamba.models.megadetector_lite_yolox: "api-reference/models-megadetector_lite_yolox.md" - zamba.models.model_manager: "api-reference/models-model_manager.md" - zamba.models.slowfast_models: "api-reference/models-slowfast_models.md" - zamba.models.utils: "api-reference/models-utils.md" + - zamba.models.yolox_models: "api-reference/models-yolox_models.md" - zamba.pytorch: - zamba.pytorch.dataloaders: "api-reference/pytorch-dataloaders.md" - zamba.pytorch.finetuning: "api-reference/pytorch-finetuning.md" diff --git a/templates/european.yaml b/templates/european.yaml index fff5681b..9448ef16 100644 --- a/templates/european.yaml +++ b/templates/european.yaml @@ -1,5 +1,5 @@ train_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # labels: YOUR_LABELS_CSV_HERE model_name: european backbone_finetune_config: @@ -9,6 +9,8 @@ train_config: train_bn: false unfreeze_backbone_at_epoch: 15 verbose: true + early_stopping_config: + patience: 3 video_loader_config: model_input_height: 240 @@ -23,7 +25,7 @@ video_loader_config: n_frames: 16 predict_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # or # filepaths: YOUR_FILEPATH_CSV_HERE model_name: european diff --git a/templates/slowfast.yaml b/templates/slowfast.yaml index 046b6a3d..518c7e6b 100644 --- a/templates/slowfast.yaml +++ b/templates/slowfast.yaml @@ -1,5 +1,5 @@ train_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # labels: YOUR_LABELS_CSV_HERE model_name: slowfast backbone_finetune_config: @@ -32,7 +32,7 @@ video_loader_config: n_frames: 32 predict_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # or # filepaths: YOUR_FILEPATH_CSV_HERE model_name: slowfast diff --git a/templates/time_distributed.yaml b/templates/time_distributed.yaml index cc036f66..f89662fb 100644 --- a/templates/time_distributed.yaml +++ b/templates/time_distributed.yaml @@ -1,14 +1,14 @@ train_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # labels: YOUR_LABELS_CSV_HERE model_name: time_distributed backbone_finetune_config: backbone_initial_ratio_lr: 0.01 multiplier: 1 - pre_train_bn: True - train_bn: False + pre_train_bn: true + train_bn: false unfreeze_backbone_at_epoch: 3 - verbose: True + verbose: true early_stopping_config: patience: 5 scheduler_config: @@ -25,14 +25,14 @@ video_loader_config: crop_bottom_pixels: 50 fps: 4 total_frames: 16 - ensure_total_frames: True + ensure_total_frames: true megadetector_lite_config: confidence: 0.25 fill_mode: score_sorted n_frames: 16 predict_config: - # data_directory: YOUR_DATA_DIRECTORY_HERE + # data_dir: YOUR_DATA_DIR HERE # or # filepaths: YOUR_FILEPATH_CSV_HERE model_name: time_distributed diff --git a/tests/assets/sample_predict_config.yaml b/tests/assets/sample_predict_config.yaml index bfa3603e..af081300 100644 --- a/tests/assets/sample_predict_config.yaml +++ b/tests/assets/sample_predict_config.yaml @@ -7,5 +7,5 @@ video_loader_config: total_frames: 16 predict_config: - data_directory: tests/assets/videos + data_dir: tests/assets/videos model_name: time_distributed diff --git a/tests/assets/sample_train_config.yaml b/tests/assets/sample_train_config.yaml index 1ec30b1b..b3b16215 100644 --- a/tests/assets/sample_train_config.yaml +++ b/tests/assets/sample_train_config.yaml @@ -7,7 +7,7 @@ video_loader_config: total_frames: 16 train_config: - data_directory: tests/assets/videos + data_dir: tests/assets/videos labels: tests/assets/labels.csv model_name: time_distributed predict_all_zamba_species: True diff --git a/tests/conftest.py b/tests/conftest.py index ca814de7..3e21dd5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -161,14 +161,14 @@ def dummy_train_config(labels_absolute_path, dummy_checkpoint, tmp_path_factory) tmp_path = tmp_path_factory.mktemp("dummy-model-dir") return DummyTrainConfig( labels=labels_absolute_path, - data_directory=TEST_VIDEOS_DIR, + data_dir=TEST_VIDEOS_DIR, model_name="dummy", checkpoint=dummy_checkpoint, max_epochs=1, batch_size=1, auto_lr_find=False, num_workers=2, - save_directory=tmp_path / "my_model", + save_dir=tmp_path / "my_model", skip_load_validation=True, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 78f3c7c8..fcd45302 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -145,13 +145,21 @@ def test_predict_specific_options(mocker, minimum_valid_predict, tmp_path): # n ) assert result.exit_code == 0 + # test save overwrite + (tmp_path / "zamba_predictions.csv").touch() + result = runner.invoke( + app, + minimum_valid_predict + ["--output-class-names", "--save-dir", str(tmp_path), "-o"], + ) + assert result.exit_code == 0 + def test_actual_prediction_on_single_video(tmp_path): # noqa: F811 data_dir = tmp_path / "videos" data_dir.mkdir() shutil.copy(TEST_VIDEOS_DIR / "data" / "raw" / "benjamin" / "04250002.MP4", data_dir) - save_path = tmp_path / "zamba" / "my_preds.csv" + save_dir = tmp_path / "zamba" result = runner.invoke( app, @@ -162,17 +170,20 @@ def test_actual_prediction_on_single_video(tmp_path): # noqa: F811 "--config", str(ASSETS_DIR / "sample_predict_config.yaml"), "--yes", - "--save-path", - str(save_path), + "--save-dir", + str(save_dir), ], ) assert result.exit_code == 0 # check preds file got saved out - assert save_path.exists() + assert save_dir.exists() # check config got saved out too - assert (save_path.parent / "predict_configuration.yaml").exists() + assert (save_dir / "predict_configuration.yaml").exists() assert ( - pd.read_csv(save_path, index_col="filepath").idxmax(axis=1).values[0] == "monkey_prosimian" + pd.read_csv(save_dir / "zamba_predictions.csv", index_col="filepath") + .idxmax(axis=1) + .values[0] + == "monkey_prosimian" ) diff --git a/tests/test_config.py b/tests/test_config.py index 3b3abfc4..5d4cb0d6 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -20,7 +20,7 @@ def test_train_data_dir_only(): with pytest.raises(ValidationError) as error: - TrainConfig(data_directory=TEST_VIDEOS_DIR) + TrainConfig(data_dir=TEST_VIDEOS_DIR) # labels is missing assert error.value.errors() == [ {"loc": ("labels",), "msg": "field required", "type": "value_error.missing"} @@ -29,19 +29,19 @@ def test_train_data_dir_only(): def test_train_data_dir_and_labels(tmp_path, labels_relative_path, labels_absolute_path): # correct data dir - config = TrainConfig(data_directory=TEST_VIDEOS_DIR, labels=labels_relative_path) - assert config.data_directory is not None + config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_relative_path) + assert config.data_dir is not None assert config.labels is not None # data dir ignored if absolute path provided in filepath - config = TrainConfig(data_directory=tmp_path, labels=labels_absolute_path) - assert config.data_directory is not None + config = TrainConfig(data_dir=tmp_path, labels=labels_absolute_path) + assert config.data_dir is not None assert config.labels is not None assert not config.labels.filepath.str.startswith(str(tmp_path)).any() # incorrect data dir with relative filepaths with pytest.raises(ValidationError) as error: - TrainConfig(data_directory=ASSETS_DIR, labels=labels_relative_path) + TrainConfig(data_dir=ASSETS_DIR, labels=labels_relative_path) assert "None of the video filepaths exist" in error.value.errors()[0]["msg"] @@ -51,8 +51,8 @@ def test_train_labels_only(labels_absolute_path): def test_predict_data_dir_only(): - config = PredictConfig(data_directory=TEST_VIDEOS_DIR) - assert config.data_directory == TEST_VIDEOS_DIR + config = PredictConfig(data_dir=TEST_VIDEOS_DIR) + assert config.data_dir == TEST_VIDEOS_DIR assert isinstance(config.filepaths, pd.DataFrame) assert sorted(config.filepaths.filepath.values) == sorted( [str(f) for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()] @@ -62,14 +62,14 @@ def test_predict_data_dir_only(): def test_predict_data_dir_and_filepaths(labels_absolute_path, labels_relative_path): # correct data dir - config = PredictConfig(data_directory=TEST_VIDEOS_DIR, filepaths=labels_relative_path) - assert config.data_directory is not None + config = PredictConfig(data_dir=TEST_VIDEOS_DIR, filepaths=labels_relative_path) + assert config.data_dir is not None assert config.filepaths is not None assert config.filepaths.filepath.str.startswith(str(TEST_VIDEOS_DIR)).all() # incorrect data dir with pytest.raises(ValidationError) as error: - PredictConfig(data_directory=ASSETS_DIR, filepaths=labels_relative_path) + PredictConfig(data_dir=ASSETS_DIR, filepaths=labels_relative_path) assert "None of the video filepaths exist" in error.value.errors()[0]["msg"] @@ -203,18 +203,16 @@ def test_labels_with_invalid_split(labels_absolute_path): def test_labels_no_splits(labels_no_splits, tmp_path): - config = TrainConfig( - data_directory=TEST_VIDEOS_DIR, labels=labels_no_splits, save_directory=tmp_path - ) + config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_no_splits, save_dir=tmp_path) assert set(config.labels.split.unique()) == set(("holdout", "train", "val")) def test_labels_split_proportions(labels_no_splits, tmp_path): config = TrainConfig( - data_directory=TEST_VIDEOS_DIR, + data_dir=TEST_VIDEOS_DIR, labels=labels_no_splits, split_proportions={"a": 3, "b": 1}, - save_directory=tmp_path, + save_dir=tmp_path, ) assert config.labels.split.value_counts().to_dict() == {"a": 14, "b": 5} @@ -230,9 +228,18 @@ def test_from_scratch(labels_absolute_path): assert "If from_scratch=True, model_name cannot be None." == error.value.errors()[0]["msg"] -def test_predict_dry_run_and_save(labels_absolute_path, caplog): - PredictConfig(filepaths=labels_absolute_path, dry_run=True, save=True) - assert "Cannot save when predicting with dry_run=True. Setting save=False." in caplog.text +def test_predict_dry_run_and_save(labels_absolute_path, caplog, tmp_path): + config = PredictConfig(filepaths=labels_absolute_path, dry_run=True, save=True) + assert ( + "Cannot save when predicting with dry_run=True. Setting save=False and save_dir=None." + in caplog.text + ) + assert not config.save + assert config.save_dir is None + + config = PredictConfig(filepaths=labels_absolute_path, dry_run=True, save_dir=tmp_path) + assert not config.save + assert config.save_dir is None def test_predict_filepaths_with_duplicates(labels_absolute_path, tmp_path, caplog): @@ -257,40 +264,60 @@ def test_model_cache_dir(labels_absolute_path, tmp_path): def test_predict_save(labels_absolute_path, tmp_path, dummy_trained_model_checkpoint): - # if save is True, use default save path + # if save is True, save in current working directory config = PredictConfig(filepaths=labels_absolute_path, skip_load_validation=True) - assert config.save == Path.cwd() / "zamba_predictions.csv" + assert config.save_dir == Path.cwd() config = PredictConfig(filepaths=labels_absolute_path, save=False, skip_load_validation=True) assert config.save is False + assert config.save_dir is None - # use checkpoint directory if checkpoint exists + # if save_dir is specified, set save to True config = PredictConfig( filepaths=labels_absolute_path, + save=False, + save_dir=tmp_path / "my_dir", skip_load_validation=True, - checkpoint=dummy_trained_model_checkpoint, ) - assert config.save == Path(dummy_trained_model_checkpoint).parent / "zamba_predictions.csv" + assert config.save is True + # save dir gets created + assert (tmp_path / "my_dir").exists() + + # empty save dir does not error + save_dir = tmp_path / "save_dir" + save_dir.mkdir() - # case does not matter as long as it's a csv config = PredictConfig( - filepaths=labels_absolute_path, save=tmp_path / "zamba/my_model/my_predictions.CSV" + filepaths=labels_absolute_path, save_dir=save_dir, skip_load_validation=True ) - assert config.save == tmp_path / "zamba/my_model/my_predictions.CSV" - - # cannot pass in directories, must specify full path - with pytest.raises(ValueError) as error: - PredictConfig( - filepaths=labels_absolute_path, save="zamba/my_model/", skip_load_validation=True + assert config.save_dir == save_dir + + # save dir with prediction csv or yaml will error + for pred_file in [ + (save_dir / "zamba_predictions.csv"), + (save_dir / "predict_configuration.yaml"), + ]: + # just takes one of the two files to raise error + pred_file.touch() + with pytest.raises(ValueError) as error: + PredictConfig( + filepaths=labels_absolute_path, save_dir=save_dir, skip_load_validation=True + ) + assert ( + f"zamba_predictions.csv and/or predict_configuration.yaml already exist in {save_dir}. If you would like to overwrite, set overwrite=True" + == error.value.errors()[0]["msg"] ) - assert "Save path must end with .csv" in error.value.errors()[0]["msg"] + pred_file.unlink() - # cannot use path that already exists - save_path = tmp_path / "pred.csv" - save_path.touch() - with pytest.raises(ValueError) as error: - PredictConfig(filepaths=labels_absolute_path, save=save_path, skip_load_validation=True) - assert "already exists" in error.value.errors()[0]["msg"] + # can overwrite + pred_file.touch() + config = PredictConfig( + filepaths=labels_absolute_path, + save_dir=save_dir, + skip_load_validation=True, + overwrite=True, + ) + assert config.save_dir == save_dir def test_validate_scheduler(labels_absolute_path): diff --git a/tests/test_densepose.py b/tests/test_densepose.py index 48e84bb0..ca26a219 100644 --- a/tests/test_densepose.py +++ b/tests/test_densepose.py @@ -122,8 +122,8 @@ def test_denseposeconfig(model, tmp_path): output_type="bananas", render_output=True, embeddings_in_json=False, - data_directory=ASSETS_DIR / "densepose_tests", - save_path=tmp_path, + data_dir=ASSETS_DIR / "densepose_tests", + save_dir=tmp_path, ) dpc = DensePoseConfig( @@ -131,13 +131,13 @@ def test_denseposeconfig(model, tmp_path): output_type="segmentation" if model == "animals" else "chimp_anatomy", render_output=True, embeddings_in_json=False, - data_directory=ASSETS_DIR / "densepose_tests", - save_path=tmp_path, + data_dir=ASSETS_DIR / "densepose_tests", + save_dir=tmp_path, ) dpc.run_model() - # ensure all outputs are saved in save_path + # ensure all outputs are saved in save_dir assert (tmp_path / "chimp_denspose_video.mp4").exists() assert (tmp_path / "chimp_denspose_labels.json").exists() diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 21948250..682a032d 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -75,14 +75,14 @@ def test_save_metrics_less_than_two_classes( trainer = train_model( train_config=DummyTrainConfig( labels=labels, - data_directory=TEST_VIDEOS_DIR, + data_dir=TEST_VIDEOS_DIR, model_name="dummy", checkpoint=dummy_checkpoint, max_epochs=1, batch_size=1, auto_lr_find=False, num_workers=2, - save_directory=tmp_path / "my_model", + save_dir=tmp_path / "my_model", skip_load_validation=True, ), video_loader_config=dummy_video_loader_config, @@ -125,22 +125,22 @@ def test_save_configuration(dummy_trainer): } -def test_train_save_directory(dummy_trainer): +def test_train_save_dir(dummy_trainer): assert Path(dummy_trainer.logger.root_dir).name == "my_model" assert Path(dummy_trainer.logger.log_dir).name == "version_0" -def test_train_save_directory_overwrite( +def test_train_save_dir_overwrite( labels_absolute_path, dummy_checkpoint, tmp_path, dummy_video_loader_config ): config = DummyTrainConfig( labels=labels_absolute_path, - data_directory=TEST_VIDEOS_DIR, + data_dir=TEST_VIDEOS_DIR, model_name="dummy", checkpoint=dummy_checkpoint, - save_directory=tmp_path / "my_model", + save_dir=tmp_path / "my_model", skip_load_validation=True, - overwrite_save_directory=True, + overwrite=True, max_epochs=1, batch_size=1, auto_lr_find=False, @@ -151,9 +151,9 @@ def test_train_save_directory_overwrite( train_config=config, video_loader_config=dummy_video_loader_config ) - assert Path(overwrite_trainer.logger.log_dir).resolve() == config.save_directory.resolve() + assert Path(overwrite_trainer.logger.log_dir).resolve() == config.save_dir.resolve() - assert not any([f.name.startswith("version_") for f in config.save_directory.iterdir()]) + assert not any([f.name.startswith("version_") for f in config.save_dir.iterdir()]) # when training from checkpoint, model_name is None so get PTL default ckpt name for f in [ @@ -162,7 +162,7 @@ def test_train_save_directory_overwrite( "val_metrics.json", "epoch=0-step=10.ckpt", ]: - assert (config.save_directory / f).exists() + assert (config.save_dir / f).exists() @pytest.mark.parametrize("model_name", ["time_distributed", "slowfast", "european"]) diff --git a/zamba/cli.py b/zamba/cli.py index 6f1b4f6d..aa212425 100644 --- a/zamba/cli.py +++ b/zamba/cli.py @@ -53,7 +53,7 @@ def train( ), save_dir: Path = typer.Option( None, - help="Directory in which to save model checkpoint and configuration file. If not specified, will save to a folder called 'zamba_{model_name}' in your working directory.", + help="An optional directory in which to save the model checkpoint and configuration file. If not specified, will save to a `version_n` folder in your working directory.", ), num_workers: int = typer.Option( None, @@ -95,7 +95,7 @@ def train( # override if any command line arguments are passed if data_dir is not None: - train_dict["data_directory"] = data_dir + train_dict["data_dir"] = data_dir if labels is not None: train_dict["labels"] = labels @@ -116,7 +116,7 @@ def train( train_dict["dry_run"] = dry_run if save_dir is not None: - train_dict["save_directory"] = save_dir + train_dict["save_dir"] = save_dir if num_workers is not None: train_dict["num_workers"] = num_workers @@ -154,17 +154,17 @@ def train( msg = f"""The following configuration will be used for training: Config file: {config_file} - Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_directory")} + Data directory: {data_dir if data_dir is not None else config_dict["train_config"].get("data_dir")} Labels csv: {labels if labels is not None else config_dict["train_config"].get("labels")} Species: {species} Model name: {config.train_config.model_name} Checkpoint: {checkpoint if checkpoint is not None else config_dict["train_config"].get("checkpoint")} - Weight download region: {config.train_config.weight_download_region} Batch size: {config.train_config.batch_size} Number of workers: {config.train_config.num_workers} GPUs: {config.train_config.gpus} Dry run: {config.train_config.dry_run} - Save directory: {config.train_config.save_directory} + Save directory: {config.train_config.save_dir} + Weight download region: {config.train_config.weight_download_region} """ if yes: @@ -203,11 +203,12 @@ def predict( batch_size: int = typer.Option(None, help="Batch size to use for training."), save: bool = typer.Option( None, - help="Whether to save out predictions to a csv file. If you want to specify the location of the csv, use save_path instead.", + help="Whether to save out predictions. If you want to specify the output directory, use save_dir instead.", ), - save_path: Path = typer.Option( + save_dir: Path = typer.Option( None, - help="Full path for prediction CSV file. Any needed parent directories will be created.", + help="An optional directory in which to save the model predictions and configuration yaml. " + "Defaults to the current working directory if save is True.", ), dry_run: bool = typer.Option(None, help="Runs one batch of inference to check for bugs."), config: Path = typer.Option( @@ -238,6 +239,9 @@ def predict( None, help="Skip check that verifies all videos can be loaded prior to inference. Only use if you're very confident all your videos can be loaded.", ), + overwrite: bool = typer.Option( + None, "--overwrite", "-o", help="Overwrite outputs in the save directory if they exist." + ), yes: bool = typer.Option( False, "--yes", @@ -271,7 +275,7 @@ def predict( # override if any command line arguments are passed if data_dir is not None: - predict_dict["data_directory"] = data_dir + predict_dict["data_dir"] = data_dir if filepaths is not None: predict_dict["filepaths"] = filepaths @@ -294,9 +298,9 @@ def predict( if save is not None: predict_dict["save"] = save - # save path takes precedence over save - if save_path is not None: - predict_dict["save"] = save_path + # save_dir takes precedence over save + if save_dir is not None: + predict_dict["save_dir"] = save_dir if proba_threshold is not None: predict_dict["proba_threshold"] = proba_threshold @@ -313,6 +317,9 @@ def predict( if skip_load_validation is not None: predict_dict["skip_load_validation"] = skip_load_validation + if overwrite is not None: + predict_dict["overwrite"] = overwrite + try: manager = ModelManager( ModelConfig( @@ -329,7 +336,7 @@ def predict( msg = f"""The following configuration will be used for inference: Config file: {config_file} - Data directory: {data_dir if data_dir is not None else config_dict["predict_config"].get("data_directory")} + Data directory: {data_dir if data_dir is not None else config_dict["predict_config"].get("data_dir")} Filepath csv: {filepaths if filepaths is not None else config_dict["predict_config"].get("filepaths")} Model: {config.predict_config.model_name} Checkpoint: {checkpoint if checkpoint is not None else config_dict["predict_config"].get("checkpoint")} @@ -337,7 +344,7 @@ def predict( Number of workers: {config.predict_config.num_workers} GPUs: {config.predict_config.gpus} Dry run: {config.predict_config.dry_run} - Save: {config.predict_config.save} + Save directory: {config.predict_config.save_dir} Proba threshold: {config.predict_config.proba_threshold} Output class names: {config.predict_config.output_class_names} Weight download region: {config.predict_config.weight_download_region} @@ -387,7 +394,7 @@ def densepose( filepaths: Path = typer.Option( None, exists=True, help="Path to csv containing `filepath` column with videos." ), - save_path: Path = typer.Option( + save_dir: Path = typer.Option( None, help="An optional directory for saving the output. Defaults to the current working directory.", ), @@ -447,13 +454,13 @@ def densepose( # override if any command line arguments are passed if data_dir is not None: - predict_dict["data_directory"] = data_dir + predict_dict["data_dir"] = data_dir if filepaths is not None: predict_dict["filepaths"] = filepaths - if save_path is not None: - predict_dict["save_path"] = save_path + if save_dir is not None: + predict_dict["save_dir"] = save_dir if weight_download_region is not None: predict_dict["weight_download_region"] = weight_download_region @@ -478,7 +485,7 @@ def densepose( Config file: {config_file} Output type: {densepose_config.output_type} Render output: {densepose_config.render_output} - Data directory: {data_dir if data_dir is not None else config_dict.get("data_directory")} + Data directory: {data_dir if data_dir is not None else config_dict.get("data_dir")} Filepath csv: {filepaths if filepaths is not None else config_dict.get("filepaths")} Weight download region: {densepose_config.weight_download_region} Cache directory: {densepose_config.cache_dir} diff --git a/zamba/models/config.py b/zamba/models/config.py index faba49dc..d12f13d0 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -78,14 +78,14 @@ def validate_model_cache_dir(model_cache_dir: Optional[Path]): def check_files_exist_and_load( - df: pd.DataFrame, data_directory: DirectoryPath, skip_load_validation: bool + df: pd.DataFrame, data_dir: DirectoryPath, skip_load_validation: bool ): """Check whether files in file list exist and can be loaded with ffmpeg. Warn and skip files that don't exist or can't be loaded. Args: df (pd.DataFrame): DataFrame with a "filepath" column - data_directory (Path): Data folder to prepend if filepath is not an + data_dir (Path): Data folder to prepend if filepath is not an absolute path. skip_load_validation (bool): Skip ffprobe check that verifies all videos can be loaded. @@ -94,7 +94,7 @@ def check_files_exist_and_load( pd.DataFrame: DataFrame with valid and loadable videos. """ # update filepath column to prepend data_dir if filepath column is not an absolute path - data_dir = Path(data_directory).resolve() + data_dir = Path(data_dir).resolve() df["filepath"] = str(data_dir) / df.filepath.path # we can have multiple rows per file with labels so limit just to one row per file for these checks @@ -109,7 +109,7 @@ def check_files_exist_and_load( # if no files exist if len(invalid_files) == len(files_df): raise ValueError( - f"None of the video filepaths exist. Are you sure they're specified correctly? Here's an example invalid path: {invalid_files.filepath.values[0]}. Either specify absolute filepaths in the csv or provide filepaths relative to `data_directory`." + f"None of the video filepaths exist. Are you sure they're specified correctly? Here's an example invalid path: {invalid_files.filepath.values[0]}. Either specify absolute filepaths in the csv or provide filepaths relative to `data_dir`." ) # if at least some files exist @@ -280,10 +280,10 @@ class TrainConfig(ZambaBaseModel): Args: labels (FilePath or pandas DataFrame): Path to a CSV or pandas DataFrame containing labels for training, with one row per label. There must be - columns called 'filepath' (absolute or relative to the data_directory) and + columns called 'filepath' (absolute or relative to the data_dir) and 'label', and optionally columns called 'split' ("train", "val", or "holdout") and 'site'. Labels must be specified to train a model. - data_directory (DirectoryPath): Path to a directory containing training + data_dir (DirectoryPath): Path to a directory containing training videos. Defaults to the working directory. model_name (str, optional): Name of the model to use for training. Options are: time_distributed, slowfast, european. Defaults to time_distributed. @@ -326,16 +326,15 @@ class TrainConfig(ZambaBaseModel): split_proportions (dict): Proportions used to divide data into training, validation, and holdout sets if a if a "split" column is not included in labels. Defaults to "train": 3, "val": 1, "holdout": 1. - save_directory (Path, optional): Path to a directory where training files + save_dir (Path, optional): Path to a directory where training files will be saved. Files include the best model checkpoint (``model_name``.ckpt), training configuration (configuration.yaml), Tensorboard logs (events.out.tfevents...), test metrics (test_metrics.json), validation metrics (val_metrics.json), and model hyperparameters (hparams.yml). - If None, a folder is created in the working directory called - "zamba_``model_name``". Defaults to None. - overwrite_save_directory (bool): If True, will save outputs in `save_directory` - overwriting if those exist. If False, will create auto-incremented `version_n` folder - in `save_directory` with model outputs. Defaults to False. + If None, a folder is created in the working directory. Defaults to None. + overwrite (bool): If True, will save outputs in `save_dir` overwriting if those + exist. If False, will create auto-incremented `version_n` folder in `save_dir` + with model outputs. Defaults to False. skip_load_validation (bool). Skip ffprobe check, which verifies that all videos can be loaded and skips files that cannot be loaded. Defaults to False. @@ -351,7 +350,7 @@ class TrainConfig(ZambaBaseModel): """ labels: Union[FilePath, pd.DataFrame] - data_directory: DirectoryPath = Path.cwd() + data_dir: DirectoryPath = Path.cwd() checkpoint: Optional[FilePath] = None scheduler_config: Optional[Union[str, SchedulerConfig]] = "default" model_name: Optional[ModelEnum] = ModelEnum.time_distributed @@ -365,8 +364,8 @@ class TrainConfig(ZambaBaseModel): early_stopping_config: Optional[EarlyStoppingConfig] = EarlyStoppingConfig() weight_download_region: RegionEnum = "us" split_proportions: Optional[Dict[str, int]] = {"train": 3, "val": 1, "holdout": 1} - save_directory: Path = Path.cwd() - overwrite_save_directory: bool = False + save_dir: Path = Path.cwd() + overwrite: bool = False skip_load_validation: bool = False from_scratch: bool = False predict_all_zamba_species: bool = True @@ -466,7 +465,7 @@ def validate_filepaths_and_labels(cls, values): # check that all videos exist and can be loaded values["labels"] = check_files_exist_and_load( df=labels, - data_directory=values["data_directory"], + data_dir=values["data_dir"], skip_load_validation=values["skip_load_validation"], ) return values @@ -508,49 +507,20 @@ def preprocess_labels(cls, values): ) logger.info( - f"Writing out split information to {values['save_directory'] / 'splits.csv'}." + f"Writing out split information to {values['save_dir'] / 'splits.csv'}." ) # create the directory to save if we need to. - values["save_directory"].mkdir(parents=True, exist_ok=True) + values["save_dir"].mkdir(parents=True, exist_ok=True) labels.reset_index()[["filepath", "split"]].drop_duplicates().to_csv( - values["save_directory"] / "splits.csv", index=False + values["save_dir"] / "splits.csv", index=False ) # filepath becomes column instead of index values["labels"] = labels.reset_index() return values - def get_model_only_params(self): - """Return only params that are not data or machine specific. - Used for generating official configs. - """ - train_config = self.dict() - - # remove data and machine specific params - for key in [ - "labels", - "data_directory", - "dry_run", - "batch_size", - "auto_lr_find", - "gpus", - "num_workers", - "max_epochs", - "weight_download_region", - "split_proportions", - "save_directory", - "overwrite_save_directory", - "skip_load_validation", - "from_scratch", - "model_cache_dir", - "predict_all_zamba_species", - ]: - train_config.pop(key) - - return train_config - class PredictConfig(ZambaBaseModel): """ @@ -558,10 +528,10 @@ class PredictConfig(ZambaBaseModel): Args: filepaths (FilePath): Path to a CSV containing videos for inference, with - one row per video in the data_directory. There must be a column called - 'filepath' (absolute or relative to the data_directory). If None, uses - all files in data_directory. Defaults to None. - data_directory (DirectoryPath): Path to a directory containing videos for + one row per video in the data_dir. There must be a column called + 'filepath' (absolute or relative to the data_dir). If None, uses + all files in data_dir. Defaults to None. + data_dir (DirectoryPath): Path to a directory containing videos for inference. Defaults to the working directory. model_name (str, optional): Name of the model to use for inference. Options are: time_distributed, slowfast, european. Defaults to time_distributed. @@ -574,9 +544,13 @@ class PredictConfig(ZambaBaseModel): that the data will be loaded in the main process. The maximum value is the number of CPUs in the system. Defaults to 3. batch_size (int): Batch size to use for inference. Defaults to 2. - save (bool or Path): Path to a CSV to save predictions. If True is passed, - "zamba_predictions.csv" is written to the current working directory. - If False is passed, predictions are not saved. Defaults to True. + save (bool): Whether to save out predictions. If False, predictions are + not saved. Defaults to True. + save_dir (Path, optional): An optional directory in which to save the model + predictions and configuration yaml. If no save_dir is specified and save=True, + outputs will be written to the current working directory. Defaults to None. + overwrite (bool): If True, overwrite outputs in save_dir if they exist. + Defaults to False. dry_run (bool): Perform inference on a single batch for testing. Predictions will not be saved. Defaults to False. proba_threshold (float, optional): Probability threshold for classification. @@ -599,14 +573,16 @@ class PredictConfig(ZambaBaseModel): default cache directory. Defaults to None. """ - data_directory: DirectoryPath = Path.cwd() + data_dir: DirectoryPath = Path.cwd() filepaths: Optional[FilePath] = None checkpoint: Optional[FilePath] = None model_name: Optional[ModelEnum] = ModelEnum.time_distributed gpus: int = GPUS_AVAILABLE num_workers: int = 3 batch_size: int = 2 - save: Union[bool, Path] = True + save: bool = True + save_dir: Optional[Path] = None + overwrite: bool = False dry_run: bool = False proba_threshold: Optional[float] = None output_class_names: bool = False @@ -622,42 +598,45 @@ class PredictConfig(ZambaBaseModel): @root_validator(skip_on_failure=True) def validate_dry_run_and_save(cls, values): - if values["dry_run"] and (values["save"] is not False): - logger.warning("Cannot save when predicting with dry_run=True. Setting save=False.") + if values["dry_run"] and ( + (values["save"] is not False) or (values["save_dir"] is not None) + ): + logger.warning( + "Cannot save when predicting with dry_run=True. Setting save=False and save_dir=None." + ) values["save"] = False + values["save_dir"] = None return values @root_validator(skip_on_failure=True) - def validate_save(cls, values): - # do this check before we look up checkpoints based on model name so we can see if checkpoint is None + def validate_save_dir(cls, values): + save_dir = values["save_dir"] save = values["save"] - checkpoint = values["checkpoint"] - # if False, no predictions will be written out - if save is False: - return values + # if no save_dir but save is True, use current working directory + if save_dir is None and save: + save_dir = Path.cwd() - else: - # if save=True and we have a local checkpoint, save in checkpoint directory - if save is True and checkpoint is not None: - save = checkpoint.parent / "zamba_predictions.csv" - - # else, save to current working directory - elif save is True and checkpoint is None: - save = Path.cwd() / "zamba_predictions.csv" - - # validate save path - if isinstance(save, Path): - if save.suffix.lower() != ".csv": - raise ValueError("Save path must end with .csv") - elif save.exists(): - raise ValueError(f"Save path {save} already exists.") - else: - values["save"] = save + if save_dir is not None: + # check if files exist + if ( + (save_dir / "zamba_predictions.csv").exists() + or (save_dir / "predict_configuration.yaml").exists() + ) and not values["overwrite"]: + raise ValueError( + f"zamba_predictions.csv and/or predict_configuration.yaml already exist in {save_dir}. If you would like to overwrite, set overwrite=True" + ) + + # make a directory if needed + save_dir.mkdir(parents=True, exist_ok=True) + + # set save to True if save_dir is set + if not save: + save = True - # create any needed parent directories - save.parent.mkdir(parents=True, exist_ok=True) + values["save_dir"] = save_dir + values["save"] = save return values @@ -686,12 +665,12 @@ def get_filepaths(cls, values): contains files with valid suffixes. """ if values["filepaths"] is None: - logger.info(f"Getting files in {values['data_directory']}.") + logger.info(f"Getting files in {values['data_dir']}.") files = [] new_suffixes = [] # iterate over all files in data directory - for f in values["data_directory"].rglob("*"): + for f in values["data_dir"].rglob("*"): if f.is_file(): # keep just files with supported suffixes if f.suffix.lower() in VIDEO_SUFFIXES: @@ -705,9 +684,9 @@ def get_filepaths(cls, values): ) if len(files) == 0: - raise ValueError(f"No video files found in {values['data_directory']}.") + raise ValueError(f"No video files found in {values['data_dir']}.") - logger.info(f"Found {len(files)} videos in {values['data_directory']}.") + logger.info(f"Found {len(files)} videos in {values['data_dir']}.") values["filepaths"] = pd.DataFrame(files, columns=["filepath"]) return values @@ -733,7 +712,7 @@ def validate_files(cls, values): values["filepaths"] = check_files_exist_and_load( df=files_df, - data_directory=values["data_directory"], + data_dir=values["data_dir"], skip_load_validation=values["skip_load_validation"], ) return values diff --git a/zamba/models/densepose/config.py b/zamba/models/densepose/config.py index ab12ca6f..65f2f480 100644 --- a/zamba/models/densepose/config.py +++ b/zamba/models/densepose/config.py @@ -34,10 +34,10 @@ class DensePoseConfig(ZambaBaseModel): Defaults to False. embeddings_in_json (bool): Whether to save the embeddings matrices in the json of the DensePose result. Setting to True can result in large json files. Defaults to False. - data_directory (Path): Where to find the files listed in filepaths (or where to look if + data_dir (Path): Where to find the files listed in filepaths (or where to look if filepaths is not provided). filepaths (Path, optional): Path to a CSV file with a list of filepaths to process. - save_path (Path, optional): Directory for where to save the output files; + save_dir (Path, optional): Directory for where to save the output files; defaults to os.getcwd(). cache_dir (Path, optional): Path for downloading and saving model weights. Defaults to env var `MODEL_CACHE_DIR` or the OS app cache dir. @@ -49,9 +49,9 @@ class DensePoseConfig(ZambaBaseModel): output_type: DensePoseOutputEnum render_output: bool = False embeddings_in_json: bool = False - data_directory: Path + data_dir: Path filepaths: Optional[Path] = None - save_path: Optional[Path] = None + save_dir: Optional[Path] = None cache_dir: Optional[Path] = None weight_download_region: RegionEnum = RegionEnum("us") @@ -71,7 +71,7 @@ def run_model(self): else: raise Exception(f"invalid {self.output_type}") - output_dir = Path(os.getcwd()) if self.save_path is None else self.save_path + output_dir = Path(os.getcwd()) if self.save_dir is None else self.save_dir dpm = DensePoseManager( model, model_cache_dir=self.cache_dir, download_region=self.weight_download_region @@ -112,12 +112,12 @@ def get_filepaths(cls, values): contains files with valid suffixes. """ if values["filepaths"] is None: - logger.info(f"Getting files in {values['data_directory']}.") + logger.info(f"Getting files in {values['data_dir']}.") files = [] new_suffixes = [] # iterate over all files in data directory - for f in values["data_directory"].rglob("*"): + for f in values["data_dir"].rglob("*"): if f.is_file(): # keep just files with supported suffixes if f.suffix.lower() in VIDEO_SUFFIXES: @@ -131,9 +131,9 @@ def get_filepaths(cls, values): ) if len(files) == 0: - raise ValueError(f"No video files found in {values['data_directory']}.") + raise ValueError(f"No video files found in {values['data_dir']}.") - logger.info(f"Found {len(files)} videos in {values['data_directory']}.") + logger.info(f"Found {len(files)} videos in {values['data_dir']}.") values["filepaths"] = pd.DataFrame(files, columns=["filepath"]) return values @@ -159,7 +159,7 @@ def validate_files(cls, values): values["filepaths"] = check_files_exist_and_load( df=files_df, - data_directory=values["data_directory"], + data_dir=values["data_dir"], skip_load_validation=True, ) return values diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index 35eb5b81..e367b3c5 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -81,7 +81,8 @@ def instantiate_model( hparams = yaml.safe_load(f) else: - if not (model_cache_dir / checkpoint).exists(): + # download if neither local checkpoint nor cached checkpoint exist + if not checkpoint.exists() and not (model_cache_dir / checkpoint).exists(): logger.info("Downloading weights for model.") checkpoint = download_weights( filename=str(checkpoint), @@ -225,16 +226,12 @@ def train_model( validate_species(model, data_module) - train_config.save_directory.mkdir(parents=True, exist_ok=True) + train_config.save_dir.mkdir(parents=True, exist_ok=True) # add folder version_n that auto increments if we are not overwriting - tensorboard_version = ( - train_config.save_directory.name if train_config.overwrite_save_directory else None - ) + tensorboard_version = train_config.save_dir.name if train_config.overwrite else None tensorboard_save_dir = ( - train_config.save_directory.parent - if train_config.overwrite_save_directory - else train_config.save_directory + train_config.save_dir.parent if train_config.overwrite else train_config.save_dir ) tensorboard_logger = TensorBoardLogger( @@ -245,9 +242,7 @@ def train_model( ) logging_and_save_dir = ( - tensorboard_logger.log_dir - if not train_config.overwrite_save_directory - else train_config.save_directory + tensorboard_logger.log_dir if not train_config.overwrite else train_config.save_dir ) model_checkpoint = ModelCheckpoint( @@ -381,6 +376,13 @@ def predict_model( "video_loader_config": json.loads(video_loader_config.json()), } + if predict_config.save is not False: + + config_path = predict_config.save_dir / "predict_configuration.yaml" + logger.info(f"Writing out full configuration to {config_path}.") + with config_path.open("w") as fp: + yaml.dump(configuration, fp) + dataloader = data_module.predict_dataloader() logger.info("Starting prediction...") probas = trainer.predict(model=model, dataloaders=dataloader) @@ -401,14 +403,9 @@ def predict_model( if predict_config.save is not False: - config_path = predict_config.save.parent / "predict_configuration.yaml" - config_path.parent.mkdir(exist_ok=True, parents=True) - logger.info(f"Writing out full configuration to {config_path}.") - with config_path.open("w") as fp: - yaml.dump(configuration, fp) - - logger.info(f"Saving out predictions to {predict_config.save}.") - with predict_config.save.open("w") as fp: + preds_path = predict_config.save_dir / "zamba_predictions.csv" + logger.info(f"Saving out predictions to {preds_path}.") + with preds_path.open("w") as fp: df.to_csv(fp, index=True) return df diff --git a/zamba/models/official_models/european/config.yaml b/zamba/models/official_models/european/config.yaml index 0b4177e0..66c46f99 100644 --- a/zamba/models/official_models/european/config.yaml +++ b/zamba/models/official_models/european/config.yaml @@ -1,6 +1,4 @@ train_config: - scheduler_config: default - model_name: european backbone_finetune_config: backbone_initial_ratio_lr: 0.01 multiplier: 1 @@ -9,21 +7,36 @@ train_config: unfreeze_backbone_at_epoch: 15 verbose: true early_stopping_config: + mode: max monitor: val_macro_f1 patience: 3 verbose: true - mode: max + model_name: european + scheduler_config: default video_loader_config: - model_input_height: 240 - model_input_width: 426 crop_bottom_pixels: 50 - fps: 4 - total_frames: 16 + early_bias: false ensure_total_frames: true + evenly_sample_total_frames: false + fps: 4.0 + frame_indices: null + frame_selection_height: null + frame_selection_width: null + i_frames: false megadetector_lite_config: confidence: 0.25 fill_mode: score_sorted + image_height: 416 + image_width: 416 n_frames: 16 + nms_threshold: 0.45 + seed: 55 + sort_by_time: true + model_input_height: 240 + model_input_width: 426 + pix_fmt: rgb24 + scene_threshold: null + total_frames: 16 predict_config: model_name: european -public_checkpoint: european_0c69da8a888c499411deaa040a91d76546ddf78a.ckpt +public_checkpoint: european_0a80dc77bf.ckpt diff --git a/zamba/models/official_models/european/predict_configuration.yaml b/zamba/models/official_models/european/predict_configuration.yaml index 3e4eb57b..be82b672 100644 --- a/zamba/models/official_models/european/predict_configuration.yaml +++ b/zamba/models/official_models/european/predict_configuration.yaml @@ -2,16 +2,18 @@ inference_start_time: '2021-10-15T03:46:08.256520' model_class: TimeDistributedEfficientNet predict_config: batch_size: 2 - cache_dir: /home/ubuntu/.cache/zamba checkpoint: experiments/european_td_dev_base/version_0/time_distributed.ckpt - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false gpus: 1 - model_name: time_distributed + model_cache_dir: /home/ubuntu/.cache/zamba + model_name: null num_workers: 3 output_class_names: false + overwrite: false proba_threshold: null - save: experiments/european_td_dev_base/version_0/zamba_predictions.csv + save: true + save_dir: experiments/european_td_dev_base/version_0 skip_load_validation: true weight_download_region: us species: @@ -27,6 +29,8 @@ species: - weasel - wild_boar video_loader_config: + cache_dir: /tmp/zamba_cache + cleanup_cache: false crop_bottom_pixels: 50 early_bias: false ensure_total_frames: true diff --git a/zamba/models/official_models/european/train_configuration.yaml b/zamba/models/official_models/european/train_configuration.yaml index 329aabe7..38eb9622 100644 --- a/zamba/models/official_models/european/train_configuration.yaml +++ b/zamba/models/official_models/european/train_configuration.yaml @@ -23,9 +23,8 @@ train_config: unfreeze_backbone_at_epoch: 15 verbose: true batch_size: 2 - cache_dir: /home/ubuntu/.cache/zamba checkpoint: data/results/experiments/td_dev_set_full_size_mdlite/results/version_1/time_distributed_zamba.ckpt - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false early_stopping_config: mode: max @@ -35,17 +34,20 @@ train_config: from_scratch: false gpus: 1 max_epochs: null + model_cache_dir: /home/ubuntu/.cache/zamba model_name: null num_workers: 3 - overwrite_save_directory: false + overwrite: false predict_all_zamba_species: true - save_directory: experiments/european_td_dev_base + save_dir: experiments/european_td_dev_base scheduler_config: default skip_load_validation: true split_proportions: null weight_download_region: us training_start_time: '2021-10-13T16:46:39.593515' video_loader_config: + cache_dir: /tmp/zamba_cache + cleanup_cache: false crop_bottom_pixels: 50 early_bias: false ensure_total_frames: true diff --git a/zamba/models/official_models/slowfast/config.yaml b/zamba/models/official_models/slowfast/config.yaml index 48baa85e..f77a4595 100644 --- a/zamba/models/official_models/slowfast/config.yaml +++ b/zamba/models/official_models/slowfast/config.yaml @@ -1,12 +1,4 @@ train_config: - scheduler_config: - scheduler: MultiStepLR - scheduler_params: - gamma: 0.5 - milestones: - - 1 - verbose: true - model_name: slowfast backbone_finetune_config: backbone_initial_ratio_lr: 0.01 multiplier: 10 @@ -15,18 +7,42 @@ train_config: unfreeze_backbone_at_epoch: 3 verbose: true early_stopping_config: + mode: max + monitor: val_macro_f1 patience: 5 + verbose: true + model_name: slowfast + scheduler_config: + scheduler: MultiStepLR + scheduler_params: + gamma: 0.5 + milestones: + - 1 + verbose: true video_loader_config: - model_input_height: 240 - model_input_width: 426 crop_bottom_pixels: 50 - fps: 8 - total_frames: 32 + early_bias: false ensure_total_frames: true + evenly_sample_total_frames: false + fps: 8.0 + frame_indices: null + frame_selection_height: null + frame_selection_width: null + i_frames: false megadetector_lite_config: confidence: 0.25 fill_mode: score_sorted + image_height: 416 + image_width: 416 n_frames: 32 + nms_threshold: 0.45 + seed: 55 + sort_by_time: true + model_input_height: 240 + model_input_width: 426 + pix_fmt: rgb24 + scene_threshold: null + total_frames: 32 predict_config: model_name: slowfast -public_checkpoint: slowfast_501182d969aabf49805829a2b09ed8078b4255a3.ckpt +public_checkpoint: slowfast_3c9d5d0c72.ckpt diff --git a/zamba/models/official_models/slowfast/predict_configuration.yaml b/zamba/models/official_models/slowfast/predict_configuration.yaml index b46a417b..c8b6ef36 100644 --- a/zamba/models/official_models/slowfast/predict_configuration.yaml +++ b/zamba/models/official_models/slowfast/predict_configuration.yaml @@ -3,15 +3,17 @@ model_class: SlowFast predict_config: batch_size: 2 checkpoint: experiments/slowfast_small_set_full_size_mdlite/version_2/slowfast.ckpt - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false gpus: 1 model_cache_dir: /home/ubuntu/.cache/zamba model_name: null num_workers: 3 output_class_names: false + overwrite: false proba_threshold: null - save: experiments/slowfast_small_set_full_size_mdlite/version_2/zamba_predictions.csv + save: true + save_dir: experiments/slowfast_small_set_full_size_mdlite/version_2 skip_load_validation: true weight_download_region: us species: diff --git a/zamba/models/official_models/slowfast/train_configuration.yaml b/zamba/models/official_models/slowfast/train_configuration.yaml index 70cc2434..653014ce 100644 --- a/zamba/models/official_models/slowfast/train_configuration.yaml +++ b/zamba/models/official_models/slowfast/train_configuration.yaml @@ -45,7 +45,7 @@ train_config: verbose: true batch_size: 2 checkpoint: null - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false early_stopping_config: mode: max @@ -58,9 +58,9 @@ train_config: model_cache_dir: /home/ubuntu/.cache/zamba model_name: slowfast num_workers: 3 - overwrite_save_directory: false + overwrite: false predict_all_zamba_species: true - save_directory: experiments/slowfast_small_set_full_size_mdlite + save_dir: experiments/slowfast_small_set_full_size_mdlite scheduler_config: scheduler: MultiStepLR scheduler_params: diff --git a/zamba/models/official_models/time_distributed/config.yaml b/zamba/models/official_models/time_distributed/config.yaml index aa161965..0adcfe0c 100644 --- a/zamba/models/official_models/time_distributed/config.yaml +++ b/zamba/models/official_models/time_distributed/config.yaml @@ -1,12 +1,4 @@ train_config: - scheduler_config: - scheduler: MultiStepLR - scheduler_params: - gamma: 0.5 - milestones: - - 3 - verbose: true - model_name: time_distributed backbone_finetune_config: backbone_initial_ratio_lr: 0.01 multiplier: 1 @@ -15,18 +7,42 @@ train_config: unfreeze_backbone_at_epoch: 3 verbose: true early_stopping_config: + mode: max + monitor: val_macro_f1 patience: 5 + verbose: true + model_name: time_distributed + scheduler_config: + scheduler: MultiStepLR + scheduler_params: + gamma: 0.5 + milestones: + - 3 + verbose: true video_loader_config: - model_input_height: 240 - model_input_width: 426 crop_bottom_pixels: 50 - fps: 4 - total_frames: 16 + early_bias: false ensure_total_frames: true + evenly_sample_total_frames: false + fps: 4.0 + frame_indices: null + frame_selection_height: null + frame_selection_width: null + i_frames: false megadetector_lite_config: confidence: 0.25 fill_mode: score_sorted + image_height: 416 + image_width: 416 n_frames: 16 + nms_threshold: 0.45 + seed: 55 + sort_by_time: true + model_input_height: 240 + model_input_width: 426 + pix_fmt: rgb24 + scene_threshold: null + total_frames: 16 predict_config: model_name: time_distributed -public_checkpoint: time_distributed_9e710aa8c92d25190a64b3b04b9122bdcb456982.ckpt +public_checkpoint: time_distributed_7f74686b7b.ckpt diff --git a/zamba/models/official_models/time_distributed/predict_configuration.yaml b/zamba/models/official_models/time_distributed/predict_configuration.yaml index ae39bbdf..9c08eab1 100644 --- a/zamba/models/official_models/time_distributed/predict_configuration.yaml +++ b/zamba/models/official_models/time_distributed/predict_configuration.yaml @@ -2,16 +2,18 @@ inference_start_time: '2021-09-30T10:55:29.100719' model_class: TimeDistributedEfficientNet predict_config: batch_size: 1 - cache_dir: /home/ubuntu/.cache/zamba checkpoint: experiments/td_small_set_full_size_mdlite/version_1/time_distributed.ckpt - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false gpus: 1 - model_name: time_distributed + model_cache_dir: /home/ubuntu/.cache/zamba + model_name: null num_workers: 5 output_class_names: false + overwrite: false proba_threshold: null - save: experiments/td_small_set_full_size_mdlite/version_1/zamba_predictions.csv + save: true + save_dir: experiments/td_small_set_full_size_mdlite/version_1 skip_load_validation: true weight_download_region: us species: @@ -48,6 +50,8 @@ species: - small_cat - wild_dog_jackal video_loader_config: + cache_dir: /tmp/zamba_cache + cleanup_cache: false crop_bottom_pixels: 50 early_bias: false ensure_total_frames: true diff --git a/zamba/models/official_models/time_distributed/train_configuration.yaml b/zamba/models/official_models/time_distributed/train_configuration.yaml index 427260c5..d4627cc6 100644 --- a/zamba/models/official_models/time_distributed/train_configuration.yaml +++ b/zamba/models/official_models/time_distributed/train_configuration.yaml @@ -44,9 +44,8 @@ train_config: unfreeze_backbone_at_epoch: 3 verbose: true batch_size: 1 - cache_dir: /home/ubuntu/.cache/zamba checkpoint: null - data_directory: /home/ubuntu/zamba-algorithms + data_dir: /home/ubuntu/zamba-algorithms dry_run: false early_stopping_config: mode: max @@ -56,11 +55,12 @@ train_config: from_scratch: true gpus: 1 max_epochs: null + model_cache_dir: /home/ubuntu/.cache/zamba model_name: time_distributed num_workers: 5 - overwrite_save_directory: false + overwrite: false predict_all_zamba_species: true - save_directory: experiments/td_small_set_full_size_mdlite + save_dir: experiments/td_small_set_full_size_mdlite scheduler_config: scheduler: MultiStepLR scheduler_params: @@ -73,6 +73,8 @@ train_config: weight_download_region: us training_start_time: '2021-09-29T23:50:09.687298' video_loader_config: + cache_dir: /tmp/zamba_cache + cleanup_cache: false crop_bottom_pixels: 50 early_bias: false ensure_total_frames: true diff --git a/zamba/models/publish_models.py b/zamba/models/publish_models.py index a32ff4e1..de731f7e 100644 --- a/zamba/models/publish_models.py +++ b/zamba/models/publish_models.py @@ -8,10 +8,47 @@ import yaml from zamba import MODELS_DIRECTORY -from zamba.models.config import WEIGHT_LOOKUP, ModelEnum, TrainConfig +from zamba.models.config import WEIGHT_LOOKUP, ModelEnum from zamba.models.densepose import MODELS as DENSEPOSE_MODELS +def get_model_only_params(full_configuration, subset="train_config"): + """Return only params that are not data or machine specific. + Used for generating official configs. + """ + if subset == "train_config": + config = full_configuration[subset] + for key in [ + "data_dir", + "dry_run", + "batch_size", + "auto_lr_find", + "gpus", + "num_workers", + "max_epochs", + "weight_download_region", + "split_proportions", + "save_dir", + "overwrite", + "skip_load_validation", + "from_scratch", + "model_cache_dir", + "predict_all_zamba_species", + ]: + config.pop(key) + + elif subset == "video_loader_config": + config = full_configuration[subset] + + if "megadetector_lite_config" in config.keys(): + config["megadetector_lite_config"].pop("device") + + for key in ["cache_dir", "cleanup_cache"]: + config.pop(key) + + return config + + def publish_model(model_name, trained_model_dir): """ Creates the files for the model folder in `official_models` and uploads the model to the three @@ -51,12 +88,13 @@ def publish_model(model_name, trained_model_dir): # prepare config for use in official models dir logger.info("Preparing official config file.") - config_yaml = MODELS_DIRECTORY / model_name / "config.yaml" - with config_yaml.open() as f: - config_dict = yaml.safe_load(f) + # start with full train configuration + with (MODELS_DIRECTORY / model_name / "train_configuration.yaml").open() as f: + train_configuration_full_dict = yaml.safe_load(f) - train_config = TrainConfig.construct(**config_dict["train_config"]).get_model_only_params() + # get limited train config + train_config = get_model_only_params(train_configuration_full_dict, subset="train_config") # e.g. european model is trained from a checkpoint; we want to expose final model # (model_name: european) not the base checkpoint @@ -66,18 +104,21 @@ def publish_model(model_name, trained_model_dir): official_config = dict( train_config=train_config, - video_loader_config=config_dict["video_loader_config"], + video_loader_config=get_model_only_params( + train_configuration_full_dict, subset="video_loader_config" + ), predict_config=dict(model_name=model_name), ) # hash train_configuration to generate public filename for model - hash_str = hashlib.sha1(str(config_dict["train_config"]).encode("utf-8")).hexdigest() + hash_str = hashlib.sha1(str(train_configuration_full_dict).encode("utf-8")).hexdigest()[:10] public_file_name = f"{model_name}_{hash_str}.ckpt" # add that to official config official_config["public_checkpoint"] = public_file_name # write out official config + config_yaml = MODELS_DIRECTORY / model_name / "config.yaml" logger.info(f"Writing out to {config_yaml}") with config_yaml.open("w") as f: yaml.dump(official_config, f, sort_keys=False) @@ -91,8 +132,11 @@ def upload_to_all_public_buckets(file, public_file_name): public_checkpoint = S3Path( f"s3://drivendata-public-assets{bucket}/zamba_official_models/{public_file_name}" ) - logger.info(f"Uploading {file} to {public_checkpoint}") - public_checkpoint.upload_from(file, force_overwrite_to_cloud=True) + if public_checkpoint.exists(): + logger.info(f"Skipping since {public_checkpoint} exists.") + else: + logger.info(f"Uploading {file} to {public_checkpoint}") + public_checkpoint.upload_from(file, force_overwrite_to_cloud=True) if __name__ == "__main__": diff --git a/zamba/models/utils.py b/zamba/models/utils.py index 7fc3112f..52042843 100644 --- a/zamba/models/utils.py +++ b/zamba/models/utils.py @@ -36,7 +36,10 @@ def download_weights( def get_model_checkpoint_filename(model_name): - config_file = MODELS_DIRECTORY / f"{model_name}/config.yaml" + if isinstance(model_name, Enum): + model_name = model_name.value + + config_file = MODELS_DIRECTORY / model_name / "config.yaml" with config_file.open() as f: config_dict = yaml.safe_load(f) - return config_dict["public_checkpoint"] + return Path(config_dict["public_checkpoint"])