Skip to content

kjabon/image_regression

Repository files navigation

image_regression

Use jax/haiku to perform regression from x-ray data to age of the patient, and also implement RandAugment in jax (borrowing from others' code).

blog post

See this post where I walk through the more interesting bits of code and training curves, working up to iteratively better results.

dataset preparation

Note, you need to use tensorflow-datasets to package the kaggle nih xray dataset into something useable yourself. Seek out the relevant documentation for creating your own tfds datasets. Here is a good place to start. For anyone interested in image processing, tfds is a worthwhile skill, and many learning resources already exist.

Example code for creating a tfds dataset can be found in: ./tfdatasets/age_nih/age_nih_dataset_builder.py where it can be run with tfds build from the terminal in the enclosing directory, age_nih, once the dataset has been downloaded.

requirements

  • jax (install with gpu support)
  • haiku
  • optax
  • tensorflow-datasets
  • jmp (optional, you can comment the relevant bits out in trainRegression.py)
  • numpy
  • dm-tree
  • abseil-py
  • PIL
  • GPUtil
  • keras_cv (optional, can comment out the relevant bits in preprocess.py)
  • tensorflow
  • dm-pix
  • imax

About

Use jax/haiku to perform regression from x-ray data to age of the patient

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published