Use jax/haiku to perform regression from x-ray data to age of the patient, and also implement RandAugment in jax (borrowing from others' code).
See this post where I walk through the more interesting bits of code and training curves, working up to iteratively better results.
Note, you need to use tensorflow-datasets to package the kaggle nih xray dataset into something useable yourself. Seek out the relevant documentation for creating your own tfds datasets. Here is a good place to start. For anyone interested in image processing, tfds is a worthwhile skill, and many learning resources already exist.
Example code for creating a tfds dataset can be found in:
./tfdatasets/age_nih/age_nih_dataset_builder.py
where it can be run with tfds build
from the terminal in the enclosing directory, age_nih
, once the dataset has been downloaded.
- jax (install with gpu support)
- haiku
- optax
- tensorflow-datasets
- jmp (optional, you can comment the relevant bits out in trainRegression.py)
- numpy
- dm-tree
- abseil-py
- PIL
- GPUtil
- keras_cv (optional, can comment out the relevant bits in preprocess.py)
- tensorflow
- dm-pix
- imax