Skip to content

Commit

Permalink
Various fixes for training (#654)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Jul 23, 2024
1 parent 18c6ab4 commit 5148118
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 207 deletions.
8 changes: 5 additions & 3 deletions optimum/commands/neuron/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def parse_args(parser: "ArgumentParser"):
"--example_dir", type=str, default=None, help="Path to where the example scripts are stored."
)
parser.add_argument(
"--max_steps", type=int, default=200, help="The maximum number of steps to run compilation for."
"--max_steps", type=int, default=10, help="The maximum number of steps to run compilation for."
)

def run(self):
Expand All @@ -148,7 +148,7 @@ def run(self):
raise ValueError("Both the encoder_sequence_length and the decoder_sequence_length must be provided.")
else:
sequence_length = [self.args.encoder_sequence_length, self.args.decoder_sequence_length]
runner.run(
returncode, stdout = runner.run(
self.args.num_cores,
self.args.precision,
self.args.train_batch_size,
Expand All @@ -158,8 +158,10 @@ def run(self):
gradient_accumulation_steps=self.args.gradient_accumulation_steps,
num_epochs=3,
max_steps=self.args.max_steps,
save_steps=10,
save_steps=self.args.max_steps // 2,
)
if returncode != 0:
raise ValueError(f"Could not add the model to the cache. Full log:\n{stdout}.")


class SynchronizeRepoCommand(BaseOptimumCLICommand):
Expand Down
49 changes: 27 additions & 22 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,33 +443,40 @@ def _reduce_loss(self, tr_loss: torch.Tensor) -> torch.Tensor:
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval):
# We always reduce the loss, even when we do not use it to avoid a new graph.
# This communication is not costly.
reduced_tr_loss = self._reduce_loss(tr_loss)
if self.state.global_step > self._globalstep_last_logged:
reduced_tr_loss = self._reduce_loss(tr_loss)

if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
tr_loss.data = self._zero_loss_value.data
else:
tr_loss.zero_()

def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
logs: Dict[str, float] = {}
tr_loss_scalar = reduced_tr_loss.to("cpu").item()
if self.control.should_log:
with torch.no_grad():
if isinstance(getattr(self, "_zero_loss_value"), torch.Tensor):
tr_loss.data = self._zero_loss_value.data
else:
tr_loss.zero_()

logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
def log_closure(self, reduced_tr_loss, grad_norm):
if is_main_worker_for_metrics():
logs: Dict[str, float] = {}
tr_loss_scalar = reduced_tr_loss.to("cpu").item()

if grad_norm is not None:
logs["grad_norm"] = (
grad_norm.detach().to("cpu").item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["loss"] = round(
tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4
)
logs["learning_rate"] = self._get_learning_rate()

if grad_norm is not None:
logs["grad_norm"] = (
grad_norm.detach().to("cpu").item()
if isinstance(grad_norm, torch.Tensor)
else grad_norm
)

self._total_loss_scalar += tr_loss_scalar
self.store_flos()
self.log(logs)

self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)

xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))
xm.add_step_closure(log_closure, (self, reduced_tr_loss, grad_norm))

