-
Notifications
You must be signed in to change notification settings - Fork 11
Support launching graph store jobs from trainer and inferencer #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
e5f6f5b
0373916
34dac6b
fbf0e4d
eb0e63c
9137e8b
162f4c8
4585ea6
c6e7beb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name | |
|
|
||
| import datetime | ||
| import time | ||
| from collections.abc import Mapping | ||
| from dataclasses import dataclass | ||
| from typing import Final, Optional, Union | ||
|
|
||
|
|
@@ -69,12 +70,17 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name | |
| ContainerSpec, | ||
| DiskSpec, | ||
| MachineSpec, | ||
| Scheduling, | ||
| WorkerPoolSpec, | ||
| env_var, | ||
| ) | ||
|
|
||
| from gigl.common import GcsUri, Uri | ||
| from gigl.common.logger import Logger | ||
| from gigl.env.pipelines_config import get_resource_config | ||
| from gigl.src.common.constants.components import GiGLComponents | ||
| from gigl.src.common.types import AppliedTaskIdentifier | ||
| from snapchat.research.gbml.gigl_resource_config_pb2 import VertexAiResourceConfig | ||
|
|
||
| logger = Logger() | ||
|
|
||
|
|
@@ -108,6 +114,67 @@ class VertexAiJobConfig: | |
| scheduling_strategy: Optional[aiplatform.gapic.Scheduling.Strategy] = None | ||
|
|
||
|
|
||
| def get_job_config_from_vertex_ai_resource_config( | ||
| applied_task_identifier: AppliedTaskIdentifier, | ||
| is_inference: bool, | ||
| task_config_uri: Uri, | ||
| resource_config_uri: Uri, | ||
| command_str: str, | ||
| args: Mapping[str, str], | ||
| run_on_cpu: bool, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have been using |
||
| container_uri: str, | ||
| vertex_ai_resource_config: VertexAiResourceConfig, | ||
| env_vars: list[env_var.EnvVar], | ||
| ) -> VertexAiJobConfig: | ||
| job_args = ( | ||
| [ | ||
| f"--job_name={applied_task_identifier}", | ||
| f"--task_config_uri={task_config_uri}", | ||
| f"--resource_config_uri={resource_config_uri}", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even if we dont have import statements, the logic that we will be formulating these args and passing them through for the trainer to use is still intermixing the architecture layers. Figuring out what arguments to pass in based off versioned trainer/inferencer implementation (v2) is still very tightly coupling business layer, and pinned to specific versioned implementation of 2 classes that "may use VAI". Ideally the purpose of this service should just be: IMO, Being able to differentiate and launch graph_store vs regular jobs is probably the max amount of coupling that should be introduced in this service. |
||
| ] | ||
| + ([] if run_on_cpu else ["--use_cuda"]) | ||
| + ([f"--{k}={v}" for k, v in args.items()]) | ||
| ) | ||
|
|
||
| command = command_str.strip().split(" ") | ||
| if is_inference: | ||
| vai_job_name = f"gigl_infer_{applied_task_identifier}" | ||
| else: | ||
| vai_job_name = f"gigl_train_{applied_task_identifier}" | ||
| resource_config_wrapper = get_resource_config( | ||
| resource_config_uri=resource_config_uri | ||
| ) | ||
| resource_config_labels = resource_config_wrapper.get_resource_labels( | ||
| component=GiGLComponents.Inferencer if is_inference else GiGLComponents.Trainer | ||
| ) | ||
| job_config = VertexAiJobConfig( | ||
| job_name=vai_job_name, | ||
| container_uri=container_uri, | ||
| command=command, | ||
| args=job_args, | ||
| environment_variables=env_vars, | ||
| machine_type=vertex_ai_resource_config.machine_type, | ||
| accelerator_type=vertex_ai_resource_config.gpu_type.upper().replace("-", "_"), | ||
| accelerator_count=vertex_ai_resource_config.gpu_limit, | ||
| replica_count=vertex_ai_resource_config.num_replicas, | ||
| labels=resource_config_labels, | ||
| timeout_s=vertex_ai_resource_config.timeout | ||
| if vertex_ai_resource_config.timeout | ||
| else None, | ||
| # This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]` | ||
| # But mypy complains otherwise... | ||
| # python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc] | ||
| # TODO(kmonte): Fix this | ||
| scheduling_strategy=getattr( | ||
| Scheduling.Strategy, | ||
| vertex_ai_resource_config.scheduling_strategy, | ||
| ) | ||
| if vertex_ai_resource_config.scheduling_strategy | ||
| else None, | ||
| ) | ||
| return job_config | ||
|
|
||
|
|
||
| class VertexAIService: | ||
| """ | ||
| A class representing a Vertex AI service. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,21 @@ | ||
| import argparse | ||
| from collections.abc import Mapping | ||
| from typing import Optional | ||
|
|
||
| from google.cloud.aiplatform_v1.types import Scheduling, accelerator_type, env_var | ||
| from google.cloud.aiplatform_v1.types import accelerator_type, env_var | ||
|
|
||
| from gigl.common import Uri, UriFactory | ||
| from gigl.common.constants import ( | ||
| DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU, | ||
| DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA, | ||
| ) | ||
| from gigl.common.logger import Logger | ||
| from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService | ||
| from gigl.common.services.vertex_ai import ( | ||
| VertexAIService, | ||
| get_job_config_from_vertex_ai_resource_config, | ||
| ) | ||
| from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY | ||
| from gigl.env.pipelines_config import get_resource_config | ||
| from gigl.src.common.constants.components import GiGLComponents | ||
| from gigl.src.common.types import AppliedTaskIdentifier | ||
| from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper | ||
| from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( | ||
|
|
@@ -20,6 +24,7 @@ | |
| from gigl.src.common.utils.metrics_service_provider import initialize_metrics | ||
| from snapchat.research.gbml.gigl_resource_config_pb2 import ( | ||
| LocalResourceConfig, | ||
| VertexAiGraphStoreConfig, | ||
| VertexAiResourceConfig, | ||
| ) | ||
|
|
||
|
|
@@ -49,6 +54,139 @@ class GLTInferencer: | |
| GiGL Component that runs a GLT Inference using a provided class path | ||
| """ | ||
|
|
||
| def _launch_single_pool( | ||
| self, | ||
| vertex_ai_resource_config: VertexAiResourceConfig, | ||
| applied_task_identifier: AppliedTaskIdentifier, | ||
| task_config_uri: Uri, | ||
| resource_config_uri: Uri, | ||
| inference_process_command: str, | ||
| inference_process_runtime_args: Mapping[str, str], | ||
| resource_config_wrapper: GiglResourceConfigWrapper, | ||
| cpu_docker_uri: Optional[str], | ||
| cuda_docker_uri: Optional[str], | ||
| ) -> None: | ||
| """Launch a single pool inference job on Vertex AI.""" | ||
| is_cpu_inference = _determine_if_cpu_inference( | ||
| inferencer_resource_config=vertex_ai_resource_config | ||
| ) | ||
| cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU | ||
| cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA | ||
| container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri | ||
|
|
||
| job_config = get_job_config_from_vertex_ai_resource_config( | ||
| applied_task_identifier=applied_task_identifier, | ||
| is_inference=True, | ||
| task_config_uri=task_config_uri, | ||
| resource_config_uri=resource_config_uri, | ||
| command_str=inference_process_command, | ||
| args=inference_process_runtime_args, | ||
| run_on_cpu=is_cpu_inference, | ||
| container_uri=container_uri, | ||
| vertex_ai_resource_config=vertex_ai_resource_config, | ||
| env_vars=[env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3")], | ||
| ) | ||
| logger.info(f"Launching inference job with config: {job_config}") | ||
|
|
||
| vertex_ai_service = VertexAIService( | ||
| project=resource_config_wrapper.project, | ||
| location=resource_config_wrapper.vertex_ai_inferencer_region, | ||
| service_account=resource_config_wrapper.service_account_email, | ||
| staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri, | ||
| ) | ||
| vertex_ai_service.launch_job(job_config=job_config) | ||
|
|
||
| def _launch_server_client( | ||
| self, | ||
| vertex_ai_graph_store_config: VertexAiGraphStoreConfig, | ||
| applied_task_identifier: AppliedTaskIdentifier, | ||
| task_config_uri: Uri, | ||
| resource_config_uri: Uri, | ||
| inference_process_command: str, | ||
| inference_process_runtime_args: Mapping[str, str], | ||
| resource_config_wrapper: GiglResourceConfigWrapper, | ||
| cpu_docker_uri: Optional[str], | ||
| cuda_docker_uri: Optional[str], | ||
| ) -> None: | ||
| """Launch a server/client inference job on Vertex AI using graph store config.""" | ||
| storage_pool_config = vertex_ai_graph_store_config.graph_store_pool | ||
| compute_pool_config = vertex_ai_graph_store_config.compute_pool | ||
|
|
||
| # Determine if CPU or GPU based on compute pool | ||
| is_cpu_inference = _determine_if_cpu_inference( | ||
| inferencer_resource_config=compute_pool_config | ||
| ) | ||
| cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU | ||
| cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA | ||
| container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri | ||
|
|
||
| logger.info(f"Running inference with command: {inference_process_command}") | ||
|
|
||
| num_compute_processes = ( | ||
| vertex_ai_graph_store_config.compute_cluster_local_world_size | ||
| ) | ||
| if not num_compute_processes: | ||
| if is_cpu_inference: | ||
| num_compute_processes = 1 | ||
| else: | ||
| num_compute_processes = ( | ||
| vertex_ai_graph_store_config.compute_pool.gpu_limit | ||
| ) | ||
| # Add server/client environment variables | ||
| environment_variables: list[env_var.EnvVar] = [ | ||
| env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), | ||
| env_var.EnvVar( | ||
| name=COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, | ||
| value=str(num_compute_processes), | ||
| ), | ||
| ] | ||
|
|
||
| # Create compute pool job config | ||
| compute_job_config = get_job_config_from_vertex_ai_resource_config( | ||
| applied_task_identifier=applied_task_identifier, | ||
| is_inference=True, | ||
| task_config_uri=task_config_uri, | ||
| resource_config_uri=resource_config_uri, | ||
| command_str=inference_process_command, | ||
| args=inference_process_runtime_args, | ||
| run_on_cpu=is_cpu_inference, | ||
| container_uri=container_uri, | ||
| vertex_ai_resource_config=compute_pool_config, | ||
| env_vars=environment_variables, | ||
| ) | ||
|
|
||
| # Create storage pool job config | ||
| storage_job_config = get_job_config_from_vertex_ai_resource_config( | ||
| applied_task_identifier=applied_task_identifier, | ||
| is_inference=True, | ||
| task_config_uri=task_config_uri, | ||
| resource_config_uri=resource_config_uri, | ||
| command_str="python -m gigl.distributed.server_client.server_main", | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: DNE atm, will be added in a later pr :) |
||
| args={}, # No extra args for storage pool | ||
| run_on_cpu=is_cpu_inference, | ||
| container_uri=container_uri, | ||
| vertex_ai_resource_config=storage_pool_config, | ||
| env_vars=environment_variables, | ||
| ) | ||
|
|
||
| # Determine region from compute pool or use default region | ||
| region = ( | ||
| compute_pool_config.gcp_region_override | ||
| if compute_pool_config.gcp_region_override | ||
| else resource_config_wrapper.region | ||
| ) | ||
|
|
||
| vertex_ai_service = VertexAIService( | ||
| project=resource_config_wrapper.project, | ||
| location=region, | ||
| service_account=resource_config_wrapper.service_account_email, | ||
| staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri, | ||
| ) | ||
| vertex_ai_service.launch_graph_store_job( | ||
| compute_pool_job_config=compute_job_config, | ||
| storage_pool_job_config=storage_job_config, | ||
| ) | ||
|
|
||
| def __execute_VAI_inference( | ||
| self, | ||
| applied_task_identifier: AppliedTaskIdentifier, | ||
|
|
@@ -74,73 +212,39 @@ def __execute_VAI_inference( | |
| inference_process_runtime_args = ( | ||
| gbml_config_pb_wrapper.inferencer_config.inferencer_args | ||
| ) | ||
| assert isinstance( | ||
| resource_config_wrapper.inferencer_config, VertexAiResourceConfig | ||
| ) | ||
| inferencer_resource_config: VertexAiResourceConfig = ( | ||
| resource_config_wrapper.inferencer_config | ||
| ) | ||
|
|
||
| is_cpu_training = _determine_if_cpu_inference( | ||
| inferencer_resource_config=inferencer_resource_config | ||
| ) | ||
| cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU | ||
| cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA | ||
| container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri | ||
|
|
||
| job_args = ( | ||
| [ | ||
| f"--job_name={applied_task_identifier}", | ||
| f"--task_config_uri={task_config_uri}", | ||
| f"--resource_config_uri={resource_config_uri}", | ||
| ] | ||
| + ([] if is_cpu_training else ["--use_cuda"]) | ||
| + ([f"--{k}={v}" for k, v in inference_process_runtime_args.items()]) | ||
| ) | ||
|
|
||
| command = inference_process_command.strip().split(" ") | ||
| logger.info(f"Running inference with command: {command}") | ||
| vai_job_name = f"gigl_infer_{applied_task_identifier}" | ||
| environment_variables: list[env_var.EnvVar] = [ | ||
| env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"), | ||
| ] | ||
| job_config = VertexAiJobConfig( | ||
| job_name=vai_job_name, | ||
| container_uri=container_uri, | ||
| command=command, | ||
| args=job_args, | ||
| environment_variables=environment_variables, | ||
| machine_type=inferencer_resource_config.machine_type, | ||
| accelerator_type=inferencer_resource_config.gpu_type.upper().replace( | ||
| "-", "_" | ||
| ), | ||
| accelerator_count=inferencer_resource_config.gpu_limit, | ||
| replica_count=inferencer_resource_config.num_replicas, | ||
| labels=resource_config_wrapper.get_resource_labels( | ||
| component=GiGLComponents.Inferencer | ||
| ), | ||
| timeout_s=inferencer_resource_config.timeout | ||
| if inferencer_resource_config.timeout | ||
| else None, | ||
| # This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]` | ||
| # But mypy complains otherwise... | ||
| # python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc] | ||
| # TODO(kmonte): Fix this | ||
| scheduling_strategy=getattr( | ||
| Scheduling.Strategy, | ||
| inferencer_resource_config.scheduling_strategy, | ||
| if isinstance( | ||
| resource_config_wrapper.inferencer_config, VertexAiResourceConfig | ||
| ): | ||
| self._launch_single_pool( | ||
| vertex_ai_resource_config=resource_config_wrapper.inferencer_config, | ||
| applied_task_identifier=applied_task_identifier, | ||
| task_config_uri=task_config_uri, | ||
| resource_config_uri=resource_config_uri, | ||
| inference_process_command=inference_process_command, | ||
| inference_process_runtime_args=inference_process_runtime_args, | ||
| resource_config_wrapper=resource_config_wrapper, | ||
| cpu_docker_uri=cpu_docker_uri, | ||
| cuda_docker_uri=cuda_docker_uri, | ||
| ) | ||
| elif isinstance( | ||
| resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig | ||
| ): | ||
| self._launch_server_client( | ||
| vertex_ai_graph_store_config=resource_config_wrapper.inferencer_config, | ||
| applied_task_identifier=applied_task_identifier, | ||
| task_config_uri=task_config_uri, | ||
| resource_config_uri=resource_config_uri, | ||
| inference_process_command=inference_process_command, | ||
| inference_process_runtime_args=inference_process_runtime_args, | ||
| resource_config_wrapper=resource_config_wrapper, | ||
| cpu_docker_uri=cpu_docker_uri, | ||
| cuda_docker_uri=cuda_docker_uri, | ||
| ) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported resource config for glt inference: {type(resource_config_wrapper.inferencer_config).__name__}" | ||
| ) | ||
| if inferencer_resource_config.scheduling_strategy | ||
| else None, | ||
| ) | ||
|
|
||
| vertex_ai_service = VertexAIService( | ||
| project=resource_config_wrapper.project, | ||
| location=resource_config_wrapper.vertex_ai_inferencer_region, | ||
| service_account=resource_config_wrapper.service_account_email, | ||
| staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri, | ||
| ) | ||
| vertex_ai_service.launch_job(job_config=job_config) | ||
|
|
||
| def run( | ||
| self, | ||
|
|
@@ -157,10 +261,12 @@ def run( | |
|
|
||
| if isinstance(resource_config_wrapper.inferencer_config, LocalResourceConfig): | ||
| raise NotImplementedError( | ||
| f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} resource config field." | ||
| f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} or {VertexAiGraphStoreConfig.__name__} resource config field." | ||
| ) | ||
| elif isinstance( | ||
| resource_config_wrapper.inferencer_config, VertexAiResourceConfig | ||
| ) or isinstance( | ||
| resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig | ||
| ): | ||
| self.__execute_VAI_inference( | ||
| applied_task_identifier=applied_task_identifier, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.