Skip to content

Commit d9b6cb0

Browse files
committed
refactor: refatoring updating image_request.py and updating tests
1 parent e9ce2d4 commit d9b6cb0

File tree

5 files changed

+346
-414
lines changed

5 files changed

+346
-414
lines changed

src/aws/osml/model_runner/api/image_request.py

+154-95
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright 2023-2024 Amazon.com, Inc. or its affiliates.
22

33
import logging
4+
from dataclasses import dataclass, field
45
from json import dumps, loads
56
from typing import Any, Dict, List, Optional
67

78
import shapely.geometry
89
import shapely.wkt
10+
from dacite import from_dict
911
from shapely.geometry.base import BaseGeometry
1012

1113
from aws.osml.model_runner.common import (
@@ -21,148 +23,204 @@
2123

2224
from .inference import ModelInvokeMode
2325
from .request_utils import shared_properties_are_valid
24-
from .sink import SinkType
26+
from .sink import VALID_SYNC_TYPES, SinkType
2527

2628
logger = logging.getLogger(__name__)
2729

2830

29-
class ImageRequest(object):
31+
@dataclass
32+
class ImageRequest:
3033
"""
3134
Request for the Model Runner to process an image.
3235
33-
This class contains the attributes that make up an image processing request along with
36+
This class contains the attributes that make up an image processing request, along with
3437
constructors and factory methods used to create these requests from common constructs.
35-
"""
3638
37-
def __init__(self, *initial_data: Dict[str, Any], **kwargs: Any):
38-
"""
39-
This constructor allows users to create these objects using a combination of dictionaries
40-
and keyword arguments.
39+
Attributes:
40+
job_id: The unique identifier for the image processing job.
41+
image_id: A combined identifier for the image, usually composed of the job ID and image URL.
42+
image_url: The URL location of the image to be processed.
43+
image_read_role: The IAM role used to read the image from the provided URL.
44+
outputs: A list of output configurations where results should be stored.
45+
model_name: The name of the model to use for image processing.
46+
model_invoke_mode: The mode in which the model is invoked, such as synchronous or asynchronous.
47+
tile_size: Dimensions of the tiles into which the image is split for processing.
48+
tile_overlap: Overlap between tiles, defined in dimensions.
49+
tile_format: The format of the tiles (e.g., NITF, GeoTIFF).
50+
tile_compression: Compression type to use for the tiles (e.g., None, JPEG).
51+
model_invocation_role: IAM role assumed for invoking the model.
52+
feature_properties: Additional properties to include in the feature processing.
53+
roi: Region of interest within the image, defined as a geometric shape.
54+
post_processing: List of post-processing steps to apply to the features detected.
55+
"""
4156

42-
:param initial_data: Dict[str, Any] = dictionaries that contain attributes/values that map to this class's
43-
attributes
44-
:param kwargs: Any = keyword arguments provided on the constructor to set specific attributes
45-
"""
46-
default_post_processing = [
57+
job_id: str = ""
58+
image_id: str = ""
59+
image_url: str = ""
60+
image_read_role: str = ""
61+
outputs: List[Dict[str, Any]] = field(default_factory=list)
62+
model_name: str = ""
63+
model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE
64+
tile_size: ImageDimensions = (1024, 1024)
65+
tile_overlap: ImageDimensions = (50, 50)
66+
tile_format: str = ImageFormats.NITF.value
67+
tile_compression: str = ImageCompression.NONE.value
68+
model_invocation_role: str = ""
69+
feature_properties: List[Dict[str, Any]] = field(default_factory=list)
70+
roi: Optional[BaseGeometry] = None
71+
post_processing: List[MRPostProcessing] = field(
72+
default_factory=lambda: [
4773
MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS())
4874
]
75+
)
4976

50-
self.job_id: str = ""
51-
self.image_id: str = ""
52-
self.image_url: str = ""
53-
self.image_read_role: str = ""
54-
self.outputs: List[dict] = []
55-
self.model_name: str = ""
56-
self.model_invoke_mode: ModelInvokeMode = ModelInvokeMode.NONE
57-
self.tile_size: ImageDimensions = (1024, 1024)
58-
self.tile_overlap: ImageDimensions = (50, 50)
59-
self.tile_format: ImageFormats = ImageFormats.NITF
60-
self.tile_compression: ImageCompression = ImageCompression.NONE
61-
self.model_invocation_role: str = ""
62-
self.feature_properties: List[dict] = []
63-
self.roi: Optional[BaseGeometry] = None
64-
self.post_processing: List[MRPostProcessing] = default_post_processing
65-
66-
for dictionary in initial_data:
67-
for key in dictionary:
68-
setattr(self, key, dictionary[key])
69-
for key in kwargs:
70-
setattr(self, key, kwargs[key])
77+
@staticmethod
78+
def from_external_message(image_request: Dict[str, Any]) -> "ImageRequest":
79+
"""
80+
Constructs an ImageRequest from a dictionary that represents an external message.
81+
82+
:param image_request: Dictionary of values from the decoded JSON request.
83+
:return: ImageRequest instance.
84+
"""
85+
properties: Dict[str, Any] = {
86+
"job_id": image_request.get("jobId", ""),
87+
"image_url": image_request.get("imageUrls", [""])[0],
88+
"image_id": f"{image_request.get('jobId', '')}:{image_request.get('imageUrls', [''])[0]}",
89+
"image_read_role": image_request.get("imageReadRole", ""),
90+
"model_name": image_request["imageProcessor"]["name"],
91+
"model_invoke_mode": ImageRequest._parse_model_invoke_mode(image_request["imageProcessor"].get("type")),
92+
"model_invocation_role": image_request["imageProcessor"].get("assumedRole", ""),
93+
"tile_size": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileSize")),
94+
"tile_overlap": ImageRequest._parse_tile_dimension(image_request.get("imageProcessorTileOverlap")),
95+
"tile_format": ImageRequest._parse_tile_format(image_request.get("imageProcessorTileFormat")),
96+
"tile_compression": ImageRequest._parse_tile_compression(image_request.get("imageProcessorTileCompression")),
97+
"roi": ImageRequest._parse_roi(image_request.get("regionOfInterest")),
98+
"outputs": ImageRequest._parse_outputs(image_request),
99+
"feature_properties": image_request.get("featureProperties", []),
100+
"post_processing": ImageRequest._parse_post_processing(image_request.get("postProcessing")),
101+
}
102+
return from_dict(ImageRequest, properties)
71103

72104
@staticmethod
73-
def from_external_message(image_request: Dict[str, Any]):
105+
def _parse_tile_dimension(value: Optional[str]) -> ImageDimensions:
74106
"""
75-
This method is used to construct an ImageRequest given a dictionary reconstructed from the
76-
JSON representation of a request that appears on the Image Job Queue. The structure of
77-
that message is generally governed by AWS API best practices and may evolve over time as
78-
the public APIs for this service mature.
107+
Converts a string value to a tuple of integers representing tile dimensions.
79108
80-
:param image_request: Dict[str, Any] = dictionary of values from the decoded JSON request
109+
:param value: String value representing tile dimension.
110+
:return: Tuple of integers as tile dimensions.
111+
"""
112+
return (int(value), int(value)) if value else None
81113

82-
:return: the ImageRequest
114+
@staticmethod
115+
def _parse_roi(roi: Optional[str]) -> Optional[BaseGeometry]:
83116
"""
84-
properties: Dict[str, Any] = {}
85-
if "imageProcessorTileSize" in image_request:
86-
tile_dimension = int(image_request["imageProcessorTileSize"])
87-
properties["tile_size"] = (tile_dimension, tile_dimension)
117+
Parses the region of interest from a WKT string.
88118
89-
if "imageProcessorTileOverlap" in image_request:
90-
overlap_dimension = int(image_request["imageProcessorTileOverlap"])
91-
properties["tile_overlap"] = (overlap_dimension, overlap_dimension)
119+
:param roi: WKT string representing the region of interest.
120+
:return: Parsed BaseGeometry object or None.
121+
"""
122+
return shapely.wkt.loads(roi) if roi else None
92123

93-
if "imageProcessorTileFormat" in image_request:
94-
properties["tile_format"] = image_request["imageProcessorTileFormat"]
124+
@staticmethod
125+
def _parse_tile_format(tile_format: Optional[str]) -> Optional[ImageFormats]:
126+
"""
127+
Parses the region desired tile format to use for processing.
95128
96-
if "imageProcessorTileCompression" in image_request:
97-
properties["tile_compression"] = image_request["imageProcessorTileCompression"]
129+
:param tile_format: String representing the tile format to use.
130+
:return: Parsed ImageFormats object or ImageFormats.NITF.
131+
"""
132+
return ImageFormats[tile_format].value if tile_format else ImageFormats.NITF.value
98133

99-
properties["job_id"] = image_request["jobId"]
134+
@staticmethod
135+
def _parse_tile_compression(tile_compression: Optional[str]) -> Optional[ImageCompression]:
136+
"""
137+
Parses the region desired tile compression format to use for processing.
100138
101-
properties["image_url"] = image_request["imageUrls"][0]
102-
properties["image_id"] = image_request["jobId"] + ":" + properties["image_url"]
103-
if "imageReadRole" in image_request:
104-
properties["image_read_role"] = image_request["imageReadRole"]
139+
:param tile_compression: String representing the tile compression format to use.
140+
:return: Parsed ImageFormats object or ImageCompression.NONE.
141+
"""
142+
return ImageCompression[tile_compression].value if tile_compression else ImageCompression.NONE.value
105143

106-
properties["model_name"] = image_request["imageProcessor"]["name"]
107-
properties["model_invoke_mode"] = image_request["imageProcessor"]["type"]
108-
if "assumedRole" in image_request["imageProcessor"]:
109-
properties["model_invocation_role"] = image_request["imageProcessor"]["assumedRole"]
144+
@staticmethod
145+
def _parse_model_invoke_mode(model_invoke_mode: Optional[str]) -> Optional[ModelInvokeMode]:
146+
"""
147+
Parses the region desired tile compression format to use for processing.
110148
111-
if "regionOfInterest" in image_request:
112-
properties["roi"] = shapely.wkt.loads(image_request["regionOfInterest"])
149+
:param model_invoke_mode: String representing the tile compression format to use.
150+
:return: Parsed ModelInvokeMode object or ModelInvokeMode.SM_ENDPOINT.
151+
"""
152+
return ModelInvokeMode[model_invoke_mode] if model_invoke_mode else ModelInvokeMode.SM_ENDPOINT
113153

114-
# Support explicit outputs
154+
@staticmethod
155+
def _parse_outputs(image_request: Dict[str, Any]) -> List[Dict[str, Any]]:
156+
"""
157+
Parses the output configuration from the image request, including support for legacy inputs.
158+
159+
:param image_request: Dictionary of image request attributes.
160+
:return: List of output configurations.
161+
"""
115162
if image_request.get("outputs"):
116-
properties["outputs"] = image_request["outputs"]
117-
# Support legacy image request
118-
elif image_request.get("outputBucket") and image_request.get("outputPrefix"):
119-
properties["outputs"] = [
163+
return image_request["outputs"]
164+
165+
# Support legacy image request fields: outputBucket and outputPrefix
166+
if image_request.get("outputBucket") and image_request.get("outputPrefix"):
167+
return [
120168
{
121169
"type": SinkType.S3.value,
122170
"bucket": image_request["outputBucket"],
123171
"prefix": image_request["outputPrefix"],
124172
}
125173
]
126-
if image_request.get("featureProperties"):
127-
properties["feature_properties"] = image_request["featureProperties"]
128-
if image_request.get("postProcessing"):
129-
image_request["postProcessing"] = loads(
130-
dumps(image_request["postProcessing"])
131-
.replace("algorithmType", "algorithm_type")
132-
.replace("iouThreshold", "iou_threshold")
133-
.replace("skipBoxThreshold", "skip_box_threshold")
134-
)
135-
properties["post_processing"] = deserialize_post_processing_list(image_request.get("postProcessing"))
136-
137-
return ImageRequest(properties)
174+
# No outputs were defined in the request
175+
logger.warning("No output syncs were present in this request.")
176+
return []
177+
178+
@staticmethod
179+
def _parse_post_processing(post_processing: Optional[Dict[str, Any]]) -> List[MRPostProcessing]:
180+
"""
181+
Deserializes and cleans up post-processing data.
182+
183+
:param post_processing: Dictionary of post-processing configurations.
184+
:return: List of MRPostProcessing instances.
185+
"""
186+
if not post_processing:
187+
return [MRPostProcessing(step=MRPostprocessingStep.FEATURE_DISTILLATION, algorithm=FeatureDistillationNMS())]
188+
cleaned_post_processing = loads(
189+
dumps(post_processing)
190+
.replace("algorithmType", "algorithm_type")
191+
.replace("iouThreshold", "iou_threshold")
192+
.replace("skipBoxThreshold", "skip_box_threshold")
193+
)
194+
return deserialize_post_processing_list(cleaned_post_processing)
138195

139196
def is_valid(self) -> bool:
140197
"""
141-
Check to see if this request contains required attributes and meaningful values
198+
Validates whether the ImageRequest instance has all required attributes.
142199
143-
:return: bool = True if the request contains all the mandatory attributes with acceptable values,
144-
False otherwise
200+
:return: True if valid, False otherwise.
145201
"""
146202
if not shared_properties_are_valid(self):
147203
logger.error("Invalid shared properties in ImageRequest")
148204
return False
149-
150-
if not self.job_id or not self.outputs:
151-
logger.error("Missing job id or outputs properties in ImageRequest")
205+
if not self.job_id:
206+
logger.error("Missing job id in ImageRequest")
152207
return False
153-
154-
num_feature_detection_options = len(self.get_feature_distillation_option())
155-
if num_feature_detection_options > 1:
156-
logger.error(f"{num_feature_detection_options} feature distillation options in ImageRequest")
208+
if len(self.get_feature_distillation_option()) > 1:
209+
logger.error("Multiple feature distillation options in ImageRequest")
157210
return False
158-
211+
if len(self.outputs) > 0:
212+
for output in self.outputs:
213+
sink_type = output.get("type")
214+
if sink_type not in VALID_SYNC_TYPES:
215+
logger.error(f"Invalid sink type '{sink_type}' in ImageRequest")
216+
return False
159217
return True
160218

161219
def get_shared_values(self) -> Dict[str, Any]:
162220
"""
163-
Returns a formatted dict that contains the properties of an image
221+
Retrieves a dictionary of shared values related to the image.
164222
165-
:return: Dict[str, Any] = the properties of an image
223+
:return: Dictionary of shared image properties.
166224
"""
167225
return {
168226
"image_id": self.image_id,
@@ -180,8 +238,9 @@ def get_shared_values(self) -> Dict[str, Any]:
180238

181239
def get_feature_distillation_option(self) -> List[FeatureDistillationAlgorithm]:
182240
"""
183-
Parses the post-processing property and extracts the relevant feature distillation selection, if present
184-
:return:
241+
Extracts the feature distillation options from the post-processing configuration.
242+
243+
:return: List of FeatureDistillationAlgorithm instances.
185244
"""
186245
return [
187246
op.algorithm

0 commit comments

Comments
 (0)