Skip to content

Commit

Permalink
Merge pull request #11 from climatepolicyradar/refactor
Browse files Browse the repository at this point in the history
Refactoring.
  • Loading branch information
THOR300 authored Sep 28, 2023
2 parents 2cc7359 + 968ef3f commit 8541282
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 55 deletions.
42 changes: 23 additions & 19 deletions cli/text2embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
import logging.config
import json
import os
from pathlib import Path
from typing import Optional
Expand Down Expand Up @@ -61,8 +60,8 @@
@click.option(
"--redo",
"-r",
help="Redo encoding for files that have already been parsed. By default, files with IDs that already exist "
"in the output directory are skipped.",
help="Redo encoding for files that have already been parsed. By default, "
"files with IDs that already exist in the output directory are skipped.",
is_flag=True,
default=False,
)
Expand All @@ -88,25 +87,30 @@ def run_as_cli(
limit: Optional[int],
):
"""
Run CLI to produce embeddings from document parser JSON outputs. Each embeddings file is called {id}.json
where {id} is the document ID of the input. Its first line is the description embedding and all other lines
are embeddings of each of the text blocks in the document in order. Encoding will automatically run on the
GPU if one is available.
Args: input_dir: Directory containing JSON files output_dir: Directory to save embeddings to s3: Whether we
are reading from and writing to S3. redo: Redo encoding for files that have already been parsed. By default,
files with IDs that already exist in the output directory are skipped. limit (Optional[int]): Optionally
limit the number of text samples to process. Useful for debugging. device (str): Device to use for
embeddings generation. Must be either "cuda" or "cpu".
Run CLI to produce embeddings from document parser JSON outputs.
Each embeddings file is called {id}.json where {id} is the document ID of the
input. Its first line is the description embedding and all other lines are
embeddings of each of the text blocks in the document in order. Encoding will
automatically run on the GPU if one is available.
Args: input_dir: Directory containing JSON files output_dir: Directory to save
embeddings to s3: Whether we are reading from and writing to S3. redo: Redo
encoding for files that have already been parsed. By default, files with IDs that
already exist in the output directory are skipped. limit (Optional[int]):
Optionally limit the number of text samples to process. Useful for debugging.
device (str): Device to use for embeddings generation. Must be either "cuda" or
"cpu".
"""
# FIXME: This solution assumes that we have a json document with language = en (supported target language)
# for every document in the parser output. This isn't very robust. This solution also requires passing
# every document into the embeddings stage so we are declaring tasks that are immediately dropped due to
# content. Filter only to tasks that have one language and where the language is supported. These could
# either be translated or in the original language.
# FIXME: This solution assumes that we have a json document with language = en (
# supported target language) for every document in the parser output. This isn't
# very robust. This solution also requires passing every document into the
# embeddings stage so we are declaring tasks that are immediately dropped due to
# content. Filter only to tasks that have one language and where the language is
# supported. These could either be translated or in the original language.

