From 570e56dd09af32e173cf041eed7497e4533ec186 Mon Sep 17 00:00:00 2001 From: Googler Date: Wed, 25 Oct 2023 10:42:41 -0700 Subject: [PATCH] fix(components): [text2sql] Turn model_inference_results_path to model_inference_results_directory and remove duplicate comment PiperOrigin-RevId: 576576299 --- .../text2sql/evaluation_llm_text2sql_pipeline.py | 12 +++++------- .../text2sql_evaluation/component.py | 10 ++++++---- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py index e106efa6981..6f0af29e52c 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py @@ -73,10 +73,6 @@ def evaluation_llm_text2sql_pipeline( must be one of the Model's supportedInputStorageFormats. If not set, default to "jsonl". For more details about this input config, see https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig. - batch_predict_instances_format: The format in which perdictions are made, - must be one of the Model's supportedInputStorageFormats. If not set, - default to "jsonl". For more details about this input config, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig. machine_type: The machine type of this custom job. If not set, defaulted to `e2-highmem-16`. More details: https://cloud.google.com/compute/docs/machine-resource @@ -148,9 +144,11 @@ def evaluation_llm_text2sql_pipeline( location=location, sql_dialect=sql_dialect, evaluation_method=evaluation_method, - # TODO(bozhengbz) Add value to model_inference_results_path - # when model batch prediction component is added. - model_inference_results_path='gs://test/model_inference_results.json', + # TODO(bozhengbz) Change value to model_inference_results_directory + # when sql query model batch prediction component is added. + model_inference_results_directory=batch_predict_table_names_task.outputs[ + 'gcs_output_directory' + ], tables_metadata_path=tables_metadata_path, machine_type=machine_type, service_account=service_account, diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_evaluation/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_evaluation/component.py index 063172067a2..a084de02d42 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_evaluation/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_evaluation/component.py @@ -16,7 +16,9 @@ from google_cloud_pipeline_components import utils as gcpc_utils from google_cloud_pipeline_components._implementation.model_evaluation import utils from google_cloud_pipeline_components._implementation.model_evaluation import version +from kfp.dsl import Artifact from kfp.dsl import container_component +from kfp.dsl import Input from kfp.dsl import Metrics from kfp.dsl import Output from kfp.dsl import OutputPath @@ -33,7 +35,7 @@ def text2sql_evaluation( location: str, sql_dialect: str, evaluation_method: str, - model_inference_results_path: str, + model_inference_results_directory: Input[Artifact], tables_metadata_path: str, display_name: str = 'text2sql-evaluation', machine_type: str = 'e2-highmem-16', @@ -49,8 +51,8 @@ def text2sql_evaluation( sql_dialect: Required. SQL dialect type, e.g. bigquery, mysql, etc. evaluation_method: Required. Text2SQL evaluation method, value can be 'parser', 'execution', 'all'. - model_inference_results_path: Required. The path for json file containing - text2sql model inference results from the last step. + model_inference_results_directory: Required. The path for json file + containing text2sql model inference results from the last step. tables_metadata_path: Required. The path for json file containing database metadata, including table names, schema fields. display_name: The name of the Evaluation job. @@ -98,7 +100,7 @@ def text2sql_evaluation( f'--location={location}', f'--sql_dialect={sql_dialect}', f'--evaluation_method={evaluation_method}', - f'--model_inference_results_path={model_inference_results_path}', + f'--model_inference_results_directory={model_inference_results_directory.path}', f'--tables_metadata_path={tables_metadata_path}', f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}', f'--gcp_resources={gcp_resources}',