From 2957ee9b85775639a2c8418dd8d5d0ee5d76b298 Mon Sep 17 00:00:00 2001 From: Googler <nobody@google.com> Date: Wed, 8 Nov 2023 21:55:26 -0800 Subject: [PATCH] feat(components):[text2sql] Integration with first party LLM model inference pipeline PiperOrigin-RevId: 580776916 --- .../evaluation_llm_text2sql_pipeline.py | 65 +++++++------------ .../text2sql_preprocess/component.py | 5 ++ .../component.py | 5 ++ 3 files changed, 34 insertions(+), 41 deletions(-) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py index 74d2d4d14fcf..93f8bad717b1 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql/evaluation_llm_text2sql_pipeline.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Text2SQL evaluation pipeline.""" -from typing import Dict +from typing import Dict, Optional, Union from google_cloud_pipeline_components import _placeholders +from google_cloud_pipeline_components._implementation.model_evaluation.endpoint_batch_predict.component import evaluation_llm_endpoint_batch_predict_pipeline_graph_component as LLMEndpointBatchPredictOp from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_evaluation.component import text2sql_evaluation as Text2SQLEvaluationOp from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_preprocess.component import text2sql_evaluation_preprocess as Text2SQLEvaluationPreprocessOp from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_validate_and_process.component import text2sql_evaluation_validate_and_process as Text2SQLEvaluationValidateAndProcessOp -from google_cloud_pipeline_components.types import artifact_types -from google_cloud_pipeline_components.v1.batch_predict_job import ModelBatchPredictOp import kfp from kfp.dsl import PIPELINE_ROOT_PLACEHOLDER @@ -37,9 +36,7 @@ def evaluation_llm_text2sql_pipeline( evaluation_method: str = 'parser', project: str = _placeholders.PROJECT_ID_PLACEHOLDER, location: str = _placeholders.LOCATION_PLACEHOLDER, - model_parameters: Dict[str, str] = {}, - batch_predict_instances_format: str = 'jsonl', - batch_predict_predictions_format: str = 'jsonl', + model_parameters: Optional[Dict[str, Union[int, float]]] = {}, machine_type: str = 'e2-highmem-16', service_account: str = '', network: str = '', @@ -48,10 +45,9 @@ def evaluation_llm_text2sql_pipeline( """The LLM Evaluation Text2SQL Pipeline. Args: - model_name: The Model used to run text2sql evaluation. Must be a publisher - model or a managed Model sharing the same ancestor location. Starting this - job has no impact on any existing deployments of the Model and their - resources. Supported model is publishers/google/models/text-bison. + model_name: The Model used to run text2sql evaluation. Must be a frist party + publisher model. Supported model names are code-bison, code-gecko, + text-bison. evaluation_data_source_path: Required. The path for json file containing text2sql evaluation input dataset, including natural language question, ground truth SQL / SQL results. @@ -69,10 +65,6 @@ def evaluation_llm_text2sql_pipeline( Default value is the same location used to run the pipeline. model_parameters: Optional. The parameters that govern the predictions, e.g. temperature, - batch_predict_instances_format: The format in which instances are given, - must be one of the Model's supportedInputStorageFormats. If not set, - default to "jsonl". For more details about this input config, see - https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig. machine_type: The machine type of this custom job. If not set, defaulted to `e2-highmem-16`. More details: https://cloud.google.com/compute/docs/machine-resource @@ -88,38 +80,29 @@ def evaluation_llm_text2sql_pipeline( will be encrypted with the provided encryption key. """ - get_vertex_model_task = kfp.dsl.importer( - artifact_uri=( - f'https://{location}-aiplatform.googleapis.com/v1/{model_name}' - ), - artifact_class=artifact_types.VertexModel, - metadata={'resourceName': model_name}, - ) - get_vertex_model_task.set_display_name('get-vertex-model') - preprocess_task = Text2SQLEvaluationPreprocessOp( project=project, location=location, evaluation_data_source_path=evaluation_data_source_path, tables_metadata_path=tables_metadata_path, prompt_template_path=prompt_template_path, + model_name=model_name, machine_type=machine_type, service_account=service_account, network=network, encryption_spec_key_name=encryption_spec_key_name, ) - batch_predict_table_names_task = ModelBatchPredictOp( - job_display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', - model=get_vertex_model_task.outputs['artifact'], + batch_predict_table_names_task = LLMEndpointBatchPredictOp( + display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', + publisher_model=model_name, location=location, - instances_format=batch_predict_instances_format, - predictions_format=batch_predict_predictions_format, - gcs_source_uris=preprocess_task.outputs['model_inference_input_path'], + source_gcs_uris=preprocess_task.outputs['model_inference_input_path'], model_parameters=model_parameters, gcs_destination_output_uri_prefix=( f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_table_names_output' ), + service_account=service_account, encryption_spec_key_name=encryption_spec_key_name, project=project, ) @@ -133,25 +116,25 @@ def evaluation_llm_text2sql_pipeline( ], tables_metadata_path=tables_metadata_path, prompt_template_path=prompt_template_path, + model_name=model_name, machine_type=machine_type, service_account=service_account, network=network, encryption_spec_key_name=encryption_spec_key_name, ) - batch_predict_column_names_task = ModelBatchPredictOp( - job_display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', - model=get_vertex_model_task.outputs['artifact'], + batch_predict_column_names_task = LLMEndpointBatchPredictOp( + display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', + publisher_model=model_name, location=location, - instances_format=batch_predict_instances_format, - predictions_format=batch_predict_predictions_format, - gcs_source_uris=validate_table_names_and_process_task.outputs[ + source_gcs_uris=validate_table_names_and_process_task.outputs[ 'model_inference_input_path' ], model_parameters=model_parameters, gcs_destination_output_uri_prefix=( f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_column_names_output' ), + service_account=service_account, encryption_spec_key_name=encryption_spec_key_name, project=project, ) @@ -165,25 +148,25 @@ def evaluation_llm_text2sql_pipeline( ], tables_metadata_path=tables_metadata_path, prompt_template_path=prompt_template_path, + model_name=model_name, machine_type=machine_type, service_account=service_account, network=network, encryption_spec_key_name=encryption_spec_key_name, ) - batch_prediction_sql_queries_task = ModelBatchPredictOp( - job_display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', - model=get_vertex_model_task.outputs['artifact'], + batch_prediction_sql_queries_task = LLMEndpointBatchPredictOp( + display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}', + publisher_model=model_name, location=location, - instances_format=batch_predict_instances_format, - predictions_format=batch_predict_predictions_format, - gcs_source_uris=validate_column_names_and_process_task.outputs[ + source_gcs_uris=validate_column_names_and_process_task.outputs[ 'model_inference_input_path' ], model_parameters=model_parameters, gcs_destination_output_uri_prefix=( f'{PIPELINE_ROOT_PLACEHOLDER}/batch_prediction_sql_queris_output' ), + service_account=service_account, encryption_spec_key_name=encryption_spec_key_name, project=project, ) diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py index 583da4c23bba..1d5ad68f4d52 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_preprocess/component.py @@ -30,6 +30,7 @@ def text2sql_evaluation_preprocess( evaluation_data_source_path: str, tables_metadata_path: str, prompt_template_path: str = '', + model_name: str = '', display_name: str = 'text2sql-evaluation-preprocess', machine_type: str = 'e2-highmem-16', service_account: str = '', @@ -48,6 +49,9 @@ def text2sql_evaluation_preprocess( metadata, including table names, schema fields. prompt_template_path: Required. The path for json file containing prompt template. Will provide default value if users do not sepecify. + model_name: The Model used to run text2sql evaluation. Must be a first + party publisher model. Supported model name values are code-bison, + code-gecko, text-bison. display_name: The name of the Evaluation job. machine_type: The machine type of this custom job. If not set, defaulted to `e2-highmem-16`. More details: @@ -90,6 +94,7 @@ def text2sql_evaluation_preprocess( f'--evaluation_data_source_path={evaluation_data_source_path}', f'--tables_metadata_path={tables_metadata_path}', f'--prompt_template_path={prompt_template_path}', + f'--model_name={model_name}', f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}', f'--gcp_resources={gcp_resources}', f'--model_inference_input_path={model_inference_input_path}', diff --git a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py index 3f1b09726207..d58cb13fc407 100644 --- a/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py +++ b/components/google-cloud/google_cloud_pipeline_components/_implementation/model_evaluation/text2sql_validate_and_process/component.py @@ -33,6 +33,7 @@ def text2sql_evaluation_validate_and_process( model_inference_results_directory: Input[Artifact], tables_metadata_path: str, prompt_template_path: str = '', + model_name: str = '', display_name: str = 'text2sql-evaluation-validate-and-process', machine_type: str = 'e2-highmem-16', service_account: str = '', @@ -53,6 +54,9 @@ def text2sql_evaluation_validate_and_process( metadata, including table names, schema fields. prompt_template_path: Required. The path for json file containing prompt template. Will provide default value if users do not sepecify. + model_name: The Model used to run text2sql evaluation. Must be a first + party publisher model. Supported model name values are code-bison, + code-gecko, text-bison. display_name: The name of the Evaluation job. machine_type: The machine type of this custom job. If not set, defaulted to `e2-highmem-16`. More details: @@ -96,6 +100,7 @@ def text2sql_evaluation_validate_and_process( f'--model_inference_results_directory={model_inference_results_directory.path}', f'--tables_metadata_path={tables_metadata_path}', f'--prompt_template_path={prompt_template_path}', + f'--model_name={model_name}', f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}', f'--gcp_resources={gcp_resources}', f'--model_inference_input_path={model_inference_input_path}',