|
27 | 27 | import tempfile |
28 | 28 | import time |
29 | 29 | import uuid |
| 30 | +import warnings |
30 | 31 | from typing import Optional |
31 | 32 |
|
32 | 33 | import pandas as pd |
@@ -937,15 +938,47 @@ def commit(self, message: str, project_id: str, force: bool = False): |
937 | 938 | print("Keeping the existing commit message.") |
938 | 939 | return |
939 | 940 |
|
| 941 | + llm_and_no_outputs = self._check_llm_and_no_outputs(project_dir=project_dir) |
| 942 | + if llm_and_no_outputs: |
| 943 | + warnings.warn( |
| 944 | + "You are committing an LLM without validation outputs computed " |
| 945 | + "in the validation set. This means that the platform will try to " |
| 946 | + "compute the validation outputs for you. This may take a while and " |
| 947 | + "there are costs associated with it." |
| 948 | + ) |
940 | 949 | commit = { |
941 | 950 | "message": message, |
942 | 951 | "date": time.ctime(), |
| 952 | + "computeOutputs": llm_and_no_outputs, |
943 | 953 | } |
944 | 954 | with open(f"{project_dir}/commit.yaml", "w", encoding="UTF-8") as commit_file: |
945 | 955 | yaml.dump(commit, commit_file) |
946 | 956 |
|
947 | 957 | print("Committed!") |
948 | 958 |
|
| 959 | + def _check_llm_and_no_outputs(self, project_dir: str) -> bool: |
| 960 | + """Checks if the project's staging area contains an LLM and no outputs.""" |
| 961 | + # Check if validation set has outputs |
| 962 | + validation_has_no_outputs = False |
| 963 | + if os.path.exists(f"{project_dir}/validation"): |
| 964 | + validation_dataset_config = utils.load_dataset_config_from_bundle( |
| 965 | + bundle_path=project_dir, label="validation" |
| 966 | + ) |
| 967 | + output_column_name = validation_dataset_config.get("outputColumnName") |
| 968 | + validation_has_no_outputs = output_column_name is None |
| 969 | + |
| 970 | + # Check if the model is an LLM |
| 971 | + model_is_llm = False |
| 972 | + if os.path.exists(f"{project_dir}/model"): |
| 973 | + model_config = utils.read_yaml(f"{project_dir}/model/model_config.yaml") |
| 974 | + architecture_type = model_config.get("architectureType") |
| 975 | + model_type = model_config.get("modelType") |
| 976 | + |
| 977 | + if architecture_type == "llm" and model_type != "shell": |
| 978 | + model_is_llm = True |
| 979 | + |
| 980 | + return validation_has_no_outputs and model_is_llm |
| 981 | + |
949 | 982 | def push(self, project_id: str, task_type: TaskType) -> Optional[ProjectVersion]: |
950 | 983 | """Pushes the commited resources to the platform. |
951 | 984 |
|
|
0 commit comments