Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add unit test
Browse files Browse the repository at this point in the history
benieric committed Jan 30, 2025
1 parent 0bd5104 commit 0eb8359
Showing 1 changed file with 50 additions and 1 deletion.
51 changes: 50 additions & 1 deletion tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,10 @@
from sagemaker import image_uris
from sagemaker.pytorch import defaults
from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel
from sagemaker.pytorch.estimator import _get_training_recipe_image_uri
from sagemaker.pytorch.estimator import (
_get_training_recipe_image_uri,
_get_training_recipe_gpu_script,
)
from sagemaker.instance_group import InstanceGroup
from sagemaker.session_settings import SessionSettings

@@ -1049,6 +1052,52 @@ def test_training_recipe_for_trainium(sagemaker_session):
assert pytorch.distribution == expected_distribution


@pytest.mark.parametrize(
"test_case",
[
{
"script": "llama_pretrain.py",
"recipe": {
"model": {
"model_type": "llama_v3",
},
},
},
{
"script": "mistral_pretrain.py",
"recipe": {
"model": {
"model_type": "mistral",
},
},
},
{
"script": "deepseek_pretrain.py",
"recipe": {
"model": {
"model_type": "deepseek_llamav3",
},
},
},
{
"script": "deepseek_pretrain.py",
"recipe": {
"model": {
"model_type": "deepseek_qwenv2",
},
},
},
],
)
@patch("shutil.copyfile")
def test_get_training_recipe_gpu_script(mock_copyfile, test_case):
script = test_case["script"]
recipe = test_case["recipe"]
mock_copyfile.return_value = None

assert _get_training_recipe_gpu_script("code_dir", recipe, "source_dir") == script


def test_training_recipe_for_trainium_custom_source_dir(sagemaker_session):
container_log_level = '"logging.INFO"'

0 comments on commit 0eb8359

Please sign in to comment.