Skip to content

Commit

Permalink
Eliminate redundant code blocks in modules and stages (#1123)
Browse files Browse the repository at this point in the history
Eliminated redundant code blocks in both modules and stages, introducing controllers to enhance maintainability, and subsequently updated tests to align with these changes.

- file_to_df
- filter_detections
- mlflow_model_writer
- serializer
- write_to_file

Fixed preserve columns property issue.
closes #965 #1074

Authors:
  - Bhargav Suryadevara (https://github.com/bsuryadevara)
  - Michael Demoret (https://github.com/mdemoret-nv)

Approvers:
  - Christopher Harris (https://github.com/cwharris)
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1123
  • Loading branch information
bsuryadevara authored Sep 6, 2023
1 parent 214232c commit be608fd
Show file tree
Hide file tree
Showing 36 changed files with 1,590 additions and 1,666 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@

# When segment modules are imported, they're added to the module registry.
# To avoid flake8 warnings about unused code, the noqa flag is used during import.
from dfp.modules import dfp_monitor
from dfp.modules import dfp_split_users
from dfp.modules import dfp_data_prep
from dfp.modules import dfp_deployment
from dfp.modules import dfp_inference
from dfp.modules import dfp_inference_pipe
from dfp.modules import dfp_monitor
from dfp.modules import dfp_postprocessing
from dfp.modules import dfp_preproc
from dfp.modules import dfp_rolling_window
from dfp.modules import dfp_split_users
from dfp.modules import dfp_training
from dfp.modules import dfp_inference_pipe
from dfp.modules import dfp_training_pipe
from dfp.modules import dfp_deployment

__all__ = [
"dfp_monitor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,23 @@ def dfp_inference(builder: mrc.Builder):

model_name_formatter = config.get("model_name_formatter", None)
fallback_user = config.get("fallback_username", "generic_user")

model_fetch_timeout = config.get("model_fetch_timeout", 1.0)
timestamp_column_name = config.get("timestamp_column_name", "timestamp")

client = MlflowClient()
model_manager = ModelManager(model_name_formatter=model_name_formatter)

model_manager = None

def get_model(user: str) -> ModelCache:
nonlocal model_manager

if not model_manager:
model_manager = ModelManager(model_name_formatter=model_name_formatter)

return model_manager.load_user_model(client, user_id=user, fallback_user_ids=[fallback_user])
return model_manager.load_user_model(client,
user_id=user,
fallback_user_ids=[fallback_user],
timeout=model_fetch_timeout)

def process_task(control_message: ControlMessage):
start_time = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from mrc.core import operators as ops
from tqdm import tqdm

from morpheus.controllers.monitor_controller import MonitorController
from morpheus.utils.module_ids import MORPHEUS_MODULE_NAMESPACE
from morpheus.utils.module_utils import register_module
from morpheus.utils.monitor_utils import MonitorController
from morpheus.utils.monitor_utils import MorpheusTqdm
from morpheus.utils.monitor_utils import SilentMorpheusTqdm

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import mrc
from mrc.core import operators as ops
from sklearn.model_selection import train_test_split

import cudf

Expand Down Expand Up @@ -87,8 +88,16 @@ def on_data(control_message: ControlMessage):
# Only train on the feature columns
train_df = final_df[final_df.columns.intersection(feature_columns)]

validation_df = None
run_validation = False

# Split into training and validation sets
if validation_size > 0.0:
train_df, validation_df = train_test_split(train_df, test_size=validation_size, shuffle=False)
run_validation = True

logger.debug("Training AE model for user: '%s'...", user_id)
model.fit(train_df, epochs=epochs)
model.fit(train_df, epochs=epochs, val_data=validation_df, run_validation=run_validation)
logger.debug("Training AE model for user: '%s'... Complete.", user_id)

dfp_mm = DFPMessageMeta(cudf.from_pandas(final_df), user_id=user_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DFPFileBatcherStage(SinglePortStage):
Parameters
----------
c : `morpheus.config.Config`
config : `morpheus.config.Config`
Pipeline configuration instance.
date_conversion_func : callable
A function that takes a file object and returns a `datetime` object representing the date of the file.
Expand All @@ -69,14 +69,14 @@ class DFPFileBatcherStage(SinglePortStage):
"""

def __init__(self,
c: Config,
config: Config,
date_conversion_func: typing.Callable[[fsspec.core.OpenFile], datetime],
period: str = "D",
sampling_rate_s: typing.Optional[int] = None,
start_time: datetime = None,
end_time: datetime = None,
sampling: typing.Union[str, float, int, None] = None):
super().__init__(c)
super().__init__(config)

self._date_conversion_func = date_conversion_func
self._period = period
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,62 +13,24 @@
# limitations under the License.
"""Stage for converting fsspec file objects to a DataFrame."""

import hashlib
import json
import logging
import os
import time
import typing
from functools import partial

import fsspec
import mrc
import pandas as pd
from mrc.core import operators as ops

from morpheus.common import FileTypes
from morpheus.config import Config
from morpheus.io.deserializers import read_file_to_df
from morpheus.controllers.file_to_df_controller import FileToDFController
from morpheus.pipeline.preallocator_mixin import PreallocatorMixin
from morpheus.pipeline.single_port_stage import SinglePortStage
from morpheus.pipeline.stream_pair import StreamPair
from morpheus.utils.column_info import DataFrameInputSchema
from morpheus.utils.column_info import process_dataframe
from morpheus.utils.downloader import Downloader

logger = logging.getLogger(f"morpheus.{__name__}")


def _single_object_to_dataframe(file_object: fsspec.core.OpenFile,
schema: DataFrameInputSchema,
file_type: FileTypes,
filter_null: bool,
parser_kwargs: dict) -> pd.DataFrame:
retries = 0
df = None
while (retries < 2):
try:
with file_object as f:
df = read_file_to_df(f,
file_type,
filter_nulls=filter_null,
df_type="pandas",
parser_kwargs=parser_kwargs)

break
except Exception as e:
if (retries < 2):
logger.warning("Error fetching %s: %s\nRetrying...", file_object, e)
retries += 1

# Optimistaclly prep the dataframe (Not necessary since this will happen again in process_dataframe, but it
# increases performance significantly)
if (schema.prep_dataframe is not None):
df = schema.prep_dataframe(df)

return df


class DFPFileToDataFrameStage(PreallocatorMixin, SinglePortStage):
"""
Stage for converting fsspec file objects to a DataFrame, pre-processing the DataFrame according to `schema`, and
Expand Down Expand Up @@ -102,14 +64,12 @@ def __init__(self,
cache_dir: str = "./.cache/dfp"):
super().__init__(config)

self._schema = schema

self._file_type = file_type
self._filter_null = filter_null
self._parser_kwargs = {} if parser_kwargs is None else parser_kwargs
self._cache_dir = os.path.join(cache_dir, "file_cache")

self._downloader = Downloader()
self._controller = FileToDFController(schema=schema,
filter_null=filter_null,
file_type=file_type,
parser_kwargs=parser_kwargs,
cache_dir=cache_dir,
timestamp_column_name=config.ae.timestamp_column_name)

@property
def name(self) -> str:
Expand All @@ -124,103 +84,10 @@ def accepted_types(self) -> typing.Tuple:
"""Accepted input types."""
return (typing.Any, )

def _get_or_create_dataframe_from_batch(
self, file_object_batch: typing.Tuple[fsspec.core.OpenFiles, int]) -> typing.Tuple[pd.DataFrame, bool]:

if (not file_object_batch):
raise RuntimeError("No file objects to process")

file_list = file_object_batch[0]
batch_count = file_object_batch[1]

file_system: fsspec.AbstractFileSystem = file_list.fs

# Create a list of dictionaries that only contains the information we are interested in hashing. `ukey` just
# hashes all the output of `info()` which is perfect
hash_data = [{"ukey": file_system.ukey(file_object.path)} for file_object in file_list]

# Convert to base 64 encoding to remove - values
objects_hash_hex = hashlib.md5(json.dumps(hash_data, sort_keys=True).encode()).hexdigest()

batch_cache_location = os.path.join(self._cache_dir, "batches", f"{objects_hash_hex}.pkl")

# Return the cache if it exists
if (os.path.exists(batch_cache_location)):
output_df = pd.read_pickle(batch_cache_location)
output_df["batch_count"] = batch_count
output_df["origin_hash"] = objects_hash_hex

return (output_df, True)

# Cache miss
download_method = partial(_single_object_to_dataframe,
schema=self._schema,
file_type=self._file_type,
filter_null=self._filter_null,
parser_kwargs=self._parser_kwargs)

download_buckets = file_list

# Loop over dataframes and concat into one
try:
dfs = self._downloader.download(download_buckets, download_method)
except Exception:
logger.exception("Failed to download logs. Error: ", exc_info=True)
raise

if (dfs is None or len(dfs) == 0):
raise ValueError("No logs were downloaded")

output_df: pd.DataFrame = pd.concat(dfs)
output_df = process_dataframe(df_in=output_df, input_schema=self._schema)

# Finally sort by timestamp and then reset the index
output_df.sort_values(by=[self._config.ae.timestamp_column_name], inplace=True)

output_df.reset_index(drop=True, inplace=True)

# Save dataframe to cache future runs
os.makedirs(os.path.dirname(batch_cache_location), exist_ok=True)

try:
output_df.to_pickle(batch_cache_location)
except Exception:
logger.warning("Failed to save batch cache. Skipping cache for this batch.", exc_info=True)

output_df["batch_count"] = batch_count
output_df["origin_hash"] = objects_hash_hex

return (output_df, False)

def convert_to_dataframe(self, fsspec_batch: typing.Tuple[fsspec.core.OpenFiles, int]):
"""Converts a batch of fsspec objects to a DataFrame."""
if (not fsspec_batch):
return None

start_time = time.time()

try:

output_df, cache_hit = self._get_or_create_dataframe_from_batch(fsspec_batch)

duration = (time.time() - start_time) * 1000.0

if (output_df is not None and logger.isEnabledFor(logging.DEBUG)):
logger.debug("fsspec objects to DF complete. Rows: %s, Cache: %s, Duration: %s ms, Rate: %s rows/s",
len(output_df),
"hit" if cache_hit else "miss",
duration,
len(output_df) / (duration / 1000.0))

return output_df
except Exception:
logger.exception("Error while converting fsspec batch to DF.")
raise

def _build_single(self, builder: mrc.Builder, input_stream: StreamPair) -> StreamPair:
stream = builder.make_node(self.unique_name,
ops.map(self.convert_to_dataframe),
ops.on_completed(self._downloader.close))
ops.map(self._controller.convert_to_dataframe),
ops.on_completed(self._controller.close))
builder.make_edge(input_stream[0], stream)

return stream, pd.DataFrame
Loading

0 comments on commit be608fd

Please sign in to comment.