diff --git a/.github/workflows/basic-tests-linux.yml b/.github/workflows/basic-tests-linux.yml
index 86636a3c..02990163 100644
--- a/.github/workflows/basic-tests-linux.yml
+++ b/.github/workflows/basic-tests-linux.yml
@@ -32,27 +32,24 @@ jobs:
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install pytest nbval
- if [ -f requirements.txt ]; then uv pip install -r requirements.txt; fi
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- uv pip install pytest
+ uv add pytest-ruff nbval
- name: Test Selected Python Scripts
run: |
source .venv/bin/activate
- pytest setup/02_installing-python-libraries/tests.py
- pytest ch04/01_main-chapter-code/tests.py
- pytest ch05/01_main-chapter-code/tests.py
- pytest ch05/07_gpt_to_llama/tests/tests.py
- pytest ch06/01_main-chapter-code/tests.py
+ pytest --ruff setup/02_installing-python-libraries/tests.py
+ pytest --ruff ch04/01_main-chapter-code/tests.py
+ pytest --ruff ch05/01_main-chapter-code/tests.py
+ pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |
source .venv/bin/activate
- pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
- pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
- pytest --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
+ pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-macos.yml b/.github/workflows/basic-tests-macos.yml
index 02f1e66c..5de2bd47 100644
--- a/.github/workflows/basic-tests-macos.yml
+++ b/.github/workflows/basic-tests-macos.yml
@@ -32,27 +32,24 @@ jobs:
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install pytest nbval
- if [ -f requirements.txt ]; then uv pip install -r requirements.txt; fi
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- uv pip install pytest
+ uv add pytest-ruff nbval
- name: Test Selected Python Scripts
run: |
source .venv/bin/activate
- pytest setup/02_installing-python-libraries/tests.py
- pytest ch04/01_main-chapter-code/tests.py
- pytest ch05/01_main-chapter-code/tests.py
- pytest ch05/07_gpt_to_llama/tests/tests.py
- pytest ch06/01_main-chapter-code/tests.py
+ pytest --ruff setup/02_installing-python-libraries/tests.py
+ pytest --ruff ch04/01_main-chapter-code/tests.py
+ pytest --ruff ch05/01_main-chapter-code/tests.py
+ pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |
source .venv/bin/activate
- pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
- pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
- pytest --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
+ pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-old-pytorch.yml b/.github/workflows/basic-tests-old-pytorch.yml
index ea6dc896..210d2972 100644
--- a/.github/workflows/basic-tests-old-pytorch.yml
+++ b/.github/workflows/basic-tests-old-pytorch.yml
@@ -35,28 +35,25 @@ jobs:
- name: Install dependencies
run: |
- python -m pip install --upgrade pip setuptools wheel
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install pytest nbval
- uv pip install torch==${{ matrix.pytorch-version }}
- uv pip install -r requirements.txt
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
- uv pip install pytest
+ uv add torch==${{ matrix.pytorch-version }}
+ uv add pytest-ruff nbval
- name: Test Selected Python Scripts
run: |
source .venv/bin/activate
- pytest setup/02_installing-python-libraries/tests.py
- pytest ch04/01_main-chapter-code/tests.py
- pytest ch05/01_main-chapter-code/tests.py
- pytest ch05/07_gpt_to_llama/tests/tests.py
- pytest ch06/01_main-chapter-code/tests.py
+ pytest --ruff setup/02_installing-python-libraries/tests.py
+ pytest --ruff ch04/01_main-chapter-code/tests.py
+ pytest --ruff ch05/01_main-chapter-code/tests.py
+ pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |
source .venv/bin/activate
- pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
- pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
- pytest --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
+ pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-pytorch-rc.yml b/.github/workflows/basic-tests-pytorch-rc.yml
index a073acf5..e91d67d2 100644
--- a/.github/workflows/basic-tests-pytorch-rc.yml
+++ b/.github/workflows/basic-tests-pytorch-rc.yml
@@ -31,28 +31,25 @@ jobs:
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install pytest nbval
- if [ -f requirements.txt ]; then uv pip install -r requirements.txt; fi
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
+ uv add pytest-ruff nbval
uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
- uv pip install pytest
- name: Test Selected Python Scripts
run: |
source .venv/bin/activate
- pytest setup/02_installing-python-libraries/tests.py
- pytest ch04/01_main-chapter-code/tests.py
- pytest ch05/01_main-chapter-code/tests.py
- pytest ch05/07_gpt_to_llama/tests/tests.py
- pytest ch06/01_main-chapter-code/tests.py
+ pytest --ruff setup/02_installing-python-libraries/tests.py
+ pytest --ruff ch04/01_main-chapter-code/tests.py
+ pytest --ruff ch05/01_main-chapter-code/tests.py
+ pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks
run: |
source .venv/bin/activate
- pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
- pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
- pytest --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
+ pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-windows-pip.yml b/.github/workflows/basic-tests-windows-pip.yml
new file mode 100644
index 00000000..4a31b630
--- /dev/null
+++ b/.github/workflows/basic-tests-windows-pip.yml
@@ -0,0 +1,57 @@
+name: Code tests (Windows pip)
+
+on:
+ push:
+ branches: [ main ]
+ paths:
+ - '**/*.py'
+ - '**/*.ipynb'
+ - '**/*.yaml'
+ - '**/*.yml'
+ - '**/*.sh'
+ pull_request:
+ branches: [ main ]
+ paths:
+ - '**/*.py'
+ - '**/*.ipynb'
+ - '**/*.yaml'
+ - '**/*.yml'
+ - '**/*.sh'
+
+jobs:
+ test:
+ runs-on: windows-latest
+
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ shell: pwsh
+ run: |
+ pip install --upgrade pip
+ pip install -r requirements.txt
+ pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows
+ pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
+ pip install pytest-ruff nbval
+
+ - name: Run Python Tests
+ shell: pwsh
+ run: |
+ pytest --ruff setup/02_installing-python-libraries/tests.py
+ pytest --ruff ch04/01_main-chapter-code/tests.py
+ pytest --ruff ch05/01_main-chapter-code/tests.py
+ pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ pytest --ruff ch06/01_main-chapter-code/tests.py
+
+ - name: Run Jupyter Notebook Tests
+ shell: pwsh
+ run: |
+ pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-windows-uv.yml.disabled b/.github/workflows/basic-tests-windows-uv.yml.disabled
new file mode 100644
index 00000000..9dd542d4
--- /dev/null
+++ b/.github/workflows/basic-tests-windows-uv.yml.disabled
@@ -0,0 +1,57 @@
+name: Code tests (Windows)
+
+on:
+ push:
+ branches: [ main ]
+ paths:
+ - '**/*.py'
+ - '**/*.ipynb'
+ - '**/*.yaml'
+ - '**/*.yml'
+ - '**/*.sh'
+ pull_request:
+ branches: [ main ]
+ paths:
+ - '**/*.py'
+ - '**/*.ipynb'
+ - '**/*.yaml'
+ - '**/*.yml'
+ - '**/*.sh'
+
+jobs:
+ test:
+ runs-on: windows-latest
+
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v4
+
+ - name: Install uv
+ shell: pwsh
+ run: |
+ Invoke-WebRequest -Uri "https://astral.sh/uv/install.ps1" -OutFile "uv_install.ps1"
+ & .\uv_install.ps1
+
+ - name: Install dependencies with uv
+ shell: pwsh
+ run: |
+ uv venv --python=python3.10
+ uv pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows
+ uv pip install -r requirements.txt
+ uv pip install pytest-ruff
+
+ - name: Run Python Tests
+ shell: pwsh
+ run: |
+ uv run pytest --ruff setup/02_installing-python-libraries/tests.py
+ uv run pytest --ruff ch04/01_main-chapter-code/tests.py
+ uv run pytest --ruff ch05/01_main-chapter-code/tests.py
+ uv run pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
+ uv run pytest --ruff ch06/01_main-chapter-code/tests.py
+
+ - name: Run Jupyter Notebook Tests
+ shell: pwsh
+ run: |
+ uv run pytest --ruff --nbval ch02/01_main-chapter-code/dataloader.ipynb
+ uv run pytest --ruff --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
+ uv run pytest --ruff --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
diff --git a/.github/workflows/basic-tests-windows.yml b/.github/workflows/basic-tests-windows.yml
deleted file mode 100644
index 43efbd75..00000000
--- a/.github/workflows/basic-tests-windows.yml
+++ /dev/null
@@ -1,57 +0,0 @@
-name: Code tests (Windows)
-
-on:
- push:
- branches: [ main ]
- paths:
- - '**/*.py' # Run workflow for changes in Python files
- - '**/*.ipynb'
- - '**/*.yaml'
- - '**/*.yml'
- - '**/*.sh'
- pull_request:
- branches: [ main ]
- paths:
- - '**/*.py'
- - '**/*.ipynb'
- - '**/*.yaml'
- - '**/*.yml'
- - '**/*.sh'
-
-jobs:
- test:
- runs-on: windows-latest
-
- steps:
- - name: Checkout Code
- uses: actions/checkout@v4
-
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: '3.10'
-
- - name: Install dependencies
- shell: bash
- run: |
- python -m pip install --upgrade pip
- pip install pytest nbval
- if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- pip install matplotlib==3.9.0
- pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt
-
- - name: Test Selected Python Scripts
- shell: bash
- run: |
- pytest setup/02_installing-python-libraries/tests.py
- pytest ch04/01_main-chapter-code/tests.py
- pytest ch05/01_main-chapter-code/tests.py
- pytest ch05/07_gpt_to_llama/tests/tests.py
- pytest ch06/01_main-chapter-code/tests.py
-
- - name: Validate Selected Jupyter Notebooks
- shell: bash
- run: |
- pytest --nbval ch02/01_main-chapter-code/dataloader.ipynb
- pytest --nbval ch03/01_main-chapter-code/multihead-attention.ipynb
- pytest --nbval ch02/04_bonus_dataloader-intuition/dataloader-intuition.ipynb
\ No newline at end of file
diff --git a/.github/workflows/check-links.yml b/.github/workflows/check-links.yml
index 2e47b7dc..b9985584 100644
--- a/.github/workflows/check-links.yml
+++ b/.github/workflows/check-links.yml
@@ -22,18 +22,16 @@ jobs:
- name: Install dependencies
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install pytest pytest-check-links
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
+ uv add pytest-ruff pytest-check-links
# Current version of retry doesn't work well if there are broken non-URL links
# pip install pytest pytest-check-links pytest-retry
- uv pip install pytest pytest-check-links
- name: Check links
run: |
source .venv/bin/activate
- pytest --check-links ./ --check-links-ignore "https://platform.openai.com/*" --check-links-ignore "https://openai.com/*" --check-links-ignore "https://arena.lmsys.org" --check-links-ignore https://unsloth.ai/blog/gradient --check-links-ignore "https://www.reddit.com/r/*" --check-links-ignore "https://code.visualstudio.com/*" --check-links-ignore https://arxiv.org/* --check-links-ignore "https://ai.stanford.edu/~amaas/data/sentiment/"
+ pytest --ruff --check-links ./ --check-links-ignore "https://platform.openai.com/*" --check-links-ignore "https://openai.com/*" --check-links-ignore "https://arena.lmsys.org" --check-links-ignore https://unsloth.ai/blog/gradient --check-links-ignore "https://www.reddit.com/r/*" --check-links-ignore "https://code.visualstudio.com/*" --check-links-ignore https://arxiv.org/* --check-links-ignore "https://ai.stanford.edu/~amaas/data/sentiment/"
# pytest --check-links ./ --check-links-ignore "https://platform.openai.com/*" --check-links-ignore "https://arena.lmsys.org" --retries 2 --retry-delay 5
diff --git a/.github/workflows/check-spelling-errors.yml b/.github/workflows/check-spelling-errors.yml
index 2f5cbc86..3edd99cf 100644
--- a/.github/workflows/check-spelling-errors.yml
+++ b/.github/workflows/check-spelling-errors.yml
@@ -22,11 +22,10 @@ jobs:
- name: Install codespell
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install codespell
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
+ uv add codespell
- name: Run codespell
run: |
diff --git a/.github/workflows/pep8-linter.yml b/.github/workflows/pep8-linter.yml
index f632e312..2b4723cb 100644
--- a/.github/workflows/pep8-linter.yml
+++ b/.github/workflows/pep8-linter.yml
@@ -15,15 +15,14 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.10'
- - name: Install flake8
+ - name: Install ruff (a faster flake 8 equivalent)
run: |
- python -m pip install --upgrade pip
- pip install uv
- uv venv --python=python3.10
- source .venv/bin/activate
- uv pip install flake8
+ curl -LsSf https://astral.sh/uv/install.sh | sh
+ uv python install 3.10
+ uv add . --dev
+ uv add ruff
- - name: Run flake8 with exceptions
+ - name: Run ruff with exceptions
run: |
source .venv/bin/activate
- flake8 . --max-line-length=140 --ignore=W504,E402,E731,C406,E741,E722,E226 --exclude .venv
+ ruff check .
diff --git a/appendix-E/01_main-chapter-code/appendix-E.ipynb b/appendix-E/01_main-chapter-code/appendix-E.ipynb
index 6122baaf..4d248911 100644
--- a/appendix-E/01_main-chapter-code/appendix-E.ipynb
+++ b/appendix-E/01_main-chapter-code/appendix-E.ipynb
@@ -226,7 +226,6 @@
"outputs": [],
"source": [
"import torch\n",
- "from torch.utils.data import Dataset\n",
"import tiktoken\n",
"from previous_chapters import SpamDataset\n",
"\n",
@@ -1518,7 +1517,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch03/01_main-chapter-code/exercise-solutions.ipynb b/ch03/01_main-chapter-code/exercise-solutions.ipynb
index d41aa5e5..b0537b9e 100644
--- a/ch03/01_main-chapter-code/exercise-solutions.ipynb
+++ b/ch03/01_main-chapter-code/exercise-solutions.ipynb
@@ -64,8 +64,6 @@
"metadata": {},
"outputs": [],
"source": [
- "import torch\n",
- "\n",
"inputs = torch.tensor(\n",
" [[0.43, 0.15, 0.89], # Your (x^1)\n",
" [0.55, 0.87, 0.66], # journey (x^2)\n",
@@ -341,7 +339,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb
index 76f7aaf4..12525fe2 100644
--- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb
+++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb
@@ -944,6 +944,7 @@
"## 9) Using PyTorch's FlexAttention\n",
"\n",
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
+ "- FlexAttention caveat: It currently doesn't support dropout\n",
"- This is supported starting from PyTorch 2.5, which you can install on a CPU machine via\n",
"\n",
" ```bash\n",
@@ -1029,7 +1030,7 @@
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
- " use_dropout = 0. if not self.training else self.dropout\n",
+ " # use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
@@ -1967,7 +1968,7 @@
"provenance": []
},
"kernelspec": {
- "display_name": "pt",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -1981,7 +1982,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb
index f04c582c..2f25aca1 100644
--- a/ch04/01_main-chapter-code/ch04.ipynb
+++ b/ch04/01_main-chapter-code/ch04.ipynb
@@ -38,19 +38,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "matplotlib version: 3.9.0\n",
- "torch version: 2.4.0\n",
- "tiktoken version: 0.7.0\n"
+ "matplotlib version: 3.10.0\n",
+ "torch version: 2.6.0\n",
+ "tiktoken version: 0.9.0\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
- "import matplotlib\n",
- "import tiktoken\n",
- "import torch\n",
- "\n",
"print(\"matplotlib version:\", version(\"matplotlib\"))\n",
"print(\"torch version:\", version(\"torch\"))\n",
"print(\"tiktoken version:\", version(\"tiktoken\"))"
@@ -1540,7 +1536,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch04/01_main-chapter-code/exercise-solutions.ipynb b/ch04/01_main-chapter-code/exercise-solutions.ipynb
index 5e4c49a8..7f514d0d 100644
--- a/ch04/01_main-chapter-code/exercise-solutions.ipynb
+++ b/ch04/01_main-chapter-code/exercise-solutions.ipynb
@@ -45,7 +45,6 @@
"source": [
"from importlib.metadata import version\n",
"\n",
- "import torch\n",
"print(\"torch version:\", version(\"torch\"))"
]
},
@@ -452,7 +451,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch05/02_alternative_weight_loading/weight-loading-hf-safetensors.ipynb b/ch05/02_alternative_weight_loading/weight-loading-hf-safetensors.ipynb
index 91a10207..30ddb083 100644
--- a/ch05/02_alternative_weight_loading/weight-loading-hf-safetensors.ipynb
+++ b/ch05/02_alternative_weight_loading/weight-loading-hf-safetensors.ipynb
@@ -95,7 +95,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from previous_chapters import GPTModel, generate_text_simple"
+ "from previous_chapters import GPTModel"
]
},
{
@@ -242,7 +242,6 @@
"outputs": [],
"source": [
"import torch\n",
- "from previous_chapters import GPTModel\n",
"\n",
"\n",
"gpt = GPTModel(BASE_CONFIG)\n",
@@ -306,7 +305,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb b/ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb
index f267bafd..c632b254 100644
--- a/ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb
+++ b/ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb
@@ -217,8 +217,8 @@
" gpt.trf_blocks[b].norm2.scale = assign_check(gpt.trf_blocks[b].norm2.scale, d[f\"h.{b}.ln_2.weight\"])\n",
" gpt.trf_blocks[b].norm2.shift = assign_check(gpt.trf_blocks[b].norm2.shift, d[f\"h.{b}.ln_2.bias\"])\n",
" \n",
- " gpt.final_norm.scale = assign_check(gpt.final_norm.scale, d[f\"ln_f.weight\"])\n",
- " gpt.final_norm.shift = assign_check(gpt.final_norm.shift, d[f\"ln_f.bias\"])\n",
+ " gpt.final_norm.scale = assign_check(gpt.final_norm.scale, d[\"ln_f.weight\"])\n",
+ " gpt.final_norm.shift = assign_check(gpt.final_norm.shift, d[\"ln_f.bias\"])\n",
" gpt.out_head.weight = assign_check(gpt.out_head.weight, d[\"wte.weight\"])"
]
},
@@ -293,7 +293,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
index 4e211ba0..2f227228 100644
--- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
+++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb
@@ -1114,7 +1114,6 @@
},
"outputs": [],
"source": [
- "import os\n",
"from pathlib import Path\n",
"\n",
"import tiktoken\n",
@@ -2633,7 +2632,7 @@
"source": [
"weights_file = hf_hub_download(\n",
" repo_id=\"meta-llama/Llama-3.2-1B\",\n",
- " filename=f\"model.safetensors\",\n",
+ " filename=\"model.safetensors\",\n",
" local_dir=\"Llama-3.2-1B\"\n",
")\n",
"current_weights = load_file(weights_file)\n",
@@ -2747,7 +2746,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb
index d108df3a..3e49d5c1 100644
--- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb
+++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb
@@ -993,7 +993,7 @@
"if LLAMA_SIZE_STR == \"1B\":\n",
" weights_file = hf_hub_download(\n",
" repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n",
- " filename=f\"model.safetensors\",\n",
+ " filename=\"model.safetensors\",\n",
" local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n",
" )\n",
" combined_weights = load_file(weights_file)\n",
@@ -1213,7 +1213,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
diff --git a/ch06/01_main-chapter-code/ch06.ipynb b/ch06/01_main-chapter-code/ch06.ipynb
index 96749477..04721f5a 100644
--- a/ch06/01_main-chapter-code/ch06.ipynb
+++ b/ch06/01_main-chapter-code/ch06.ipynb
@@ -79,28 +79,6 @@
"
"
]
},
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "946c3e56-b04b-4b0f-b35f-b485ce5b28df",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Utility to prevent certain cells from being executed twice\n",
- "\n",
- "from IPython.core.magic import register_line_cell_magic\n",
- "\n",
- "executed_cells = set()\n",
- "\n",
- "@register_line_cell_magic\n",
- "def run_once(line, cell):\n",
- " if line not in executed_cells:\n",
- " get_ipython().run_cell(cell)\n",
- " executed_cells.add(line)\n",
- " else:\n",
- " print(f\"Cell '{line}' has already been executed.\")"
- ]
- },
{
"cell_type": "markdown",
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c",
@@ -450,9 +428,6 @@
}
],
"source": [
- "%%run_once balance_df\n",
- "\n",
- "\n",
"def create_balanced_dataset(df):\n",
" \n",
" # Count the instances of \"spam\"\n",
@@ -490,7 +465,6 @@
},
"outputs": [],
"source": [
- "%%run_once label_mapping\n",
"balanced_df[\"Label\"] = balanced_df[\"Label\"].map({\"ham\": 0, \"spam\": 1}) "
]
},
diff --git a/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb b/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb
index 9b00d048..4529c816 100644
--- a/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb
+++ b/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb
@@ -190,13 +190,13 @@
" \n",
" # Calculating accuracy and balanced accuracy\n",
" accuracy_train = accuracy_score(y_train, y_pred_train)\n",
- " balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)\n",
+ " # balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)\n",
" \n",
" accuracy_val = accuracy_score(y_val, y_pred_val)\n",
- " balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)\n",
+ " # balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)\n",
"\n",
" accuracy_test = accuracy_score(y_test, y_pred_test)\n",
- " balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)\n",
+ " # balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)\n",
" \n",
" # Printing the results\n",
" print(f\"Training Accuracy: {accuracy_train*100:.2f}%\")\n",
@@ -269,7 +269,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch07/01_main-chapter-code/exercise-solutions.ipynb b/ch07/01_main-chapter-code/exercise-solutions.ipynb
index b054203f..8dfe0539 100644
--- a/ch07/01_main-chapter-code/exercise-solutions.ipynb
+++ b/ch07/01_main-chapter-code/exercise-solutions.ipynb
@@ -144,12 +144,11 @@
]
},
{
- "cell_type": "code",
- "execution_count": 3,
- "id": "17f1a42c-7cc0-4746-8a6d-3a4cb37e2ca1",
+ "cell_type": "markdown",
+ "id": "81f0d9c8-8f41-4455-b9ae-6b17de610cc3",
"metadata": {},
- "outputs": [],
"source": [
+ "```python\n",
"import tiktoken\n",
"from torch.utils.data import Dataset\n",
"\n",
@@ -178,7 +177,8 @@
" return len(self.data)\n",
"\n",
"\n",
- "tokenizer = tiktoken.get_encoding(\"gpt2\")"
+ "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
+ "```"
]
},
{
@@ -1017,7 +1017,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.4"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch07/03_model-evaluation/llm-instruction-eval-openai.ipynb b/ch07/03_model-evaluation/llm-instruction-eval-openai.ipynb
index 01de4fd1..c099b1b2 100644
--- a/ch07/03_model-evaluation/llm-instruction-eval-openai.ipynb
+++ b/ch07/03_model-evaluation/llm-instruction-eval-openai.ipynb
@@ -170,7 +170,7 @@
" return response.choices[0].message.content\n",
"\n",
"\n",
- "prompt = f\"Respond with 'hello world' if you got this message.\"\n",
+ "prompt = \"Respond with 'hello world' if you got this message.\"\n",
"run_chatgpt(prompt, client)"
]
},
@@ -563,7 +563,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.11"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/ch07/05_dataset-generation/reflection-gpt4.ipynb b/ch07/05_dataset-generation/reflection-gpt4.ipynb
index 00b538b1..ce24d27e 100644
--- a/ch07/05_dataset-generation/reflection-gpt4.ipynb
+++ b/ch07/05_dataset-generation/reflection-gpt4.ipynb
@@ -200,7 +200,7 @@
" return response.choices[0].message.content\n",
"\n",
"\n",
- "prompt = f\"Respond with 'hello world' if you got this message.\"\n",
+ "prompt = \"Respond with 'hello world' if you got this message.\"\n",
"run_chatgpt(prompt, client)"
]
},
@@ -1058,7 +1058,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.6"
+ "version": "3.10.16"
}
},
"nbformat": 4,
diff --git a/pyproject.toml b/pyproject.toml
index ad8c5a83..130b7ac0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,12 +9,12 @@ dependencies = [
"jupyterlab>=4.0",
"tiktoken>=0.5.1",
"matplotlib>=3.7.1",
- "tensorflow>=2.18.0",
+ "tensorflow>=2.18.0; sys_platform != \"win32\"",
+ "tensorflow-cpu>=2.18.0; sys_platform == \"win32\"",
"tqdm>=4.66.1",
"numpy>=1.26,<2.1",
"pandas>=2.2.1",
- "psutil>=5.9.5",
- "packaging>=24.2",
+ "pip>=25.0.1",
]
[tool.setuptools.packages]
@@ -27,3 +27,14 @@ llms-from-scratch = { workspace = true }
dev = [
"llms-from-scratch",
]
+
+[tool.ruff]
+line-length = 140
+
+[tool.ruff.lint]
+exclude = [".venv"]
+# Ignored rules (W504 removed)
+ignore = [
+ "C406", "E226", "E402", "E702", "E703",
+ "E722", "E731", "E741"
+]
diff --git a/requirements.txt b/requirements.txt
index f1bbb7b7..60d486a6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,10 @@
-torch >= 2.3.0 # all
-jupyterlab >= 4.0 # all
-tiktoken >= 0.5.1 # ch02; ch04; ch05
-matplotlib >= 3.7.1 # ch04; ch05
-tensorflow >= 2.18.0 # ch05
-tqdm >= 4.66.1 # ch05; ch07
-numpy >= 1.26, < 2.1 # dependency of several other libraries like torch and pandas
-pandas >= 2.2.1 # ch06
-psutil >= 5.9.5 # ch07; already installed automatically as dependency of torch
+torch >= 2.3.0 # all
+jupyterlab >= 4.0 # all
+tiktoken >= 0.5.1 # ch02; ch04; ch05
+matplotlib >= 3.7.1 # ch04; ch05
+tensorflow>=2.18.0; sys_platform != "win32" # ch05 (non-Windows)
+tensorflow-cpu>=2.18.0; sys_platform == "win32" # ch05 (Windows)
+tqdm >= 4.66.1 # ch05; ch07
+numpy >= 1.26, < 2.1 # dependency of several other libraries like torch and pandas
+pandas >= 2.2.1 # ch06
+psutil >= 5.9.5 # ch07; already installed automatically as dependency of torch
diff --git a/setup/01_optional-python-setup-preferences/README.md b/setup/01_optional-python-setup-preferences/README.md
index 312461a0..a7a08fee 100644
--- a/setup/01_optional-python-setup-preferences/README.md
+++ b/setup/01_optional-python-setup-preferences/README.md
@@ -22,7 +22,7 @@ This section guides you through the Python setup and package installation proced
>
> If you prefer the native `uv` commands, refer to the [./native-uv.md tutorial](./native-uv.md). I also recommend checking the official [`uv` documentation](https://docs.astral.sh/uv/).
>
-> While `uv add` offers speed advantages, I find `uv pip` slightly more user-friendly, making it a good starting point for beginners. However, if you're new to Python package management, the native `uv` interface is also a great way to learn.
+> While `uv add` offers additional speed advantages, I think that `uv pip` is slightly more user-friendly, making it a good starting point for beginners. However, if you're new to Python package management, the native `uv` interface is also a great opportunity to learn it from the start. It's also how I use `uv` now, but I realize it the barrier to entry is a bit higher if you are coming from `pip` and `conda`.
@@ -146,6 +146,10 @@ uv pip install -U -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/r
+> [!NOTE]
+> If you have problems with the following commands above due to certain dependencies (for example, if you are using Windows), you can always fall back to using regular pip:
+> `pip install -r requirements.txt`
+
**Finalizing the setup**
diff --git a/setup/01_optional-python-setup-preferences/native-uv.md b/setup/01_optional-python-setup-preferences/native-uv.md
index 1074b893..66bb1989 100644
--- a/setup/01_optional-python-setup-preferences/native-uv.md
+++ b/setup/01_optional-python-setup-preferences/native-uv.md
@@ -2,7 +2,7 @@
This tutorial is an alternative to *Option 1: Using uv* in the [README.md](./README.md) document for those who prefer `uv`'s native commands over the `uv pip` interface. While `uv pip` is faster than pure `pip`, `uv`'s native interface is even faster than `uv pip` as it has less overhead and doesn't have to handle legacy support for PyPy package dependency management.
-The table below provides a comparison of the speeds of different dependency and package management approaches. The speed comparison specifically refers to package dependency resolution during installation, not the runtime performance of the installed packages. Note that ackage installation is a one-time process for this project, so it is reasonable to choose the preferred approach by overall convenience, not just installation speed.
+The table below provides a comparison of the speeds of different dependency and package management approaches. The speed comparison specifically refers to package dependency resolution during installation, not the runtime performance of the installed packages. Note that package installation is a one-time process for this project, so it is reasonable to choose the preferred approach by overall convenience, not just installation speed.
| Command | Speed Comparison |
@@ -74,9 +74,15 @@ To install all required packages from a `pyproject.toml` file (such as the one l
uv add . --dev
```
+> [!NOTE]
+> If you have problems with the following commands above due to certain dependencies (for example, if you are using Windows), you can always fall back to regular pip:
+> `uv add pip`
+> `uv run python -m pip install -U -r requirements.txt`
+
+
-Note that the `uv add` command above will create a separate virtual environment via the `.venv` subfolder.
+Note that the `uv add` command above will create a separate virtual environment via the `.venv` subfolder. (In case you want to delete your virtual environment to start from scratch, you can simply delete the `.venv` folder.)
You can install new packages, that are not specified in the `pyproject.toml` via `uv add`, for example:
@@ -84,26 +90,34 @@ You can install new packages, that are not specified in the `pyproject.toml` via
uv add packaging
```
-
-## Optional: Manage virtual environments manually
+And you can remove packages via `uv remove`, for example,
-Alternatively, you can still install the dependencies directly from the repository using `uv pip install`. Note that this requires creating and activating the virtual environment manually:
+```bash
+uv remove packaging
+```
+
+
+
+
+## 3. Run Python code
-**1. Create a new virtual environment**
+Your environment should now be ready to run the code in the repository.
-Run the following command to manually create a new virtual environment, which will be saved via a new `.venv` subfolder:
+Optionally, you can run an environment check by executing the `python_environment_check.py` script in this repository:
```bash
-uv venv --python=python3.10
+uv run python setup/02_installing-python-libraries/python_environment_check.py
```
-
-**2. Activate virtual environment**
-Next, we need to activate this new virtual environment.
+
+
+
+
+Or, if you don't want to type `uv run python` ever time you execute code, manually activate the virtual environment first.
On macOS/Linux:
@@ -117,40 +131,38 @@ On Windows (PowerShell):
.venv\Scripts\activate
```
-
-
-**3. Install dependencies**
+Then, run:
-Finally, we can install dependencies from a remote location using the `uv pip` interface:
```bash
-uv pip install -U -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
+python setup/02_installing-python-libraries/python_environment_check.py
```
-
-
-
-## 4. Run Python code
-
-**Finalizing the setup**
-
-Your environment should now be ready to run the code in the repository.
+**Launching JupyterLab**
-Optionally, you can run an environment check by executing the `python_environment_check.py` script in this repository:
+You can launch a JupyterLab instance via:
```bash
-uv run python setup/02_installing-python-libraries/python_environment_check.py
+uv run jupyter lab
```
+**Skipping the `uv run` command**
+If you find typing `uv run` cumbersome and want to run scripts via
-
+```bash
+python script.py
+```
+and launch JupyterLab via
+```bash
+juputer lab
+```
-Or, if you don't want to type `uv run python` ever time you execute code, manually activate the virtual environment first.
+instead, you can activated the environment manually.
On macOS/Linux:
@@ -164,26 +176,52 @@ On Windows (PowerShell):
.venv\Scripts\activate
```
-Then, run:
+
+
+## Optional: Manage virtual environments manually
+
+Alternatively, you can still install the dependencies directly from the repository using `uv pip install`. But note that this doesn't record dependencies in a `uv.lock` file as `uv add` does. Also, it requires creating and activating the virtual environment manually:
+
+
+
+**1. Create a new virtual environment**
+
+Run the following command to manually create a new virtual environment, which will be saved via a new `.venv` subfolder:
```bash
-python setup/02_installing-python-libraries/python_environment_check.py
+uv venv --python=python3.10
```
-**Launching JupyterLab**
+**2. Activate virtual environment**
-You can launch a JupyterLab instance via:
+Next, we need to activate this new virtual environment.
+
+On macOS/Linux:
```bash
-uv run jupyter lab
+source .venv/bin/activate
+```
+
+On Windows (PowerShell):
+
+```bash
+.venv\Scripts\activate
+```
+
+
+
+**3. Install dependencies**
+
+Finally, we can install dependencies from a remote location using the `uv pip` interface:
+
+```bash
+uv pip install -U -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/requirements.txt
```
-Or, if you manually activated the environment as described earlier, you can drop the `uv run` prefix.
-
---
diff --git a/setup/02_installing-python-libraries/python_environment_check.py b/setup/02_installing-python-libraries/python_environment_check.py
index 8c785ada..3c47e4cc 100644
--- a/setup/02_installing-python-libraries/python_environment_check.py
+++ b/setup/02_installing-python-libraries/python_environment_check.py
@@ -3,99 +3,100 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
-from importlib.metadata import PackageNotFoundError, import_module
-import importlib.metadata
+
+from importlib.metadata import PackageNotFoundError, import_module, version as get_version
from os.path import dirname, exists, join, realpath
from packaging.version import parse as version_parse
+from packaging.requirements import Requirement
+from packaging.specifiers import SpecifierSet
import platform
import sys
if version_parse(platform.python_version()) < version_parse("3.9"):
- print("[FAIL] We recommend Python 3.9 or newer but"
- " found version %s" % (sys.version))
+ print("[FAIL] We recommend Python 3.9 or newer but found version %s" % sys.version)
else:
- print("[OK] Your Python version is %s" % (platform.python_version()))
+ print("[OK] Your Python version is %s" % platform.python_version())
def get_packages(pkgs):
- versions = []
+ """
+ Returns a dictionary mapping package names (in lowercase) to their installed version.
+ """
+ result = {}
for p in pkgs:
try:
+ # Try to import the package
imported = import_module(p)
try:
- version = (getattr(imported, "__version__", None) or
- getattr(imported, "version", None) or
- getattr(imported, "version_info", None))
+ version = getattr(imported, "__version__", None)
if version is None:
- # If common attributes don"t exist, use importlib.metadata
- version = importlib.metadata.version(p)
- versions.append(version)
+ version = get_version(p)
+ result[p.lower()] = version
except PackageNotFoundError:
- # Handle case where package is not installed
- versions.append("0.0")
+ result[p.lower()] = "0.0"
except ImportError:
- # Fallback if importlib.import_module fails for unexpected reasons
- versions.append("0.0")
- return versions
+ result[p.lower()] = "0.0"
+ return result
def get_requirements_dict():
+ """
+ Parses requirements.txt and returns a dictionary mapping package names (lowercase)
+ to a specifier string (e.g. ">=2.18.0,<3.0"). It uses packaging.requirements.Requirement
+ to properly handle environment markers.
+ """
+
PROJECT_ROOT = dirname(realpath(__file__))
PROJECT_ROOT_UP_TWO = dirname(dirname(PROJECT_ROOT))
REQUIREMENTS_FILE = join(PROJECT_ROOT_UP_TWO, "requirements.txt")
if not exists(REQUIREMENTS_FILE):
REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt")
- d = {}
+ reqs = {}
with open(REQUIREMENTS_FILE) as f:
for line in f:
- if not line.strip():
+ # Remove inline comments and trailing whitespace.
+ # This splits on the first '#' and takes the part before it.
+ line = line.split("#", 1)[0].strip()
+ if not line:
continue
- if "," in line:
- left, right = line.split(",")
- lower = right.split("#")[0].strip()
- package, _, upper = left.split(" ")
- package = package.strip()
- _, lower = lower.split(" ")
- lower = lower.strip()
- upper = upper.strip()
- d[package] = (upper, lower)
- else:
- line = line.split("#")[0].strip()
- line = line.split(" ")
- line = [ln.strip() for ln in line]
- d[line[0]] = line[-1]
- return d
-
+ try:
+ req = Requirement(line)
+ except Exception as e:
+ print(f"Skipping line due to parsing error: {line} ({e})")
+ continue
+ # Evaluate the marker if present.
+ if req.marker is not None and not req.marker.evaluate():
+ continue
+ # Store the package name and its version specifier.
+ spec = str(req.specifier) if req.specifier else ">=0"
+ reqs[req.name.lower()] = spec
+ return reqs
-def check_packages(d):
- versions = get_packages(d.keys())
- for (pkg_name, suggested_ver), actual_ver in zip(d.items(), versions):
- if isinstance(suggested_ver, tuple):
- lower, upper = suggested_ver[0], suggested_ver[1]
- else:
- lower = suggested_ver
- upper = None
+def check_packages(reqs):
+ """
+ Checks the installed versions of packages against the requirements.
+ """
+ installed = get_packages(reqs.keys())
+ for pkg_name, spec_str in reqs.items():
+ spec_set = SpecifierSet(spec_str)
+ actual_ver = installed.get(pkg_name, "0.0")
if actual_ver == "N/A":
continue
- actual_ver = version_parse(actual_ver)
- lower = version_parse(lower)
- if upper is not None:
- upper = version_parse(upper)
- if actual_ver < lower and upper is None:
- print(f"[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {lower}")
- elif actual_ver < lower:
- print(f"[FAIL] {pkg_name} {actual_ver}, please upgrade to >= {lower} and < {upper}")
- elif upper is not None and actual_ver >= upper:
- print(f"[FAIL] {pkg_name} {actual_ver}, please downgrade to >= {lower} and < {upper}")
+ actual_ver_parsed = version_parse(actual_ver)
+ # If the installed version is a pre-release, allow pre-releases in the specifier.
+ if actual_ver_parsed.is_prerelease:
+ spec_set.prereleases = True
+ if actual_ver_parsed not in spec_set:
+ print(f"[FAIL] {pkg_name} {actual_ver_parsed}, please install a version matching {spec_set}")
else:
- print(f"[OK] {pkg_name} {actual_ver}")
+ print(f"[OK] {pkg_name} {actual_ver_parsed}")
def main():
- d = get_requirements_dict()
- check_packages(d)
+ reqs = get_requirements_dict()
+ check_packages(reqs)
if __name__ == "__main__":