Skip to content

Commit

Permalink
feat(components):[text2sql] Integration with first party LLM model in…
Browse files Browse the repository at this point in the history
…ference pipeline

PiperOrigin-RevId: 580776916
  • Loading branch information
Googler committed Nov 10, 2023
1 parent fb4512d commit 2957ee9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Text2SQL evaluation pipeline."""
from typing import Dict
from typing import Dict, Optional, Union

from google_cloud_pipeline_components import _placeholders
from google_cloud_pipeline_components._implementation.model_evaluation.endpoint_batch_predict.component import evaluation_llm_endpoint_batch_predict_pipeline_graph_component as LLMEndpointBatchPredictOp
from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_evaluation.component import text2sql_evaluation as Text2SQLEvaluationOp
from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_preprocess.component import text2sql_evaluation_preprocess as Text2SQLEvaluationPreprocessOp
from google_cloud_pipeline_components._implementation.model_evaluation.text2sql_validate_and_process.component import text2sql_evaluation_validate_and_process as Text2SQLEvaluationValidateAndProcessOp
from google_cloud_pipeline_components.types import artifact_types
from google_cloud_pipeline_components.v1.batch_predict_job import ModelBatchPredictOp
import kfp
from kfp.dsl import PIPELINE_ROOT_PLACEHOLDER

Expand All @@ -37,9 +36,7 @@ def evaluation_llm_text2sql_pipeline(
evaluation_method: str = 'parser',
project: str = _placeholders.PROJECT_ID_PLACEHOLDER,
location: str = _placeholders.LOCATION_PLACEHOLDER,
model_parameters: Dict[str, str] = {},
batch_predict_instances_format: str = 'jsonl',
batch_predict_predictions_format: str = 'jsonl',
model_parameters: Optional[Dict[str, Union[int, float]]] = {},
machine_type: str = 'e2-highmem-16',
service_account: str = '',
network: str = '',
Expand All @@ -48,10 +45,9 @@ def evaluation_llm_text2sql_pipeline(
"""The LLM Evaluation Text2SQL Pipeline.
Args:
model_name: The Model used to run text2sql evaluation. Must be a publisher
model or a managed Model sharing the same ancestor location. Starting this
job has no impact on any existing deployments of the Model and their
resources. Supported model is publishers/google/models/text-bison.
model_name: The Model used to run text2sql evaluation. Must be a frist party
publisher model. Supported model names are code-bison, code-gecko,
text-bison.
evaluation_data_source_path: Required. The path for json file containing
text2sql evaluation input dataset, including natural language question,
ground truth SQL / SQL results.
Expand All @@ -69,10 +65,6 @@ def evaluation_llm_text2sql_pipeline(
Default value is the same location used to run the pipeline.
model_parameters: Optional. The parameters that govern the predictions, e.g.
temperature,
batch_predict_instances_format: The format in which instances are given,
must be one of the Model's supportedInputStorageFormats. If not set,
default to "jsonl". For more details about this input config, see
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.batchPredictionJobs#InputConfig.
machine_type: The machine type of this custom job. If not set, defaulted to
`e2-highmem-16`. More details:
https://cloud.google.com/compute/docs/machine-resource
Expand All @@ -88,38 +80,29 @@ def evaluation_llm_text2sql_pipeline(
will be encrypted with the provided encryption key.
"""

get_vertex_model_task = kfp.dsl.importer(
artifact_uri=(
f'https://{location}-aiplatform.googleapis.com/v1/{model_name}'
),
artifact_class=artifact_types.VertexModel,
metadata={'resourceName': model_name},
)
get_vertex_model_task.set_display_name('get-vertex-model')

preprocess_task = Text2SQLEvaluationPreprocessOp(
project=project,
location=location,
evaluation_data_source_path=evaluation_data_source_path,
tables_metadata_path=tables_metadata_path,
prompt_template_path=prompt_template_path,
model_name=model_name,
machine_type=machine_type,
service_account=service_account,
network=network,
encryption_spec_key_name=encryption_spec_key_name,
)

batch_predict_table_names_task = ModelBatchPredictOp(
job_display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
model=get_vertex_model_task.outputs['artifact'],
batch_predict_table_names_task = LLMEndpointBatchPredictOp(
display_name='text2sql-batch-prediction-table-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
publisher_model=model_name,
location=location,
instances_format=batch_predict_instances_format,
predictions_format=batch_predict_predictions_format,
gcs_source_uris=preprocess_task.outputs['model_inference_input_path'],
source_gcs_uris=preprocess_task.outputs['model_inference_input_path'],
model_parameters=model_parameters,
gcs_destination_output_uri_prefix=(
f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_table_names_output'
),
service_account=service_account,
encryption_spec_key_name=encryption_spec_key_name,
project=project,
)
Expand All @@ -133,25 +116,25 @@ def evaluation_llm_text2sql_pipeline(
],
tables_metadata_path=tables_metadata_path,
prompt_template_path=prompt_template_path,
model_name=model_name,
machine_type=machine_type,
service_account=service_account,
network=network,
encryption_spec_key_name=encryption_spec_key_name,
)

batch_predict_column_names_task = ModelBatchPredictOp(
job_display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
model=get_vertex_model_task.outputs['artifact'],
batch_predict_column_names_task = LLMEndpointBatchPredictOp(
display_name='text2sql-batch-prediction-column-names-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
publisher_model=model_name,
location=location,
instances_format=batch_predict_instances_format,
predictions_format=batch_predict_predictions_format,
gcs_source_uris=validate_table_names_and_process_task.outputs[
source_gcs_uris=validate_table_names_and_process_task.outputs[
'model_inference_input_path'
],
model_parameters=model_parameters,
gcs_destination_output_uri_prefix=(
f'{PIPELINE_ROOT_PLACEHOLDER}/batch_predict_column_names_output'
),
service_account=service_account,
encryption_spec_key_name=encryption_spec_key_name,
project=project,
)
Expand All @@ -165,25 +148,25 @@ def evaluation_llm_text2sql_pipeline(
],
tables_metadata_path=tables_metadata_path,
prompt_template_path=prompt_template_path,
model_name=model_name,
machine_type=machine_type,
service_account=service_account,
network=network,
encryption_spec_key_name=encryption_spec_key_name,
)

batch_prediction_sql_queries_task = ModelBatchPredictOp(
job_display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
model=get_vertex_model_task.outputs['artifact'],
batch_prediction_sql_queries_task = LLMEndpointBatchPredictOp(
display_name='text2sql-batch-prediction-sql-queries-{{$.pipeline_job_uuid}}-{{$.pipeline_task_uuid}}',
publisher_model=model_name,
location=location,
instances_format=batch_predict_instances_format,
predictions_format=batch_predict_predictions_format,
gcs_source_uris=validate_column_names_and_process_task.outputs[
source_gcs_uris=validate_column_names_and_process_task.outputs[
'model_inference_input_path'
],
model_parameters=model_parameters,
gcs_destination_output_uri_prefix=(
f'{PIPELINE_ROOT_PLACEHOLDER}/batch_prediction_sql_queris_output'
),
service_account=service_account,
encryption_spec_key_name=encryption_spec_key_name,
project=project,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def text2sql_evaluation_preprocess(
evaluation_data_source_path: str,
tables_metadata_path: str,
prompt_template_path: str = '',
model_name: str = '',
display_name: str = 'text2sql-evaluation-preprocess',
machine_type: str = 'e2-highmem-16',
service_account: str = '',
Expand All @@ -48,6 +49,9 @@ def text2sql_evaluation_preprocess(
metadata, including table names, schema fields.
prompt_template_path: Required. The path for json file containing prompt
template. Will provide default value if users do not sepecify.
model_name: The Model used to run text2sql evaluation. Must be a first
party publisher model. Supported model name values are code-bison,
code-gecko, text-bison.
display_name: The name of the Evaluation job.
machine_type: The machine type of this custom job. If not set, defaulted
to `e2-highmem-16`. More details:
Expand Down Expand Up @@ -90,6 +94,7 @@ def text2sql_evaluation_preprocess(
f'--evaluation_data_source_path={evaluation_data_source_path}',
f'--tables_metadata_path={tables_metadata_path}',
f'--prompt_template_path={prompt_template_path}',
f'--model_name={model_name}',
f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}',
f'--gcp_resources={gcp_resources}',
f'--model_inference_input_path={model_inference_input_path}',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def text2sql_evaluation_validate_and_process(
model_inference_results_directory: Input[Artifact],
tables_metadata_path: str,
prompt_template_path: str = '',
model_name: str = '',
display_name: str = 'text2sql-evaluation-validate-and-process',
machine_type: str = 'e2-highmem-16',
service_account: str = '',
Expand All @@ -53,6 +54,9 @@ def text2sql_evaluation_validate_and_process(
metadata, including table names, schema fields.
prompt_template_path: Required. The path for json file containing prompt
template. Will provide default value if users do not sepecify.
model_name: The Model used to run text2sql evaluation. Must be a first
party publisher model. Supported model name values are code-bison,
code-gecko, text-bison.
display_name: The name of the Evaluation job.
machine_type: The machine type of this custom job. If not set, defaulted
to `e2-highmem-16`. More details:
Expand Down Expand Up @@ -96,6 +100,7 @@ def text2sql_evaluation_validate_and_process(
f'--model_inference_results_directory={model_inference_results_directory.path}',
f'--tables_metadata_path={tables_metadata_path}',
f'--prompt_template_path={prompt_template_path}',
f'--model_name={model_name}',
f'--root_dir={PIPELINE_ROOT_PLACEHOLDER}',
f'--gcp_resources={gcp_resources}',
f'--model_inference_input_path={model_inference_input_path}',
Expand Down

0 comments on commit 2957ee9

Please sign in to comment.