Skip to content
Draft
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions python/gigl/common/services/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name

import datetime
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Final, Optional, Union

Expand All @@ -69,12 +70,17 @@ def get_pipeline() -> int: # NOTE: `get_pipeline` here is the Pipeline name
ContainerSpec,
DiskSpec,
MachineSpec,
Scheduling,
WorkerPoolSpec,
env_var,
)

from gigl.common import GcsUri, Uri
from gigl.common.logger import Logger
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from snapchat.research.gbml.gigl_resource_config_pb2 import VertexAiResourceConfig

logger = Logger()

Expand Down Expand Up @@ -108,6 +114,67 @@ class VertexAiJobConfig:
scheduling_strategy: Optional[aiplatform.gapic.Scheduling.Strategy] = None


def get_job_config_from_vertex_ai_resource_config(
applied_task_identifier: AppliedTaskIdentifier,
is_inference: bool,
task_config_uri: Uri,
resource_config_uri: Uri,
command_str: str,
args: Mapping[str, str],
run_on_cpu: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have been using use_cuda throughout the repo.

container_uri: str,
vertex_ai_resource_config: VertexAiResourceConfig,
env_vars: list[env_var.EnvVar],
) -> VertexAiJobConfig:
job_args = (
[
f"--job_name={applied_task_identifier}",
f"--task_config_uri={task_config_uri}",
f"--resource_config_uri={resource_config_uri}",
Copy link
Collaborator

@svij-sc svij-sc Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we dont have import statements, the logic that we will be formulating these args and passing them through for the trainer to use is still intermixing the architecture layers.

Figuring out what arguments to pass in based off versioned trainer/inferencer implementation (v2) is still very tightly coupling business layer, and pinned to specific versioned implementation of 2 classes that "may use VAI".
v1/v3 trainer/inferencer cant use this.

Ideally the purpose of this service should just be:
"Launch some job", "Run some pipeline", "Get status of X", "Wait for Y".

IMO, Being able to differentiate and launch graph_store vs regular jobs is probably the max amount of coupling that should be introduced in this service.

]
+ ([] if run_on_cpu else ["--use_cuda"])
+ ([f"--{k}={v}" for k, v in args.items()])
)

command = command_str.strip().split(" ")
if is_inference:
vai_job_name = f"gigl_infer_{applied_task_identifier}"
else:
vai_job_name = f"gigl_train_{applied_task_identifier}"
resource_config_wrapper = get_resource_config(
resource_config_uri=resource_config_uri
)
resource_config_labels = resource_config_wrapper.get_resource_labels(
component=GiGLComponents.Inferencer if is_inference else GiGLComponents.Trainer
)
job_config = VertexAiJobConfig(
job_name=vai_job_name,
container_uri=container_uri,
command=command,
args=job_args,
environment_variables=env_vars,
machine_type=vertex_ai_resource_config.machine_type,
accelerator_type=vertex_ai_resource_config.gpu_type.upper().replace("-", "_"),
accelerator_count=vertex_ai_resource_config.gpu_limit,
replica_count=vertex_ai_resource_config.num_replicas,
labels=resource_config_labels,
timeout_s=vertex_ai_resource_config.timeout
if vertex_ai_resource_config.timeout
else None,
# This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]`
# But mypy complains otherwise...
# python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc]
# TODO(kmonte): Fix this
scheduling_strategy=getattr(
Scheduling.Strategy,
vertex_ai_resource_config.scheduling_strategy,
)
if vertex_ai_resource_config.scheduling_strategy
else None,
)
return job_config


class VertexAIService:
"""
A class representing a Vertex AI service.
Expand Down
244 changes: 175 additions & 69 deletions python/gigl/src/inference/v2/glt_inferencer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
import argparse
from collections.abc import Mapping
from typing import Optional

from google.cloud.aiplatform_v1.types import Scheduling, accelerator_type, env_var
from google.cloud.aiplatform_v1.types import accelerator_type, env_var

from gigl.common import Uri, UriFactory
from gigl.common.constants import (
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU,
DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA,
)
from gigl.common.logger import Logger
from gigl.common.services.vertex_ai import VertexAiJobConfig, VertexAIService
from gigl.common.services.vertex_ai import (
VertexAIService,
get_job_config_from_vertex_ai_resource_config,
)
from gigl.env.distributed import COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY
from gigl.env.pipelines_config import get_resource_config
from gigl.src.common.constants.components import GiGLComponents
from gigl.src.common.types import AppliedTaskIdentifier
from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper
from gigl.src.common.types.pb_wrappers.gigl_resource_config import (
Expand All @@ -20,6 +24,7 @@
from gigl.src.common.utils.metrics_service_provider import initialize_metrics
from snapchat.research.gbml.gigl_resource_config_pb2 import (
LocalResourceConfig,
VertexAiGraphStoreConfig,
VertexAiResourceConfig,
)

Expand Down Expand Up @@ -49,6 +54,139 @@ class GLTInferencer:
GiGL Component that runs a GLT Inference using a provided class path
"""

