Skip to content

Commit

Permalink
Merge pull request #199 from JohnSnowLabs/binary-img-handling
Browse files Browse the repository at this point in the history
Binary img handling
  • Loading branch information
C-K-Loan committed Sep 10, 2023
2 parents bd2e257 + 3ec7179 commit 726041d
Show file tree
Hide file tree
Showing 5 changed files with 239 additions and 122 deletions.
6 changes: 6 additions & 0 deletions nlu/pipe/nlu_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self,
trained_mirror_anno: Optional[JslAnnoId] = None,
applicable_file_types: List[str] = None, # Used for OCR annotators to deduct applicable file types
is_trained: bool = True, # Set to true for trainable annotators
requires_binary_format: bool = False, # Set to true for OCR annotators that require binary image format
requires_image_format: bool = False, # Set to true for OCR annotators that require image format
is_visual_annotator: bool = False, # Set to true for OCR annotators that require image format
):
self.name = name
self.type = type
Expand Down Expand Up @@ -110,6 +113,9 @@ def __init__(self,
self.trained_mirror_anno = trained_mirror_anno
self.applicable_file_types = applicable_file_types
self.is_trained = is_trained
self.requires_binary_format = requires_binary_format
self.requires_image_format = requires_image_format
self.is_visual_annotator = is_visual_annotator

def set_metadata(self, jsl_anno_object: Union[AnnotatorApproach, AnnotatorModel],
nlu_ref: str,
Expand Down
3 changes: 2 additions & 1 deletion nlu/pipe/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def __init__(self):
self.has_span_classifiers = False
self.prefer_light = False
self.has_table_qa_models = False

self.requires_image_format = False
self.requires_binary_format = False
def add(self, component: NluComponent, nlu_reference=None, pretrained_pipe_component=False,
name_to_add='', idx=None):
'''
Expand Down
7 changes: 6 additions & 1 deletion nlu/pipe/utils/pipe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,13 @@ def add_metadata_to_pipe(pipe: NLUPipeline):

for c in pipe.components:
# Check for OCR componments
if c.jsl_anno_py_class in py_class_to_anno_id.keys():
if c.jsl_anno_py_class in py_class_to_anno_id.keys() or c.is_visual_annotator:
pipe.contains_ocr_components = True
if c.requires_image_format:
pipe.requires_image_format = True
if c.requires_binary_format:
pipe.requires_binary_format = True

# Check for licensed components
if c.license in [Licenses.ocr, Licenses.hc]:
pipe.has_licensed_components = True
Expand Down
209 changes: 155 additions & 54 deletions nlu/pipe/utils/predict_helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
from typing import Optional

import sparknlp
from pyspark.sql.functions import monotonically_increasing_id
Expand All @@ -8,9 +10,52 @@

logger = logging.getLogger('nlu')
from nlu.pipe.pipe_logic import PipeUtils
import pandas as pd
from nlu.pipe.utils.data_conversion_utils import DataConversionUtils

import pandas as pd
from pydantic import BaseModel


def serialize(img_path):
with open(img_path, 'rb') as img_file:
return img_file.read()


def deserialize(binary_image, path):
with open(path, 'wb') as img_file:
img_file.write(binary_image)


class PredictParams(BaseModel):
output_level: Optional[str] = ''
positions: Optional[bool] = False
keep_stranger_features: Optional[bool] = True
metadata: Optional[bool] = False
multithread: Optional[bool] = True
drop_irrelevant_cols: Optional[bool] = True
return_spark_df: Optional[bool] = False
get_embeddings: Optional[bool] = True

@staticmethod
def has_param_cols(df: pd.DataFrame):
return all([c not in df.columns for c in PredictParams.__fields__.keys()])

@staticmethod
def maybe_from_pandas_df(df: pd.DataFrame):
# only first row is used
if df.shape[0] == 0:
return None
if PredictParams.has_param_cols(df):
# no params in df
return None
param_row = df.iloc[0].to_dict()
try:
return PredictParams(**param_row)
except Exception as e:
print(f'Exception trying to parse prediction parameters for param row:'
f' \n{param_row} \n', e)
return None


def __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings):
Expand Down Expand Up @@ -74,8 +119,31 @@ def __predict_ocr_spark(pipe, data, output_level, positions, keep_stranger_featu
file_paths = OcrDataConversionUtils.glob_files_of_accepted_type(paths, accepted_file_types)
spark = sparknlp.start() # Fetches Spark Session that has already been licensed

data = pipe.vanilla_transformer_pipe.transform(spark.read.format("image").load(file_paths)).withColumn(
'origin_index', monotonically_increasing_id().alias('origin_index'))
# Some annos require `image` format, some will require `binary` format. We need to figure out which one is needed possible provide both
if pipe.requires_image_format and pipe.requires_binary_format:
from pyspark.sql.functions import regexp_replace
# Image & Binary formats required. We read as both and join the dfs
img_df = spark.read.format("image").load(file_paths).withColumn("modified_origin",
regexp_replace("image.origin", ":/{1,}", ":"))

# Read the files in binaryFile format
binary_df = spark.read.format("binaryFile").load(file_paths).withColumn("modified_path",
regexp_replace("path", ":/{1,}", ":"))

data = img_df.join(binary_df, img_df["modified_origin"] == binary_df["modified_path"]).drop('modified_path')

elif pipe.requires_image_format:
# only image format required
data = spark.read.format("image").load(file_paths)
elif pipe.requires_binary_format:
# only binary required
data = spark.read.format("binaryFile").load(file_paths)
else:
# fallback default
data = spark.read.format("binaryFile").load(file_paths)
data = data.withColumn('origin_index', monotonically_increasing_id().alias('origin_index'))

data = pipe.vanilla_transformer_pipe.transform(data)
return pipe.pythonify_spark_dataframe(data,
keep_stranger_features=keep_stranger_features,
output_metadata=metadata,
Expand Down Expand Up @@ -121,9 +189,39 @@ def __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_fea
get_embeddings=get_embeddings
)

def __db_endpoint_predict__(pipe,data):
"""
1) parse pred params from first row maybe
2) serialize/deserialize img
"""
print("CUSOTM NLU MODE!")
print(data.columns)
params = PredictParams.maybe_from_pandas_df(data)
if params:
params = params.dict()
else:
params = {}
files = []
if 'file' in data.columns and 'file_type' in data.columns:
print("DETECTED FILE COLS")
skip_first = PredictParams.has_param_cols(data)
for i, row in data.iterrows():
print(f"DESERIALIZING {row.file_type} file {row.file}")
if i == 0 and skip_first:
continue
file_name = f'file{i}.{row.file_type}'
files.append(file_name)
deserialize(row.file, file_name)
data = files

if params:
return __predict__(pipe, data, **params, normal_pred_on_db=True)
else:
# no params detect, we call again with default params
return __predict__(pipe, data, **PredictParams().dict(),normal_pred_on_db=True)

def __predict__(pipe, data, output_level, positions, keep_stranger_features, metadata, multithread,
drop_irrelevant_cols, return_spark_df, get_embeddings):
drop_irrelevant_cols, return_spark_df, get_embeddings, normal_pred_on_db=False):
'''
Annotates a Pandas Dataframe/Pandas Series/Numpy Array/Spark DataFrame/Python List strings /Python String
:param data: Data to predict on
Expand All @@ -137,6 +235,9 @@ def __predict__(pipe, data, output_level, positions, keep_stranger_features, met
:return:
'''

if 'DB_ENDPOINT_ENV' in os.environ and not normal_pred_on_db:
return __db_endpoint_predict__(pipe,data)

if output_level == '' and not pipe.has_table_qa_models:
# Default sentence level for all components
if pipe.has_nlp_components and not PipeUtils.contains_t5_or_gpt(
Expand Down Expand Up @@ -165,66 +266,66 @@ def __predict__(pipe, data, output_level, positions, keep_stranger_features, met
else:
pipe.fit()

pipe.__configure_light_pipe_usage__(DataConversionUtils.size_of(data), multithread)
pipe.__configure_light_pipe_usage__(DataConversionUtils.size_of(data), multithread)

if pipe.contains_ocr_components and pipe.contains_audio_components:
""" Idea:
Expect Array of Paths
For every path classify file ending and use it to correctly handle Img or Audio stuff
"""
raise Exception('Cannot mix Audio and OCR components in a Pipe?')
if pipe.contains_ocr_components and pipe.contains_audio_components:
""" Idea:
Expect Array of Paths
For every path classify file ending and use it to correctly handle Img or Audio stuff
"""
raise Exception('Cannot mix Audio and OCR components in a Pipe?')

if pipe.contains_audio_components:
return __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
if pipe.contains_audio_components:
return __predict_audio_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)

if pipe.contains_ocr_components:
# Ocr processing
try:
return __predict_ocr_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU OCR pipe")
if return_spark_df:
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
elif not get_embeddings and multithread or pipe.prefer_light:
# In Some scenarios we prefer light, because Bugs in ChunkMapper...
# Try Multithreaded with Fallback vanilla as option. No Embeddings in this mode
try:
return predict_multi_threaded_light_pipe(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
if pipe.contains_ocr_components:
# Ocr processing
try:
return __predict_ocr_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU OCR pipe")
if return_spark_df:
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
elif not get_embeddings and multithread or pipe.prefer_light:
# In Some scenarios we prefer light, because Bugs in ChunkMapper...
# Try Multithreaded with Fallback vanilla as option. No Embeddings in this mode
try:
return predict_multi_threaded_light_pipe(pipe, data, output_level, positions, keep_stranger_features,
metadata, drop_irrelevant_cols, get_embeddings=get_embeddings)


except Exception as err:
logger.warning(
f"Multithreaded mode with Light pipeline failed. trying to predict again with non multithreaded mode, "
f"err={err}")
except Exception as err:
logger.warning(
f"Multithreaded mode with Light pipeline failed. trying to predict again with non multithreaded mode, "
f"err={err}")
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features,
metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
else:
# Standard predict with no fallback
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")
else:
# Standard predict with no fallback
try:
return __predict_standard_spark(pipe, data, output_level, positions, keep_stranger_features, metadata,
drop_irrelevant_cols, return_spark_df, get_embeddings)
except Exception as err:
logger.warning(f"Predictions Failed={err}")
pipe.print_exception_err(err)
raise Exception("Failure to process data with NLU")


def debug_print_pipe_cols(pipe):
for c in pipe.components:
print(f'{c.spark_input_column_names}->{c.name}->{c.spark_output_column_names}')
def debug_print_pipe_cols(pipe):
for c in pipe.components:
print(f'{c.spark_input_column_names}->{c.name}->{c.spark_output_column_names}')
Loading

0 comments on commit 726041d

Please sign in to comment.