Skip to content

Commit

Permalink
Merge pull request #5 from cloudera/grpc
Browse files Browse the repository at this point in the history
gRPC API surface plus response/requests datatype naming conventions
  • Loading branch information
jasonmeverett authored Aug 7, 2024
2 parents 324c165 + cca4959 commit 83eb7c8
Show file tree
Hide file tree
Showing 23 changed files with 2,441 additions and 509 deletions.
4 changes: 2 additions & 2 deletions .app/state.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
"huggingface_model_name": "distilbert/distilgpt2"
}
],
"jobs": [],
"mlflow": [],
"fine_tuning_jobs": [],
"evaluation_jobs": [],
"prompts": [
{
"id": "14674c2e-6641-4494-aff9-6f3cd021d710",
Expand Down
9 changes: 8 additions & 1 deletion bin/run-app.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
!streamlit run main.py --server.port $CDSW_APP_PORT --server.address 127.0.0.1
import subprocess
import os

CDSW_APP_PORT = os.environ.get("CDSW_APP_PORT")
out = subprocess.run([f"bash ./bin/start-app-script.sh {CDSW_APP_PORT}"], shell=True, check=True)
print(out)

print("App start script is complete.")
22 changes: 22 additions & 0 deletions bin/start-app-script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/bin/bash


# gRPC server will spawn on 50051
PORT=50051

{ # Try to start up the server
echo "Starting up the gRPC server..."
nohup python bin/start-grpc-server.py &
} || {
echo "gRPC server initialization script failed. Is there already a local server running in the pod?"
}


echo "Waiting 5 seconds..."
sleep 5


# Start up the streamlit application
echo "Starting up streamlit application..."
streamlit run main.py --server.port $1 --server.address 127.0.0.1

33 changes: 33 additions & 0 deletions bin/start-grpc-server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# start-grpc-server.py
from concurrent import futures
import logging
import grpc
from ft.proto import fine_tuning_studio_pb2_grpc
from ft.service import FineTuningStudioApp
from multiprocessing import Process
import socket

def start_server(blocking: bool = False):
port = "50051"
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
fine_tuning_studio_pb2_grpc.add_FineTuningStudioServicer_to_server(FineTuningStudioApp(), server=server)
server.add_insecure_port("[::]:" + port)
server.start()
print("Server started, listening on " + port)

if blocking:
server.wait_for_termination()

def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex(('localhost', port))
return result == 0


port = 50051
if not is_port_in_use(port):
print("Starting up the gRPC server.")
# Start the gRPC server if it's not already running
start_server(blocking=True)
else:
print("Server is already running.")
30 changes: 15 additions & 15 deletions ft/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def __init__(self, props: FineTuningAppProps):
self.datasets = props.datasets_manager
return

def add_dataset(self, request: ImportDatasetRequest) -> ImportDatasetResponse:
def add_dataset(self, request: AddDatasetRequest) -> AddDatasetResponse:
"""
Add a dataset to the App based on the request.
"""
import_response: ImportDatasetResponse = self.datasets.import_dataset(request)
import_response: AddDatasetResponse = self.datasets.import_dataset(request)

# If we've successfully imported a new dataset, then make sure we update
# the app's dataset state with this data. The way we detect this in protobuf
Expand All @@ -122,8 +122,8 @@ def remove_dataset(self, id: str):
datasets=datasets,
prompts=prompts,
adapters=state.adapters,
jobs=state.jobs,
mlflow=state.mlflow,
fine_tuning_jobs=state.fine_tuning_jobs,
evaluation_jobs=state.evaluation_jobs,
models=state.models
))

Expand All @@ -143,16 +143,16 @@ def remove_prompt(self, id: str):
datasets=state.datasets,
prompts=prompts,
adapters=state.adapters,
jobs=state.jobs,
mlflow=state.mlflow,
fine_tuning_jobs=state.fine_tuning_jobs,
evaluation_jobs=state.evaluation_jobs,
models=state.models
))

def import_model(self, request: ImportModelRequest) -> ImportModelResponse:
def import_model(self, request: AddModelRequest) -> AddModelResponse:
"""
Add a dataset to the App based on the request.
"""
import_response: ImportModelResponse = self.models.import_model(request)
import_response: AddModelResponse = self.models.import_model(request)

# If we've successfully imported a new dataset, then make sure we update
# the app's dataset state with this data. For now, using protobuf, we will
Expand Down Expand Up @@ -193,8 +193,8 @@ def remove_model(self, id: str):
datasets=state.datasets,
prompts=state.prompts,
adapters=state.adapters,
jobs=state.jobs,
mlflow=state.mlflow,
fine_tuning_jobs=state.fine_tuning_jobs,
evaluation_jobs=state.evaluation_jobs,
models=models
))

Expand All @@ -206,20 +206,20 @@ def launch_ft_job(self, request: StartFineTuningJobRequest) -> StartFineTuningJo

if not job_launch_response.job == StartFineTuningJobResponse().job:
state: AppState = get_state()
state.jobs.append(job_launch_response.job)
state.fine_tuning_jobs.append(job_launch_response.job)
write_state(state)

return job_launch_response

def launch_mlflow_job(self, request: StartMLflowEvaluationJobRequest) -> StartMLflowEvaluationJobResponse:
def launch_mlflow_job(self, request: StartEvaluationJobRequest) -> StartEvaluationJobResponse:
"""
Create and launch a job for MLflow
"""
job_launch_response: StartMLflowEvaluationJobResponse = self.mlflow.start_ml_flow_evaluation_job(request)
job_launch_response: StartEvaluationJobResponse = self.mlflow.start_ml_flow_evaluation_job(request)

