Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions doctr/models/kie_predictor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def forward(
) -> Document:
# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
raise ValueError(
"incorrect input shape: all pages are expected to be multi-channel 2D images."
)

origin_page_shapes = [page.shape[:2] for page in pages]

Expand All @@ -82,22 +84,30 @@ def forward(

# Detect document rotation and rotate pages
seg_maps = [
np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
np.uint8
)
np.where(
np.expand_dims(np.amax(out_map, axis=-1), axis=-1)
> kwargs.get("bin_thresh", 0.3),
255,
0,
).astype(np.uint8)
for out_map in out_maps
]
if self.detect_orientation:
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
general_pages_orientations, origin_pages_orientations = (
self._get_orientations(pages, seg_maps)
)
orientations = [
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
{"value": orientation_page, "confidence": None}
for orientation_page in origin_pages_orientations
]
else:
orientations = None
general_pages_orientations = None
origin_pages_orientations = None
if self.straighten_pages:
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
pages = self._straighten_pages(
pages, seg_maps, general_pages_orientations, origin_pages_orientations
)
# update page shapes after straightening
origin_page_shapes = [page.shape[:2] for page in pages]

Expand Down Expand Up @@ -130,37 +140,58 @@ def forward(
crop_orientations: Any = {}
if not self.assume_straight_pages:
for class_name in dict_loc_preds.keys():
crops[class_name], dict_loc_preds[class_name], word_orientations = self._rectify_crops(
crops[class_name], dict_loc_preds[class_name]
crops[class_name], dict_loc_preds[class_name], word_orientations = (
self._rectify_crops(crops[class_name], dict_loc_preds[class_name])
)
crop_orientations[class_name] = [
{"value": orientation[0], "confidence": orientation[1]} for orientation in word_orientations
{"value": orientation[0], "confidence": orientation[1]}
for orientation in word_orientations
]

# Identify character sequences
word_preds = {
k: self.reco_predictor([crop for page_crops in crop_value for crop in page_crops], **kwargs)
k: self.reco_predictor(
[crop for page_crops in crop_value for crop in page_crops], **kwargs
)
for k, crop_value in crops.items()
}
if not crop_orientations:
crop_orientations = {k: [{"value": 0, "confidence": None} for _ in word_preds[k]] for k in word_preds}
crop_orientations = {
k: [{"value": 0, "confidence": None} for _ in word_preds[k]]
for k in word_preds
}

boxes: dict = {}
text_preds: dict = {}
word_crop_orientations: dict = {}
for class_name in dict_loc_preds.keys():
boxes[class_name], text_preds[class_name], word_crop_orientations[class_name] = self._process_predictions(
dict_loc_preds[class_name], word_preds[class_name], crop_orientations[class_name]
(
boxes[class_name],
text_preds[class_name],
word_crop_orientations[class_name],
) = self._process_predictions(
dict_loc_preds[class_name],
word_preds[class_name],
crop_orientations[class_name],
)

boxes_per_page: list[dict] = invert_data_structure(boxes) # type: ignore[assignment]
objectness_scores_per_page: list[dict] = invert_data_structure(objectness_scores) # type: ignore[assignment]
objectness_scores_per_page: list[dict] = invert_data_structure(
objectness_scores
) # type: ignore[assignment]
text_preds_per_page: list[dict] = invert_data_structure(text_preds) # type: ignore[assignment]
crop_orientations_per_page: list[dict] = invert_data_structure(word_crop_orientations) # type: ignore[assignment]
crop_orientations_per_page: list[dict] = invert_data_structure(
word_crop_orientations
) # type: ignore[assignment]

if self.detect_language:
languages = [get_language(self.get_text(text_pred)) for text_pred in text_preds_per_page]
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
languages = [
get_language(self.get_text(text_pred))
for text_pred in text_preds_per_page
]
languages_dict = [
{"value": lang[0], "confidence": lang[1]} for lang in languages
]
else:
languages_dict = None

Expand All @@ -178,8 +209,6 @@ def forward(

@staticmethod
def get_text(text_pred: dict) -> str:
text = []
for value in text_pred.values():
text += [item[0] for item in value]

# Build list of all item[0] in one pass, avoid repeated list concatenations
text = [item[0] for value in text_pred.values() for item in value]
return " ".join(text)