Skip to content

Commit

Permalink
Merge branch 'main' into aot/bert-nemo-ux
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi authored Dec 3, 2024
2 parents c0a7836 + 9abd81b commit de1c788
Show file tree
Hide file tree
Showing 11 changed files with 64 additions and 60 deletions.
11 changes: 11 additions & 0 deletions examples/audio/process_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
raise RuntimeError('Model does not have a sampler')

if cfg.audio_dir is not None:
input_dir = cfg.audio_dir
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True))
else:
# get filenames from manifest
Expand All @@ -193,6 +194,15 @@ def main(cfg: ProcessConfig) -> ProcessConfig:
audio_file = manifest_dir / audio_file
filepaths.append(str(audio_file.absolute()))

# common path for all files
common_path = os.path.commonpath(filepaths)
if Path(common_path).is_relative_to(manifest_dir):
# if all paths are relative to the manifest, use manifest dir as input dir
input_dir = manifest_dir
else:
# use the parent of the common path as input dir
input_dir = Path(common_path).parent

if cfg.max_utts is not None:
# Limit the number of utterances to process
filepaths = filepaths[: cfg.max_utts]
Expand Down Expand Up @@ -238,6 +248,7 @@ def autocast():
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
input_channel_selector=cfg.input_channel_selector,
input_dir=input_dir,
)

