Skip to content

Commit c6e0a83

Browse files
committed
KEP-2170: Add unit and E2E tests for model and dataset initializers
Signed-off-by: wei-chenglai <[email protected]>
1 parent 95be3c0 commit c6e0a83

17 files changed

+780
-2
lines changed

.github/workflows/integration-tests.yaml

+5-1
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,15 @@ jobs:
9595

9696
- name: Run tests
9797
run: |
98-
pip install pytest
98+
pip install pytest urllib3 huggingface_hub
9999
python3 -m pip install -e sdk/python; pytest -s sdk/python/test --log-cli-level=debug --namespace=default
100100
env:
101101
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
102102

103+
- name: Run specific tests for Python 3.10+
104+
if: ${{ matrix.python-version == '3.10' || matrix.python-version == '3.11' }}
105+
run: pytest pkg/initializer_v2/test/e2e
106+
103107
- name: Collect volcano logs
104108
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
105109
run: |

.github/workflows/test-python.yaml

+3-1
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,6 @@ jobs:
3232
pip install -U './sdk/python[huggingface]'
3333
3434
- name: Run unit test for training sdk
35-
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
35+
run: |
36+
pytest ./sdk/python/kubeflow/training/api/training_client_test.py
37+
pytest ./pkg/initializer_v2/test/unit

pkg/initializer_v2/test/__init__.py

Whitespace-only changes.

pkg/initializer_v2/test/conftest.py

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import sys
3+
4+
import pytest
5+
6+
# Add project root to path if needed
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
8+
9+
10+
@pytest.fixture
11+
def mock_env_vars():
12+
"""Fixture to set and clean up environment variables"""
13+
original_env = dict(os.environ)
14+
15+
def _set_env_vars(**kwargs):
16+
for key, value in kwargs.items():
17+
if value is None:
18+
os.environ.pop(key, None)
19+
else:
20+
os.environ[key] = str(value)
21+
return os.environ
22+
23+
yield _set_env_vars
24+
25+
# Cleanup
26+
os.environ.clear()
27+
os.environ.update(original_env)
28+
29+
30+
@pytest.fixture
31+
def huggingface_model_instance():
32+
"""Fixture for HuggingFace Model instance"""
33+
from pkg.initializer_v2.model.huggingface import HuggingFace
34+
35+
return HuggingFace()
36+
37+
38+
@pytest.fixture
39+
def huggingface_dataset_instance():
40+
"""Fixture for HuggingFace Dataset instance"""
41+
from pkg.initializer_v2.dataset.huggingface import HuggingFace
42+
43+
return HuggingFace()
44+
45+
46+
@pytest.fixture
47+
def real_hf_token():
48+
"""Fixture to provide real HuggingFace token for E2E tests"""
49+
token = os.getenv("HUGGINGFACE_TOKEN")
50+
# if not token:
51+
# pytest.skip("HUGGINGFACE_TOKEN environment variable not set")
52+
return token

pkg/initializer_v2/test/e2e/__init__.py