metrics = None
if self.control.should_evaluate:
Expand Down Expand Up @@ -1023,8 +1030,6 @@ def _inner_training_loop(

# Gradient clipping
if args.max_grad_norm is not None and args.max_grad_norm > 0:
# deepspeed does its own clipping

if is_sagemaker_mp_enabled() and args.fp16:
self.optimizer.clip_master_grads(args.max_grad_norm)
_grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
Expand Down
64 changes: 34 additions & 30 deletions optimum/neuron/utils/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import re
from pathlib import Path
from typing import List, Optional, Union
from uuid import uuid4

from huggingface_hub import (
HfApi,
RepoUrl,
create_repo,
get_token,
whoami,
)
from huggingface_hub.utils import RepositoryNotFoundError
from huggingface_hub.utils import GatedRepoError, HfHubHTTPError, RepositoryNotFoundError, RevisionNotFoundError
from transformers import PretrainedConfig

from ...utils import logging
Expand Down Expand Up @@ -127,36 +127,40 @@ def is_private_repo(repo_id: str) -> bool:
return private


_CACHED_HAS_WRITE_ACCESS_TO_REPO = {}


def has_write_access_to_repo(repo_id: str) -> bool:
# It is assumed that the user does not have write access to a canonical repo.
# In any case, since this function is designed to check for write access on cache repos, it should never be the
# case.
if "/" not in repo_id:
return False
# If the result has already been cached, use it instead of requesting the HF Hub again.
token = get_token()
key = (token, repo_id)
if key in _CACHED_HAS_WRITE_ACCESS_TO_REPO:
return _CACHED_HAS_WRITE_ACCESS_TO_REPO[key]

api = HfApi()
has_access = None
try:
user = whoami()
except Exception:
return False
# Token role can either be "read" or "write".
token_role = user["auth"]["accessToken"]["role"]
if token_role == "read":
return False
username_or_organization = repo_id.rsplit("/", maxsplit=1)[0]
if user["name"] == username_or_organization:
return True
has_write_access_in_org = False
for org in user["orgs"]:
if org["name"] == username_or_organization:
# Role in an organization can be either:
# "admin", "write", "contributor", "read".
if is_main_worker() and org["roleInOrg"] == "contributor":
logger.warning(
f"You are logged in as a contributor to the cache repo {repo_id}. It is not possible to infer "
"whether you have write access on this repo or not, so it will be assumed you do not."
)
has_write_access_in_org = org["roleInOrg"] in ["admin", "write"]
break
return has_write_access_in_org
api.delete_branch(repo_id=repo_id, repo_type="model", branch=f"this-branch-does-not-exist-{uuid4()}")
except GatedRepoError:
has_access = False
except RepositoryNotFoundError:
# We could raise an error to indicate the user that the repository could not even be found:
# raise ValueError(f"Repository {repo_id} not found (repo_type: {repo_type}). Is it a private one?") from e
# But here we simply return `False`, because it means that we do not have write access to this repo in the end.
has_access = False
except RevisionNotFoundError:
has_access = True # has write access, otherwise would have been 403 forbidden.
except HfHubHTTPError as e:
if e.response.status_code == 403:
has_access = False

if has_access is None:
raise ValueError(f"Cannot determine write access to {repo_id}")

# Cache the result for subsequent calls.
_CACHED_HAS_WRITE_ACCESS_TO_REPO[key] = has_access

return has_access


def get_hf_hub_cache_repos(log_warnings: bool = False) -> List[str]:
Expand Down
82 changes: 59 additions & 23 deletions optimum/neuron/utils/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@

logger = logging.get_logger()

_BASE_RAW_FILES_PATH_IN_GH_REPO = "https://raw.githubusercontent.com/huggingface/optimum-neuron/"
_GH_REPO_RAW_URL = "https://raw.githubusercontent.com/huggingface/optimum-neuron"
_GH_REPO_URL = "https://github.com/huggingface/optimum-neuron"
_GH_REPO_EXAMPLE_FOLDERS = [
"audio-classification",
# "audio-classification",
"image-classification",
"language-modeling",
"multiple-choice",
Expand All @@ -62,27 +63,72 @@
"summarization": "run_summarization",
"translation": "run_translation",
"image-classification": "run_image_classification",
"audio-classification": "run_audio_classification",
# "audio-classification": "run_audio_classification",
"speech-recognition": "run_speech_recognition_ctc",
}


def list_filenames_in_github_repo_directory(
github_repo_directory_url: str, only_files: bool = False, only_directories: bool = False
) -> List[str]:
"""
Lists the content of a repository on GitHub.
"""
if only_files and only_directories:
raise ValueError("Either `only_files` or `only_directories` can be set to True.")

response = requests.get(github_repo_directory_url)

if response.status_code != 200:
raise ValueError(f"Could not fetch the content of the page: {github_repo_directory_url}.")

# Here we use regex instead of beautiful soup to not rely on yet another library.
table_regex = r"\<table aria-labelledby=\"folders-and-files\".*\<\/table\>"
filename_column_regex = r"\<div class=\"react-directory-filename-cell\".*?\<\/div>"
if only_files:
filename_regex = r"\<a .* aria-label=\"([\w\.]+), \(File\)\""
elif only_directories:
filename_regex = r"\<a .* aria-label=\"([\w\.]+), \(Directory\)\""
else:
filename_regex = r"\<a .* aria-label=\"([\w\.]+)"

filenames = []

table_match = re.search(table_regex, response.text)
if table_match is not None:
table_content = response.text[table_match.start(0) : table_match.end(0)]
for column in re.finditer(filename_column_regex, table_content):
match = re.search(filename_regex, column.group(0))
if match:
filenames.append(match.group(1))

return list(set(filenames))


def download_example_script_from_github(task_name: str, target_directory: Path, revision: str = "main") -> Path:
# TODO: test that every existing task can be downloaded.
script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
example_script_path = target_directory / script_name
was_saved = False
script_name = f"{_TASK_TO_EXAMPLE_SCRIPT[task_name]}.py"
example_script_path = target_directory
for folder in _GH_REPO_EXAMPLE_FOLDERS:
url = f"{_BASE_RAW_FILES_PATH_IN_GH_REPO}/{revision}/examples/{folder}/{script_name}"
r = requests.get(url)
if r.status_code != 200:
raw_url_folder = f"{_GH_REPO_RAW_URL}/{revision}/examples/{folder}"
url_folder = f"{_GH_REPO_URL}/{revision}/examples/{folder}"
filenames_for_example = list_filenames_in_github_repo_directory(url_folder, only_files=True)
if script_name not in filenames_for_example:
continue
with open(example_script_path, "w") as fp:
fp.write(r.text)
was_saved = True
for filename in filenames_for_example:
r = requests.get(f"{raw_url_folder}/{filename}")
if r.status_code != 200:
continue
local_path = target_directory / filename
with open(local_path, "w") as fp:
fp.write(r.text)
if filename == script_name:
was_saved = True
example_script_path = local_path
if was_saved:
break
if not was_saved:
raise FileNotFoundError(f"Could not find an example script for the task {task_name} on the GitHub repo")

return example_script_path


Expand Down Expand Up @@ -198,16 +244,6 @@ def __init__(
self.task = task

self.example_dir = example_dir
if example_dir is None:
example_dir = Path(__file__).parent.parent.parent.parent / "examples"
if not example_dir.exists():
logger.info(
f"Could not find the example script for the task {task} locally. Please provide the path manually "
"or install `optimum-neuron` from sources. Otherwise the example will be downloaded from the "
"GitHub repo."
)
else:
self.example_dir = example_dir

if use_venv:
raise NotImplementedError("use_venv=True is not supported yet.")
Expand Down
Loading

0 comments on commit 5148118

Please sign in to comment.