Skip to content

Commit

Permalink
Update fine_tuning_llm_to_generate_persian_product_catalogs_in_json_f…
Browse files Browse the repository at this point in the history
…ormat.ipynb
  • Loading branch information
MrzEsma authored Jun 17, 2024
1 parent 67b05c7 commit 9b2cabc
Showing 1 changed file with 59 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@
"fp16 = False\n",
"bf16 = False\n",
"per_device_train_batch_size = 4\n",
"per_device_eval_batch_size = 4\n",
"gradient_accumulation_steps = 1\n",
"gradient_checkpointing = True\n",
"learning_rate = 0.00015\n",
Expand Down Expand Up @@ -221,7 +220,7 @@
}
},
"source": [
"## Train Code"
"## Model Training"
]
},
{
Expand All @@ -247,6 +246,48 @@
"print(f\"Size of the train set: {len(train_dataset)}. Size of the validation set: {len(eval_dataset)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a5216910d0a339a",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Load LoRA configuration\n",
"peft_config = LoraConfig(\n",
" r=lora_r,\n",
" lora_alpha=lora_alpha,\n",
" lora_dropout=lora_dropout,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
" target_modules=target_modules\n",
")"
]
},
{
"cell_type": "markdown",
"id": "230bfceb895c6738",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"The LoraConfig object is used to configure the LoRA (Low-Rank Adaptation) settings for the model when using the Peft library. This can help to reduce the number of parameters that need to be fine-tuned, which can lead to faster training and lower memory usage. Here's a breakdown of the parameters:\n",
"- `r`: The rank of the low-rank matrices used in LoRA. This parameter controls the dimensionality of the low-rank adaptation and directly impacts the model's capacity to adapt and the computational cost.\n",
"- `lora_alpha`: This parameter controls the scaling factor for the low-rank adaptation matrices. A higher alpha value can increase the model's capacity to learn new tasks.\n",
"- `lora_dropout`: The dropout rate for LoRA. This can help to prevent overfitting during fine-tuning. In this case, it's set to 0.1.\n",
"- `bias`: Specifies whether to add a bias term to the low-rank matrices. In this case, it's set to \"none\", which means that no bias term will be added.\n",
"- `task_type`: Defines the type of task for which the model is being fine-tuned. Here, \"CAUSAL_LM\" indicates that the task is a causal language modeling task, which predicts the next word in a sequence.\n",
"- `target_modules`: Specifies the modules in the model to which LoRA will be applied. In this case, it's set to `[\"q_proj\", \"v_proj\", 'k_proj']`, which are the query, value, and key projection layers in the model's attention mechanism."
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -271,25 +312,24 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8a5216910d0a339a",
"cell_type": "markdown",
"id": "535275d96f478839",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"outputs": [],
"source": [
"# Load LoRA configuration\n",
"peft_config = LoraConfig(\n",
" lora_alpha=lora_alpha,\n",
" lora_dropout=lora_dropout,\n",
" r=lora_r,\n",
" bias=\"none\",\n",
" task_type=\"CAUSAL_LM\",\n",
")"
"This block configures the settings for using BitsAndBytes (bnb), a library that provides efficient memory management and compression techniques for PyTorch models. Specifically, it defines how the model weights will be loaded and quantized in 4-bit precision, which is useful for reducing memory usage and potentially speeding up inference.\n",
"\n",
"- `load_in_4bit`: A boolean that determines whether to load the model in 4-bit precision.\n",
"- `bnb_4bit_quant_type`: Specifies the type of 4-bit quantization to use. Here, it's set to 4-bit NormalFloat (NF4) quantization type, which is a new data type introduced in QLoRA. This type is information-theoretically optimal for normally distributed weights, providing an efficient way to quantize the model for fine-tuning.\n",
"- `bnb_4bit_compute_dtype`: Sets the data type used for computations involving the quantized model. In QLoRA, it's set to \"float16\", which is commonly used for mixed-precision training to balance performance and precision.\n",
"- `bnb_4bit_use_double_quant`: This boolean parameter indicates whether to use double quantization. Setting it to False means that only single quantization will be used, which is typically faster but might be slightly less accurate.\n",
"\n",
"Why we have two data type (quant_type and compute_type)? \n",
"QLoRA employs two distinct data types: one for storing base model weights (in here 4-bit NormalFloat) and another for computational operations (16-bit). During the forward and backward passes, QLoRA dequantizes the weights from the storage format to the computational format. However, it only calculates gradients for the LoRA parameters, which utilize 16-bit bfloat. This approach ensures that weights are decompressed only when necessary, maintaining low memory usage throughout both training and inference phases.\n"
]
},
{
Expand Down Expand Up @@ -502,7 +542,7 @@
}
},
"source": [
"## Inference Code"
"## Inference"
]
},
{
Expand All @@ -527,6 +567,7 @@
" torch.cuda.empty_cache()\n",
" gc.collect()\n",
"\n",
"\n",
"clear_hardwares()\n",
"clear_hardwares()"
]
Expand All @@ -545,13 +586,13 @@
"source": [
"def generate(model, prompt: str, kwargs):\n",
" tokenized_prompt = tokenizer(prompt, return_tensors='pt').to(model.device)\n",
" \n",
"\n",
" prompt_length = len(tokenized_prompt.get('input_ids')[0])\n",
" \n",
"\n",
" with torch.cuda.amp.autocast():\n",
" output_tokens = model.generate(**tokenized_prompt, **kwargs) if kwargs else model.generate(**tokenized_prompt)\n",
" output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)\n",
" \n",
"\n",
" return output"
]
},
Expand Down

0 comments on commit 9b2cabc

Please sign in to comment.