Whitespace-only changes.
+107
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import os
2+
import runpy
3+
import shutil
4+
import tempfile
5+
6+
import pytest
7+
8+
import pkg.initializer_v2.utils.utils as utils
9+
from sdk.python.kubeflow.storage_initializer.constants import VOLUME_PATH_DATASET
10+
11+
12+
class TestDatasetE2E:
13+
"""E2E tests for dataset initialization"""
14+
15+
@pytest.fixture(autouse=True)
16+
def setup_teardown(self, monkeypatch):
17+
"""Setup and teardown for each test"""
18+
# Create temporary directory for dataset downloads
19+
current_dir = os.path.dirname(os.path.abspath(__file__))
20+
self.temp_dir = tempfile.mkdtemp(dir=current_dir)
21+
os.environ[VOLUME_PATH_DATASET] = self.temp_dir
22+
23+
# Store original environment
24+
self.original_env = dict(os.environ)
25+
26+
# Monkeypatch the constant in the module
27+
import sdk.python.kubeflow.storage_initializer.constants as constants
28+
29+
monkeypatch.setattr(constants, "VOLUME_PATH_DATASET", self.temp_dir)
30+
31+
yield
32+
33+
# Cleanup
34+
shutil.rmtree(self.temp_dir, ignore_errors=True)
35+
os.environ.clear()
36+
os.environ.update(self.original_env)
37+
38+
def verify_dataset_files(self, expected_files):
39+
"""Verify downloaded dataset files"""
40+
if expected_files:
41+
actual_files = set(os.listdir(self.temp_dir))
42+
missing_files = set(expected_files) - actual_files
43+
assert not missing_files, f"Missing expected files: {missing_files}"
44+
45+
@pytest.mark.parametrize(
46+
"test_name, provider, test_case",
47+
[
48+
# Public HuggingFace dataset test
49+
(
50+
"HuggingFace - Public dataset",
51+
"huggingface",
52+
{
53+
"storage_uri": "hf://karpathy/tiny_shakespeare",
54+
"access_token": None,
55+
"expected_files": ["tiny_shakespeare.py"],
56+
"expected_error": None,
57+
},
58+
),
59+
# Private HuggingFace dataset test
60+
# (
61+
# "HuggingFace - Private dataset",
62+
# "huggingface",
63+
# {
64+
# "storage_uri": "hf://username/private-dataset",
65+
# "use_real_token": True,
66+
# "expected_files": ["config.json", "dataset.safetensors"],
67+
# "expected_error": None
68+
# }
69+
# ),
70+
# Invalid HuggingFace dataset test
71+
(
72+
"HuggingFace - Invalid dataset",
73+
"huggingface",
74+
{
75+
"storage_uri": "hf://invalid/nonexistent-dataset",
76+
"access_token": None,
77+
"expected_files": None,
78+
"expected_error": Exception,
79+
},
80+
),
81+
],
82+
)
83+
def test_dataset_download(self, test_name, provider, test_case, real_hf_token):
84+
"""Test end-to-end dataset download for different providers"""
85+
print(f"Running E2E test for {provider}: {test_name}")
86+
87+
# Setup environment variables based on test case
88+
os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"]
89+
expected_files = test_case.get("expected_files")
90+
91+
# Handle token/credentials
92+
if test_case.get("use_real_token"):
93+
os.environ["ACCESS_TOKEN"] = real_hf_token
94+
elif test_case.get("access_token"):
95+
os.environ["ACCESS_TOKEN"] = test_case["access_token"]
96+
97+
# Run the main script
98+
if test_case["expected_error"]:
99+
with pytest.raises(test_case["expected_error"]):
100+
runpy.run_module(
101+
"pkg.initializer_v2.dataset.__main__", run_name="__main__"
102+
)
103+
else:
104+
runpy.run_module("pkg.initializer_v2.dataset.__main__", run_name="__main__")
105+
self.verify_dataset_files(expected_files)
106+
107+
print("Test execution completed")
+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import os
2+
import runpy
3+
import shutil
4+
import tempfile
5+
6+
import pytest
7+
8+
import pkg.initializer_v2.utils.utils as utils
9+
from sdk.python.kubeflow.storage_initializer.constants import VOLUME_PATH_MODEL
10+
11+
12+
class TestModelE2E:
13+
"""E2E tests for model initialization"""
14+
15+
@pytest.fixture(autouse=True)
16+
def setup_teardown(self, monkeypatch):
17+
"""Setup and teardown for each test"""
18+
# Create temporary directory for model downloads
19+
current_dir = os.path.dirname(os.path.abspath(__file__))
20+
self.temp_dir = tempfile.mkdtemp(dir=current_dir)
21+
print(self.temp_dir)
22+
os.environ[VOLUME_PATH_MODEL] = self.temp_dir
23+
24+
# Store original environment
25+
self.original_env = dict(os.environ)
26+
27+
# Monkeypatch the constant in the module
28+
import sdk.python.kubeflow.storage_initializer.constants as constants
29+
30+
monkeypatch.setattr(constants, "VOLUME_PATH_MODEL", self.temp_dir)
31+
32+
yield
33+
34+
# Cleanup
35+
shutil.rmtree(self.temp_dir, ignore_errors=True)
36+
os.environ.clear()
37+
os.environ.update(self.original_env)
38+
39+
def verify_model_files(self, expected_files):
40+
"""Verify downloaded model files"""
41+
if expected_files:
42+
actual_files = set(os.listdir(self.temp_dir))
43+
missing_files = set(expected_files) - actual_files
44+
assert not missing_files, f"Missing expected files: {missing_files}"
45+
46+
@pytest.mark.parametrize(
47+
"test_name, provider, test_case",
48+
[
49+
# Public HuggingFace model test
50+
(
51+
"HuggingFace - Public model",
52+
"huggingface",
53+
{
54+
"storage_uri": "hf://hf-internal-testing/tiny-random-bert",
55+
"access_token": None,
56+
"expected_files": [
57+
"config.json",
58+
"model.safetensors",
59+
"tokenizer.json",
60+
"tokenizer_config.json",
61+
],
62+
"expected_error": None,
63+
},
64+
),
65+
# Private HuggingFace model test
66+
# (
67+
# "HuggingFace - Private model",
68+
# "huggingface",
69+
# {
70+
# "storage_uri": "hf://username/private-model",
71+
# "use_real_token": True,
72+
# "expected_files": ["config.json", "model.safetensors"],
73+
# "expected_error": None
74+
# }
75+
# ),
76+
# Invalid HuggingFace model test
77+
(
78+
"HuggingFace - Invalid model",
79+
"huggingface",
80+
{
81+
"storage_uri": "hf://invalid/nonexistent-model",
82+
"access_token": None,
83+
"expected_files": None,
84+
"expected_error": Exception,
85+
},
86+
),
87+
],
88+
)
89+
def test_model_download(self, test_name, provider, test_case, real_hf_token):
90+
"""Test end-to-end model download for different providers"""
91+
print(f"Running E2E test for {provider}: {test_name}")
92+
93+
# Setup environment variables based on test case
94+
os.environ[utils.STORAGE_URI_ENV] = test_case["storage_uri"]
95+
expected_files = test_case.get("expected_files")
96+
97+
# Handle token/credentials
98+
if test_case.get("use_real_token"):
99+
os.environ["ACCESS_TOKEN"] = real_hf_token
100+
elif test_case.get("access_token"):
101+
os.environ["ACCESS_TOKEN"] = test_case["access_token"]
102+
103+
# Run the main script
104+
if test_case["expected_error"]:
105+
with pytest.raises(test_case["expected_error"]):
106+
runpy.run_module(
107+
"pkg.initializer_v2.model.__main__", run_name="__main__"
108+
)
109+
else:
110+
runpy.run_module("pkg.initializer_v2.model.__main__", run_name="__main__")
111+
self.verify_model_files(expected_files)
112+
113+
print("Test execution completed")

