Skip to content

Commit

Permalink
Databricks Endpoint Mode
Browse files Browse the repository at this point in the history
  • Loading branch information
C-K-Loan committed Sep 10, 2023
1 parent 5477927 commit 3ec7179
Showing 1 changed file with 130 additions and 52 deletions.
182 changes: 130 additions & 52 deletions nlu/pipe/utils/predict_helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
from typing import Optional

import sparknlp
from pyspark.sql.functions import monotonically_increasing_id
Expand All @@ -8,9 +10,52 @@

logger = logging.getLogger('nlu')
from nlu.pipe.pipe_logic import PipeUtils
import pandas as pd
from nlu.pipe.utils.data_conversion_utils import DataConversionUtils

import pandas as pd
from pydantic import BaseModel


def serialize(img_path):
with open(img_path, 'rb') as img_file:
return img_file.read()


def deserialize(binary_image, path):
with open(path, 'wb') as img_file:
img_file.write(binary_image)


class PredictParams(BaseModel):
output_level: Optional[str] = ''
positions: Optional[bool] = False
keep_stranger_features: Optional[bool] = True
metadata: Optional[bool] = False
multithread: Optional[bool] = True
drop_irrelevant_cols: Optional[bool] = True
return_spark_df: Optional[bool] = False
get_embeddings: Optional[bool] = True

@staticmethod
def has_param_cols(df: pd.DataFrame):
return all([c not in df.columns for c in PredictParams.__fields__.keys()])

@staticmethod
def maybe_from_pandas_df(df: pd.DataFrame):
# only first row is used
if df.shape[0] == 0:
return None
if PredictParams.has_param_cols(df):
# no params in df
return None
param_row = df.iloc[0].to_dict()
try:
return PredictParams(**param_row)
except Exception as e:
print(f'Exception trying to parse prediction parameters for param row:'
f' \n{param_row} \n', e)
return None


def __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings):
Expand Down Expand Up @@ -144,9 +189,39 @@ def __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_fea
get_embeddings=get_embeddings
)

def __db_endpoint_predict__(pipe,data):
"""
1) parse pred params from first row maybe
2) serialize/deserialize img
"""
print("CUSOTM NLU MODE!")
print(data.columns)
params = PredictParams.maybe_from_pandas_df(data)
if params:
params = params.dict()
else:
params = {}
files = []
if 'file' in data.columns and 'file_type' in data.columns:
print("DETECTED FILE COLS")
skip_first = PredictParams.has_param_cols(data)
for i, row in data.iterrows():
print(f"DESERIALIZING {row.file_type} file {row.file}")
if i == 0 and skip_first:
continue
file_name = f'file{i}.{row.file_type}'
files.append(file_name)
deserialize(row.file, file_name)
data = files

if params:
return __predict__(pipe, data, **params, normal_pred_on_db=True)
else:
# no params detect, we call again with default params
return __predict__(pipe, data, **PredictParams().dict(),normal_pred_on_db=True)

def __predict__(pipe, data, output_level, positions, keep_stranger_features, metadata, multithread,
drop_irrelevant_cols, return_spark_df, get_embeddings):
drop_irrelevant_cols, return_spark_df, get_embeddings, normal_pred_on_db=False):
'''
Annotates a Pandas Dataframe/Pandas Series/Numpy Array/Spark DataFrame/Python List strings /Python String
:param data: Data to predict on
Expand All @@ -160,6 +235,9 @@ def __predict__(pipe, data, output_level, positions, keep_stranger_features, met
:return:
'''

if 'DB_ENDPOINT_ENV' in os.environ and not normal_pred_on_db:
return __db_endpoint_predict__(pipe,data)

if output_level == '' and not pipe.has_table_qa_models:
# Default sentence level for all components
if pipe.has_nlp_components and not PipeUtils.contains_t5_or_gpt(
Expand Down Expand Up @@ -188,66 +266,66 @@ def __predict__(pipe, data, output_level, positions, keep_stranger_features, met
else:
pipe.fit()

pipe.__configure_light_pipe_usage__(DataConversionUtils.size_of(data), multithread)
pipe.__configure_light_pipe_usage__(DataConversionUtils.size_of(data), multithread)

if pipe.contains_ocr_components and pipe.contains_audio_components:
""" Idea:
Expect Array of Paths
For every path classify file ending and use it to correctly handle Img or Audio stuff
"""
raise Exception('Cannot mix Audio and OCR components in a Pipe?')
if pipe.contains_ocr_components and pipe.contains_audio_components:
""" Idea:
Expect Array of Paths
For every path classify file ending and use it to correctly handle Img or Audio stuff
"""
raise Exception('Cannot mix Audio and OCR components in a Pipe?')

if pipe.contains_audio_components:
return __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
if pipe.contains_audio_components:
return __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)

if pipe.contains_ocr_components:
# Ocr processing
try:
return __predict_ocr_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU OCR pipe")
if return_spark_df:
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
elif not get_embeddings and multithread or pipe.prefer_light:
# In Some scenarios we prefer light, because Bugs in ChunkMapper...
# Try Multithreaded with Fallback vanilla as option. No Embeddings in this mode
try:
return predict_multi_threaded_light_pipe(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
if pipe.contains_ocr_components:
# Ocr processing
try:
return __predict_ocr_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU OCR pipe")
if return_spark_df:
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
elif not get_embeddings and multithread or pipe.prefer_light:
# In Some scenarios we prefer light, because Bugs in ChunkMapper...
# Try Multithreaded with Fallback vanilla as option. No Embeddings in this mode
try:
return predict_multi_threaded_light_pipe(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)


except Exception as err:
logger.warning(
f"Multithreaded mode with Light pipeline failed. trying to predict again with non multithreaded mode, "
f"err={err}")
except Exception as err:
logger.warning(
f"Multithreaded mode with Light pipeline failed. trying to predict again with non multithreaded mode, "
f"err={err}")
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
else:
# Standard predict with no fallback
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
else:
# Standard predict with no fallback
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")


def debug_print_pipe_cols(pipe):
for c in pipe.components:
print(f'{c.spark_input_column_names}->{c.name}->{c.spark_output_column_names}')
def debug_print_pipe_cols(pipe):
for c in pipe.components:
print(f'{c.spark_input_column_names}->{c.name}->{c.spark_output_column_names}')

0 comments on commit 3ec7179

Please sign in to comment.