def _launch_single_pool(
self,
vertex_ai_resource_config: VertexAiResourceConfig,
applied_task_identifier: AppliedTaskIdentifier,
task_config_uri: Uri,
resource_config_uri: Uri,
inference_process_command: str,
inference_process_runtime_args: Mapping[str, str],
resource_config_wrapper: GiglResourceConfigWrapper,
cpu_docker_uri: Optional[str],
cuda_docker_uri: Optional[str],
) -> None:
"""Launch a single pool inference job on Vertex AI."""
is_cpu_inference = _determine_if_cpu_inference(
inferencer_resource_config=vertex_ai_resource_config
)
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri

job_config = get_job_config_from_vertex_ai_resource_config(
applied_task_identifier=applied_task_identifier,
is_inference=True,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
command_str=inference_process_command,
args=inference_process_runtime_args,
run_on_cpu=is_cpu_inference,
container_uri=container_uri,
vertex_ai_resource_config=vertex_ai_resource_config,
env_vars=[env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3")],
)
logger.info(f"Launching inference job with config: {job_config}")

vertex_ai_service = VertexAIService(
project=resource_config_wrapper.project,
location=resource_config_wrapper.vertex_ai_inferencer_region,
service_account=resource_config_wrapper.service_account_email,
staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri,
)
vertex_ai_service.launch_job(job_config=job_config)

def _launch_server_client(
self,
vertex_ai_graph_store_config: VertexAiGraphStoreConfig,
applied_task_identifier: AppliedTaskIdentifier,
task_config_uri: Uri,
resource_config_uri: Uri,
inference_process_command: str,
inference_process_runtime_args: Mapping[str, str],
resource_config_wrapper: GiglResourceConfigWrapper,
cpu_docker_uri: Optional[str],
cuda_docker_uri: Optional[str],
) -> None:
"""Launch a server/client inference job on Vertex AI using graph store config."""
storage_pool_config = vertex_ai_graph_store_config.graph_store_pool
compute_pool_config = vertex_ai_graph_store_config.compute_pool

# Determine if CPU or GPU based on compute pool
is_cpu_inference = _determine_if_cpu_inference(
inferencer_resource_config=compute_pool_config
)
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_inference else cuda_docker_uri

logger.info(f"Running inference with command: {inference_process_command}")

num_compute_processes = (
vertex_ai_graph_store_config.compute_cluster_local_world_size
)
if not num_compute_processes:
if is_cpu_inference:
num_compute_processes = 1
else:
num_compute_processes = (
vertex_ai_graph_store_config.compute_pool.gpu_limit
)
# Add server/client environment variables
environment_variables: list[env_var.EnvVar] = [
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
env_var.EnvVar(
name=COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY,
value=str(num_compute_processes),
),
]

# Create compute pool job config
compute_job_config = get_job_config_from_vertex_ai_resource_config(
applied_task_identifier=applied_task_identifier,
is_inference=True,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
command_str=inference_process_command,
args=inference_process_runtime_args,
run_on_cpu=is_cpu_inference,
container_uri=container_uri,
vertex_ai_resource_config=compute_pool_config,
env_vars=environment_variables,
)

# Create storage pool job config
storage_job_config = get_job_config_from_vertex_ai_resource_config(
applied_task_identifier=applied_task_identifier,
is_inference=True,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
command_str="python -m gigl.distributed.server_client.server_main",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: DNE atm, will be added in a later pr :)

args={}, # No extra args for storage pool
run_on_cpu=is_cpu_inference,
container_uri=container_uri,
vertex_ai_resource_config=storage_pool_config,
env_vars=environment_variables,
)

# Determine region from compute pool or use default region
region = (
compute_pool_config.gcp_region_override
if compute_pool_config.gcp_region_override
else resource_config_wrapper.region
)

vertex_ai_service = VertexAIService(
project=resource_config_wrapper.project,
location=region,
service_account=resource_config_wrapper.service_account_email,
staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri,
)
vertex_ai_service.launch_graph_store_job(
compute_pool_job_config=compute_job_config,
storage_pool_job_config=storage_job_config,
)

