diff --git a/optimum/commands/neuron/cache.py b/optimum/commands/neuron/cache.py
index 9cc3bf425..b3beab853 100644
--- a/optimum/commands/neuron/cache.py
+++ b/optimum/commands/neuron/cache.py
@@ -130,7 +130,7 @@ def parse_args(parser: "ArgumentParser"):
"--example_dir", type=str, default=None, help="Path to where the example scripts are stored."
)
parser.add_argument(
- "--max_steps", type=int, default=200, help="The maximum number of steps to run compilation for."
+ "--max_steps", type=int, default=10, help="The maximum number of steps to run compilation for."
)
def run(self):
@@ -148,7 +148,7 @@ def run(self):
raise ValueError("Both the encoder_sequence_length and the decoder_sequence_length must be provided.")
else:
sequence_length = [self.args.encoder_sequence_length, self.args.decoder_sequence_length]
- runner.run(
+ returncode, stdout = runner.run(
self.args.num_cores,
self.args.precision,
self.args.train_batch_size,
@@ -158,8 +158,10 @@ def run(self):
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
num_epochs=3,
max_steps=self.args.max_steps,
- save_steps=10,
+ save_steps=self.args.max_steps // 2,
)
+ if returncode != 0:
+ raise ValueError(f"Could not add the model to the cache. Full log:\n{stdout}.")
class SynchronizeRepoCommand(BaseOptimumCLICommand):
diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py
index f63b6b469..103c29ba3 100755
--- a/optimum/neuron/trainers.py
+++ b/optimum/neuron/trainers.py
@@ -443,33 +443,40 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor:
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
# We always reduce the loss, even when we do not use it to avoid a new graph.
# This communication is not costly.
- reduced_tr_loss = self._reduce_loss(tr_loss)
+ if self.state.global_step > self._globalstep_last_logged:
+ reduced_tr_loss = self._reduce_loss(tr_loss)
- if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
- if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
- tr_loss.data = self._zero_loss_value.data
- else:
- tr_loss.zero_()
-
- def log_closure(self, reduced_tr_loss, grad_norm):
- if is_main_worker_for_metrics():
- logs: Dict[str, float] = {}
- tr_loss_scalar = reduced_tr_loss.to("cpu").item()
+ if self.control.should_log:
+ with torch.no_grad():
+ if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
+ tr_loss.data = self._zero_loss_value.data
+ else:
+ tr_loss.zero_()
- logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
- logs["learning_rate"] = self._get_learning_rate()
+ def log_closure(self, reduced_tr_loss, grad_norm):
+ if is_main_worker_for_metrics():
+ logs: Dict[str, float] = {}
+ tr_loss_scalar = reduced_tr_loss.to("cpu").item()
- if grad_norm is not None:
- logs["grad_norm"] = (
- grad_norm.detach().to("cpu").item() if isinstance(grad_norm, torch.Tensor) else grad_norm
+ logs["loss"] = round(
+ tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4
)
+ logs["learning_rate"] = self._get_learning_rate()
+
+ if grad_norm is not None:
+ logs["grad_norm"] = (
+ grad_norm.detach().to("cpu").item()
+ if isinstance(grad_norm, torch.Tensor)
+ else grad_norm
+ )
+
+ self._total_loss_scalar += tr_loss_scalar
+ self.store_flos()
+ self.log(logs)
- self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
- self.store_flos()
- self.log(logs)
- xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))
+ xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))
metrics = None
if self.control.should_evaluate:
@@ -1023,8 +1030,6 @@ def _inner_training_loop(
# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
- # deepspeed does its own clipping
-
if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
diff --git a/optimum/neuron/utils/cache_utils.py b/optimum/neuron/utils/cache_utils.py
index e87ed63e5..28845713c 100644
--- a/optimum/neuron/utils/cache_utils.py
+++ b/optimum/neuron/utils/cache_utils.py
@@ -16,15 +16,15 @@
import re
from pathlib import Path
from typing import List, Optional, Union
+from uuid import uuid4
from huggingface_hub import (
HfApi,
RepoUrl,
create_repo,
get_token,
- whoami,
)
-from huggingface_hub.utils import RepositoryNotFoundError
+from huggingface_hub.utils import GatedRepoError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
from transformers import PretrainedConfig
from ...utils import logging
@@ -127,36 +127,40 @@ def is_private_repo(repo_id: str) -> bool:
return private
+_CACHED_HAS_WRITE_ACCESS_TO_REPO = {}
+
+
def has_write_access_to_repo(repo_id: str) -> bool:
- # It is assumed that the user does not have write access to a canonical repo.
- # In any case, since this function is designed to check for write access on cache repos, it should never be the
- # case.
- if "/" not in repo_id:
- return False
+ # If the result has already been cached, use it instead of requesting the HF Hub again.
+ token = get_token()
+ key = (token, repo_id)
+ if key in _CACHED_HAS_WRITE_ACCESS_TO_REPO:
+ return _CACHED_HAS_WRITE_ACCESS_TO_REPO[key]
+
+ api = HfApi()
+ has_access = None
try:
- user = whoami()
- except Exception:
- return False
- # Token role can either be "read" or "write".
- token_role = user["auth"]["accessToken"]["role"]
- if token_role == "read":
- return False
- username_or_organization = repo_id.rsplit("/", maxsplit=1)[0]
- if user["name"] == username_or_organization:
- return True
- has_write_access_in_org = False
- for org in user["orgs"]:
- if org["name"] == username_or_organization:
- # Role in an organization can be either:
- # "admin", "write", "contributor", "read".
- if is_main_worker() and org["roleInOrg"] == "contributor":
- logger.warning(
- f"You are logged in as a contributor to the cache repo {repo_id}. It is not possible to infer "
- "whether you have write access on this repo or not, so it will be assumed you do not."
- )
- has_write_access_in_org = org["roleInOrg"] in ["admin", "write"]
- break
- return has_write_access_in_org
+ api.delete_branch(repo_id=repo_id, repo_type="model", branch=f"this-branch-does-not-exist-{uuid4()}")
+ except GatedRepoError:
+ has_access = False
+ except RepositoryNotFoundError:
+ # We could raise an error to indicate the user that the repository could not even be found:
+ # raise ValueError(f"Repository {repo_id} not found (repo_type: {repo_type}). Is it a private one?") from e
+ # But here we simply return `False`, because it means that we do not have write access to this repo in the end.
+ has_access = False
+ except RevisionNotFoundError:
+ has_access = True # has write access, otherwise would have been 403 forbidden.
+ except HfHubHTTPError as e:
+ if e.response.status_code == 403:
+ has_access = False
+
+ if has_access is None:
+ raise ValueError(f"Cannot determine write access to {repo_id}")
+
+ # Cache the result for subsequent calls.
+ _CACHED_HAS_WRITE_ACCESS_TO_REPO[key] = has_access
+
+ return has_access
def get_hf_hub_cache_repos(log_warnings: bool = False) -> List[str]:
diff --git a/optimum/neuron/utils/runner.py b/optimum/neuron/utils/runner.py
index 42e599d3a..108ea6d90 100644
--- a/optimum/neuron/utils/runner.py
+++ b/optimum/neuron/utils/runner.py
@@ -39,9 +39,10 @@
logger = logging.get_logger()
-_BASE_RAW_FILES_PATH_IN_GH_REPO = "https://raw.githubusercontent.com/huggingface/optimum-neuron/"
+_GH_REPO_RAW_URL = "https://raw.githubusercontent.com/huggingface/optimum-neuron"
+_GH_REPO_URL = "https://github.com/huggingface/optimum-neuron"
_GH_REPO_EXAMPLE_FOLDERS = [
- "audio-classification",
+ # "audio-classification",
"image-classification",
"language-modeling",
"multiple-choice",
@@ -62,27 +63,72 @@
"summarization": "run_summarization",
"translation": "run_translation",
"image-classification": "run_image_classification",
- "audio-classification": "run_audio_classification",
+ # "audio-classification": "run_audio_classification",
"speech-recognition": "run_speech_recognition_ctc",
}
+def list_filenames_in_github_repo_directory(
+ github_repo_directory_url: str, only_files: bool = False, only_directories: bool = False
+) -> List[str]:
+ """
+ Lists the content of a repository on GitHub.
+ """
+ if only_files and only_directories:
+ raise ValueError("Either `only_files` or `only_directories` can be set to True.")
+
+ response = requests.get(github_repo_directory_url)
+
+ if response.status_code != 200:
+ raise ValueError(f"Could not fetch the content of the page: {github_repo_directory_url}.")
+
+ # Here we use regex instead of beautiful soup to not rely on yet another library.
+ table_regex = r"\
"
+ filename_column_regex = r"\"
+ if only_files:
+ filename_regex = r"\
Path:
- # TODO: test that every existing task can be downloaded.
- script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
- example_script_path = target_directory / script_name
was_saved = False
+ script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
+ example_script_path = target_directory
for folder in _GH_REPO_EXAMPLE_FOLDERS:
- url = f"{_BASE_RAW_FILES_PATH_IN_GH_REPO}/{revision}/examples/{folder}/{script_name}"
- r = requests.get(url)
- if r.status_code != 200:
+ raw_url_folder = f"{_GH_REPO_RAW_URL}/{revision}/examples/{folder}"
+ url_folder = f"{_GH_REPO_URL}/{revision}/examples/{folder}"
+ filenames_for_example = list_filenames_in_github_repo_directory(url_folder, only_files=True)
+ if script_name not in filenames_for_example:
continue
- with open(example_script_path, "w") as fp:
- fp.write(r.text)
- was_saved = True
+ for filename in filenames_for_example:
+ r = requests.get(f"{raw_url_folder}/{filename}")
+ if r.status_code != 200:
+ continue
+ local_path = target_directory / filename
+ with open(local_path, "w") as fp:
+ fp.write(r.text)
+ if filename == script_name:
+ was_saved = True
+ example_script_path = local_path
+ if was_saved:
+ break
if not was_saved:
raise FileNotFoundError(f"Could not find an example script for the task {task_name} on the GitHub repo")
-
return example_script_path
@@ -198,16 +244,6 @@ def __init__(
self.task = task
self.example_dir = example_dir
- if example_dir is None:
- example_dir = Path(__file__).parent.parent.parent.parent / "examples"
- if not example_dir.exists():
- logger.info(
- f"Could not find the example script for the task {task} locally. Please provide the path manually "
- "or install `optimum-neuron` from sources. Otherwise the example will be downloaded from the "
- "GitHub repo."
- )
- else:
- self.example_dir = example_dir
if use_venv:
raise NotImplementedError("use_venv=True is not supported yet.")
diff --git a/tests/cli/test_neuron_cache_cli.py b/tests/cli/test_neuron_cache_cli.py
index ad8c6649b..0d7887445 100644
--- a/tests/cli/test_neuron_cache_cli.py
+++ b/tests/cli/test_neuron_cache_cli.py
@@ -19,22 +19,18 @@
import subprocess
from pathlib import Path
from tempfile import TemporaryDirectory
-from unittest import TestCase
+from typing import Optional
-from huggingface_hub import HfApi, create_repo, delete_repo
+import pytest
+from huggingface_hub import HfApi, delete_repo
from huggingface_hub.utils import RepositoryNotFoundError
from transformers import BertConfig, BertModel, BertTokenizer
-from transformers.testing_utils import is_staging_test
from optimum.neuron.utils.cache_utils import (
- CACHE_REPO_FILENAME,
CACHE_REPO_NAME,
load_custom_cache_repo_name_from_hf_home,
)
from optimum.neuron.utils.testing_utils import is_trainium_test
-from optimum.utils.testing_utils import USER
-
-from ..utils import StagingTestMixin
# Taken from https://pynative.com/python-generate-random-string/
@@ -44,116 +40,53 @@ def get_random_string(length: int) -> str:
@is_trainium_test
-@is_staging_test
-class TestNeuronCacheCLI(StagingTestMixin, TestCase):
- def setUp(self):
- self._hf_home = os.environ.get("HF_HOME", "")
-
- self.repo_name = "blabla"
- self.repo_id = f"{USER}/{self.repo_name}"
-
- self.default_repo_name = CACHE_REPO_NAME
- self.default_repo_id = f"{USER}/{self.default_repo_name}"
-
- def tearDown(self):
- super().tearDown()
- os.environ["HF_HOME"] = self._hf_home
+class TestNeuronCacheCLI:
+ def _optimum_neuron_cache_create(self, cache_repo_id: Optional[str] = None, public: bool = False):
+ name_str = f"--name {cache_repo_id}" if cache_repo_id is not None else ""
+ public_str = "--public" if public else ""
+ command = f"optimum-cli neuron cache create {name_str} {public_str}".split()
+ p = subprocess.Popen(command)
+ _ = p.wait()
try:
- delete_repo(self.default_repo_id, repo_type="model")
- except RepositoryNotFoundError:
- pass
+ repo_id = cache_repo_id if cache_repo_id is not None else CACHE_REPO_NAME
+ info = HfApi().repo_info(repo_id, repo_type="model")
+ assert info.private == (
+ not public
+ ), "The privacy of the repo should match the presence of the --public flag."
- try:
- delete_repo(self.repo_id, repo_type="model")
except RepositoryNotFoundError:
- pass
-
- def _optimum_neuron_cache_create(self, default_name: bool = True, public: bool = False):
- with TemporaryDirectory() as tmpdirname:
- repo_id = self.default_repo_id if default_name else self.repo_id
-
- env = dict(self._env, HF_HOME=tmpdirname)
-
- command = f"huggingface-cli login --token {self._staging_token}".split()
- p = subprocess.Popen(command, env=env)
- returncode = p.wait()
- self.assertEqual(returncode, 0)
-
- name_str = f"--name {self.repo_name}" if not default_name else ""
- public_str = "--public" if public else ""
- command = f"optimum-cli neuron cache create {name_str} {public_str}".split()
- p = subprocess.Popen(command, env=env)
- returncode = p.wait()
-
- try:
- info = HfApi().repo_info(repo_id, repo_type="model")
- self.assertEqual(
- info.private, not public, "The privacy of the repo should match the presence of the --public flag."
- )
- except RepositoryNotFoundError:
- self.fail("The repo was not created.")
-
- hf_home_cache_repo_file = f"{tmpdirname}/{CACHE_REPO_FILENAME}"
- self.assertEqual(
- repo_id,
- load_custom_cache_repo_name_from_hf_home(hf_home_cache_repo_file),
- f"Saved local Neuron cache name should be equal to {repo_id}.",
- )
-
- def test_optimum_neuron_cache_create_with_default_name(self):
- return self._optimum_neuron_cache_create(public=False)
-
- def test_optimum_neuron_cache_create_public_with_default_name(self):
- return self._optimum_neuron_cache_create(public=True)
-
- def test_optimum_neuron_cache_create_with_custom_name(self):
- return self._optimum_neuron_cache_create(default_name=False)
-
- def test_optimum_neuron_cache_set(self):
- with TemporaryDirectory() as tmpdirname:
- os.environ["HF_HOME"] = tmpdirname
-
- create_repo(self.repo_name, repo_type="model")
-
- command = f"optimum-cli neuron cache set {self.repo_id}".split()
- env = dict(self._env, HF_HOME=tmpdirname)
- p = subprocess.Popen(command, env=env)
- returncode = p.wait()
- self.assertEqual(returncode, 0)
-
- hf_home_cache_repo_file = f"{tmpdirname}/{CACHE_REPO_FILENAME}"
- self.assertEqual(
- self.repo_id,
- load_custom_cache_repo_name_from_hf_home(hf_home_cache_repo_file),
- f"Saved local Neuron cache name should be equal to {self.repo_id}.",
- )
-
- def test_optimum_neuron_cache_add(self):
+ pytest.fail("The repo was not created.")
+ finally:
+ delete_repo(repo_id)
+
+ assert (
+ repo_id == load_custom_cache_repo_name_from_hf_home()
+ ), f"Saved local Neuron cache name should be equal to {repo_id}."
+
+ def test_optimum_neuron_cache_create_with_custom_name(self, hub_test):
+ seed = random.randint(0, 100)
+ repo_id = f"{hub_test}-{seed}"
+ return self._optimum_neuron_cache_create(cache_repo_id=repo_id)
+
+ def test_optimum_neuron_cache_create_public_with_custom_name(self, hub_test):
+ seed = random.randint(0, 100)
+ repo_id = f"{hub_test}-{seed}"
+ return self._optimum_neuron_cache_create(cache_repo_id=repo_id, public=True)
+
+ def test_optimum_neuron_cache_set(self, hub_test):
+ repo_id = hub_test
+ command = f"optimum-cli neuron cache set {repo_id}".split()
+ p = subprocess.Popen(command)
+ returncode = p.wait()
+ assert returncode == 0
+ assert (
+ repo_id == load_custom_cache_repo_name_from_hf_home()
+ ), f"Saved local Neuron cache name should be equal to {repo_id}."
+
+ def test_optimum_neuron_cache_add(self, hub_test):
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
- os.environ["CUSTOM_CACHE_REPO"] = self.CUSTOM_CACHE_REPO
- # TODO: activate those later.
- # Without any sequence length, it should fail.
- # command = (
- # "optimum-cli neuron cache add -m bert-base-uncased --task text-classification --train_batch_size 16 "
- # "--precision bf16 --num_cores 2"
- # ).split()
- # p = subprocess.Popen(command, stderr=PIPE)
- # _, stderr = p.communicate()
- # stderr = stderr.decode("utf-8")
- # self.assertIn("either sequence_length or encoder_sequence and decoder_sequence_length", stderr)
-
- # Without both encoder and decoder sequence lengths, it should fail.
- # command = (
- # "optimum-cli neuron cache add -m t5-small --task translation --train_batch_size 16 --precision bf16 "
- # "--num_cores 2 --encoder_sequence_length 512"
- # ).split()
- # p = subprocess.Popen(command, stderr=PIPE)
- # _, stderr = p.communicate()
- # stderr = stderr.decode("utf-8")
- # self.assertIn("Both the encoder_sequence and decoder_sequence_length", stderr)
-
# Create dummy BERT model.
bert_model_name = tmpdir / "bert_model"
config = BertConfig()
@@ -162,10 +95,12 @@ def test_optimum_neuron_cache_add(self):
config.num_attention_heads = 2
config.vocab_size = 100
- with open(tmpdir / "vocab.txt", "w") as fp:
- fp.write("\n".join(get_random_string(random.randint(10, 20))))
+ mandatory_tokens = ["[UNK]", "[SEP]", "[CLS]"]
- tokenizer = BertTokenizer(tmpdir / "vocab.txt")
+ with open(tmpdir / "bert_vocab.txt", "w") as fp:
+ fp.write("\n".join([get_random_string(random.randint(10, 20))] + mandatory_tokens))
+
+ tokenizer = BertTokenizer((tmpdir / "bert_vocab.txt").as_posix())
tokenizer.save_pretrained(bert_model_name)
model = BertModel(config)
@@ -181,7 +116,7 @@ def test_optimum_neuron_cache_add(self):
).split()
p = subprocess.Popen(command, env=env)
returncode = p.wait()
- self.assertNotEqual(returncode, 0)
+ assert returncode != 0
# With wrong num_cores value, it should fail.
command = (
@@ -190,22 +125,15 @@ def test_optimum_neuron_cache_add(self):
).split()
p = subprocess.Popen(command, env=env)
returncode = p.wait()
- self.assertNotEqual(returncode, 0)
+ assert returncode != 0
# Non seq2seq model.
command = (
f"optimum-cli neuron cache add -m {bert_model_name} --task text-classification --train_batch_size 1 "
"--precision bf16 --num_cores 2 --sequence_length 128"
).split()
- p = subprocess.Popen(command, env=env)
- returncode = p.wait()
- self.assertEqual(returncode, 0)
-
- # seq2seq model.
- command = (
- f"optimum-cli neuron cache add -m {bert_model_name} --task translation --train_batch_size 1 --precision bf16 "
- "--num_cores 2 --encoder_sequence_length 12 --decoder_sequence_length 12"
- ).split()
- p = subprocess.Popen(command, env=env)
- returncode = p.wait()
- self.assertEqual(returncode, 0)
+ p = subprocess.Popen(command, stdout=subprocess.PIPE, env=env)
+ stdout, _ = p.communicate()
+ print(stdout)
+ returncode = p.returncode
+ assert returncode == 0
diff --git a/tests/conftest.py b/tests/conftest.py
index a681ed087..8062756a5 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -17,7 +17,7 @@
from pathlib import Path
import pytest
-from huggingface_hub import create_repo, delete_repo, get_token, login, logout
+from huggingface_hub import HfApi, create_repo, delete_repo, get_token, login, logout
from optimum.neuron.utils.cache_utils import (
delete_custom_cache_repo_name_from_hf_home,
@@ -115,6 +115,12 @@ def _hub_test(create_local_cache: bool = False):
yield custom_cache_repo_with_seed
delete_repo(custom_cache_repo_with_seed, repo_type="model")
+
+ model_repos = HfApi().list_models()
+ for repo in model_repos:
+ if repo.id.startswith("optimum-neuron-cache-for-testing-"):
+ delete_repo(repo.id)
+
if local_cache_path_with_seed.is_dir():
shutil.rmtree(local_cache_path_with_seed)
if orig_token is not None: