From 2957ee9b85775639a2c8418dd8d5d0ee5d76b298 Mon Sep 17 00:00:00 2001
From: Googler <nobody@google.com>
Date: Wed, 8 Nov 2023 21:55:26 -0800
Subject: [PATCH] feat(components):[text2sql] Integration with first party LLM
 model inference pipeline

PiperOrigin-RevId: 580776916
---
 .../evaluation_llm_text2sql_pipeline.py       | 65 +++++++------------
 .../text2sql_preprocess/component.py          |  5 ++
 .../component.py                              |  5 ++
 3 files changed, 34 insertions(+), 41 deletions(-)

diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py
index 74d2d4d14fcf..93f8bad717b1 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py
@@ -12,14 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Text2SQL evaluation pipeline."""
-from typing import Dict
+from typing import Dict, Optional, Union
 
 from google_cloud_pipeline_components import _placeholders
+from google_cloud_pipeline_components._implementation.model_evaluation.endpoint_batch_predict.component import evaluation_llm_endpoint_batch_predict_pipeline_graph_component as LLMEndpointBatchPredictOp
 from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_evaluation.component import text2sql_evaluation as Text2SQLEvaluationOp
 from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_preprocess.component import text2sql_evaluation_preprocess as Text2SQLEvaluationPreprocessOp
 from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_validate_and_process.component import text2sql_evaluation_validate_and_process as Text2SQLEvaluationValidateAndProcessOp
-from google_cloud_pipeline_components.types import artifact_types
-from google_cloud_pipeline_components.v1.batch_predict_job import ModelBatchPredictOp
 import kfp
 from kfp.dsl import PIPELINE_ROOT_PLACEHOLDER
 
@@ -37,9 +36,7 @@ def evaluation_llm_text2sql_pipeline(
     evaluation_method: str = 'parser',
     project: str = _placeholders.PROJECT_ID_PLACEHOLDER,
     location: str = _placeholders.LOCATION_PLACEHOLDER,
-    model_parameters: Dict[str, str] = {},
-    batch_predict_instances_format: str = 'jsonl',
-    batch_predict_predictions_format: str = 'jsonl',
+    model_parameters: Optional[Dict[str, Union[int, float]]] = {},
     machine_type: str = 'e2-highmem-16',
     service_account: str = '',
     network: str = '',
@@ -48,10 +45,9 @@ def evaluation_llm_text2sql_pipeline(
   """The LLM Evaluation Text2SQL Pipeline.
 
   Args:
-    model_name: The Model used to run text2sql evaluation. Must be a publisher
-      model or a managed Model sharing the same ancestor location. Starting this
-      job has no impact on any existing deployments of the Model and their
-      resources. Supported model is publishers/google/models/text-bison.
+    model_name: The Model used to run text2sql evaluation. Must be a frist party
+      publisher model. Supported model names are code-bison, code-gecko,
+      text-bison.
     evaluation_data_source_path: Required. The path for json file containing
       text2sql evaluation input dataset, including natural language question,
       ground truth SQL / SQL results.
@@ -69,10 +65,6 @@ def evaluation_llm_text2sql_pipeline(
       Default value is the same location used to run the pipeline.
     model_parameters: Optional. The parameters that govern the predictions, e.g.
       temperature,
-    batch_predict_instances_format: The format in which instances are given,
-      must be one of the Model's supportedInputStorageFormats. If not set,
-      default to "jsonl".  For more details about this input config, see
-      https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
     machine_type: The machine type of this custom job. If not set, defaulted to
       `e2-highmem-16`. More details:
       https://cloud.google.com/compute/docs/machine-resource
@@ -88,38 +80,29 @@ def evaluation_llm_text2sql_pipeline(
       will be encrypted with the provided encryption key.
   """
 
-  get_vertex_model_task = kfp.dsl.importer(
-      artifact_uri=(
-          f'https://{location}-aiplatform.googleapis.com/v1/{model_name}'
-      ),
-      artifact_class=artifact_types.VertexModel,
-      metadata={'resourceName': model_name},
-  )
-  get_vertex_model_task.set_display_name('get-vertex-model')
-
   preprocess_task = Text2SQLEvaluationPreprocessOp(
       project=project,
       location=location,
       evaluation_data_source_path=evaluation_data_source_path,
       tables_metadata_path=tables_metadata_path,
       prompt_template_path=prompt_template_path,
+      model_name=model_name,
       machine_type=machine_type,
       service_account=service_account,
       network=network,
       encryption_spec_key_name=encryption_spec_key_name,
   )
 
-  batch_predict_table_names_task = ModelBatchPredictOp(
-      job_display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
-      model=get_vertex_model_task.outputs['artifact'],
+  batch_predict_table_names_task = LLMEndpointBatchPredictOp(
+      display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
+      publisher_model=model_name,
       location=location,
-      instances_format=batch_predict_instances_format,
-      predictions_format=batch_predict_predictions_format,
-      gcs_source_uris=preprocess_task.outputs['model_inference_input_path'],
+      source_gcs_uris=preprocess_task.outputs['model_inference_input_path'],
       model_parameters=model_parameters,
       gcs_destination_output_uri_prefix=(
           f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_table_names_output'
       ),
+      service_account=service_account,
       encryption_spec_key_name=encryption_spec_key_name,
       project=project,
   )
@@ -133,25 +116,25 @@ def evaluation_llm_text2sql_pipeline(
       ],
       tables_metadata_path=tables_metadata_path,
       prompt_template_path=prompt_template_path,
+      model_name=model_name,
       machine_type=machine_type,
       service_account=service_account,
       network=network,
       encryption_spec_key_name=encryption_spec_key_name,
   )
 
-  batch_predict_column_names_task = ModelBatchPredictOp(
-      job_display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
-      model=get_vertex_model_task.outputs['artifact'],
+  batch_predict_column_names_task = LLMEndpointBatchPredictOp(
+      display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
+      publisher_model=model_name,
       location=location,
-      instances_format=batch_predict_instances_format,
-      predictions_format=batch_predict_predictions_format,
-      gcs_source_uris=validate_table_names_and_process_task.outputs[
+      source_gcs_uris=validate_table_names_and_process_task.outputs[
           'model_inference_input_path'
       ],
       model_parameters=model_parameters,
       gcs_destination_output_uri_prefix=(
           f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_column_names_output'
       ),
+      service_account=service_account,
       encryption_spec_key_name=encryption_spec_key_name,
       project=project,
   )
@@ -165,25 +148,25 @@ def evaluation_llm_text2sql_pipeline(
       ],
       tables_metadata_path=tables_metadata_path,
       prompt_template_path=prompt_template_path,
+      model_name=model_name,
       machine_type=machine_type,
       service_account=service_account,
       network=network,
       encryption_spec_key_name=encryption_spec_key_name,
   )
 
-  batch_prediction_sql_queries_task = ModelBatchPredictOp(
-      job_display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
-      model=get_vertex_model_task.outputs['artifact'],
+  batch_prediction_sql_queries_task = LLMEndpointBatchPredictOp(
+      display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
+      publisher_model=model_name,
       location=location,
-      instances_format=batch_predict_instances_format,
-      predictions_format=batch_predict_predictions_format,
-      gcs_source_uris=validate_column_names_and_process_task.outputs[
+      source_gcs_uris=validate_column_names_and_process_task.outputs[
           'model_inference_input_path'
       ],
       model_parameters=model_parameters,
       gcs_destination_output_uri_prefix=(
           f'{PIPELINE_ROOT_PLACEHOLDER}/batch_prediction_sql_queris_output'
       ),
+      service_account=service_account,
       encryption_spec_key_name=encryption_spec_key_name,
       project=project,
   )
diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py
index 583da4c23bba..1d5ad68f4d52 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py
@@ -30,6 +30,7 @@ def text2sql_evaluation_preprocess(
     evaluation_data_source_path: str,
     tables_metadata_path: str,
     prompt_template_path: str = '',
+    model_name: str = '',
     display_name: str = 'text2sql-evaluation-preprocess',
     machine_type: str = 'e2-highmem-16',
     service_account: str = '',
@@ -48,6 +49,9 @@ def text2sql_evaluation_preprocess(
         metadata, including table names, schema fields.
       prompt_template_path: Required. The path for json file containing prompt
         template. Will provide default value if users do not sepecify.
+      model_name: The Model used to run text2sql evaluation. Must be a first
+        party publisher model. Supported model name values are code-bison,
+        code-gecko, text-bison.
       display_name: The name of the Evaluation job.
       machine_type: The machine type of this custom job. If not set, defaulted
         to `e2-highmem-16`. More details:
@@ -90,6 +94,7 @@ def text2sql_evaluation_preprocess(
               f'--evaluation_data_source_path={evaluation_data_source_path}',
               f'--tables_metadata_path={tables_metadata_path}',
               f'--prompt_template_path={prompt_template_path}',
+              f'--model_name={model_name}',
               f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}',
               f'--gcp_resources={gcp_resources}',
               f'--model_inference_input_path={model_inference_input_path}',
diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py
index 3f1b09726207..d58cb13fc407 100644
--- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py
+++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py
@@ -33,6 +33,7 @@ def text2sql_evaluation_validate_and_process(
     model_inference_results_directory: Input[Artifact],
     tables_metadata_path: str,
     prompt_template_path: str = '',
+    model_name: str = '',
     display_name: str = 'text2sql-evaluation-validate-and-process',
     machine_type: str = 'e2-highmem-16',
     service_account: str = '',
@@ -53,6 +54,9 @@ def text2sql_evaluation_validate_and_process(
         metadata, including table names, schema fields.
       prompt_template_path: Required. The path for json file containing prompt
         template. Will provide default value if users do not sepecify.
+      model_name: The Model used to run text2sql evaluation. Must be a first
+        party publisher model. Supported model name values are code-bison,
+        code-gecko, text-bison.
       display_name: The name of the Evaluation job.
       machine_type: The machine type of this custom job. If not set, defaulted
         to `e2-highmem-16`. More details:
@@ -96,6 +100,7 @@ def text2sql_evaluation_validate_and_process(
               f'--model_inference_results_directory={model_inference_results_directory.path}',
               f'--tables_metadata_path={tables_metadata_path}',
               f'--prompt_template_path={prompt_template_path}',
+              f'--model_name={model_name}',
               f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}',
               f'--gcp_resources={gcp_resources}',
               f'--model_inference_input_path={model_inference_input_path}',