diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4eee084aa..400626399 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -8,24 +8,24 @@ - local: containers title: Optimum Containers - sections: - - local: tutorials/overview - title: Overview - - local: tutorials/notebooks + - local: training_tutorials/notebooks title: Notebooks - - local: tutorials/fine_tune_bert + - local: training_tutorials/fine_tune_bert title: Fine-tune BERT for Text Classification on AWS Trainium - - local: tutorials/stable_diffusion - title: Generate images with Stable Diffusion models on AWS Inferentia - - local: tutorials/llama2-13b-chatbot + - local: training_tutorials/finetune_llm + title: Fine-tune Llama 3 8B on AWS Trainium + title: Training Tutorials + - sections: + - local: inference_tutorials/notebooks + title: Notebooks + - local: inference_tutorials/llama2-13b-chatbot title: Create your own chatbot with llama-2-13B on AWS Inferentia - - local: tutorials/fine_tune_llama_7b - title: Fine-tune Llama 2 7B on AWS Trainium - - local: tutorials/sentence_transformers + - local: inference_tutorials/sentence_transformers title: Sentence Transformers on AWS Inferentia - title: Tutorials + - local: inference_tutorials/stable_diffusion + title: Generate images with Stable Diffusion models on AWS Inferentia + title: Inference Tutorials - sections: - - local: guides/overview - title: Overview - local: guides/setup_aws_instance title: Set up AWS Trainium instance - local: guides/sagemaker diff --git a/docs/source/guides/distributed_training.mdx b/docs/source/guides/distributed_training.mdx index d15a332a0..7ecd9fd03 100644 --- a/docs/source/guides/distributed_training.mdx +++ b/docs/source/guides/distributed_training.mdx @@ -18,8 +18,9 @@ But there is a caveat: each Neuron core is an independent data-parallel worker b To alleviate that, `optimum-neuron` supports parallelism features enabling you to harness the full power of your Trainium instance: 1. [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html): It is an optimization of data-parallelism which consists in sharding the optimizer state (which usually represents half of the memory needed on the device) over the data-parallel ranks. - 2. [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): It is a technique which consists in sharding each of your model parameters along a given dimension on multiple devices. The number of devices to shard your parameters on is called the `tensor_parallel_size`. - 3. [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): **coming soon!** + 2. [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): It is a technique which consists in sharding each of your model matrix-multiplications along a given axis (row or column) on multiple devices. It also known as intra-layer model parallelism. The number of devices to shard your parameters on is called the `tensor_parallel_size`. + 3. [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf): It is an optimization over Tensor Parallelism which shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations. + 4. [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): It consists in sharding the model block layers on multiple devices. It is also known as inter-layer model parallelism. The number of devices to shard your layers on is called the `pipeline_parallel_size`. The good news is that is it possible to combine those techniques, and `optimum-neuron` makes it very easy! diff --git a/docs/source/guides/overview.mdx b/docs/source/guides/overview.mdx deleted file mode 100644 index 4e80df869..000000000 --- a/docs/source/guides/overview.mdx +++ /dev/null @@ -1,30 +0,0 @@ - - -# Overview - -Welcome to the 🤗 Optimum Neuron how-to guides! - -These guides tackle more advanced topics and will show you how to easily get the best from AWS Trainium / Inferentia: - -- [How to setup AWS Trainium instance](./setup_aws_instance) -- [Training and Deployment using Amazon Sagemaker](./sagemaker) -- [Neuron model cache](./cache_system) -- [How to fine-tune a Transformers model with AWS Trainium](./fine_tune) -- [Distributed training with AWS Neuron](./distributed_training) -- [Export a model to Inferentia](./export_model) -- [Neuron Model Inference](./models) -- [Inference pipelines with AWS Neuron (Inf2/Trn1)](./pipelines) diff --git a/docs/source/guides/setup_aws_instance.mdx b/docs/source/guides/setup_aws_instance.mdx index a93cada83..fc8772a8c 100644 --- a/docs/source/guides/setup_aws_instance.mdx +++ b/docs/source/guides/setup_aws_instance.mdx @@ -16,6 +16,13 @@ limitations under the License. # Set up AWS Trainium instance +In this guide, we will show you: + +1. How to create an AWS Trainium instance +2. How to use and run Jupyter Notebooks on your instance + +## Create an AWS Trainium Instance + The simplest way to work with AWS Trainium and Hugging Face Transformers is the [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2) (DLAMI). The DLAMI comes with all required libraries pre-packaged for you, including the Neuron Drivers, Transformers, Datasets, and Accelerate. To create an EC2 Trainium instance, you can start from the console or the Marketplace. This guide will start from the [EC2 console](https://console.aws.amazon.com/ec2sp/v2/). @@ -96,4 +103,18 @@ instance-id: i-0570615e41700a481 +--------+--------+--------+---------+ ``` +## Configuring `Jupyter Notebook` on your AWS Trainium Instance + +With the instance is up and running, we can ssh into it. +But instead of developing inside a terminal it is also possible to use a `Jupyter Notebook` environment. We can use it for preparing our dataset and launching the training (at least when working on a single node). + +For this, we need to add a port for forwarding in the `ssh` command, which will tunnel our localhost traffic to the Trainium instance. + +```bash +PUBLIC_DNS="" # IP address, e.g. ec2-3-80-.... +KEY_PATH="" # local path to key, e.g. ssh/trn.pem + +ssh -L 8080:localhost:8080 -i ${KEY_NAME}.pem ubuntu@$PUBLIC_DNS +``` + You are done! You can now start using the Trainium accelerators with Hugging Face Transformers. Check out the [Fine-tune Transformers with AWS Trainium](./fine_tune) guide to get started. diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 8f460b5f2..e2f46d5a2 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -24,7 +24,7 @@ The list of officially validated models and tasks is available [here](https://hu
Tutorials @@ -34,7 +34,7 @@ The list of officially validated models and tasks is available [here](https://hu Start here if you are using 🤗 Optimum Neuron for the first time!

- +
How-to guides
diff --git a/docs/source/tutorials/llama2-13b-chatbot.mdx b/docs/source/inference_tutorials/llama2-13b-chatbot.mdx similarity index 100% rename from docs/source/tutorials/llama2-13b-chatbot.mdx rename to docs/source/inference_tutorials/llama2-13b-chatbot.mdx diff --git a/docs/source/tutorials/notebooks.mdx b/docs/source/inference_tutorials/notebooks.mdx similarity index 71% rename from docs/source/tutorials/notebooks.mdx rename to docs/source/inference_tutorials/notebooks.mdx index 10815aee5..fee75190f 100644 --- a/docs/source/tutorials/notebooks.mdx +++ b/docs/source/inference_tutorials/notebooks.mdx @@ -1,5 +1,5 @@ + +# Fine-tune and Test Llama-3 8B on AWS Trainium + +_Note: The complete script for this tutorial can be downloaded [here](https://github.com/huggingface/optimum-neuron/docs/source/training_tutorials/finetune_llm.py)._ + +This tutorial will teach you how to fine-tune open source LLMs like [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) on AWS Trainium. In our example, we are going to leverage the [Optimum Neuron](https://huggingface.co/docs/optimum-neuron/index), [Transformers](https://huggingface.co/docs/transformers/index) and [Datasets](https://huggingface.co/docs/datasets/index) libraries. + +You will learn how to: + +1. [Setup AWS Environment](#1-setup-aws-environment) +2. [Load and process the dataset](#2-load-and-prepare-the-dataset) +3. [Fine-tune Llama on AWS Trainium using the `NeuronTrainer`](#3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer) +4. [Launch Training](#4-launch-training) +5. [Evaluate and test fine-tuned Llama model](#5-evaluate-and-test-fine-tuned-llama-model) + + + + +While we will use `Llama-3 8B` in this tutorial, it is completely possible to use other models, simply by swtiching the `model_id`. +For instance, it is possible to fine-tune: + +- Mistral models, such as [Mistral 7b (`mistralai/Mistral-7B-Instruct-v0.3`)](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) +- Llama-2 models, such as [Llama-2 7b (`meta-llama/Llama-2-7b-hf`)](https://huggingface.co/meta-llama/Llama-2-7b-hf) + +And many others! + + + +## 1. Setup AWS Environment + +Before starting this tutorial, you will need to setup your environment: + +1. Create an AWS Trainium instance. You can follow this [guide](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance) to create one. +2. Make sure you are logged in on the Hugging Face Hub: +```bash +huggingface-cli login --token YOUR_TOKEN +``` +3. Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 8B, for which there are two possibilities: + * The official gated repo: [`meta-llama/Meta-Llama-3-8B`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) + * The non-official un-gated repo: [`NousResearch/Meta-Llama-3-8B`](https://huggingface.co/NousResearch/Meta-Llama-3-8B) +4. Clone the Optimum Neuron repository, **which contains the [complete script](https://github.com/huggingface/optimum-neuron/docs/source/training_tutorials/finetune_llm.py) described in this tutorial:** +```bash +git clone https://github.com/huggingface/optimum-neuron.git +``` + +## 2. Load and prepare the dataset + +For this tutorial, we will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k), an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization. + +Example: + +```python +{ + "instruction": "What is world of warcraft", + "context": "", + "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment" +} +``` + +We can use the `load_dataset()` method from the 🤗 Datasets library to load the `dolly` dataset very easily. + +```python +from datasets import load_dataset +from random import randrange + +# Load dataset from the hub +dataset = load_dataset("databricks/databricks-dolly-15k", split="train") + +print(f"dataset size: {len(dataset)}") +print(dataset[randrange(len(dataset))]) +# dataset size: 15011 +``` + +To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `format_dolly` that takes a raw sample and returns a string with our format instruction. + +```python +def format_dolly(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) + return prompt +``` + +In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. In other words, we are stacking multiple samples to one sequence and split them with an EOS Token. Packing/stacking samples can be done during training or before. Here, we will do it before training to save time. + +The following function `pack_dataset` takes a `dataset` and a `chunk_length` and returns a packed dataset: + +```python +from functools import partial +from itertools import chain + +# empty list to save remainder from batches to use in next batch +remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []} + +def pack_dataset(dataset, chunk_length=2048): + print(f"Chunking dataset into chunks of {chunk_length} tokens.") + + def chunk(sample, chunk_length=chunk_length): + # define global remainder variable to save remainder from batches to use in next batch + global remainder + # Concatenate all texts and add remainder from previous batch + concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()} + concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()} + # get total number of tokens for batch + batch_total_length = len(concatenated_examples[list(sample.keys())[0]]) + + # get max number of chunks for batch + if batch_total_length >= chunk_length: + batch_chunk_length = (batch_total_length // chunk_length) * chunk_length + + # Split by chunks of max_len. + result = { + k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] + for k, t in concatenated_examples.items() + } + # add remainder to global variable for next batch + remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()} + # prepare labels + result["labels"] = result["input_ids"].copy() + return result + + # tokenize and chunk dataset + lm_dataset = dataset.map( + partial(chunk, chunk_length=chunk_length), + batched=True, + ) + print(f"Total number of samples: {len(lm_dataset)}") + return lm_dataset +``` + +To summarize to prepare our dataset we will: + +1. Format our samples using the template method and add an EOS token at the end of each sample +2. Tokenize our dataset to convert it from text to tokens +3. Pack our dataset to 2048 tokens + +```python +from transformers import AutoTokenizer +from random import randint + +# Hugging Face Hub model id +# model_id = "meta-llama/Meta-Llama-3-8B" # gated +model_id = "NousResearch/Meta-Llama-3-8B" # ungated + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# template dataset to add prompt to each sample +def template_dataset(sample): + sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}" + return sample + +# apply prompt template per sample +dataset = dataset.map(template_dataset, remove_columns=list(dataset.features)) + +# print random sample +print(dataset[randint(0, len(dataset))]["text"]) + +# tokenize dataset +dataset = dataset.map( + lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features) +) + +# chunk dataset +lm_dataset = pack_dataset(dataset, chunk_length=2048) # We use 2048 as the maximum length for packing +``` + +After we processed the datasets we are going save it to disk. You could also save it to S3 or the Hugging Face Hub for later use. + +_Note: Packing and preprocessing your dataset can be run outside of the Trainium instance._ + +```python +# save train_dataset to disk +dataset_path = "tokenized_dolly" +lm_dataset.save_to_disk(dataset_path) +``` + +## 3. Fine-tune Llama on AWS Trainium using the `NeuronTrainer` + +Normally you would use the **[Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer)** and **[TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments)** classes to fine-tune PyTorch-based transformer models. + +But together with AWS, we have developed the [~`optimum.neuron.NeuronTrainer`] to improve performance, robustness, and ease-of-use when training on Trainium instances. It can be used as a 1-to-1 replacement for the `Trainer`. + +Since Llama-3 8B is a big model it will not fit on a single Neuron core, we need distributed training. In Optimum Neuron we support: + 1. [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html): It is an optimization of data-parallelism which consists in sharding the optimizer state (which usually represents half or more of the memory needed on the device) over the data-parallel ranks. + 2. [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): It is a technique which consists in sharding each of your model matrix-multiplications along a given axis (row or column) on multiple devices. It also known as intra-layer model parallelism. The number of devices to shard your parameters on is called the `tensor_parallel_size`. + 3. [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf): It is an optimization over Tensor Parallelism which shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations. + 4. [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): It consists in sharding the model block layers on multiple devices. It is also known as inter-layer model parallelism. The number of devices to shard your layers on is called the `pipeline_parallel_size`. + + + +If you want to know more about distributed training you can take a look at the [documentation](https://huggingface.co/docs/optimum-neuron/guides/distributed_training). + + + +Here, since we want to fine-tune an 8B model, we will not need to use pipeline parallelism. +Our training code will look as follows: + +```python +from optimum.neuron import NeuronTrainer as Trainer +from optimum.neuron.distributed import lazy_load_for_parallelism + +# Define the tensor_parallel_size +tensor_parallel_size = 8 + +# Load model from the Hugging face Hub +with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size): + model = AutoModelForCausalLM.from_pretrained(model_id) + +trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=dataset, + data_collator=default_data_collator, # no special collator needed since we stacked the dataset +) + +# Start training +trainer.train() + +trainer.save_model() # saves the tokenizer too for easy upload +``` + +The key points here are: + +- We use the `lazy_load_for_parallelism` context manager to lazily load the model. This will not load the full model weights on each worker, but instead only load the required weights (sharded or full). **This is much more memory efficient, and often mandatory to use.** +- We use the [~`optimum.neuron.NeuronTrainer`] to perform training. It will take the lazily loaded model, along with the `training_args`, which are an instance of [~`optimum.neuron.NeuronTrainingArguments`], and will handle all the parallelization and training on the Neuron cores. + +## 4. Launch Training + +We prepared a script called [finetune_llm.py](https://github.com/huggingface/optimum-neuron/docs/source/training_tutorials/finetune_llm.py) summing up everything mentioned in this tutorial. + + + +This script is a minimalistic version of our official example training script to run causal language modeling fine-tuning, called [run_clm.py](https://github.com/huggingface/optimum-neuron/blob/main/examples/language-modeling/run_clm.py). For the sake of this tutorial, we tried to get rid of anything that is not necessary, but if you want to do more custom things, maybe the solution is already implemented in `run_clm.py`! + +Also, these scripts are more designed as templates than final scripts. Feel free to take `finetune_llm.py` or `run_clm.py` and adapt them to your own needs! + + + +### Precompilation + +When training models on AWS Trainium we first need to compile our model with our training arguments. + +To overcome this, we added a [model cache repository](https://huggingface.co/docs/optimum-neuron/guides/cache_system), which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses. + +_Note: If your model configuration is not cached please open an issue on [Github](https://github.com/huggingface/optimum-neuron/issues), we are happy to include it._ + +The compilation command simply consists in calling your script as an input to the `neuron_parallel_compile` utility: + +```bash +MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node=32 finetune_llm.py \ + --model_id {model_id} \ + --dataset_path {dataset_path} \ + --bf16 True \ + --learning_rate 5e-5 \ + --output_dir dolly_llama \ + --overwrite_output_dir True \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing True \ + --tensor_parallel_size 8 \ + --max_steps 10 \ + --logging_steps 10 +``` + + + +Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training. + + + +_Note: Compiling without a cache can take a while. It will also create dummy files in the `dolly_llama_sharded` during compilation you will have to remove them afterwards. We also need to add `MALLOC_ARENA_MAX=64` to limit the CPU allocation to avoid potential crashes, don't remove it for now._ + +```bash +# remove dummy artifacts which are created by the precompilation command +rm -rf dolly_llama +``` + +### Actual Training + +After compilation is done we can start our actual training with a similar command, we just need to remove the use of `neuron_parallel_compile`. + +We will use `torchrun` to launch our training script. `torchrun` is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as `nproc_per_node` arguments alongside our hyperparameters. + +The difference to the compilation command is that we changed from `max_steps=10` to `num_train_epochs=3`. + +Launch the training, with the following command. + +```bash +MALLOC_ARENA_MAX=64 XLA_USE_BF16=1 torchrun --nproc_per_node=32 finetune_llm.py \ + --model_id {model_id} \ + --dataset_path {dataset_path} \ + --bf16 True \ + --learning_rate 5e-5 \ + --output_dir dolly_llama \ + --overwrite_output_dir True \ + --skip_cache_push True \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing True \ + --tensor_parallel_size 8 \ + --num_train_epochs 3 \ + --logging_steps 10 +``` + +That's it, we successfully trained Llama-3 8B on AWS Trainium! + +But before we can share and test our model we need to consolidate our model. Since we used Tensor Parallelism during training, we saved sharded versions of the checkpoints. We need to consolidate them now. + +### Consolidate the Checkpoint + +The Optimum CLI provides a way of doing that very easily via the `optimum neuron consolidate [sharded_checkpoint] [output_dir]` command: + +```bash +optimum-cli neuron consolidate dolly_llama dolly_llama +``` + +## 5. Evaluate and test fine-tuned Llama model + +As for training, to be able to run inference on AWS Trainium or AWS Inferentia2 we need to compile our model. In this case, we will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 for inference. + +Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the `NeuronModelForCausalLM` class to load our vanilla transformers checkpoint and convert it to neuron. + +```python +from optimum.neuron import NeuronModelForCausalLM +from transformers import AutoTokenizer + +compiler_args = {"num_cores": 2, "auto_cast_type": 'fp16'} +input_shapes = {"batch_size": 1, "sequence_length": 2048} + +tokenizer = AutoTokenizer.from_pretrained("dolly_llama") +model = NeuronModelForCausalLM.from_pretrained( + "dolly_llama", + export=True, + **compiler_args, + **input_shapes) +``` + +_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific._ + +```python +# COMMENT IN if you want to save the compiled model +# model.save_pretrained("compiled_dolly_llama") +``` + +We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a `dict` with our `instruction` and optionally a `context`. + +```python +def format_dolly_inference(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if "context" in sample else None + response = f"### Answer\n" + prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) + return prompt + + +def generate(sample): + prompt = format_dolly_inference(sample) + inputs = tokenizer(prompt, return_tensors="pt") + outputs = model.generate( + **inputs, + max_new_tokens=512, + do_sample=True, + temperature=0.9, + top_k=50, + top_p=0.9 + ) + return tokenizer.decode(outputs[0], skip_special_tokens=False)[len(prompt):] +``` + +Let's test inference. First we test without a context. + +_Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2._ + +```python +prompt = { + "instruction": "Can you tell me something about AWS?" +} +res = generate(prompt) + +print(res) +``` + +> AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys. + +That looks correct. Now, lets add some context, e.g. as you would do for RAG applications: + +```python +prompt = { + "instruction": "How can I train models on AWS Trainium?", + "context": "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks." +} +res = generate(prompt) + +print(res) +``` + +> You can use the Optimum Neuron interface to train models on AWS Trainium. + +Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium. diff --git a/docs/source/training_tutorials/finetune_llm.py b/docs/source/training_tutorials/finetune_llm.py new file mode 100644 index 000000000..d3fd2bfd0 --- /dev/null +++ b/docs/source/training_tutorials/finetune_llm.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass, field +from functools import partial +from itertools import chain +from typing import Optional + +from datasets import load_dataset, load_from_disk +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + default_data_collator, + set_seed, +) + +from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser +from optimum.neuron import NeuronTrainer as Trainer +from optimum.neuron import NeuronTrainingArguments as TrainingArguments +from optimum.neuron.distributed import lazy_load_for_parallelism + + +# Load dataset from the hub +dataset = load_dataset("databricks/databricks-dolly-15k", split="train") + + +def format_dolly(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) + return prompt + + +# empty list to save remainder from batches to use in next batch +remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []} + + +def pack_dataset(dataset, chunk_length=2048): + print(f"Chunking dataset into chunks of {chunk_length} tokens.") + + def chunk(sample, chunk_length=chunk_length): + # define global remainder variable to save remainder from batches to use in next batch + global remainder + # Concatenate all texts and add remainder from previous batch + concatenated_examples = {k: list(chain(*sample[k])) for k in sample.keys()} + concatenated_examples = {k: remainder[k] + concatenated_examples[k] for k in concatenated_examples.keys()} + # get total number of tokens for batch + batch_total_length = len(concatenated_examples[list(sample.keys())[0]]) + + # get max number of chunks for batch + if batch_total_length >= chunk_length: + batch_chunk_length = (batch_total_length // chunk_length) * chunk_length + + # Split by chunks of max_len. + result = { + k: [t[i : i + chunk_length] for i in range(0, batch_chunk_length, chunk_length)] + for k, t in concatenated_examples.items() + } + # add remainder to global variable for next batch + remainder = {k: concatenated_examples[k][batch_chunk_length:] for k in concatenated_examples.keys()} + # prepare labels + result["labels"] = result["input_ids"].copy() + return result + + # tokenize and chunk dataset + lm_dataset = dataset.map( + partial(chunk, chunk_length=chunk_length), + batched=True, + ) + print(f"Total number of samples: {len(lm_dataset)}") + return lm_dataset + + +def create_and_save_dataset(model_id: str, dataset_path: str): + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # template dataset to add prompt to each sample + def template_dataset(sample): + sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}" + return sample + + # apply prompt template per sample + dataset = dataset.map(template_dataset, remove_columns=list(dataset.features)) + + # tokenize dataset + dataset = dataset.map( + lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features) + ) + + # chunk dataset + lm_dataset = pack_dataset(dataset, chunk_length=2048) # We use 2048 as the maximum length for packing + + # save train_dataset to disk + lm_dataset.save_to_disk(dataset_path) + + +def training_function(script_args, training_args): + # load dataset + dataset = load_from_disk(script_args.dataset_path) + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) + with lazy_load_for_parallelism(tensor_parallel_size=training_args.tensor_parallel_size): + model = AutoModelForCausalLM.from_pretrained(script_args.model_id) + + # Create Trainer instance + trainer = Trainer( + model=model, + tokenizer=tokenizer, + args=training_args, + train_dataset=dataset, + data_collator=default_data_collator, # no special collator needed since we stacked the dataset + ) + + # Start training + trainer.train() + + trainer.save_model() # Saves the tokenizer too for easy upload + + +@dataclass +class ScriptArguments: + model_id: str = field( + default="meta-llama/Meta-Llama-3-8B", + metadata={"help": "The model that you want to train from the Hugging Face hub."}, + ) + dataset_path: Optional[str] = field( + metadata={"help": "Path to the preprocessed and tokenized dataset."}, + default=None, + ) + + +def main(): + parser = HfArgumentParser([ScriptArguments, TrainingArguments]) + script_args, training_args = parser.parse_args_into_dataclasses() + + if script_args.dataset_path is None: + create_and_save_dataset(script_args.model_id, "tokenized_dolly") + script_args.dataset_path = "tokenized_dolly" + + # set seed + set_seed(training_args.seed) + + # run training function + training_function(script_args, training_args) + + +if __name__ == "__main__": + main() diff --git a/docs/source/training_tutorials/notebooks.mdx b/docs/source/training_tutorials/notebooks.mdx new file mode 100644 index 000000000..2916d1c3f --- /dev/null +++ b/docs/source/training_tutorials/notebooks.mdx @@ -0,0 +1,24 @@ + + +# Notebooks + +We prepared some notebooks for you, so that you can run directly tutorials in the documentation. + +| Notebook | Description | Studio Lab | +|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:| +| [Fine-tune BERT for text classification on AWS Trainium](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/text-classification/notebook.ipynb) | Show how to fine-tune BERT on AWS Trainium for text classification. | [![Open in AWS Studio](https://studiolab.sagemaker.aws/studiolab.svg)](https://studiolab.sagemaker.aws/import/github/huggingface/optimum-neuron/blob/main/notebooks/text-classification/notebook.ipynb) | + diff --git a/docs/source/tutorials/fine_tune_llama_7b.mdx b/docs/source/tutorials/fine_tune_llama_7b.mdx deleted file mode 100644 index ddd814885..000000000 --- a/docs/source/tutorials/fine_tune_llama_7b.mdx +++ /dev/null @@ -1,358 +0,0 @@ - - -# Fine-tune and Test Llama 2 7B on AWS Trainium - -*There is a notebook version of that tutorial [here](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/text-generation/llama2-7b-fine-tuning.ipynb)*. - -This tutorial will teach you how to fine-tune open LLMs like [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) on AWS Trainium. In our example, we are going to leverage Hugging Face https://huggingface.co/docs/optimum-neuron/index, [Transformers](https://huggingface.co/docs/transformers/index) and https://huggingface.co/docs/datasets/index. - -You will learn how to: - -1. [Setup AWS environment](#1-setup-aws-environment) -2. [Load and process the dataset](#2-load-and-prepare-the-dataset) -3. [Fine-tune Llama on AWS Trainium using the `NeuronTrainer`](#3-fine-tune-llama-on-aws-trainium-using-the-neurontrainer) -4. [Evaluate and test fine-tuned Llama model](#4-evaluate-and-test-fine-tuned-llama-model) - -## Quick intro: AWS Trainium - -[AWS Trainium (Trn1)](https://aws.amazon.com/de/ec2/instance-types/trn1/) is a purpose-built EC2 for deep learning (DL) training workloads. Trainium is the successor of [AWS Inferentia](https://aws.amazon.com/ec2/instance-types/inf1/?nc1=h_ls) focused on high-performance training workloads. Trainium has been optimized for training natural language processing, computer vision, and recommender models. The accelerator supports a wide range of data types, including FP32, TF32, BF16, FP16, UINT8, and configurable FP8. - -The biggest Trainium instance, the `trn1.32xlarge` comes with over 500GB of memory, making it easy to fine-tune ~10B parameter models on a single instance. Below you will find an overview of the available instance types. More details [here](https://aws.amazon.com/de/ec2/instance-types/trn1/#Product_details): - -| instance size | accelerators | accelerator memory | vCPU | CPU Memory | price per hour | -| ----------------------------- | ------------ | ------------------ | ---- | ---------- | -------------- | -| trn1.2xlarge | 1 | 32 | 8 | 32 | \$1.34 | -| trn1.32xlarge | 16 | 512 | 128 | 512 | \$21.50 | -| trn1n.32xlarge (2x bandwidth) | 16 | 512 | 128 | 512 | \$24.78 | - -_Note: This tutorial was created on a trn1.32xlarge AWS EC2 Instance._ - -## 1. Setup AWS environment - -In this example, we will use the `trn1.32xlarge` instance on AWS with 16 Accelerator, including 32 Neuron Cores and the [Hugging Face Neuron Deep Learning AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2). The Hugging Face AMI comes with all important libraries, like Transformers, Datasets, Optimum and Neuron packages pre-installed. This makes it super easy to get started, since there is no need for environment management. - -This tutorial doesn’t cover how to create the instance in detail. You can check out the dedicated tutorial about [“Setting up AWS Trainium for Hugging Face Transformers”](https://huggingface.co/docs/optimum-neuron/guides/setup_aws_instance), which includes a step-by-step guide on setting up the environment. - -Once the instance is up and running, we can ssh into it. But instead of developing inside a terminal we want to use a `Jupyter` environment, which we can use for preparing our dataset and launching the training. For this, we need to add a port for forwarding in the `ssh` command, which will tunnel our localhost traffic to the Trainium instance. - -```bash -PUBLIC_DNS="" # IP address, e.g. ec2-3-80-.... -KEY_PATH="" # local path to key, e.g. ssh/trn.pem - -ssh -L 8080:localhost:8080 -i ${KEY_NAME}.pem ubuntu@$PUBLIC_DNS -``` - -Let's now pull the optimum repository with the [example notebook and scripts](https://github.com/huggingface/optimum-neuron/tree/main/notebooks/text-generation). - -```bash -git clone https://github.com/huggingface/optimum-neuron.git -``` - -Next we can change our directory to `notbooks/text-generation` and launch the `jupyter` environment. - -```bash -# change directory -cd optimum-neuron/notebooks/text-generation -# launch jupyter -python -m notebook --allow-root --port=8080 -``` - -You should see a familiar **`jupyter`** output with a URL to the notebook. - -**`http://localhost:8080/?token=8c1739aff1755bd7958c4cfccc8d08cb5da5234f61f129a9`** - -We can click on it, and a **`jupyter`** environment opens in our local browser. Open the notebook **`llama2-7b-fine-tuning.ipynb`** and lets get started. - -_Note: We are going to use the Jupyter environment only for preparing the dataset and then `torchrun` for launching our training script for distributed training._ - -If you are going to use official Llama 2 checkpoint you need to login into our hugging face account, which has access to the model, to use your token for accessing the gated repository. We can do this by running the following command: - -_Note: We also provide an ungated checkpoint._ - -```python -!huggingface-cli login --token YOUR_TOKEN -``` - -## 2. Load and prepare the dataset - -We will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k) an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization. - -```python -{ - "instruction": "What is world of warcraft", - "context": "", - "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment" -} -``` - -To load the `dolly` dataset, we use the `load_dataset()` method from the 🤗 Datasets library. - -```python -from datasets import load_dataset -from random import randrange - -# Load dataset from the hub -dataset = load_dataset("databricks/databricks-dolly-15k", split="train") - -print(f"dataset size: {len(dataset)}") -print(dataset[randrange(len(dataset))]) -# dataset size: 15011 - -``` - -To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `formatting_function` that takes a sample and returns a string with our format instruction. - -```python -def format_dolly(sample): - instruction = f"### Instruction\n{sample['instruction']}" - context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None - response = f"### Answer\n{sample['response']}" - # join all the parts together - prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) - return prompt - -``` - -let's test our formatting function on a random example. - -```python -from random import randrange - -print(format_dolly(dataset[randrange(len(dataset))])) -``` - -In addition, to formatting our samples we also want to pack multiple samples to one sequence to have a more efficient training. This means that we are stacking multiple samples to one sequence and split them with an EOS Token. This makes the training more efficient. Packing/stacking samples can be done during training or before. We will do it before training to save time. We created a utility method [pack_dataset](https://github.com/huggingface/optimum-neuron/tree/main/notebooks/text-generation/scripts/utils/pack_dataset.py) that takes a dataset and a packing function and returns a packed dataset. - -```python -from transformers import AutoTokenizer - -# Hugging Face model id -model_id = "philschmid/Llama-2-7b-hf" # ungated -# model_id = "meta-llama/Llama-2-7b-hf" # gated - -tokenizer = AutoTokenizer.from_pretrained(model_id) -``` - -To pack/stack our dataset we need to first tokenize it and then we can pack it with the `pack_dataset` method. To prepare our dataset we will now: - -1. Format our samples using the template method and add an EOS token at the end of each sample -2. Tokenize our dataset to convert it from text to tokens -3. Pack our dataset to 2048 tokens - -```python -from random import randint -# add utils method to path for loading dataset -import sys -sys.path.append("./scripts/utils") # make sure you change this to the correct path -from pack_dataset import pack_dataset - - -# template dataset to add prompt to each sample -def template_dataset(sample): - sample["text"] = f"{format_dolly(sample)}{tokenizer.eos_token}" - return sample - -# apply prompt template per sample -dataset = dataset.map(template_dataset, remove_columns=list(dataset.features)) -# print random sample -print(dataset[randint(0, len(dataset))]["text"]) - -# tokenize dataset -dataset = dataset.map( - lambda sample: tokenizer(sample["text"]), batched=True, remove_columns=list(dataset.features) -) - -# chunk dataset -lm_dataset = pack_dataset(dataset, chunk_length=2048) # We use 2048 as the maximum length for packing -``` - -After we processed the datasets we are going save it to disk. You could also save it to S3 or the Hugging Face Hub for later use. - -_Note: Packing and preprocessing your dataset can be run outside of the Trainium instance._ - -```python -# save train_dataset to disk -dataset_path = "tokenized_dolly" -lm_dataset.save_to_disk(dataset_path) -``` - -## 3. Fine-tune Llama on AWS Trainium using the `NeuronTrainer` - -Normally you would use the **[Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer)** and **[TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments)** to fine-tune PyTorch-based transformer models. - -But together with AWS, we have developed a `NeuronTrainer` to improve performance, robustness, and safety when training on Trainium instances. The `NeuronTrainer` is part of the `optimum-neuron` library and can be used as a 1-to-1 replacement for the `Trainer`. - -When it comes to distributed training on AWS Trainium there are a few things we need to take care of. Since Llama is a big model it might not fit on a single accelerator, thats why we added support for different distributed training strategies to the `NeuronTrainer` including: - -- [ZeRO-1](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/tutorials/training/zero1_gpt2.html): shards the optimizer state over multiple devices. -- [Tensor Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/tensor_parallelism_overview.html): shards the model parameters along a given dimension on multiple devices, defined with `tensor_parallel_size` -- [Sequence parallelism](https://arxiv.org/pdf/2205.05198.pdf) shards the activations on the sequence axis outside of the tensor parallel regions. It is useful because it saves memory by sharding the activations. -- [Pipeline Parallelism](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/neuronx-distributed/pipeline_parallelism_overview.html): _coming soon_ - -We prepared a [run_clm.py](https://github.com/huggingface/optimum-neuron/blob/main/notebooks/text-generation/scripts/run_clm.py), which implements those distributed training strategies for you already. If you want to know more about the details you can take a look at the [documentation](https://huggingface.co/docs/optimum-neuron/guides/distributed_training). When training models on AWS Accelerators we first need to compile our model with our training arguments. - -To overcome this we added a [model cache](https://huggingface.co/docs/optimum-neuron/guides/cache_system), which allows us to use precompiled models and configuration from Hugging Face Hub to skip the compilation step. But every change in the config, will lead to a new compilation, which could result in some cache misses. - -_Note: If your configuration is not cached please open an issue on [Github](https://github.com/huggingface/optimum-neuron/issues), we are happy to include it._ - -We pre-compiled the config for our training already meaning you can either skip the cell below or rerun it will only take a few minutes since it reuses the cached configuration. - -```python -# precompilation command -!MALLOC_ARENA_MAX=64 neuron_parallel_compile torchrun --nproc_per_node=32 scripts/run_clm.py \ - --model_id {model_id} \ - --dataset_path {dataset_path} \ - --bf16 True \ - --learning_rate 5e-5 \ - --output_dir dolly_llama \ - --overwrite_output_dir True \ - --per_device_train_batch_size 1 \ - --gradient_checkpointing True \ - --tensor_parallel_size 8 \ - --max_steps 10 \ - --logging_steps 10 \ - --gradient_accumulation_steps 16 -``` - -_Note: Compiling without a cache can take ~40 minutes. It will also create dummy files in the `dolly_llama_sharded` during compilation you we have to remove them afterwards. We also need to add `MALLOC_ARENA_MAX=64` to limit the CPU allocation to avoid potential crashes, don't remove it for now._ - -```python -# remove dummy artifacts which are created by the precompilation command -!rm -rf dolly_llama -``` - -After the compilation is done we can start our training with a similar command, we just need to remove the `neuron_parallel_compile`. We will use `torchrun` to launch our training script. `torchrun` is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as `nproc_per_node` arguments alongside our hyperparameters. -The difference to the compilation command is that we changed from `max_steps=10` to `num_train_epochs=3`. - -Launch the training, with the following command. - -```python -!MALLOC_ARENA_MAX=64 torchrun --nproc_per_node=32 scripts/run_clm.py \ - --model_id {model_id} \ - --dataset_path {dataset_path} \ - --bf16 True \ - --learning_rate 5e-5 \ - --output_dir dolly_llama \ - --overwrite_output_dir True \ - --skip_cache_push True \ - --per_device_train_batch_size 1 \ - --gradient_checkpointing True \ - --tensor_parallel_size 8 \ - --num_train_epochs 3 \ - --logging_steps 10 \ - --gradient_accumulation_steps 16 -``` - -Thats it, we successfully trained Llama 7B on AWS Trainium. The training took for 3 epochs on dolly (15k samples) took 43:24 minutes where the raw training time was only 31:46 minutes. This leads to a cost of ~$15.5 for the e2e training on the trn1.32xlarge instance. Not Bad! - -But before we can share and test our model we need to consolidate our model. Since we used Tensor Parallelism during training, we need to consolidate the model weights before we can use it. Tensor Parallelism shards the model weights accross different workers, only sharded checkpoints will be saved during training. - -The Optimum CLI provides a way of doing that very easily via the `optimum neuron consolidate`` command: - -```python -!optimum-cli neuron consolidate dolly_llama/tensor_parallel_shards dolly_llama -``` - -Lets remove our "sharded" checkpoints as we have consolidated them already to safetensors. - -```python -!rm -rf dolly_llama/tensor_parallel_shards -``` - -## 4. Evaluate and test fine-tuned Llama model - -Similar to training to be able to run inferece on AWS Trainium or AWS Inferentia2 we need to compile our model for the correct use. We will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 for inference. - -Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the `NeuronModelForCausalLM` class to load our vanilla transformers checkpoint and convert it to neuron. - -```python -from optimum.neuron import NeuronModelForCausalLM -from transformers import AutoTokenizer - -compiler_args = {"num_cores": 2, "auto_cast_type": 'fp16'} -input_shapes = {"batch_size": 1, "sequence_length": 2048} - -tokenizer = AutoTokenizer.from_pretrained("dolly_llama") -model = NeuronModelForCausalLM.from_pretrained( - "dolly_llama", - export=True, - **compiler_args, - **input_shapes) - -``` - -_Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific._ - -```python -# COMMENT IN if you want to save the compiled model -# model.save_pretrained("compiled_dolly_llama") -``` - -We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a `dict` with our `instruction` and optionally a `context`. - -```python -def format_dolly_inference(sample): - instruction = f"### Instruction\n{sample['instruction']}" - context = f"### Context\n{sample['context']}" if "context" in sample else None - response = f"### Answer\n" - # join all the parts together - prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) - return prompt - - -def generate(sample): - prompt = format_dolly_inference(sample) - inputs = tokenizer(prompt, return_tensors="pt") - outputs = model.generate(**inputs, - max_new_tokens=512, - do_sample=True, - temperature=0.9, - top_k=50, - top_p=0.9) - return tokenizer.decode(outputs[0], skip_special_tokens=False)[len(prompt):] -``` - -Lets test inference. First we test without a context. - -_Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2._ - -```python -prompt = { - "instruction": "Can you tell me something about AWS?" -} -res = generate(prompt) - -print(res) -``` - -> AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys. - -That looks correct. Now, lets add some context, e.g. as you would do for RAG applications - -```python -prompt = { - "instruction": "How can I train models on AWS Trainium?", - "context": "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks." -} -res = generate(prompt) - -print(res) -``` - -> You can use the Optimum Neuron interface to train models on AWS Trainium. - -Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium. diff --git a/docs/source/tutorials/overview.mdx b/docs/source/tutorials/overview.mdx deleted file mode 100644 index 52e86948a..000000000 --- a/docs/source/tutorials/overview.mdx +++ /dev/null @@ -1,27 +0,0 @@ - - -# Overview - -Welcome to the 🤗 Optimum Neuron tutorials! - -These tutorials will help you quickly get started with AWS Trainium / Inferentia on the following topics: - -- [Getting started with AWS Trainium and Hugging Face Transformers](./fine_tune_bert) -- [Generate images with Stable Diffusion models on AWS Inferentia2](./stable_diffusion) -- [Create your own chatbot with llama-2-13B on AWS Inferentia](./llama2-13b-chatbot) -- [Fine-tune and Test Llama 2 7B on AWS Trainium](./fine_tune_llama_7b) -- [Sentence Transformers on AWS Inferentia with Optimum Neuron](./sentence_transformers) diff --git a/notebooks/text-generation/llama2-7b-fine-tuning.ipynb b/notebooks/text-generation/llama2-7b-fine-tuning.ipynb index 70f2da861..f86eef356 100644 --- a/notebooks/text-generation/llama2-7b-fine-tuning.ipynb +++ b/notebooks/text-generation/llama2-7b-fine-tuning.ipynb @@ -114,9 +114,45 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cd7d187e4cbe448eaeafa10c9e803b9b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading readme: 0%| | 0.00/8.20k [00:00 5\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdatabricks/databricks-dolly-15k\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msplit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataset size: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(dataset)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(dataset[randrange(\u001b[38;5;28mlen\u001b[39m(dataset))])\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/site-packages/datasets/load.py:2556\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, verification_mode, ignore_verifications, keep_in_memory, save_infos, revision, token, use_auth_token, task, streaming, num_proc, storage_options, trust_remote_code, **config_kwargs)\u001b[0m\n\u001b[1;32m 2551\u001b[0m verification_mode \u001b[38;5;241m=\u001b[39m VerificationMode(\n\u001b[1;32m 2552\u001b[0m (verification_mode \u001b[38;5;129;01mor\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mBASIC_CHECKS) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m save_infos \u001b[38;5;28;01melse\u001b[39;00m VerificationMode\u001b[38;5;241m.\u001b[39mALL_CHECKS\n\u001b[1;32m 2553\u001b[0m )\n\u001b[1;32m 2555\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 2556\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2557\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2558\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2559\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2560\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2561\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2562\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2563\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2564\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2565\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2566\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2567\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2568\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrust_remote_code\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrust_remote_code\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2569\u001b[0m \u001b[43m \u001b[49m\u001b[43m_require_default_config_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 2570\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2571\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2573\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 2574\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/site-packages/datasets/load.py:2265\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, token, use_auth_token, storage_options, trust_remote_code, _require_default_config_name, **config_kwargs)\u001b[0m\n\u001b[1;32m 2263\u001b[0m builder_cls \u001b[38;5;241m=\u001b[39m get_dataset_builder_class(dataset_module, dataset_name\u001b[38;5;241m=\u001b[39mdataset_name)\n\u001b[1;32m 2264\u001b[0m \u001b[38;5;66;03m# Instantiate the dataset builder\u001b[39;00m\n\u001b[0;32m-> 2265\u001b[0m builder_instance: DatasetBuilder \u001b[38;5;241m=\u001b[39m \u001b[43mbuilder_cls\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2266\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2267\u001b[0m \u001b[43m \u001b[49m\u001b[43mdataset_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2268\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2269\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2270\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2271\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mhash\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdataset_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mhash\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2272\u001b[0m \u001b[43m \u001b[49m\u001b[43minfo\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minfo\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2273\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2275\u001b[0m \u001b[43m \u001b[49m\u001b[43mstorage_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstorage_options\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2276\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mbuilder_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2277\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2278\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2279\u001b[0m builder_instance\u001b[38;5;241m.\u001b[39m_use_legacy_cache_dir_if_possible(dataset_module)\n\u001b[1;32m 2281\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m builder_instance\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/site-packages/datasets/builder.py:365\u001b[0m, in \u001b[0;36mDatasetBuilder.__init__\u001b[0;34m(self, cache_dir, dataset_name, config_name, hash, base_path, info, features, token, use_auth_token, repo_id, data_files, data_dir, storage_options, writer_batch_size, name, **config_kwargs)\u001b[0m\n\u001b[1;32m 358\u001b[0m data_files \u001b[38;5;241m=\u001b[39m DataFilesDict\u001b[38;5;241m.\u001b[39mfrom_patterns(\n\u001b[1;32m 359\u001b[0m sanitize_patterns(data_files),\n\u001b[1;32m 360\u001b[0m base_path\u001b[38;5;241m=\u001b[39mbase_path,\n\u001b[1;32m 361\u001b[0m download_config\u001b[38;5;241m=\u001b[39mDownloadConfig(token\u001b[38;5;241m=\u001b[39mtoken, storage_options\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstorage_options),\n\u001b[1;32m 362\u001b[0m )\n\u001b[1;32m 364\u001b[0m \u001b[38;5;66;03m# Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset\u001b[39;00m\n\u001b[0;32m--> 365\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeatures\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[43minspect\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBUILDER_CONFIG_CLASS\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mparameters \u001b[38;5;129;01mand\u001b[39;00m features \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 366\u001b[0m config_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfeatures\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m features\n\u001b[1;32m 367\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data_files \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/inspect.py:3254\u001b[0m, in \u001b[0;36msignature\u001b[0;34m(obj, follow_wrapped, globals, locals, eval_str)\u001b[0m\n\u001b[1;32m 3252\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msignature\u001b[39m(obj, \u001b[38;5;241m*\u001b[39m, follow_wrapped\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28mglobals\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28mlocals\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, eval_str\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 3253\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Get a signature object for the passed callable.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 3254\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mSignature\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfollow_wrapped\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_wrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3255\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_str\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meval_str\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/inspect.py:3002\u001b[0m, in \u001b[0;36mSignature.from_callable\u001b[0;34m(cls, obj, follow_wrapped, globals, locals, eval_str)\u001b[0m\n\u001b[1;32m 2998\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 2999\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfrom_callable\u001b[39m(\u001b[38;5;28mcls\u001b[39m, obj, \u001b[38;5;241m*\u001b[39m,\n\u001b[1;32m 3000\u001b[0m follow_wrapped\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;28mglobals\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;28mlocals\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, eval_str\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 3001\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Constructs Signature for the given callable object.\"\"\"\u001b[39;00m\n\u001b[0;32m-> 3002\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_signature_from_callable\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msigcls\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3003\u001b[0m \u001b[43m \u001b[49m\u001b[43mfollow_wrapper_chains\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfollow_wrapped\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3004\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_str\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meval_str\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/inspect.py:2463\u001b[0m, in \u001b[0;36m_signature_from_callable\u001b[0;34m(obj, follow_wrapper_chains, skip_bound_arg, globals, locals, eval_str, sigcls)\u001b[0m\n\u001b[1;32m 2458\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m sig\u001b[38;5;241m.\u001b[39mreplace(parameters\u001b[38;5;241m=\u001b[39mnew_params)\n\u001b[1;32m 2460\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isfunction(obj) \u001b[38;5;129;01mor\u001b[39;00m _signature_is_functionlike(obj):\n\u001b[1;32m 2461\u001b[0m \u001b[38;5;66;03m# If it's a pure Python function, or an object that is duck type\u001b[39;00m\n\u001b[1;32m 2462\u001b[0m \u001b[38;5;66;03m# of a Python function (Cython functions, for instance), then:\u001b[39;00m\n\u001b[0;32m-> 2463\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_signature_from_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43msigcls\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2464\u001b[0m \u001b[43m \u001b[49m\u001b[43mskip_bound_arg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mskip_bound_arg\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2465\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mglobals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlocals\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43meval_str\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43meval_str\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2467\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _signature_is_builtin(obj):\n\u001b[1;32m 2468\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _signature_from_builtin(sigcls, obj,\n\u001b[1;32m 2469\u001b[0m skip_bound_arg\u001b[38;5;241m=\u001b[39mskip_bound_arg)\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/inspect.py:2325\u001b[0m, in \u001b[0;36m_signature_from_function\u001b[0;34m(cls, func, skip_bound_arg, globals, locals, eval_str)\u001b[0m\n\u001b[1;32m 2323\u001b[0m kind \u001b[38;5;241m=\u001b[39m _POSITIONAL_ONLY \u001b[38;5;28;01mif\u001b[39;00m posonly_left \u001b[38;5;28;01melse\u001b[39;00m _POSITIONAL_OR_KEYWORD\n\u001b[1;32m 2324\u001b[0m annotation \u001b[38;5;241m=\u001b[39m annotations\u001b[38;5;241m.\u001b[39mget(name, _empty)\n\u001b[0;32m-> 2325\u001b[0m parameters\u001b[38;5;241m.\u001b[39mappend(\u001b[43mParameter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mannotation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mannotation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2326\u001b[0m \u001b[43m \u001b[49m\u001b[43mkind\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkind\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 2327\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m posonly_left:\n\u001b[1;32m 2328\u001b[0m posonly_left \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/inspect.py:2639\u001b[0m, in \u001b[0;36mParameter.__init__\u001b[0;34m(self, name, kind, default, annotation)\u001b[0m\n\u001b[1;32m 2637\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, kind, \u001b[38;5;241m*\u001b[39m, default\u001b[38;5;241m=\u001b[39m_empty, annotation\u001b[38;5;241m=\u001b[39m_empty):\n\u001b[1;32m 2638\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 2639\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_kind \u001b[38;5;241m=\u001b[39m \u001b[43m_ParameterKind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkind\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2640\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m:\n\u001b[1;32m 2641\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvalue \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkind\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m is not a valid Parameter.kind\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/enum.py:385\u001b[0m, in \u001b[0;36mEnumMeta.__call__\u001b[0;34m(cls, value, names, module, qualname, type, start)\u001b[0m\n\u001b[1;32m 360\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;124;03mEither returns an existing member, or creates a new enum class.\u001b[39;00m\n\u001b[1;32m 362\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 382\u001b[0m \u001b[38;5;124;03m`type`, if set, will be mixed in as the first base class.\u001b[39;00m\n\u001b[1;32m 383\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 384\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m: \u001b[38;5;66;03m# simple value lookup\u001b[39;00m\n\u001b[0;32m--> 385\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__new__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvalue\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 386\u001b[0m \u001b[38;5;66;03m# otherwise, functional API: we're creating a new Enum type\u001b[39;00m\n\u001b[1;32m 387\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_create_(\n\u001b[1;32m 388\u001b[0m value,\n\u001b[1;32m 389\u001b[0m names,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 393\u001b[0m start\u001b[38;5;241m=\u001b[39mstart,\n\u001b[1;32m 394\u001b[0m )\n", + "File \u001b[0;32m~/micromamba/envs/optimum_neuron/lib/python3.10/enum.py:678\u001b[0m, in \u001b[0;36mEnum.__new__\u001b[0;34m(cls, value)\u001b[0m\n\u001b[1;32m 672\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mEnum\u001b[39;00m(metaclass\u001b[38;5;241m=\u001b[39mEnumMeta):\n\u001b[1;32m 673\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 674\u001b[0m \u001b[38;5;124;03m Generic enumeration.\u001b[39;00m\n\u001b[1;32m 675\u001b[0m \n\u001b[1;32m 676\u001b[0m \u001b[38;5;124;03m Derive from this class to define new enumerations.\u001b[39;00m\n\u001b[1;32m 677\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 678\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m, value):\n\u001b[1;32m 679\u001b[0m \u001b[38;5;66;03m# all enum instances are actually created during class construction\u001b[39;00m\n\u001b[1;32m 680\u001b[0m \u001b[38;5;66;03m# without calling this method; this method is called by the metaclass'\u001b[39;00m\n\u001b[1;32m 681\u001b[0m \u001b[38;5;66;03m# __call__ (i.e. Color(3) ), and by pickle\u001b[39;00m\n\u001b[1;32m 682\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(value) \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mcls\u001b[39m:\n\u001b[1;32m 683\u001b[0m \u001b[38;5;66;03m# For lookups like Color(Color.RED)\u001b[39;00m\n\u001b[1;32m 684\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m value\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "from datasets import load_dataset\n", "from random import randrange\n", @@ -126,7 +162,7 @@ "\n", "print(f\"dataset size: {len(dataset)}\")\n", "print(dataset[randrange(len(dataset))])\n", - "# dataset size: 15011\n" + "# dataset size: 15011" ] }, { @@ -134,12 +170,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `formatting_function` that takes a sample and returns a string with our format instruction." + "To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `format_dolly` that takes a sample and returns a string with our format instruction." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -149,21 +185,33 @@ " response = f\"### Answer\\n{sample['response']}\"\n", " # join all the parts together\n", " prompt = \"\\n\\n\".join([i for i in [instruction, context, response] if i is not None])\n", - " return prompt\n" + " return prompt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "lets test our formatting function on a random example." + "Let's test our formatting function on a random example:\n" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'dataset' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mrandom\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m randrange\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(format_dolly(\u001b[43mdataset\u001b[49m[randrange(\u001b[38;5;28mlen\u001b[39m(dataset))]))\n", + "\u001b[0;31mNameError\u001b[0m: name 'dataset' is not defined" + ] + } + ], "source": [ "from random import randrange\n", "\n", @@ -174,7 +222,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In addition, to formatting our samples we also want to pack multiple samples to one sequence to have a more efficient training. This means that we are stacking multiple samples to one sequence and split them with an EOS Token. This makes the training more efficient. Packing/stacking samples can be done during training or before. We will do it before training to save time. We created a utility method [pack_dataset](./scripts/utils/pack_dataset.py) that takes a dataset and a packing function and returns a packed dataset.\n" + "In addition to formatting our samples, we also want to pack multiple samples to one sequence to have a more efficient training. This means that we are stacking multiple samples to one sequence and split them with an EOS Token. This makes the training more efficient. Packing/stacking samples can be done during training or before. We will do it before training to save time. We created a utility method [pack_dataset](./scripts/utils/pack_dataset.py) that takes a dataset and a packing function and returns a packed dataset.\n" ] }, { @@ -546,7 +594,7 @@ ], "metadata": { "kernelspec": { - "display_name": "pytorch", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -560,7 +608,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/notebooks/text-generation/scripts/utils/pack_dataset.py b/notebooks/text-generation/scripts/utils/pack_dataset.py index 9f06a8637..885011fde 100644 --- a/notebooks/text-generation/scripts/utils/pack_dataset.py +++ b/notebooks/text-generation/scripts/utils/pack_dataset.py @@ -2,10 +2,10 @@ from itertools import chain +# empty list to save remainder from batches to use in next batch remainder = {"input_ids": [], "attention_mask": [], "token_type_ids": []} -# empty list to save remainder from batches to use in next batch def pack_dataset(dataset, chunk_length=2048): print(f"Chunking dataset into chunks of {chunk_length} tokens.")