def __execute_VAI_inference(
self,
applied_task_identifier: AppliedTaskIdentifier,
Expand All @@ -74,73 +212,39 @@ def __execute_VAI_inference(
inference_process_runtime_args = (
gbml_config_pb_wrapper.inferencer_config.inferencer_args
)
assert isinstance(
resource_config_wrapper.inferencer_config, VertexAiResourceConfig
)
inferencer_resource_config: VertexAiResourceConfig = (
resource_config_wrapper.inferencer_config
)

is_cpu_training = _determine_if_cpu_inference(
inferencer_resource_config=inferencer_resource_config
)
cpu_docker_uri = cpu_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CPU
cuda_docker_uri = cuda_docker_uri or DEFAULT_GIGL_RELEASE_SRC_IMAGE_CUDA
container_uri = cpu_docker_uri if is_cpu_training else cuda_docker_uri

job_args = (
[
f"--job_name={applied_task_identifier}",
f"--task_config_uri={task_config_uri}",
f"--resource_config_uri={resource_config_uri}",
]
+ ([] if is_cpu_training else ["--use_cuda"])
+ ([f"--{k}={v}" for k, v in inference_process_runtime_args.items()])
)

command = inference_process_command.strip().split(" ")
logger.info(f"Running inference with command: {command}")
vai_job_name = f"gigl_infer_{applied_task_identifier}"
environment_variables: list[env_var.EnvVar] = [
env_var.EnvVar(name="TF_CPP_MIN_LOG_LEVEL", value="3"),
]
job_config = VertexAiJobConfig(
job_name=vai_job_name,
container_uri=container_uri,
command=command,
args=job_args,
environment_variables=environment_variables,
machine_type=inferencer_resource_config.machine_type,
accelerator_type=inferencer_resource_config.gpu_type.upper().replace(
"-", "_"
),
accelerator_count=inferencer_resource_config.gpu_limit,
replica_count=inferencer_resource_config.num_replicas,
labels=resource_config_wrapper.get_resource_labels(
component=GiGLComponents.Inferencer
),
timeout_s=inferencer_resource_config.timeout
if inferencer_resource_config.timeout
else None,
# This should be `aiplatform.gapic.Scheduling.Strategy[inferencer_resource_config.scheduling_strategy]`
# But mypy complains otherwise...
# python/gigl/src/inference/v2/glt_inferencer.py:124: error: The type "type[Strategy]" is not generic and not indexable [misc]
# TODO(kmonte): Fix this
scheduling_strategy=getattr(
Scheduling.Strategy,
inferencer_resource_config.scheduling_strategy,
if isinstance(
resource_config_wrapper.inferencer_config, VertexAiResourceConfig
):
self._launch_single_pool(
vertex_ai_resource_config=resource_config_wrapper.inferencer_config,
applied_task_identifier=applied_task_identifier,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
inference_process_command=inference_process_command,
inference_process_runtime_args=inference_process_runtime_args,
resource_config_wrapper=resource_config_wrapper,
cpu_docker_uri=cpu_docker_uri,
cuda_docker_uri=cuda_docker_uri,
)
elif isinstance(
resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig
):
self._launch_server_client(
vertex_ai_graph_store_config=resource_config_wrapper.inferencer_config,
applied_task_identifier=applied_task_identifier,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
inference_process_command=inference_process_command,
inference_process_runtime_args=inference_process_runtime_args,
resource_config_wrapper=resource_config_wrapper,
cpu_docker_uri=cpu_docker_uri,
cuda_docker_uri=cuda_docker_uri,
)
else:
raise NotImplementedError(
f"Unsupported resource config for glt inference: {type(resource_config_wrapper.inferencer_config).__name__}"
)
if inferencer_resource_config.scheduling_strategy
else None,
)

vertex_ai_service = VertexAIService(
project=resource_config_wrapper.project,
location=resource_config_wrapper.vertex_ai_inferencer_region,
service_account=resource_config_wrapper.service_account_email,
staging_bucket=resource_config_wrapper.temp_assets_regional_bucket_path.uri,
)
vertex_ai_service.launch_job(job_config=job_config)

def run(
self,
Expand All @@ -157,10 +261,12 @@ def run(

if isinstance(resource_config_wrapper.inferencer_config, LocalResourceConfig):
raise NotImplementedError(
f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} resource config field."
f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} or {VertexAiGraphStoreConfig.__name__} resource config field."
)
elif isinstance(
resource_config_wrapper.inferencer_config, VertexAiResourceConfig
) or isinstance(
resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig
):
self.__execute_VAI_inference(
applied_task_identifier=applied_task_identifier,
Expand Down
Loading