Skip to content

Commit

Permalink
Batch notebooks update for txt2img to address environment error (Azur…
Browse files Browse the repository at this point in the history
…e#2953)

* Pass custom environment in batch txt2img notebooks

* cli example

* cli update

* code quality

---------

Co-authored-by: grajguru <[email protected]>
  • Loading branch information
gauravrajguru and grajguru authored Jan 17, 2024
1 parent a5c29c3 commit 581112f
Show file tree
Hide file tree
Showing 13 changed files with 737 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def init():
global g_logger
global g_file_loader_dictionary
global aacs_client
global aacs_enabled

g_logger = logging.getLogger("azureml")
g_logger.setLevel(logging.INFO)
Expand All @@ -60,10 +61,15 @@ def init():
".pqt": load_parquet,
}

endpoint = os.environ.get("CONTENT_SAFETY_ENDPOINT")
key = os.environ.get("CONTENT_SAFETY_KEY")
endpoint = os.environ.get("CONTENT_SAFETY_ENDPOINT", None)
key = os.environ.get("CONTENT_SAFETY_KEY", None)
# Create an Content Safety client
aacs_client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
if endpoint is not None and key is not None:
aacs_client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
aacs_enabled = True
else:
aacs_enabled = False
g_logger.warn("Azure AI Content Safety (aacs) is disabled.")


def get_input_schema(model_path):
Expand Down Expand Up @@ -372,13 +378,18 @@ def run(batch_input):
predict_result = []
try:
aacs_threshold = int(os.environ.get("CONTENT_SAFETY_THRESHOLD", default=1))
blocked_input = analyze_data(
batch_input, aacs_threshold, blocked_input=None, is_input=True
)
if aacs_enabled:
blocked_input = analyze_data(
batch_input, aacs_threshold, blocked_input=None, is_input=True
)
predict_result = g_model.predict(batch_input)
_ = analyze_data(
predict_result, aacs_threshold, blocked_input=blocked_input, is_input=False
)
if aacs_enabled:
_ = analyze_data(
predict_result,
aacs_threshold,
blocked_input=blocked_input,
is_input=False,
)
except Exception as e:
g_logger.error("Processing mini batch failed with exception: " + str(e))
g_logger.error(traceback.format_exc())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ else
fi

# 4. Submit a sample request to endpoint
data_path="./text_to_image_batch_data/batch_data"
data_path="./text_to_image_batch_data"
python utils/prepare_data.py --payload-path $data_path --mode "batch"
# Path where the processes csvs are dumped. This is the input to the endpoint
processed_data_path="./text_to_image_batch_data/processed_batch_data"
Expand All @@ -73,11 +73,32 @@ az ml batch-endpoint create --name $endpoint_name $workspace_info || {
echo "endpoint create failed"; exit 1;
}

# create a environment for batch deployment

environment_name="text-to-image-model-env"
environment_label="latest"

if ! az ml environment show --name $environment_name --label $environment_label $workspace_info
then
echo "Environment $environment_name:$environment_label does not exist in Workspace."
echo "---Creating environment---"
az ml environment create --name $environment_name --build-context "./scoring-files/docker_env" \
$workspace_info || {
echo "environment create failed"; exit 1;
}
exit 1
fi

environment_version=$(az ml environment show --name $environment_name --label $environment_label $workspace_info --query version --output tsv)

# deploy model from registry to endpoint in workspace
az ml batch-deployment create --file batch-deploy.yml $workspace_info --set \
endpoint_name=$endpoint_name \
name=$deployment_name \
compute=$deployment_compute \
environment=azureml:$environment_name:$environment_version \
code_configuration.code="scoring-files/score" \
code_configuration.scoring_script="score_batch.py" \
model=azureml://registries/$registry_name/models/$model_name/versions/$model_version || {
echo "deployment create failed"; exit 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@
" print(\"---Creating environment---\")\n",
" env = Environment(\n",
" name=environment_name,\n",
" build=BuildContext(path=\"./aacs-scoring-files/docker_env\"),\n",
" build=BuildContext(path=\"./scoring-files/docker_env\"),\n",
" )\n",
" ml_client.environments.create_or_update(env)\n",
" env = ml_client.environments.get(environment_name, label=\"latest\")\n",
Expand Down Expand Up @@ -447,7 +447,7 @@
" model=model,\n",
" environment=env,\n",
" code_configuration=CodeConfiguration(\n",
" code=\"aacs-scoring-files/score\",\n",
" code=\"scoring-files/score\",\n",
" scoring_script=\"score_batch.py\",\n",
" ),\n",
" compute=compute_name,\n",
Expand Down Expand Up @@ -516,7 +516,7 @@
"import pandas as pd\n",
"\n",
"# Specify the folder where your CSV files are located\n",
"data_path = \"aacs-scoring-files/text-to-image-batch-data\"\n",
"data_path = \"scoring-files/text-to-image-batch-data\"\n",
"\n",
"# Use glob to get a list of CSV files in the folder\n",
"csv_files = glob.glob(os.path.join(data_path, \"*.csv\"))\n",
Expand All @@ -537,7 +537,7 @@
"from pathlib import Path\n",
"\n",
"# Specify the folder where your CSV files should be saved\n",
"processed_dataset_parent_dir = \"aacs-scoring-files/processed-text-to-image-batch-data\"\n",
"processed_dataset_parent_dir = \"scoring-files/processed-text-to-image-batch-data\"\n",
"os.makedirs(processed_dataset_parent_dir, exist_ok=True)\n",
"batch_input_file = \"batch_input.csv\"\n",
"\n",
Expand Down Expand Up @@ -712,7 +712,7 @@
" names=[\n",
" \"row_number_per_file\",\n",
" \"prompt\",\n",
" \"image_file_name\",\n",
" \"generated_image\",\n",
" \"nsfw_content_detected\",\n",
" \"file_name\",\n",
" ],\n",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM mcr.microsoft.com/azureml/openmpi4.1.0-ubuntu20.04:20230620.v1

ENV CONDA_ENVIRONMENT_PATH /azureml-envs/text-to-image

# Prepend path to AzureML conda environment
ENV PATH $CONDA_ENVIRONMENT_PATH/bin:$PATH

# Create conda environment
COPY conda_dependencies.yaml .
RUN conda env create -p $CONDA_ENVIRONMENT_PATH -f conda_dependencies.yaml -q && \
rm conda_dependencies.yaml && \
conda run -p $CONDA_ENVIRONMENT_PATH pip cache purge && \
conda clean -a -y

RUN pip freeze

# This is needed for mpi to locate libpython
ENV LD_LIBRARY_PATH $CONDA_ENVIRONMENT_PATH/lib:$LD_LIBRARY_PATH
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
channels:
- conda-forge
dependencies:
- python=3.8.16
- pip<=23.1.2
- pip:
- mlflow==2.3.2
- torch==1.13.0
- transformers==4.29.1
- diffusers==0.23.0
- accelerate==0.22.0
- azureml-core==1.52.0
- azureml-mlflow==1.52.0
- azure-ai-contentsafety==1.0.0b1
- aiolimiter==1.1.0
- azure-ai-mlmonitoring==0.1.0a3
- azure-mgmt-cognitiveservices==13.4.0
- azure-identity==1.13.0
name: mlflow-env
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"append_row": {"pandas.DataFrame.to_csv": {"sep": ",","index": true}}}
Loading

0 comments on commit 581112f

Please sign in to comment.