|
122 | 122 | "source": [ |
123 | 123 | "import matplotlib.pyplot as plt\n", |
124 | 124 | "from scripts.utils import setup_workspace\n", |
125 | | - "from scripts.dataset import prepare_finqa_dataset\n", |
| 125 | + "from scripts.dataset import prepare_finqa_dataset, prepare_sharegpt_dataset\n", |
126 | 126 | "from scripts.run import get_run_metrics\n", |
127 | 127 | "from scripts.reinforcement_learning import run_rl_training_pipeline\n", |
128 | 128 | "from scripts.evaluation import run_evaluation_pipeline\n", |
129 | 129 | "from scripts.speculative_decoding import (\n", |
130 | 130 | " run_draft_model_pipeline,\n", |
131 | 131 | " prepare_combined_model_for_deployment,\n", |
132 | 132 | " deploy_speculative_decoding_endpoint,\n", |
| 133 | + " deploy_base_model_endpoint,\n", |
| 134 | + " run_evaluation_speculative_decoding,\n", |
133 | 135 | ")\n", |
134 | | - "from scripts.deployment import create_managed_deployment, test_deployment" |
| 136 | + "from scripts.deployment import test_deployment" |
135 | 137 | ] |
136 | 138 | }, |
137 | 139 | { |
|
150 | 152 | "cell_type": "markdown", |
151 | 153 | "metadata": {}, |
152 | 154 | "source": [ |
153 | | - "<p>Prepare dataset for Finetuning. This would save train, test and valid dataset under data folder</p>" |
| 155 | + "<p>Prepare dataset for Fine-tuning. This would save train, test and valid dataset under data folder</p>" |
154 | 156 | ] |
155 | 157 | }, |
156 | 158 | { |
|
484 | 486 | "<p><strong>Reference:</strong> <a href=\"https://arxiv.org/abs/2503.01840\">https://arxiv.org/abs/2503.01840</a></p>\n" |
485 | 487 | ] |
486 | 488 | }, |
| 489 | + { |
| 490 | + "cell_type": "code", |
| 491 | + "execution_count": null, |
| 492 | + "metadata": {}, |
| 493 | + "outputs": [], |
| 494 | + "source": [ |
| 495 | + "draft_train_data_path = prepare_sharegpt_dataset()" |
| 496 | + ] |
| 497 | + }, |
487 | 498 | { |
488 | 499 | "cell_type": "code", |
489 | 500 | "execution_count": null, |
|
498 | 509 | " num_epochs=1, # Number of train epochs to be run by draft trainer.\n", |
499 | 510 | " monitor=False, # Set to True to wait for completion.\n", |
500 | 511 | " base_model_mlflow_path=\"azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9\",\n", |
501 | | - " draft_train_data_path=\"./data_for_draft_model/train/sharegpt_train_small.jsonl\",\n", |
| 512 | + " draft_train_data_path=draft_train_data_path,\n", |
502 | 513 | ")" |
503 | 514 | ] |
504 | 515 | }, |
|
591 | 602 | "endpoint_name = deploy_speculative_decoding_endpoint(\n", |
592 | 603 | " ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n", |
593 | 604 | " combined_model=combined_model, # Reference from previous steps where combined model is created.\n", |
594 | | - " instance_type=\"octagepu\", # Instance type Kubernetes Cluster\n", |
595 | | - " compute_name=\"k8s-a100-compute\",\n", |
| 605 | + " instance_type=\"Standard_NC40ads_H100_v5\", # Instance type\n", |
596 | 606 | ")" |
597 | 607 | ] |
598 | 608 | }, |
|
631 | 641 | "outputs": [], |
632 | 642 | "source": [ |
633 | 643 | "# Deploy managed online endpoint with base model\n", |
634 | | - "base_endpoint_name = create_managed_deployment( # Function to create endpoint for base model.\n", |
| 644 | + "base_endpoint_name = deploy_base_model_endpoint( # Function to create endpoint for base model.\n", |
635 | 645 | " ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n", |
636 | | - " model_asset_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # Huggingface ID of the base model.\n", |
637 | | - " instance_type=\"Standard_ND96amsr_A100_v4\", # Compute SKU on which base model will be deployed.\n", |
| 646 | + " instance_type=\"Standard_NC40ads_H100_v5\", # Compute SKU on which base model will be deployed.\n", |
638 | 647 | ")" |
639 | 648 | ] |
640 | 649 | }, |
|
711 | 720 | "# Run evaluation job to compare base model and speculative decoding endpoints' performance\n", |
712 | 721 | "evaluation_job = run_evaluation_speculative_decoding(\n", |
713 | 722 | " ml_client=ml_client,\n", |
| 723 | + " registry_ml_client=registry_ml_client,\n", |
714 | 724 | " base_endpoint_name=base_endpoint_name, # Base model endpoint from previous step.\n", |
715 | 725 | " speculative_endpoint_name=endpoint_name, # Speculative endpoint from previous step.\n", |
716 | | - " base_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n", |
717 | | - " speculative_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n", |
| 726 | + " base_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n", |
| 727 | + " speculative_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n", |
| 728 | + " compute_cluster=\"d13-v2\",\n", |
718 | 729 | ")" |
719 | 730 | ] |
720 | 731 | }, |
|
735 | 746 | "cell_type": "markdown", |
736 | 747 | "metadata": {}, |
737 | 748 | "source": [ |
738 | | - "<img src=\"metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">" |
| 749 | + "<img src=\"./images/metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">" |
739 | 750 | ] |
740 | 751 | } |
741 | 752 | ], |
|
0 commit comments