logger.info(
f"Running embeddings generation...",
"Running embeddings generation...",
extra={
"props": {
"input_dir": input_dir,
Expand Down
24 changes: 12 additions & 12 deletions src/languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def validate_languages_decorator(func):

def wrapper(*args, **kwargs):
if (
unsupported_languages := config.TARGET_LANGUAGES
- config.ENCODER_SUPPORTED_LANGUAGES
unsupported_languages := config.TARGET_LANGUAGES
- config.ENCODER_SUPPORTED_LANGUAGES
):
logger.warning(
f"The following languages have been requested for encoding but are not "
Expand All @@ -29,22 +29,22 @@ def wrapper(*args, **kwargs):
def task_has_one_lang_that_is_supported(task: ParserOutput) -> bool:
"""Return true if the task has one language that is supported by the encoder."""
return (
task.languages
and (len(task.languages) == 1)
and (
task.languages[0]
in config.ENCODER_SUPPORTED_LANGUAGES.union(config.TARGET_LANGUAGES)
)
task.languages
and (len(task.languages) == 1)
and (
task.languages[0]
in config.ENCODER_SUPPORTED_LANGUAGES.union(config.TARGET_LANGUAGES)
)
)


def task_has_no_source_url_languages_or_data(task: ParserOutput) -> bool:
"""Return true if the task has no source url, languages or html/pdf data."""
return (
not task.document_source_url
and not task.languages
and task.html_data is None
and task.pdf_data is None
not task.document_source_url
and not task.languages
and task.html_data is None
and task.pdf_data is None
)


Expand Down
3 changes: 1 addition & 2 deletions src/s3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import tempfile
from typing import Any, Sequence

import boto3
import numpy as np
from botocore.exceptions import ClientError
from aws_error_utils import errors
from botocore.exceptions import ClientError

from src.config import S3_PATTERN

Expand Down
32 changes: 20 additions & 12 deletions src/test/test_languages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@ def test_get_docs_of_supported_language(
test_parser_output_source_url_un_supported_lang_data: List[ParserOutput],
):
"""Tests that the function returns only docs of a supported language."""
assert get_docs_of_supported_language(
test_parser_output_no_source_url_no_lang_no_data
) == test_parser_output_no_source_url_no_lang_no_data
assert (
get_docs_of_supported_language(test_parser_output_no_source_url_no_lang_no_data)
== test_parser_output_no_source_url_no_lang_no_data
)

assert get_docs_of_supported_language(
test_parser_output_source_url_no_lang_no_data
) == []
assert (
get_docs_of_supported_language(test_parser_output_source_url_no_lang_no_data)
== []
)

assert get_docs_of_supported_language(
test_parser_output_source_url_supported_lang_data
) == test_parser_output_source_url_supported_lang_data
assert (
get_docs_of_supported_language(
test_parser_output_source_url_supported_lang_data
)
== test_parser_output_source_url_supported_lang_data
)

assert get_docs_of_supported_language(
test_parser_output_source_url_un_supported_lang_data
) == []
assert (
get_docs_of_supported_language(
test_parser_output_source_url_un_supported_lang_data
)
== []
)
11 changes: 6 additions & 5 deletions src/test/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_validate_s3_pattern(test_file_key):
assert s3client is not None

try:
validate_s3_pattern(f"random_string")
validate_s3_pattern("random_string")
except Exception as e:
assert "Key does not represent an s3 path: random_string" in str(e)

Expand All @@ -30,7 +30,7 @@ def test_check_file_exists_in_s3(pipeline_s3_client, test_file_key):
"""Test whether we can check whether a file exists in s3."""

assert check_file_exists_in_s3(f"s3://{test_file_key}")
assert not check_file_exists_in_s3(f"s3://random_bucket/prefix/file.json")
assert not check_file_exists_in_s3("s3://random_bucket/prefix/file.json")


def test_get_s3_keys_with_prefix(
Expand All @@ -42,7 +42,7 @@ def test_get_s3_keys_with_prefix(
) == [f"{test_prefix}/test_id.json"]

try:
get_s3_keys_with_prefix(f"random_string")
get_s3_keys_with_prefix("random_string")
except Exception as e:
assert "Prefix does not represent an s3 path: random_string" in str(e)

Expand All @@ -55,7 +55,8 @@ def test_s3_object_read_text(pipeline_s3_client, test_file_key, test_file_json):
def test_write_json_to_s3(pipeline_s3_client, s3_bucket_and_region, test_file_json):
"""Test that we can write json to an s3 object."""
write_json_to_s3(
json.dumps(test_file_json), f"s3://{s3_bucket_and_region['bucket']}/prefix/test.json"
json.dumps(test_file_json),
f"s3://{s3_bucket_and_region['bucket']}/prefix/test.json",
)
assert (
json.loads(
Expand All @@ -82,4 +83,4 @@ def test_save_ndarray_to_s3_as_npy(pipeline_s3_client, s3_bucket_and_region):
np.array([1, 2, 3]), "s3://random_bucket/prefix/test.npy"
)
except Exception as e:
assert f"Bucket random_bucket does not exist" in str(e)
assert "Bucket random_bucket does not exist" in str(e)
3 changes: 2 additions & 1 deletion src/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def test_has_valid_text_override(test_parser_output_array: Sequence[ParserOutput
"""
Test that the get_text_blocks method provides the right response.
Particularly when using the including_invalid_html parameter."""
Particularly when using the including_invalid_html parameter.
"""

output = test_parser_output_array[1]
assert output.get_text_blocks() == []
Expand Down
8 changes: 4 additions & 4 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,17 +150,17 @@ def get_files_to_process(
files_to_process = os.listdir(input_dir)

files_to_process_ids = get_ids_with_suffix(files_to_process, ".json")
files_already_processed = document_ids_previously_parsed.intersection(files_to_process_ids)
files_already_processed = document_ids_previously_parsed.intersection(
files_to_process_ids
)
if not redo and files_already_processed:
logger.warning(
f"{len(files_already_processed)} "
f"documents found that have already been encoded. Skipping. "
)

files_to_process_ids_sequence = [
id_
for id_ in files_to_process_ids
if id_ not in document_ids_previously_parsed
id_ for id_ in files_to_process_ids if id_ not in document_ids_previously_parsed
]
if not files_to_process_ids_sequence:
logger.warning("No more documents to encode. Exiting.")
Expand Down

0 comments on commit 8541282

Please sign in to comment.