if not job_launch_response.job == StartMLflowEvaluationJobResponse().job:
if not job_launch_response.job == StartEvaluationJobResponse().job:
state: AppState = get_state()
state.mlflow.append(job_launch_response.job)
state.evaluation_jobs.append(job_launch_response.job)
write_state(state)

return job_launch_response
Expand Down
12 changes: 12 additions & 0 deletions ft/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

import streamlit as st
import grpc

from ft.api import *
from ft.proto.fine_tuning_studio_pb2_grpc import FineTuningStudioStub


with grpc.insecure_channel('localhost:50051') as channel:
stub = FineTuningStudioStub(channel=channel)
datasets: ListDatasetsResponse = stub.ListDatasets(ListDatasetsRequest())
print(datasets)
10 changes: 5 additions & 5 deletions ft/managers/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from ft.api import DatasetMetadata, ImportDatasetRequest, DatasetType, ImportDatasetResponse
from ft.api import *
from typing import List
from datasets import load_dataset_builder
from ft.state import get_state
Expand All @@ -24,7 +24,7 @@ def get_dataset(self, id: str) -> DatasetMetadata:
pass

@abstractmethod
def import_dataset(self, request: ImportDatasetRequest) -> ImportDatasetResponse:
def import_dataset(self, request: AddDatasetRequest) -> AddDatasetResponse:
pass


Expand All @@ -42,11 +42,11 @@ def get_dataset(self, id: str) -> DatasetMetadata:
assert len(datasets) == 1
return datasets[0]

def import_dataset(self, request: ImportDatasetRequest) -> ImportDatasetResponse:
def import_dataset(self, request: AddDatasetRequest) -> AddDatasetResponse:
"""
Retrieve dataset information without fully loading it into memory.
"""
response = ImportDatasetResponse()
response = AddDatasetResponse()

# Create a new dataset metadata for the imported dataset.
if request.type == DatasetType.DATASET_TYPE_HUGGINGFACE:
Expand All @@ -72,7 +72,7 @@ def import_dataset(self, request: ImportDatasetRequest) -> ImportDatasetResponse
name=request.huggingface_name,
description=dataset_info.description
)
response = ImportDatasetResponse(dataset=metadata)
response = AddDatasetResponse(dataset=metadata)

except Exception as e:
raise ValueError(f"Failed to load dataset. {e}")
Expand Down
12 changes: 6 additions & 6 deletions ft/managers/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ def list_ml_flow_evaluation_jobs(self):
pass

@abstractmethod
def get_ml_flow_evaluation_job(self, job_id: str) -> MLflowEvaluationJobMetadata:
def get_ml_flow_evaluation_job(self, job_id: str) -> EvaluationJobMetadata:
pass

@abstractmethod
def start_ml_flow_evaluation_job(self, request: StartMLflowEvaluationJobRequest):
def start_ml_flow_evaluation_job(self, request: StartEvaluationJobRequest):
pass


Expand All @@ -36,11 +36,11 @@ def list_ml_flow_evaluation_jobs(self):
# Method to list ML flow evaluation jobs
pass

def get_ml_flow_evaluation_job(self, job_id: str) -> MLflowEvaluationJobMetadata:
def get_ml_flow_evaluation_job(self, job_id: str) -> EvaluationJobMetadata:
# Method to get a specific ML flow evaluation job
return super().get_ml_flow_evaluation_job(job_id)

def start_ml_flow_evaluation_job(self, request: StartMLflowEvaluationJobRequest):
def start_ml_flow_evaluation_job(self, request: StartEvaluationJobRequest):
"""
Launch a CML Job which runs/orchestrates a finetuning operation.
The CML Job itself does not run the finetuning work; it will launch a CML Worker(s) to allow
Expand Down Expand Up @@ -127,7 +127,7 @@ def start_ml_flow_evaluation_job(self, request: StartMLflowEvaluationJobRequest)
job_id=created_job.id
)

metadata = MLflowEvaluationJobMetadata(
metadata = EvaluationJobMetadata(
job_id=job_id,
base_model_id=request.base_model_id,
dataset_id=request.dataset_id,
Expand All @@ -142,4 +142,4 @@ def start_ml_flow_evaluation_job(self, request: StartMLflowEvaluationJobRequest)
evaluation_dir=result_dir
)

return StartMLflowEvaluationJobResponse(job=metadata)
return StartEvaluationJobResponse(job=metadata)
10 changes: 5 additions & 5 deletions ft/managers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def list_models(self) -> List[ModelMetadata]:
pass

@abstractmethod
def import_model(self, request: ImportModelRequest) -> ImportModelResponse:
def import_model(self, request: AddModelRequest) -> AddModelResponse:
pass

@abstractmethod
Expand All @@ -45,8 +45,8 @@ def __init__(self):
def list_models(self) -> List[ModelMetadata]:
return get_state().models

def import_model(self, request: ImportModelRequest) -> ImportModelResponse:
response: ImportModelResponse = ImportModelResponse()
def import_model(self, request: AddModelRequest) -> AddModelResponse:
response: AddModelResponse = AddModelResponse()

if request.type == ModelType.MODEL_TYPE_HUGGINGFACE:
try:
Expand All @@ -67,7 +67,7 @@ def import_model(self, request: ImportModelRequest) -> ImportModelResponse:
huggingface_model_name=request.huggingface_name,
)

response = ImportModelResponse(
response = AddModelResponse(
model=model_metadata
)
except Exception as e:
Expand Down Expand Up @@ -96,7 +96,7 @@ def import_model(self, request: ImportModelRequest) -> ImportModelResponse:
)
)

response = ImportModelResponse(
response = AddModelResponse(
model=model_metadata
)
except Exception as e:
Expand Down
Loading

0 comments on commit 83eb7c8

Please sign in to comment.