Skip to content

Commit

Permalink
add mpi4jax & numpyro (__-){
Browse files Browse the repository at this point in the history
  • Loading branch information
Marmaduke Woodman committed Apr 8, 2023
1 parent 1413dfc commit 96701ce
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
RUN apt-get update
RUN apt-get install -y python3-pip
RUN pip3 install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip3 install jupyterlab brian2 matplotlib joblib scipy vbjax
RUN pip3 install jupyterlab brian2 matplotlib joblib scipy vbjax mpi4jax numpyro
CMD python3 -c 'import jax; print(jax.numpy.zeros(32).device())'

0 comments on commit 96701ce

Please sign in to comment.