diff --git a/images/pytorch-benchmarks/Dockerfile b/images/pytorch-benchmarks/Dockerfile index 18663eb..50ccab7 100644 --- a/images/pytorch-benchmarks/Dockerfile +++ b/images/pytorch-benchmarks/Dockerfile @@ -1,12 +1,13 @@ -FROM nvcr.io/nvidia/pytorch:22.12-py3 +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime -# RUN pip install -U pip -# RUN pip install torchvision torchaudio -RUN pip uninstall torch -y && pip install --no-cache torch torchvision torchaudio +RUN apt update && apt install -y git +# TODO: Checkout certain commit here +RUN git clone https://github.com/pytorch/benchmark +WORKDIR /workspace/benchmark +# Add other benchmarks here? +RUN python install.py alexnet resnet50 llama -# RUN git clone https://github.com/pytorch/benchmark -# WORKDIR /workspace/benchmark -# # Add other benchmarks here? -# RUN python install.py alexnet resnet50 llama +# PyTorch install.py pins numpy=1.21.2 but this breaks numba so update both here +RUN pip install -U numpy numba -# COPY run-benchmark.sh /usr/local/bin/ \ No newline at end of file +COPY run-benchmark.sh /usr/local/bin/ \ No newline at end of file