diff --git a/comps/finetuning/src/Dockerfile b/comps/finetuning/src/Dockerfile index 5c1fe569c1..840a1be3f8 100644 --- a/comps/finetuning/src/Dockerfile +++ b/comps/finetuning/src/Dockerfile @@ -20,9 +20,9 @@ ENV PATH=$PATH:/home/user/.local/bin ARG uvpip='uv pip install --system --no-cache-dir' RUN python -m pip install --no-cache-dir --upgrade pip setuptools uv && \ - python -m $uvpip torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \ - python -m $uvpip intel-extension-for-pytorch && \ - python -m $uvpip oneccl_bind_pt --index-strategy unsafe-best-match --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + python -m $uvpip oneccl_bind_pt==2.7.0+cpu --index-strategy unsafe-best-match --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ && \ + python -m $uvpip torch==2.7.0+cpu torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu && \ + python -m $uvpip intel-extension-for-pytorch==2.7.0 && \ python -m $uvpip -r /home/user/comps/finetuning/src/requirements-cpu.txt ENV PYTHONPATH=$PYTHONPATH:/home/user diff --git a/comps/finetuning/src/README.md b/comps/finetuning/src/README.md index c01a44e72c..58f795a5b1 100644 --- a/comps/finetuning/src/README.md +++ b/comps/finetuning/src/README.md @@ -252,6 +252,42 @@ Below is an example for the format of the DPO dataset: {"system": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.", "question": "Generate an approximately fifteen-word sentence that describes all this data: Midsummer House eatType restaurant; Midsummer House food Chinese; Midsummer House priceRange moderate; Midsummer House customer rating 3 out of 5; Midsummer House near All Bar One", "chosen": "Midsummer House is a moderately priced Chinese restaurant with a 3/5 customer rating, located near All Bar One.", "rejected": " Sure! Here's a sentence that describes all the data you provided:\n\n\"Midsummer House is a moderately priced Chinese restaurant with a customer rating of 3 out of 5, located near All Bar One, offering a variety of delicious dishes.\""} ``` +#### 3.2.6 Reasoning model finetuning + +Use the following command to launch a job for finetuning reasoning model, such as `Qwen/Qwen2.5-7B`: + +```bash +# create a finetuning job +curl http://${your_ip}:8015/v1/fine_tuning/jobs \ + -X POST \ + -H "Content-Type: application/json" \ + -d '{ + "training_file": "medical_o1_sft.json", + "model": "Qwen/Qwen2.5-3B", + "Dataset": { + "max_length":2048 + }, + "Training":{ + "epochs":1, + "batch_size":1, + "gradient_accumulation_steps":8 + }, + "General": { + "task":"reasoning" + } + }' +``` + +Below is an example for the format of the reasoning dataset [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT): + +```json +{ + "Question": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?", + "Complex_CoT": "Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the left side of the heart, it could shoot off to the brain and cause that sudden weakness by blocking blood flow there.\n\nHmm, but how would the clot get from the right side of the heart to the left without going through the lungs and getting filtered out?\n\nHere's where our cardiac anomaly comes in: a patent foramen ovale or PFO. That's like a sneaky little shortcut in the heart between the right and left atria.\n\nAnd it's actually pretty common, found in about a quarter of adults, which definitely makes it the top suspect here.\n\nSo with all these pieces - long travel, leg clot, sudden weakness - a PFO fits the bill perfectly, letting a clot cross over and cause all this.\n\nEverything fits together pretty neatly, so I'd bet PFO is the heart issue waiting to be discovered. Yeah, that really clicks into place!", + "Response": "The specific cardiac abnormality most likely to be found in this scenario is a patent foramen ovale (PFO). This condition could allow a blood clot from the venous system, such as one from a deep vein thrombosis in the leg, to bypass the lungs and pass directly into the arterial circulation. This can occur when the clot moves from the right atrium to the left atrium through the PFO. Once in the arterial system, the clot can travel to the brain, potentially causing an embolic stroke, which would explain the sudden weakness in the left arm and leg. The connection between the recent travel, which increases the risk of deep vein thrombosis, and the neurological symptoms suggests the presence of a PFO facilitating a paradoxical embolism." +} +``` + ### 3.3 Manage fine-tuning job Below commands show how to list finetuning jobs, retrieve a finetuning job, cancel a finetuning job and list checkpoints of a finetuning job. diff --git a/comps/finetuning/src/integrations/finetune_config.py b/comps/finetuning/src/integrations/finetune_config.py index eec96c5dc2..10aebeab54 100644 --- a/comps/finetuning/src/integrations/finetune_config.py +++ b/comps/finetuning/src/integrations/finetune_config.py @@ -74,7 +74,7 @@ def check_report_to(cls, v: str): @validator("task") def check_task(cls, v: str): - assert v in ["instruction_tuning", "pretraining", "dpo", "rerank", "embedding"] + assert v in ["instruction_tuning", "pretraining", "dpo", "rerank", "embedding", "reasoning"] return v @@ -116,6 +116,9 @@ class DatasetConfig(BaseModel): ) query_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for query") passage_instruction_for_retrieval: Optional[str] = Field(default=None, description="instruction for passage") + reasoning_dataset_keys: Optional[List[str]] = Field( + default=["Question", "Complex_CoT", "Response"], description="keys of reasoning dataset" + ) class RayResourceConfig(BaseModel): diff --git a/comps/finetuning/src/integrations/llm_on_ray/finetune/data_process.py b/comps/finetuning/src/integrations/llm_on_ray/finetune/data_process.py index 070f77a651..3c57cfa7c9 100644 --- a/comps/finetuning/src/integrations/llm_on_ray/finetune/data_process.py +++ b/comps/finetuning/src/integrations/llm_on_ray/finetune/data_process.py @@ -44,8 +44,8 @@ def make_prompt(self, examples): prompts["prompt_targets"] = [] for rec in examples: instruction = rec["instruction"] - response = rec["input"] - context = rec.get("output") + context = rec["input"] + response = rec.get("output") if not instruction: raise ValueError(f"Expected an instruction in: {rec}") # if not response: @@ -193,7 +193,94 @@ def tokenize(self, examples): if self.mask_input: labels[:input_id_len] = [IGNORE_INDEX] * input_id_len # mask response - if self.mask_response: + elif self.mask_response: + labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len) + + examples["input_ids"].append(results["input_ids"]) + examples["labels"].append(labels) + examples["attention_mask"].append(results["attention_mask"]) + return examples + + +class ReasoningDataProcessor: + def __init__(self, config, tokenizer): + self.tokenizer = tokenizer + self.think_tokens = ["", ""] + tokenizer.add_special_tokens({"additional_special_tokens": self.think_tokens}) + tokenizer.pad_token = tokenizer.eos_token + self.system = """ +You are an assistant that engages in extremely thorough, self-questioning reasoning. You will help the user to answer the question they propose. +Your responses must be given after the thorough and rigorous reasoning, and output your reasoning content between and tags. +Make sure to always include the final answer, and output the final answer after the tag. +""" + self.keys = config["Dataset"].get("reasoning_dataset_keys", ["Question", "Complex_CoT", "Response"]) + assert len(self.keys) >= 2, "dataset must have 2 keys or more." + self.padding_side = config["Dataset"].get("padding_side", "right") + self.truncation_side = config["Dataset"].get("truncation_side", "right") + self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 2048) + self.truncation = config["Dataset"].get("truncation", True) + # set padding to max_length for hpu to avoid bug in GaudiTrainer and accelerate training on hpu + self.padding = config["Dataset"].get("padding", True) if config["Training"]["device"] != "hpu" else "max_length" + self.mask_input = config["Dataset"].get("mask_input", True) + self.mask_response = config["Dataset"].get("mask_response", True) + + def make_prompt(self, examples): + prompts = {} + prompts["prompt_sources"] = [] + prompts["prompt_targets"] = [] + for rec in examples: + for key in self.keys: + assert key in rec, f"Key {key} not in dataset, provide correct keys in reasoning_dataset_keys argument." + question = rec[self.keys[0]] + reasoning = rec[self.keys[1]] if len(self.keys) == 3 else "" + response = rec[self.keys[2]] if len(self.keys) == 3 else rec[self.keys[1]] + if not question: + raise ValueError(f"Expected a question in: {rec}") + prompt = self.system + "\n" + "### User" + "\n" + question + "\n" + "### Assistant" + "\n" + prompts["prompt_sources"].append(prompt) + if reasoning: + prompt_response = ( + self.think_tokens[0] + reasoning + self.think_tokens[1] + "\n" + response + self.tokenizer.eos_token + ) + else: + prompt_response = response + self.tokenizer.eos_token + prompts["prompt_targets"].append(prompt_response) + return prompts + + def tokenize(self, examples): + keys = list(examples.data.keys()) + if len(keys) != 2: + raise ValueError("Unsupported dataset format") + + examples["input_ids"] = [] + examples["labels"] = [] + examples["attention_mask"] = [] + for s, t in zip(examples[keys[0]], examples[keys[1]]): + results = self.tokenizer( + s + t, + padding=self.padding, + truncation=self.truncation, + return_tensors=None, + max_length=self.max_length, + ) + + input_ids = results["input_ids"] + input_len = len(input_ids) + labels = copy.deepcopy(input_ids) + if self.mask_input or self.mask_response: + sources_tokenized = self.tokenizer( + s, + padding=False, + truncation=True, + return_tensors=None, + max_length=self.max_length, + ) + input_id_len = len(sources_tokenized["input_ids"]) + # mask input + if self.mask_input: + labels[:input_id_len] = [IGNORE_INDEX] * input_id_len + # mask response + elif self.mask_response: labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len) examples["input_ids"].append(results["input_ids"]) diff --git a/comps/finetuning/src/integrations/llm_on_ray/finetune/finetune.py b/comps/finetuning/src/integrations/llm_on_ray/finetune/finetune.py index 12101a581c..f86ca7cab9 100644 --- a/comps/finetuning/src/integrations/llm_on_ray/finetune/finetune.py +++ b/comps/finetuning/src/integrations/llm_on_ray/finetune/finetune.py @@ -32,6 +32,7 @@ GroupCollator, InstructionDataProcessor, PretrainingDataProcessor, + ReasoningDataProcessor, TrainDatasetForCE, TrainDatasetForEmbedding, ) @@ -301,6 +302,21 @@ def group_texts(examples): desc="Tokenize dataset", ) return tokenized_dataset + elif task == "reasoning": + processor = ReasoningDataProcessor(config, tokenizer) + for key in dataset: + prompts = processor.make_prompt(dataset[key]) + dataset[key] = datasets.Dataset.from_dict(prompts) + + column_names = list(dataset["train"].features) + tokenized_dataset = dataset.map( + processor.tokenize, + remove_columns=column_names, + batched=True, + load_from_cache_file=False, + desc="Tokenize dataset", + ) + return tokenized_dataset elif task == "rerank": dataset["train"] = TrainDatasetForCE(dataset["train"], config["Dataset"], tokenizer) return dataset @@ -313,10 +329,14 @@ def group_texts(examples): def prepare_data_collator(config: Dict, tokenizer): task = config["General"].get("task", "instruction_tuning") - if task == "instruction_tuning" or task == "pretraining": + if task in ["instruction_tuning", "pretraining"]: return transformers.DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8 ) + elif task == "reasoning": + return transformers.DataCollatorForSeq2Seq( + tokenizer=tokenizer, max_length=config["Dataset"]["max_length"], return_tensors="pt" + ) elif task == "dpo": return DPOCollator(tokenizer) elif task == "rerank": @@ -338,14 +358,14 @@ def load_model(config: Dict): model_config = config["General"].get("config", {}) task = config["General"].get("task", "instruction_tuning") ref_model = None - if task in ["instruction_tuning", "pretraining", "dpo"]: + if task in ["instruction_tuning", "pretraining", "dpo", "reasoning"]: model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, **model_config) if task == "dpo": ref_model = transformers.AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=model_dtype, **model_config ) lora_config = config["General"].get("lora_config", None) - if lora_config and task == "instruction_tuning": + if lora_config and task in ["instruction_tuning"]: peft_config = LoraConfig(**lora_config) model = get_peft_model(model, peft_config) elif task == "rerank": @@ -389,6 +409,8 @@ def load_model(config: Dict): def get_trainer(config: Dict, model, ref_model, tokenizer, tokenized_dataset, data_collator): device = config["Training"]["device"] task = config["General"].get("task", "instruction_tuning") + if task == "reasoning": + model.resize_token_embeddings(len(tokenizer)) if device in ["cpu", "gpu", "cuda"]: training_args = convert_to_training_args(TrainingArguments, config) if task == "dpo": diff --git a/tests/finetuning/test_finetuning_native.sh b/tests/finetuning/test_finetuning_native.sh index 8f129258be..898016c6f1 100644 --- a/tests/finetuning/test_finetuning_native.sh +++ b/tests/finetuning/test_finetuning_native.sh @@ -10,11 +10,12 @@ ip_address=$(hostname -I | awk '{print $1}') finetuning_service_port=8015 ray_port=8265 service_name="finetuning" +TAG="comp" function build_docker_images() { cd $WORKPATH echo $(pwd) - docker build -t opea/finetuning:comps --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/src/Dockerfile . + docker build -t opea/finetuning:$TAG --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy --build-arg HF_TOKEN=$HF_TOKEN -f comps/finetuning/src/Dockerfile . if [ $? -ne 0 ]; then echo "opea/finetuning built fail" exit 1 @@ -25,6 +26,7 @@ function build_docker_images() { function start_service() { export no_proxy="localhost,127.0.0.1,"${ip_address} + export TAG=$TAG cd $WORKPATH/comps/finetuning/deployment/docker_compose docker compose -f compose.yaml up ${service_name} -d > start_services_with_compose.log sleep 1m @@ -113,6 +115,10 @@ function validate_finetune() { break elif [[ "$STATUS" == "failed" ]]; then echo "training: failed." + docker logs $DOCKER_NAME 2>&1 | tee ${LOG_PATH}/finetuning-job.log + RAY_JOB_ID=$(grep -o 'raysubmit_[A-Za-z0-9]\+' ${LOG_PATH}/finetuning-job.log | tail -n 1) + echo "Log for failure." + docker exec $DOCKER_NAME python -c "from ray.job_submission import JobSubmissionClient;client = JobSubmissionClient();print(client.get_job_logs('$RAY_JOB_ID'))" exit 1 else echo "training: '$STATUS'" @@ -225,6 +231,39 @@ EOF '{"id":"ft-job' \ '{"training_file": "test_data_dpo.jsonl","model": "facebook/opt-125m","General":{"task":"dpo"}}' + ########################## + # reasoning SFT test # + ########################## + # test /v1/dataprep upload file + cat < test_data_reasoning.jsonl +[ + { + "Question": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?", + "Complex_CoT": "Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the left side of the heart, it could shoot off to the brain and cause that sudden weakness by blocking blood flow there.\n\nHmm, but how would the clot get from the right side of the heart to the left without going through the lungs and getting filtered out?\n\nHere's where our cardiac anomaly comes in: a patent foramen ovale or PFO. That's like a sneaky little shortcut in the heart between the right and left atria.\n\nAnd it's actually pretty common, found in about a quarter of adults, which definitely makes it the top suspect here.\n\nSo with all these pieces - long travel, leg clot, sudden weakness - a PFO fits the bill perfectly, letting a clot cross over and cause all this.\n\nEverything fits together pretty neatly, so I'd bet PFO is the heart issue waiting to be discovered. Yeah, that really clicks into place!", + "Response": "The specific cardiac abnormality most likely to be found in this scenario is a patent foramen ovale (PFO). This condition could allow a blood clot from the venous system, such as one from a deep vein thrombosis in the leg, to bypass the lungs and pass directly into the arterial circulation. This can occur when the clot moves from the right atrium to the left atrium through the PFO. Once in the arterial system, the clot can travel to the brain, potentially causing an embolic stroke, which would explain the sudden weakness in the left arm and leg. The connection between the recent travel, which increases the risk of deep vein thrombosis, and the neurological symptoms suggests the presence of a PFO facilitating a paradoxical embolism." + }, + { + "Question": "A 33-year-old woman is brought to the emergency department 15 minutes after being stabbed in the chest with a screwdriver. Given her vital signs of pulse 110/min, respirations 22/min, and blood pressure 90/65 mm Hg, along with the presence of a 5-cm deep stab wound at the upper border of the 8th rib in the left midaxillary line, which anatomical structure in her chest is most likely to be injured?", + "Complex_CoT": "Okay, let's figure out what's going on here. A woman comes in with a stab wound from a screwdriver. It's in her chest, upper border of the 8th rib, left side, kind of around the midaxillary line. First thought, that's pretty close to where the lung sits, right?\n\nLet's talk about location first. This spot is along the left side of her body. Above the 8th rib, like that, is where a lot of important stuff lives, like the bottom part of the left lung, possibly the diaphragm too, especially considering how deep the screwdriver went.\n\nThe wound is 5 cm deep. That sounds pretty deep. I mean, it could definitely reach down to the lung tissue or maybe the diaphragm. Given that it's midaxillary, we're in the territory where the lower lobe of the left lung hangs out. It's also possible there's some intersection with where the diaphragm begins, but the lung feels more probable somehow.\n\nNow, her vitals are concerning: elevated heart rate and low blood pressure. This is serious. My gut says that this kind of vital sign picture could mean something like pneumothorax or maybe hemothorax. Both can happen if the lung gets punctured, and they can make the blood pressure drop and the heart rate skyrocket since she's obviously distressed.\n\nSo, putting it all together, the most obvious culprit is the lower lobe of the left lung. The wound's depth and her condition point that way. And, yeah, this adds up with pneumothorax or maybe blood collecting in the chest—the kind of stuff that can really mess with breathing and circulation.\n\nAlright, with this in mind, it sure seems like the lung is the most likely thing that got hurt here. Makes sense given the situation, where the wound is, considering her symptoms, and the whole setup.", + "Response": "In this scenario, the most likely anatomical structure to be injured is the lower lobe of the left lung. The location of the stab wound—at the upper border of the 8th rib in the left midaxillary line—indicates proximity to the lower lobe of the lung. The depth of the wound (5 cm) suggests it is sufficient to reach lung tissue. Her vital signs of elevated heart rate and low blood pressure could signal complications like a pneumothorax or hemothorax, common consequences of lung trauma that would result from a penetrating injury in this area. Given these considerations, the lower lobe of the left lung is the most probable structure injured." + } +] +EOF + validate_upload \ + "http://${ip_address}:$finetuning_service_port/v1/files" \ + "dpo - upload" \ + "finetuning" \ + "fine-tune" \ + "test_data_reasoning.jsonl" + + # test /v1/fine_tuning/jobs + validate_finetune \ + "http://${ip_address}:$finetuning_service_port/v1/fine_tuning/jobs" \ + "reasoning - finetuning" \ + "finetuning" \ + '{"id":"ft-job' \ + '{"training_file": "test_data_reasoning.jsonl","model": "Qwen/Qwen2.5-0.5B","General":{"task":"reasoning"}}' + } function stop_docker() {