diff --git a/containers/Dockerfile b/containers/Dockerfile new file mode 100644 index 00000000..39aa2481 --- /dev/null +++ b/containers/Dockerfile @@ -0,0 +1,7 @@ +ARG PY_VERSION=3.8 +FROM docker.io/python:${PY_VERSION} + +RUN pip install jaxlib jax elegy dataget matplotlib typer +WORKDIR /usr/src/app +COPY examples/ /usr/src/app/ + diff --git a/containers/GPU/Dockerfile b/containers/GPU/Dockerfile new file mode 100644 index 00000000..2466a6ed --- /dev/null +++ b/containers/GPU/Dockerfile @@ -0,0 +1,18 @@ +FROM docker.io/nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + software-properties-common && \ + add-apt-repository ppa:deadsnakes -y && \ + apt-get update && apt-get install -y --no-install-recommends \ + python3.8-dev \ + python3.8-distutils \ + curl \ + && curl -Lk "https://bootstrap.pypa.io/get-pip.py" | python3.8 && \ + rm -rf /var/lib/apt/lists/* + +ENV BASE_URL="https://storage.googleapis.com/jax-releases" +RUN python3.8 -m pip install --upgrade $BASE_URL/cuda102/jaxlib-0.1.51-cp38-none-manylinux2010_x86_64.whl +RUN python3.8 -m pip install --upgrade jax elegy dataget matplotlib typer +ENV XLA_PYTHON_CLIENT_ALLOCATOR="platform" + +COPY examples/ /usr/src/app/