Skip to content

GPU-accelerated value iteration and simulation for perishable inventory control using JAX

License

Notifications You must be signed in to change notification settings

joefarrington/viso_jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

viso_jax

DOI

GPU-accelerated value iteration and simulation for perishable inventory control using JAX

New package: MDPax

In January 2025, we released a full Python package to help other researchers take advantage of the findings in this work: MDPax. See the GitHub repository and documentation for more details.

Introduction

This repository provides the code to support the paper Going faster to see further: GPU-accelerated value iteration and simulation for perishable inventory control using JAX by Farrington et al (arXiv preprint)

The paper considers three perishable inventory management scenarios from recent studies and demonstrates that it is possible to run value iteration to find the optimal policy for problems that were previously considered infeasible or intractable using the Python library JAX to take advantage of the parellel processing capabilities of modern GPUs.

We provide a Google Colab notebook that can be used to reproduce the experiments on a free cloud-based GPU in the notebooks directory.

Open In Colab

Scenarios

Scenario A

Based on "Reward shaping to improve the performance of deep reinforcement learning in perishable inventory management" by De Moor et al (2022)

Referred to as de_moor_perishable in the directory structure and configuration files.

Scenario B

Based on the two product scenarios in "On computing optimal policies in perishable inventory control using value iteration" by Hendrix et al (2019)

Additional experimental settings are taken from Ortega et al (2019).

Referred to as hendrix_perishable_substitution_two_product in the directory structure and configuration files.

Scenario C

Adapted from the scenario in Chapter 6 of "Data-driven modelling and control of hospital blood inventory" by Mirjalili (2022)

Referred to as mirjalili_perishable_platelets in the directory structure and configuration files.

Additional scenarios

The single product scenario described by Hendrix et al (2019) is included in the repository as hendrix_single_product.

Installation

To use JAX with Nvidia GPU-acceleration, you must first install CUDA and CuDNN. See the JAX installation instructions for further details.

Python dependencies are listed in pyproject.toml. We use uv for dependency management.

viso_jax and its Python dependencies can be installed using the code snippet below. This snippet assumes that you have uv installed.

git clone https://github.com/joefarrington/viso_jax.git
cd viso_jax
uv venv
source .venv/bin/activate
uv pip install .

Once installation is complete, you can test that JAX recognises an accelerator (GPU or TPU) by running the following snippet:

uv run pytest tests -m "jax"

Reproducing experiments with Colab

Open In Colab

The Colab notebook reproduce_viso_jax_experiments.ipynb in the notebooks directory includes a form corresponding to Scenarios A, B and C in the paper. These can be used to reproduce the value iteration and simulation optimization experiments on a cloud-based GPU.

The Colab notebook also includes an Advanced section with brief interactive tutorials demonstrating how to run experiments with different settings using the command line.

The notebook was last tested on 2023-03-17. Changes to the Colab virtual machine may affect performance and/or lead to incompatibilities. Please raise an issue if you encounter problems when using the notebook. The Colab release notes will detail any changes that have been made to the Colab service since the notebook was last tested.

Running experiments using the command line

We used the shell scripts in the directory bash_scripts to run our experiments. There is one script corresponding to each results table in the paper.

Scenario A

  • Table 3: run_all_scenario_a.sh

Scenario B

  • Table 4: run_all_scenario_b_hendrix_settings.sh
  • Table 5: run_all_scenario_b_ortega_settings.sh

Scenario C

  • Table 6: run_all_scenario_c.sh

By default, the value iteration experiments will save results in viso_jax/value_iteration/outputs/{name of bash script}/{date}/{time}/m{x}/{exp name} and the simulation optimization experiments will save results in viso_jax/simopt/outputs/{name of bash script}/{date}/{time}/m{x}/{exp name}. Alternative output paths can be specified in the bash script by changing the command line argument hydra.run.dirpassed to the scripts run_value_iteration.py and run_optuna_simopt.py respectively.

Details from the experiments used to populate the results table are saved in a yaml file, output_info.yaml. This file also includes the KPI information used to populate the results tables in the appendices. By default, all experiments save the hydra configuration details and a log file. Value iteration experiments may save checkpoints of the value function, and by default will save the final policy and the final estimate of the value function (poilcy.csv, V.csv). Simulation optimization experiments save a record of the trials run by Optuna (trials.csv) and the row corresponding to the best trial (best_trial.csv) by default.

Tests

The test suite is intended to be high level, comparing outputs of our methods (e.g. policies, heuristic policy parameters and mean returns) to those reported in the original papers. Some tests may fail due to an out-of-memory error if run on a GPU with less than 12GB of VRAM because the maximum batch sizes (for value iteration) and number of rollouts to perform in parallel (for policy evaluation) have been set to run on the Nvidia GTX 3060 used during development.

Recommended resources

JAX

Getting started

The JAX documentation includes JAX 101, a set of interactive introductory tutorials. We also recommed reading JAX - the sharp bits to understand key differences between NumPy and JAX.

The Awesome JAX GitHub repository contains links to a wide variety of Python libraries and projects based on JAX.

Value iteration

Thomas J Sargent and John Stachurski provide an interactive tutorial implementing value iteration in JAX for an economics problem and compard the speed of two NumPy-based approaches with GPU-accelerated JAX.

Hydra

We specified the configurations of our experiments using Hydra, which support composable configuration files and provides a command line interface for overriding configuration items.

Gymnax

We created reinforcement learning environments for each scenario using Gymnax. Gymnax provides an API similar to the OpenAI gym API, but allows simulated rollouts to run in parallel on GPU. This is particularly helpful for simulation optimization because it allows many possible parameters for the heuristic policies (e.g. a base-stock policy) to be evaluated at the same time, each on many parallel rollouts.

Optuna

We used Optuna to search the parameter spaces for heuristic policies in our simulation optimization experiments.

About

GPU-accelerated value iteration and simulation for perishable inventory control using JAX

Resources

License

Stars

Watchers

Forks

Packages

No packages published