logging.info(f"Finished processing {len(filepaths)} files!")
Expand Down
14 changes: 12 additions & 2 deletions nemo/collections/audio/models/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def process(
batch_size: int = 1,
num_workers: Optional[int] = None,
input_channel_selector: Optional[ChannelSelectorType] = None,
input_dir: Optional[str] = None,
) -> List[str]:
"""
Takes paths to audio files and returns a list of paths to processed
Expand All @@ -344,6 +345,7 @@ def process(
num_workers: Number of workers for the dataloader
input_channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio.
If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`.
input_dir: Optional, directory that contains the input files. If provided, the output directory will mirror the input directory structure.
Returns:
Paths to processed audio signals.
Expand Down Expand Up @@ -413,9 +415,17 @@ def process(

for example_idx in range(processed_batch.size(0)):
# This assumes the data loader is not shuffling files
file_name = os.path.basename(paths2audio_files[file_idx])
if input_dir is not None:
# Make sure the output has the same directory structure as the input
filepath_relative = os.path.relpath(paths2audio_files[file_idx], start=input_dir)
else:
# Input dir is not provided, save files in the output directory
filepath_relative = os.path.basename(paths2audio_files[file_idx])
# Prepare output file
output_file = os.path.join(output_dir, f'processed_{file_name}')
output_file = os.path.join(output_dir, filepath_relative)
# Create output dir if necessary
if not os.path.isdir(os.path.dirname(output_file)):
os.makedirs(os.path.dirname(output_file))
# Crop the output signal to the actual length
output_signal = processed_batch[example_idx, :, : input_length[example_idx]].cpu().numpy()
# Write audio
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/llm/gpt/data/packed_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def prepare_packed_sequence_data(
sequences, histogram = create_hist(dataset, max_seq_length)

assignments = create_packing_strategy(histogram, packed_sequence_size, packing_algorithm)
output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size)
output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id)

# save output data
np.save(output_path, output_data)
Expand Down
14 changes: 2 additions & 12 deletions nemo/deploy/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.


use_query_llm = True
try:
from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch
except Exception:
use_query_llm = False

use_megatron_llm = True
try:
from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable
except Exception:
use_megatron_llm = False
from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable
from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch
11 changes: 9 additions & 2 deletions nemo/utils/sequence_packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def create_packing_strategy(


def fill_packing_strategy(
assignments: List[List[int]], sequences: Dict[int, List[Dict]], pack_size: int
assignments: List[List[int]], sequences: Dict[int, List[Dict]], pack_size: int, pad_id: int
) -> List[Dict]:
"""
Fills the packing strategy with actual sequence data based on assignments and sequence information.
Expand All @@ -192,6 +192,7 @@ def fill_packing_strategy(
sequences: A dictionary where keys are sequence lengths and values are lists of corresponding sequences
from the dataset (output of 'create_hist').
pack_size: The maximum capacity of each bin.
pad_id: The tokenizer's padding token.
Returns:
output_data: A list of dictionaries, where each dictionary represents a packed sequence with its input IDs,
Expand All @@ -205,7 +206,13 @@ def fill_packing_strategy(
input_ids = np.array([x['input_ids'] for x in per_seq_data])[perm].tolist()
try:
loss_mask = np.array(
[[idx >= x['answer_start_idx'] for idx in range(len(x['input_ids']))] for x in per_seq_data]
[
[
idx >= x['answer_start_idx'] and x['input_ids'][idx] != pad_id
for idx in range(len(x['input_ids']))
]
for x in per_seq_data
]
)[perm].tolist()
except KeyError:
loss_mask = None
Expand Down
7 changes: 5 additions & 2 deletions scripts/dataset_processing/add_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,15 @@ def add_noise(infile, snrs, noise_manifest, out_dir, num_workers=1):
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_manifest", type=str, required=True, help="clean test set",
"--input_manifest",
type=str,
required=True,
help="clean test set",
)
parser.add_argument("--noise_manifest", type=str, required=True, help="path to noise manifest file")
parser.add_argument("--out_dir", type=str, required=True, help="destination directory for audio and manifests")
parser.add_argument("--snrs", type=int, nargs="+", default=[0, 10, 20, 30])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--num_workers", default=1, type=int)
parser.add_argument("--sample_rate", default=16000, type=int)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion scripts/deploy/nlp/query_inframework.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import argparse
import sys

from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch
from nemo.deploy.nlp import NemoQueryLLMPyTorch


def get_args(argv):
Expand Down
42 changes: 14 additions & 28 deletions scripts/llm/llama3_pretraining.py → scripts/llm/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

# NOTE: This script is only an example of using NeMo with NeMo-Run's APIs and is subject to change without notice.
# This script is used for pretraining a Llama3 model, specifically for the 8b or 70b model variants, on local and slurm executors.
# It uses NeMo 2.0 recipes (https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/recipes/llama3_8b.py#L74) and NeMo-Run (https://github.com/NVIDIA/NeMo-Run) to configure and execute the runs.
# This script is used for pretraining on local and slurm executors.
# It uses NeMo 2.0 recipes (https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/recipes/) and NeMo-Run (https://github.com/NVIDIA/NeMo-Run) to configure and execute the runs.

import argparse
from functools import partial
Expand All @@ -26,12 +26,12 @@


def get_parser():
parser = argparse.ArgumentParser(description="Llama3 Pretraining")
parser = argparse.ArgumentParser(description="NeMo2.0 Pretraining")
parser.add_argument(
"--size",
"--recipe",
type=str,
default="8b",
help="Choose llama3 model size 70b/8b",
default="llama3_8b",
help="Choose NeMo 2.0 recipe. Recipes are named in the format of <model_name>_<model_size>(_<long_sequenth_length> or other special settings)",
)
parser.add_argument(
"--tag",
Expand Down Expand Up @@ -66,7 +66,7 @@ def slurm_executor(
time: str = "01:00:00",
custom_mounts: Optional[list[str]] = None,
custom_env_vars: Optional[dict[str, str]] = None,
container_image: str = "nvcr.io/nvidia/nemo:dev",
container_image: str = "nvcr.io/nvidia/nemo:24.09",
retries: int = 0,
) -> run.SlurmExecutor:
if not (user and host and remote_job_dir and account and partition and nodes and devices):
Expand Down Expand Up @@ -102,7 +102,7 @@ def slurm_executor(
mem="0",
exclusive=True,
gres="gpu:8",
packager=run.GitArchivePackager(subpath="examples/llm/run"),
packager=run.GitArchivePackager(),
)

executor.container_image = container_image
Expand Down Expand Up @@ -134,28 +134,14 @@ def main():
if args.tag and not args.tag.startswith("-"):
args.tag = "-" + args.tag

MODEL_SIZE_MAPPING: dict[str, dict[str, Any]] = {
"8b": {
"exp_name": "llama3-8b",
"nemo": {
"pretrain": partial(llm.llama3_8b.pretrain_recipe, num_nodes=1, num_gpus_per_node=8),
},
},
"70b": {
"exp_name": "llama3-70b",
"nemo": {
"pretrain": partial(llm.llama3_70b.pretrain_recipe, num_nodes=128, num_gpus_per_node=8),
},
},
}

exp_name = MODEL_SIZE_MAPPING[args.size]["exp_name"]
exp_name = args.recipe

# Uses configs from NeMo directly
pretrain = MODEL_SIZE_MAPPING[args.size]["nemo"]["pretrain"](
name=exp_name,
ckpt_dir="/nemo_run/checkpoints",
)
assert hasattr(
llm, args.recipe
), f"Recipe named {args.recipe} not found. General format is <model_name>_<model_size>(_<long_sequenth_length> or other special settings)"
pretrain_recipe = getattr(llm, args.recipe).pretrain_recipe
pretrain = partial(pretrain_recipe)(name=exp_name, dir="/nemo_run/checkpoints")

# Overwrite the dataloader in the recipe to use your custom dataloader.
# dataloader = set_your_custom_dataloader
Expand Down
7 changes: 4 additions & 3 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def tokenize_dataset(cfg: 'DictConfig'):

max_seq_length = dataset.max_seq_length
pad_id = dataset.tokenizer.eos_id
tokenizer = dataset.tokenizer
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
dataset = np.array([dataset[i] for i in range(len(dataset))])
if cp_size > 1:
Expand Down Expand Up @@ -162,7 +163,7 @@ def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id):
for data in dataset:
max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult))
pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id)
return dataset
return dataset, tokenizer


