Skip to content

Commit

Permalink
Merge pull request #83 from Filimoa/png-bug
Browse files Browse the repository at this point in the history
PNG Bug
  • Loading branch information
Filimoa authored Nov 13, 2024
2 parents 4b054b8 + dd53969 commit 385882e
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 41 deletions.
Empty file added src/cookbooks/images.ipynb
Empty file.
116 changes: 113 additions & 3 deletions src/openparse/processing/basic_transforms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
import base64
import io
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Literal
from typing import Dict, List, Literal, Type, TypeVar

from openparse.schemas import Bbox, Node, TextElement
from PIL import Image

from openparse.schemas import Bbox, ImageElement, Node, TableElement, TextElement

E = TypeVar("E", TextElement, ImageElement, TableElement)


def get_elements_of_type(nodes: List[Node], element_type: Type[E]) -> List[E]:
elements: List[E] = []
for node in nodes:
for element in node.elements:
if isinstance(element, element_type):
elements.append(element)
return elements


class ProcessingStep(ABC):
Expand All @@ -14,6 +29,96 @@ def process(self, nodes: List[Node]) -> List[Node]:
raise NotImplementedError("Subclasses must implement this method.")


class CombineSlicedImages(ProcessingStep):
"""
PDF will slice images into multiple pieces if they are too large. This combines them back together.
"""

def _combine_images_in_group(
self, image_elements: List[ImageElement]
) -> ImageElement:
"""Combine a list of ImageElements into a single ImageElement."""
if not image_elements:
raise ValueError("No images to combine.")

images = []
for node in image_elements:
image_data = base64.b64decode(node.image)
image = Image.open(io.BytesIO(image_data))
# image = image.rotate(180)
images.append(image)

# Determine the width and total height of the final image
width = max(img.width for img in images)
total_height = sum(img.height for img in images)

# Create a new blank image
new_image = Image.new("RGB", (width, total_height))

# Paste images one below the other
y_offset = 0
for img in images:
new_image.paste(img, (0, y_offset))
y_offset += img.height

# Save or encode the final image
buffered = io.BytesIO()
new_image.save(buffered, format="PNG")
final_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")

return ImageElement(
bbox=image_elements[0].bbox,
image=final_base64,
image_mimetype="image/png",
text="",
)

def _group_overlapping_images(
self, image_elements: List[ImageElement], buffer: float = 1.0
) -> List[List[ImageElement]]:
"""Group images that overlap or are adjacent."""
groups = []
used = set()

for i, elem1 in enumerate(image_elements):
if i in used:
continue
group = [elem1]
used.add(i)
queue = [elem1]
while queue:
current = queue.pop()
for j, elem2 in enumerate(image_elements):
if j in used:
continue
if current.overlaps(elem2, buffer=buffer):
group.append(elem2)
used.add(j)
queue.append(elem2)
groups.append(group)
return groups

def process(self, nodes: List[Node]) -> List[Node]:
nodes_by_page: Dict[int, List[Node]] = defaultdict(list)
for node in nodes:
pages = {element.bbox.page for element in node.elements}
for page in pages:
nodes_by_page[page].append(node)

new_nodes = []
for page, page_nodes in nodes_by_page.items():
image_nodes = [e for e in page_nodes if e.variant == {"image"}]
if image_nodes:
image_elements = get_elements_of_type(image_nodes, ImageElement)
text_elements = get_elements_of_type(page_nodes, TextElement)

combined_image = self._combine_images_in_group(image_elements)
new_nodes.append(Node(elements=(combined_image, *text_elements)))
else:
new_nodes.extend(page_nodes)
return new_nodes


class RemoveTextInsideTables(ProcessingStep):
"""
If we're using the table extraction pipeline, we need to remove text that is inside tables to avoid duplication.
Expand Down Expand Up @@ -162,7 +267,12 @@ def __init__(self, min_tokens: int):
self.min_tokens = min_tokens

def process(self, nodes: List[Node]) -> List[Node]:
return [node for node in nodes if node.tokens >= self.min_tokens]
res = []
for node in nodes:
if node.tokens <= self.min_tokens and "image" not in node.variant:
continue
res.append(node)
return res


class CombineNodesSpatially(ProcessingStep):
Expand Down
3 changes: 3 additions & 0 deletions src/openparse/processing/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
CombineBullets,
CombineHeadingsWithClosestText,
CombineNodesSpatially,
CombineSlicedImages,
ProcessingStep,
RemoveFullPageStubs,
RemoveMetadataElements,
Expand Down Expand Up @@ -69,6 +70,7 @@ class BasicIngestionPipeline(IngestionPipeline):
def __init__(self):
self.transformations = [
RemoveTextInsideTables(),
CombineSlicedImages(),
RemoveFullPageStubs(max_area_pct=0.35),
# mostly aimed at combining bullets and weird formatting
CombineNodesSpatially(
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(

self.transformations = [
RemoveTextInsideTables(),
CombineSlicedImages(),
RemoveFullPageStubs(max_area_pct=0.35),
# mostly aimed at combining bullets and weird formatting
CombineNodesSpatially(
Expand Down
29 changes: 26 additions & 3 deletions src/openparse/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ class ImageElement(BaseModel):
def embed_text(self) -> str:
if self._embed_text:
return self._embed_text

return self.text

@cached_property
Expand All @@ -381,9 +380,20 @@ def is_at_similar_height(
error_margin: float = 1,
) -> bool:
y_distance = abs(self.bbox.y1 - other.bbox.y1)

return y_distance <= error_margin

def overlaps(self, other: "ImageElement", buffer: float = 1.0) -> bool:
"""Check if this image overlaps or is adjacent to another image, considering a buffer."""
if self.bbox.page != other.bbox.page:
return False

return not (
self.bbox.x1 + buffer < other.bbox.x0 - buffer
or self.bbox.x0 - buffer > other.bbox.x1 + buffer
or self.bbox.y1 + buffer < other.bbox.y0 - buffer
or self.bbox.y0 - buffer > other.bbox.y1 + buffer
)


#############
### NODES ###
Expand Down Expand Up @@ -641,7 +651,20 @@ def _repr_markdown_(self):
"""
When called in a Jupyter environment, this will display the node as Markdown, which Jupyter will then render as HTML.
"""
return self.text
markdown_parts = []
for element in self.elements:
if element.variant == NodeVariant.TEXT:
markdown_parts.append(element.text)
elif element.variant == NodeVariant.IMAGE:
image_data = element.image
mime_type = element.image_mimetype
if mime_type == "unknown":
mime_type = "image/png"
markdown_image = f"![Image](data:{mime_type};base64,{image_data})"
markdown_parts.append(markdown_image)
elif element.variant == NodeVariant.TABLE:
markdown_parts.append(element.text)
return "\n\n".join(markdown_parts)

def __add__(self, other: "Node") -> "Node":
"""
Expand Down
Loading

0 comments on commit 385882e

Please sign in to comment.