Skip to content

Commit

Permalink
Fix metadata CSV output in example generation pipeline.
Browse files Browse the repository at this point in the history
Previously, the function for generating the metadata CSV required materializing all lines of the CSV in a single worker's memory. This caused the worker to OOM and crash when the number of examples was too large (e.g. >1M).

PiperOrigin-RevId: 687026814
  • Loading branch information
jzxu authored and copybara-github committed Oct 17, 2024
1 parent a06a09c commit 1f52a97
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 97 deletions.
81 changes: 21 additions & 60 deletions src/skai/generate_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Pipeline for generating tensorflow examples from satellite images."""

import binascii
import csv
import dataclasses
import hashlib
import itertools
Expand All @@ -28,6 +27,8 @@
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple

import apache_beam as beam
import apache_beam.dataframe.convert
import apache_beam.dataframe.io
import cv2
import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -944,28 +945,14 @@ def _generate_examples_pipeline(
num_shards=num_output_shards))

if output_metadata_file:
field_names = [
'example_id',
'encoded_coordinates',
'longitude',
'latitude',
'post_image_id',
'pre_image_id',
'plus_code',
]
_ = (
rows = (
examples
| 'convert_metadata_examples_to_dict' >> beam.Map(_get_example_metadata)
| 'combine_to_list' >> beam.combiners.ToList()
| 'write_metadata_to_file'
>> beam.ParDo(
WriteMetadataToCSVFn(
metadata_output_file_path=(
f'{output_dir}/examples/metadata_examples.csv'
), field_names=field_names
)
)
| 'extract_metadata_rows' >> beam.Map(_get_example_metadata)
| 'remove_duplicates' >> beam.Distinct()
)
df = apache_beam.dataframe.convert.to_dataframe(rows)
output_prefix = f'{output_dir}/examples/metadata/metadata.csv'
apache_beam.dataframe.io.to_csv(df, output_prefix, index=False)

result = pipeline.run()
if wait_for_dataflow_job:
Expand Down Expand Up @@ -1103,27 +1090,6 @@ def run_example_generation(
)


class WriteMetadataToCSVFn(beam.DoFn):
"""DoFn to write meta data of examples to csv file.
Attributes:
metadata_output_file_path: File path to output meta data of all examples.
field_names: Field names to be included in output file.
"""

def __init__(self, metadata_output_file_path: str, field_names: List[str]):
self.metadata_output_file_path = metadata_output_file_path
self.field_names = field_names

def process(self, element):
with tf.io.gfile.GFile(
self.metadata_output_file_path, 'w'
) as csv_output_file:
csv_writer = csv.DictWriter(csv_output_file, fieldnames=self.field_names)
csv_writer.writeheader()
csv_writer.writerows(element)


class ExampleType(typing.NamedTuple):
example_id: str
encoded_coordinates: str
Expand All @@ -1136,21 +1102,16 @@ class ExampleType(typing.NamedTuple):

@beam.typehints.with_output_types(ExampleType)
def _get_example_metadata(example: tf.train.Example) -> ExampleType:
example_id = utils.get_bytes_feature(example, 'example_id')[0].decode()
encoded_coordinates = utils.get_bytes_feature(example, 'encoded_coordinates')[
0
].decode()
longitude, latitude = utils.get_float_feature(example, 'coordinates')
post_image_id = utils.get_bytes_feature(example, 'post_image_id')[0].decode()
pre_image_id = utils.get_bytes_feature(example, 'pre_image_id')[0].decode()
plus_code = utils.get_bytes_feature(example, 'plus_code')[0].decode()

return dict({
'example_id': example_id,
'encoded_coordinates': encoded_coordinates,
'longitude': longitude,
'latitude': latitude,
'post_image_id': post_image_id,
'pre_image_id': pre_image_id,
'plus_code': plus_code,
})
return ExampleType(
example_id=utils.get_bytes_feature(example, 'example_id')[0].decode(),
encoded_coordinates=utils.get_bytes_feature(
example, 'encoded_coordinates'
)[0].decode(),
longitude=utils.get_float_feature(example, 'coordinates')[0],
latitude=utils.get_float_feature(example, 'coordinates')[1],
post_image_id=utils.get_bytes_feature(example, 'post_image_id')[
0
].decode(),
pre_image_id=utils.get_bytes_feature(example, 'pre_image_id')[0].decode(),
plus_code=utils.get_bytes_feature(example, 'plus_code')[0].decode(),
)
18 changes: 10 additions & 8 deletions src/skai/generate_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for generate_examples.py."""

import glob
import os
import pathlib
import tempfile
Expand Down Expand Up @@ -526,26 +527,27 @@ def testGenerateExamplesWithOutputMetaDataFile(self):
tfrecords = os.listdir(
os.path.join(output_dir, 'examples', 'unlabeled-large')
)
df_metadata_contents = pd.read_csv(
os.path.join(output_dir, 'examples', 'metadata_examples.csv')
metadata_pattern = os.path.join(
output_dir, 'examples', 'metadata', 'metadata.csv-*-of-*'
)
metadata = pd.concat([pd.read_csv(p) for p in glob.glob(metadata_pattern)])

# No assert for example_id as each example_id depends on the image path
# which varies with platforms where this test is run
self.assertEqual(
df_metadata_contents.encoded_coordinates[0], 'A17B32432A1085C1'
metadata.encoded_coordinates[0], 'A17B32432A1085C1'
)
self.assertAlmostEqual(
df_metadata_contents.latitude[0], -16.632892608642578
metadata.latitude[0], -16.632892608642578
)
self.assertAlmostEqual(
df_metadata_contents.longitude[0], 178.48292541503906
metadata.longitude[0], 178.48292541503906
)
self.assertEqual(df_metadata_contents.pre_image_id[0], self.test_image_path)
self.assertEqual(metadata.pre_image_id[0], self.test_image_path)
self.assertEqual(
df_metadata_contents.post_image_id[0], self.test_image_path
metadata.post_image_id[0], self.test_image_path
)
self.assertEqual(df_metadata_contents.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertEqual(metadata.plus_code[0], '5VMW9F8M+R5V8F4')
self.assertSameElements(tfrecords, ['unlabeled-00000-of-00001.tfrecord'])

def testConfigLoadedCorrectlyFromJsonFile(self):
Expand Down
60 changes: 39 additions & 21 deletions src/skai/labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ def sample_with_buffer(
return sample


def _read_sharded_csvs(pattern: str) -> pd.DataFrame:
"""Reads CSV shards matching pattern and merges them."""
paths = tf.io.gfile.glob(pattern)
if not paths:
raise ValueError(f'File pattern {pattern} did not match any files.')
dfs = []
expected_columns = None
for path in paths:
with tf.io.gfile.GFile(path, 'r') as f:
df = pd.read_csv(f)
if expected_columns is None:
expected_columns = set(df.columns)
else:
actual_columns = set(df.columns)
if actual_columns != expected_columns:
raise ValueError(f'Inconsistent columns in file {path}')
dfs.append(df)
return pd.concat(dfs, ignore_index=True)


def get_buffered_example_ids(
examples_pattern: str,
buffered_sampling_radius: float,
Expand All @@ -238,25 +258,23 @@ def get_buffered_example_ids(
Returns:
Set of allowed example ids.
"""
metadata_path = str(
os.path.join(
'/'.join(examples_pattern.split('/')[:-2]),
'metadata_examples.csv',
)
)
with tf.io.gfile.GFile(metadata_path, 'r') as f:
try:
df_metadata = pd.read_csv(f)
df_metadata = df_metadata[
~df_metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)
except tf.errors.NotFoundError as error:
raise SystemExit(
f'\ntf.errors.NotFoundError: {metadata_path} was not found\nUse'
' examples_to_csv module to generate metadata_examples.csv and/or'
' put metadata_examples.csv in the appropriate directory that is'
' PATH_DIR/examples/'
) from error
root_dir = '/'.join(examples_pattern.split('/')[:-2])
single_csv_pattern = str(os.path.join(root_dir, 'metadata_examples.csv'))
if tf.io.gfile.exists(single_csv_pattern):
metadata = _read_sharded_csvs(single_csv_pattern)
else:
sharded_csv_pattern = str(
os.path.join(
root_dir,
'metadata',
'metadata.csv-*-of-*',
)
)
metadata = _read_sharded_csvs(sharded_csv_pattern)

metadata = metadata[
~metadata['example_id'].isin(excluded_example_ids)
].reset_index(drop=True)

logging.info(
'Randomly searching for buffered samples with buffer radius %.2f'
Expand All @@ -265,11 +283,11 @@ def get_buffered_example_ids(
)
points = utils.convert_to_utm(
gpd.GeoSeries(
gpd.points_from_xy(df_metadata['longitude'], df_metadata['latitude']),
gpd.points_from_xy(metadata['longitude'], metadata['latitude']),
crs=4326,
)
)
gpd_df = gpd.GeoDataFrame(df_metadata, geometry=points)
gpd_df = gpd.GeoDataFrame(metadata, geometry=points)
max_examples = len(gpd_df) if max_examples is None else max_examples
df_buffered_samples = sample_with_buffer(
gpd_df, max_examples, buffered_sampling_radius
Expand Down
36 changes: 28 additions & 8 deletions src/skai/labeling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for labeling."""

import os
import random
import tempfile

from absl.testing import absltest
from absl.testing import parameterized
import geopandas as gpd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -80,9 +79,12 @@ def _read_tfrecords(path: str) -> list[Example]:
return examples


class LabelingTest(absltest.TestCase):
class LabelingTest(parameterized.TestCase):

def test_create_buffered_tfrecords(self):
@parameterized.parameters(
dict(sharded_example_metadata=True), dict(sharded_example_metadata=False)
)
def test_create_buffered_tfrecords(self, sharded_example_metadata: bool):
"""Tests create_buffered_tfrecords."""
# Create 5 unlabeled examples in 3 tfrecords.
with tempfile.TemporaryDirectory() as examples_dir:
Expand All @@ -93,9 +95,6 @@ def test_create_buffered_tfrecords(self):
examples_pattern = os.path.join(
examples_dir, 'examples', 'unlabeled', '*'
)
metadata_examples_path = os.path.join(
examples_dir, 'examples', 'metadata_examples.csv'
)
filtered_tfrecords_output_dir = os.path.join(
examples_dir, 'filtered',
)
Expand All @@ -114,7 +113,28 @@ def test_create_buffered_tfrecords(self):
columns=['example_id', 'longitude', 'latitude'],
)
df_metadata = df_metadata.sample(frac=1)
df_metadata.to_csv(metadata_examples_path, index=False)
if sharded_example_metadata:
metadata_dir = os.path.join(examples_dir, 'examples', 'metadata')
os.mkdir(metadata_dir)
df_metadata.iloc[:2].to_csv(
os.path.join(
metadata_dir,
'metadata.csv-00000-of-00002',
),
index=False,
)
df_metadata.iloc[2:].to_csv(
os.path.join(
metadata_dir,
'metadata.csv-00001-of-00002',
),
index=False,
)
else:
metadata_examples_path = os.path.join(
examples_dir, 'examples', 'metadata_examples.csv'
)
df_metadata.to_csv(metadata_examples_path, index=False)

example_id_lon_lat_create_tfrecords = {
'001': [('a', [92.850449, 20.148951]), ('b', [92.889694, 20.157515])],
Expand Down

0 comments on commit 1f52a97

Please sign in to comment.