pkg/initializer_v2/test/unit/__init__.py

Whitespace-only changes.

pkg/initializer_v2/test/unit/dataset/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import runpy
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
6+
7+
@pytest.mark.parametrize(
8+
"test_name, test_case",
9+
[
10+
(
11+
"Successful download with HuggingFace provider",
12+
{
13+
"storage_uri": "hf://dataset/path",
14+
"access_token": "test_token",
15+
"mock_config_error": False,
16+
"expected_error": None,
17+
},
18+
),
19+
(
20+
"Missing storage URI environment variable",
21+
{
22+
"storage_uri": None,
23+
"access_token": None,
24+
"mock_config_error": False,
25+
"expected_error": Exception,
26+
},
27+
),
28+
(
29+
"Invalid storage URI scheme",
30+
{
31+
"storage_uri": "invalid://dataset/path",
32+
"access_token": None,
33+
"mock_config_error": False,
34+
"expected_error": Exception,
35+
},
36+
),
37+
(
38+
"Config loading failure",
39+
{
40+
"storage_uri": "hf://dataset/path",
41+
"access_token": None,
42+
"mock_config_error": True,
43+
"expected_error": Exception,
44+
},
45+
),
46+
],
47+
)
48+
def test_dataset_main(test_name, test_case, mock_env_vars):
49+
"""Test main script with different scenarios"""
50+
print(f"Running test: {test_name}")
51+
52+
# Setup mock environment variables
53+
env_vars = {
54+
"STORAGE_URI": test_case["storage_uri"],
55+
"ACCESS_TOKEN": test_case["access_token"],
56+
}
57+
mock_env_vars(**env_vars)
58+
59+
# Setup mock HuggingFace instance
60+
mock_hf_instance = MagicMock()
61+
if test_case["mock_config_error"]:
62+
mock_hf_instance.load_config.side_effect = Exception
63+
64+
with patch(
65+
"pkg.initializer_v2.dataset.huggingface.HuggingFace",
66+
return_value=mock_hf_instance,
67+
) as mock_hf:
68+
69+
# Execute test
70+
if test_case["expected_error"]:
71+
with pytest.raises(test_case["expected_error"]):
72+
runpy.run_module(
73+
"pkg.initializer_v2.dataset.__main__", run_name="__main__"
74+
)
75+
else:
76+
runpy.run_module("pkg.initializer_v2.dataset.__main__", run_name="__main__")
77+
78+
# Verify HuggingFace instance methods were called
79+
mock_hf_instance.load_config.assert_called_once()
80+
mock_hf_instance.download_dataset.assert_called_once()
81+
82+
# Verify HuggingFace class instantiation
83+
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
84+
mock_hf.assert_called_once()
85+
86+
print("Test execution completed")

0 commit comments

Comments
 (0)