@dataclass
Expand All @@ -187,11 +188,11 @@ def from_config(self, cfg: 'DictConfig'):
)
def main(cfg: 'DictConfig') -> None:
args = PackingArgs().from_config(cfg)
dataset = tokenize_dataset(cfg)
dataset, tokenizer = tokenize_dataset(cfg)
sequences, histogram = create_hist(dataset, cfg.model.data.train_ds.max_seq_length)
for pack_size in args.pack_sizes:
assignments = create_packing_strategy(histogram, pack_size, args.packing_algorithm)
output_data = fill_packing_strategy(assignments, sequences, pack_size)
output_data = fill_packing_strategy(assignments, sequences, pack_size, tokenizer.eos_id)

# save output data
os.makedirs(args.output_dir, exist_ok=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/deploy/nemo_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import torch

from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable
from nemo.deploy.nlp import MegatronLLMDeployable
from tests.infer_data_path import get_infer_test_data

run_export_tests = True
Expand Down
12 changes: 4 additions & 8 deletions tests/export/nemo_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from nemo.deploy.nlp import MegatronLLMDeployable, NemoQueryLLMPyTorch
except Exception as e:
LOGGER.warning(
f"Cannot import MegatronLLMDeployable, in-framework inference will not be available. {type(e).__name__}: {e}"
"Cannot import MegatronLLMDeployable or NemoQueryLLMPyTorch,"
f" in-framework inference will not be available. {type(e).__name__}: {e}"
)
in_framework_supported = False

Expand Down Expand Up @@ -104,12 +105,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path):
all_expected_outputs.append(expected_output)
if model is not None:

in_framework_model = False
if in_framework_supported:
if isinstance(model, MegatronLLMDeployable):
in_framework_model = True

if in_framework_model:
if in_framework_supported and isinstance(model, MegatronLLMDeployable):
model_output = model.generate(
inputs=[prompt],
length_params={"min_length": 1, "max_length": 1},
Expand Down Expand Up @@ -153,7 +149,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path):
correct_answers_relaxed += 1

if nq is not None:
if isinstance(nq, NemoQueryLLMPyTorch):
if in_framework_supported and isinstance(nq, NemoQueryLLMPyTorch):
deployed_output = nq.query_llm(
prompts=[prompt],
max_length=1,
Expand Down

0 comments on commit de1c788

Please sign in to comment.