diff --git a/.github/actions/setup-root-poetry-environment/action.yml b/.github/actions/setup-root-poetry-environment/action.yml index 0733c432..51a61d62 100644 --- a/.github/actions/setup-root-poetry-environment/action.yml +++ b/.github/actions/setup-root-poetry-environment/action.yml @@ -33,7 +33,7 @@ runs: - if: steps.cached-poetry.outputs.cache-hit != 'true' run: | poetry lock --no-update - poetry install --no-interaction + poetry install --no-interaction --with pdf shell: bash - name: Clear lint cache diff --git a/encord_agents/core/pdf.py b/encord_agents/core/pdf.py new file mode 100644 index 00000000..e3c2b901 --- /dev/null +++ b/encord_agents/core/pdf.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import pymupdf + + +def extract_page( + pdf_path: Path, + page_number: int, +) -> Path: + target_file_path = pdf_path.with_name(f"{pdf_path.name}_{page_number}").with_suffix(".png") + # Open the PDF + doc = pymupdf.open(pdf_path) + + # Get the specified page + page = doc[page_number] # 0-based index + # Render page to an image (pixmap) + pix = page.get_pixmap() + # Save the image + pix.save(target_file_path) + + # Close the document + doc.close() + return target_file_path diff --git a/encord_agents/core/utils.py b/encord_agents/core/utils.py index 656d559b..11a49254 100644 --- a/encord_agents/core/utils.py +++ b/encord_agents/core/utils.py @@ -16,6 +16,7 @@ from encord_agents import __version__ from encord_agents.core.data_model import FrameData, LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs from encord_agents.core.settings import Settings +from encord_agents.exceptions import PrintableError from .video import get_frame @@ -204,6 +205,17 @@ def download_asset(storage_item: StorageItem, frame: int | None = None) -> Gener cv2.imwrite(frame_file.as_posix(), frame_content) file_path = frame_file + if storage_item.item_type == StorageItemType.PDF and frame is not None: + try: + from encord_agents.core.pdf import extract_page + + page_file_path = extract_page(file_path, frame) + except ImportError as e: + raise PrintableError( + "Trying to access a crop from a pdf. Please install encord-agents[pdf] to access pdf support" + ) from e + file_path = page_file_path + yield file_path diff --git a/encord_agents/gcp/dependencies.py b/encord_agents/gcp/dependencies.py index 21ec43da..b853200a 100644 --- a/encord_agents/gcp/dependencies.py +++ b/encord_agents/gcp/dependencies.py @@ -74,7 +74,7 @@ def ( return get_user_client() -def dep_single_frame(storage_item: StorageItem) -> NDArray[np.uint8]: +def dep_single_frame(storage_item: StorageItem, frame_data: FrameData) -> NDArray[np.uint8]: """ Dependency to inject the first frame of the underlying asset. @@ -103,7 +103,7 @@ def my_agent( Numpy array of shape [h, w, 3] RGB colors. """ - with download_asset(storage_item, frame=0) as asset: + with download_asset(storage_item, frame=frame_data.frame) as asset: img = cv2.cvtColor(cv2.imread(asset.as_posix()), cv2.COLOR_BGR2RGB) return np.asarray(img, dtype=np.uint8) diff --git a/poetry.lock b/poetry.lock index efa37089..937cc71e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2593,6 +2593,23 @@ pyyaml = "*" [package.extras] extra = ["pygments (>=2.12)"] +[[package]] +name = "pymupdf" +version = "1.25.5" +description = "A high performance Python library for data extraction, analysis, conversion & manipulation of PDF (and other) documents." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pymupdf-1.25.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cde4e1c9cfb09c0e1e9c2b7f4b787dd6bb34a32cfe141a4675e24af7c0c25dd3"}, + {file = "pymupdf-1.25.5-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:5a35e2725fae0ab57f058dff77615c15eb5961eac50ba04f41ebc792cd8facad"}, + {file = "pymupdf-1.25.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d94b800e9501929c42283d39bc241001dd87fdeea297b5cb40d5b5714534452f"}, + {file = "pymupdf-1.25.5-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ee22155d3a634642d76553204867d862ae1bdd9f7cf70c0797d8127ebee6bed5"}, + {file = "pymupdf-1.25.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6ed7fc25271004d6d3279c20a80cb2bb4cda3efa9f9088dcc07cd790eca0bc63"}, + {file = "pymupdf-1.25.5-cp39-abi3-win32.whl", hash = "sha256:65e18ddb37fe8ec4edcdbebe9be3a8486b6a2f42609d0a142677e42f3a0614f8"}, + {file = "pymupdf-1.25.5-cp39-abi3-win_amd64.whl", hash = "sha256:7f44bc3d03ea45b2f68c96464f96105e8c7908896f2fb5e8c04f1fb8dae7981e"}, + {file = "pymupdf-1.25.5.tar.gz", hash = "sha256:5f96311cacd13254c905f6654a004a0a2025b71cabc04fda667f5472f72c15a0"}, +] + [[package]] name = "pytest" version = "8.3.3" @@ -3740,4 +3757,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "bd3d5d4e71c0dcea2d01c2aa2655a406f7c3b3cff04247ce4dc7e115e11a1e0c" +content-hash = "67a116c974698063e0c065b30245d3375d28b80096a40a45fea12bc1e800d46b" diff --git a/pyproject.toml b/pyproject.toml index 735d66f7..a43b243a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,11 @@ notebook = "^7.3.2" mkdocs-macros-plugin = "^1.3.7" fastapi = "^0.115.0" +[tool.poetry.group.pdf] +optional = true + +[tool.poetry.group.pdf.dependencies] +PyMuPdf = ">1.25.0" [tool.poetry.scripts] encord-agents = "encord_agents.cli.main:app" diff --git a/tests/integration_tests/fastapi/test_dependencies.py b/tests/integration_tests/fastapi/test_dependencies.py index 55bb7507..8a81c985 100644 --- a/tests/integration_tests/fastapi/test_dependencies.py +++ b/tests/integration_tests/fastapi/test_dependencies.py @@ -13,12 +13,18 @@ from encord.storage import StorageItem from encord.user_client import EncordUserClient -from encord_agents.core.data_model import FrameData, LabelRowInitialiseLabelsArgs, LabelRowMetadataIncludeArgs +from encord_agents.core.data_model import ( + FrameData, + InstanceCrop, + LabelRowInitialiseLabelsArgs, + LabelRowMetadataIncludeArgs, +) from encord_agents.fastapi.cors import EncordCORSMiddleware, authorization_error_exception_handler, get_encord_app from encord_agents.fastapi.dependencies import ( dep_client, dep_label_row, dep_label_row_with_args, + dep_object_crops, dep_objects, dep_project, dep_storage_item, @@ -32,16 +38,21 @@ exit() +VIDEO_BOX_1_SIZE = 0.5 +VIDEO_BOX_2_SIZE = 0.6 + + class SharedResolutionContext(NamedTuple): project: Project video_label_row: LabelRowV2 + pdf_label_row: LabelRowV2 object_hash: str def build_app(context: SharedResolutionContext) -> FastAPI: project = context.project video_label_row = context.video_label_row - object_hash = context.object_hash + video_object_hash = context.object_hash app = FastAPI() app.add_middleware(EncordCORSMiddleware) app.exception_handlers[AuthorisationError] = authorization_error_exception_handler @@ -102,11 +113,46 @@ def post_frame_data_with_object_hash( frame_data: FrameData, object_instances: Annotated[list[ObjectInstance], Depends(dep_objects)] ) -> None: assert frame_data - assert frame_data.object_hashes == [object_hash] + assert frame_data.object_hashes == [video_object_hash] assert len(object_instances) == 1 - assert object_instances[0].object_hash == object_hash + assert object_instances[0].object_hash == video_object_hash assert isinstance(object_instances[0], ObjectInstance) + @app.post("/object-instance-crops-video") + def post_object_instance_crops_video( + frame_data: FrameData, + crops: Annotated[ + list[InstanceCrop], + Depends(dep_object_crops()), + ], + ) -> None: + assert crops + assert len(crops) == 1 + if frame_data.frame == 0: + assert crops[0].frame == 0 + assert crops[0].instance.object_hash == video_object_hash + assert video_label_row.height is not None and video_label_row.width is not None + expected_shape = (video_label_row.height * VIDEO_BOX_1_SIZE, video_label_row.width * VIDEO_BOX_1_SIZE, 3) + assert crops[0].content.shape == expected_shape + else: + assert crops[0].frame == 1 + assert crops[0].instance.object_hash != video_object_hash + assert video_label_row.height is not None and video_label_row.width is not None + expected_shape = (video_label_row.height * VIDEO_BOX_2_SIZE, video_label_row.width * VIDEO_BOX_2_SIZE, 3) + assert crops[0].content.shape == expected_shape + + @app.post("/object-instance-crops-pdf") + def post_object_instance_crops_pdf( + frame_data: FrameData, + crops: Annotated[list[InstanceCrop], Depends(dep_object_crops())], + ) -> None: + assert crops + assert len(crops) == 1 + assert crops[0].frame == 0 + # Hard-coded shape: Depends on object crop size and PDF file + expected_shape = (554, 428, 3) + assert crops[0].content.shape == expected_shape + return app @@ -118,14 +164,32 @@ def context(user_client: EncordUserClient, class_level_ephemeral_project_hash: s video_label_row = next( row for row in label_rows if row.data_type == DataType.VIDEO ) # Pick a video such that frame obviously makes sense + pdf_label_row = next(row for row in label_rows if row.data_type == DataType.PDF) video_label_row.initialise_labels() + pdf_label_row.initialise_labels() bbox_object = project.ontology_structure.get_child_by_hash(BBOX_ONTOLOGY_HASH, type_=Object) + pdf_obj_instance = bbox_object.create_instance() + pdf_obj_instance.set_for_frames( + BoundingBoxCoordinates(height=0.7, width=0.7, top_left_x=0, top_left_y=0), frames=[0] + ) + pdf_label_row.add_object_instance(pdf_obj_instance) obj_instance = bbox_object.create_instance() - obj_instance.set_for_frames(BoundingBoxCoordinates(height=0.5, width=0.5, top_left_x=0, top_left_y=0)) + obj_instance.set_for_frames( + BoundingBoxCoordinates(height=VIDEO_BOX_1_SIZE, width=VIDEO_BOX_1_SIZE, top_left_x=0, top_left_y=0), frames=[0] + ) + obj_instance_frame_2 = bbox_object.create_instance() + obj_instance_frame_2.set_for_frames( + BoundingBoxCoordinates(height=VIDEO_BOX_2_SIZE, width=VIDEO_BOX_2_SIZE, top_left_x=0, top_left_y=0), frames=[1] + ) video_label_row.add_object_instance(obj_instance) + video_label_row.add_object_instance(obj_instance_frame_2) video_label_row.save() + pdf_label_row.save() return SharedResolutionContext( - project=project, video_label_row=video_label_row, object_hash=obj_instance.object_hash + project=project, + video_label_row=video_label_row, + pdf_label_row=pdf_label_row, + object_hash=obj_instance.object_hash, ) @@ -210,6 +274,37 @@ def test_httpError_raised_appropriately(self, router_path: str) -> None: assert json_resp assert json_resp["message"] + def test_object_instance_crops_video(self) -> None: + resp = self.client.post( + "/object-instance-crops-video", + json={ + "projectHash": self.context.project.project_hash, + "dataHash": self.context.video_label_row.data_hash, + "frame": 0, + }, + ) + assert resp.status_code == 200, resp.content + resp = self.client.post( + "/object-instance-crops-video", + json={ + "projectHash": self.context.project.project_hash, + "dataHash": self.context.video_label_row.data_hash, + "frame": 1, + }, + ) + assert resp.status_code == 200, resp.content + + def test_object_instance_crops_pdf(self) -> None: + resp = self.client.post( + "/object-instance-crops-pdf", + json={ + "projectHash": self.context.project.project_hash, + "dataHash": self.context.pdf_label_row.data_hash, + "frame": 0, + }, + ) + assert resp.status_code == 200, resp.content + class TestCustomCorsRegex: def test_custom_cors_regex(self) -> None: