Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Add unit and E2E tests for model and dataset initializers #2323

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ jobs:
env:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}

- name: Run initializer_v2 e2e for Python 3.11+
if: ${{ matrix.python-version == '3.11' }}
run: |
pip install urllib3 huggingface_hub
pip install -U './sdk_v2'
pytest ./test/e2e/initializer_v2

- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
run: |
Expand Down
11 changes: 10 additions & 1 deletion .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,13 @@ jobs:
pip install -U './sdk/python[huggingface]'

- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
run: |
pytest ./sdk/python/kubeflow/training/api/training_client_test.py

- name: Run Python unit tests for v2
run: |
pip install -U './sdk_v2'
export PYTHONPATH="${{ github.workspace }}:$PYTHONPATH"
pytest ./pkg/initializer_v2/model
pytest ./pkg/initializer_v2/dataset
pytest ./pkg/initializer_v2/utils
7 changes: 6 additions & 1 deletion pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting dataset initialization")

try:
Expand All @@ -29,3 +30,7 @@
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception


if __name__ == "__main__":
main()
143 changes: 143 additions & 0 deletions pkg/initializer_v2/dataset/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils


@pytest.fixture
def huggingface_dataset_instance():
"""Fixture for HuggingFace Dataset instance"""
from pkg.initializer_v2.dataset.huggingface import HuggingFace

return HuggingFace()


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://dataset/path"},
{"storage_uri": "hf://dataset/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected, huggingface_dataset_instance):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_dataset_instance.load_config()
assert (
huggingface_dataset_instance.config.storage_uri == expected["storage_uri"]
)
assert (
huggingface_dataset_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
"mock_login_side_effect": None,
"mock_download_side_effect": None,
"expected_error": None,
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/dataset-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/dataset-v1",
"mock_login_side_effect": None,
"mock_download_side_effect": None,
"expected_error": None,
},
),
(
"Login failure",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
"mock_login_side_effect": Exception,
"mock_download_side_effect": None,
"expected_error": Exception,
},
),
(
"Download failure",
{
"config": {
"storage_uri": "hf://invalid/repo/name",
"access_token": None,
},
"should_login": False,
"expected_repo_id": "invalid/repo/name",
"mock_login_side_effect": None,
"mock_download_side_effect": Exception,
"expected_error": Exception,
},
),
],
)
def test_download_dataset(test_name, test_case, huggingface_dataset_instance):
"""Test dataset download with different configurations"""

print(f"Running test: {test_name}")

huggingface_dataset_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Configure mock behavior
if test_case["mock_login_side_effect"]:
mock_login.side_effect = test_case["mock_login_side_effect"]
if test_case["mock_download_side_effect"]:
mock_download.side_effect = test_case["mock_download_side_effect"]

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
huggingface_dataset_instance.download_dataset()
else:
huggingface_dataset_instance.download_dataset()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=DATASET_PATH,
repo_type="dataset",
)
print("Test execution completed")
122 changes: 122 additions & 0 deletions pkg/initializer_v2/dataset/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"mock_config_error": False,
"mock_download_error": False,
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Config loading failure",
{
"storage_uri": "hf://dataset/path",
"access_token": None,
"mock_config_error": True,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Dataset download failure",
{
"storage_uri": "hf://dataset/path/error",
"access_token": None,
"mock_config_error": False,
"mock_download_error": True,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()
if test_case["mock_config_error"]:
mock_hf_instance.load_config.side_effect = Exception
if test_case["mock_download_error"]:
mock_hf_instance.download_dataset.side_effect = Exception

with patch(
"pkg.initializer_v2.dataset.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_dataset.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
7 changes: 6 additions & 1 deletion pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting pre-trained model initialization")

try:
Expand All @@ -31,3 +32,7 @@
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception


if __name__ == "__main__":
main()
Loading
Loading