Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[WIP] Example of DataPipes and DataFrames integration (pytorch#60840)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#60840

Test Plan: Imported from OSS

Reviewed By: wenleix, ejguan

Differential Revision: D29461080

Pulled By: VitalyFedyunin

fbshipit-source-id: 4909394dcd39e97ee49b699fda542b311b7e0d82
  • Loading branch information
VitalyFedyunin authored and facebook-github-bot committed Sep 14, 2021
1 parent ee554e2 commit ab5e1c6
Show file tree
Hide file tree
Showing 14 changed files with 1,108 additions and 12 deletions.
65 changes: 65 additions & 0 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@
HAS_DILL = False
skipIfNoDill = skipIf(not HAS_DILL, "no dill")

try:
import pandas # type: ignore[import] # noqa: F401 F403
HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False
skipIfNoDataFrames = skipIf(not HAS_PANDAS, "no dataframes (pandas)")

T_co = TypeVar("T_co", covariant=True)


Expand Down Expand Up @@ -393,6 +400,64 @@ def test_demux_mux_datapipe(self):
self.assertEqual(source_numbers, list(n))


class TestDataFramesPipes(TestCase):
"""
Most of test will fail if pandas instaled, but no dill available.
Need to rework them to avoid multiple skips.
"""
def _get_datapipe(self, range=10, dataframe_size=7):
return NumbersDataset(range) \
.map(lambda i: (i, i % 3))

def _get_dataframes_pipe(self, range=10, dataframe_size=7):
return NumbersDataset(range) \
.map(lambda i: (i, i % 3)) \
._to_dataframes_pipe(
columns=['i', 'j'],
dataframe_size=dataframe_size)

@skipIfNoDataFrames
@skipIfNoDill # TODO(VitalyFedyunin): Decouple tests from dill by avoiding lambdas in map
def test_capture(self):
dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
df_numbers = self._get_dataframes_pipe()
df_numbers['k'] = df_numbers['j'] + df_numbers.i * 3
self.assertEqual(list(dp_numbers), list(df_numbers))

@skipIfNoDataFrames
@skipIfNoDill
def test_shuffle(self):
# With non-zero (but extremely low) probability (when shuffle do nothing),
# this test fails, so feel free to restart
df_numbers = self._get_dataframes_pipe(range=1000).shuffle()
dp_numbers = self._get_datapipe(range=1000)
df_result = [tuple(item) for item in df_numbers]
self.assertNotEqual(list(dp_numbers), df_result)
self.assertEqual(list(dp_numbers), sorted(df_result))

@skipIfNoDataFrames
@skipIfNoDill
def test_batch(self):
df_numbers = self._get_dataframes_pipe(range=100).batch(8)
df_numbers_list = list(df_numbers)
last_batch = df_numbers_list[-1]
self.assertEqual(4, len(last_batch))
unpacked_batch = [tuple(row) for row in last_batch]
self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)

@skipIfNoDataFrames
@skipIfNoDill
def test_unbatch(self):
df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
dp_numbers = self._get_datapipe(range=100)
self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))

@skipIfNoDataFrames
@skipIfNoDill
def test_filter(self):
df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], list(df_numbers))

class FileLoggerSimpleHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, logfile=None, **kwargs):
self.__loggerHandle = None
Expand Down
2 changes: 1 addition & 1 deletion tools/linter/clang_tidy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def find_changed_lines(diff: str) -> Dict[str, List[Tuple[int, int]]]:
i += 1
ranges[-1][1] = added_line_nos[-1]

files[file.path].append(*ranges)
files[file.path] += ranges

return dict(files)

Expand Down
1 change: 1 addition & 0 deletions torch/utils/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DataChunk,
Dataset,
Dataset as MapDataPipe,
DFIterDataPipe,
IterableDataset,
IterableDataset as IterDataPipe,
Subset,
Expand Down
12 changes: 9 additions & 3 deletions torch/utils/data/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
class functional_datapipe(object):
name: str

def __init__(self, name: str) -> None:
def __init__(self, name: str, enable_df_api_tracing=False) -> None:
"""
Args:
enable_df_api_tracing - if set, any returned DataPipe would accept
DataFrames API in tracing mode.
"""
self.name = name
self.enable_df_api_tracing = enable_df_api_tracing

def __call__(self, cls):
if issubclass(cls, IterDataPipe):
Expand All @@ -25,9 +31,9 @@ def __call__(self, cls):
not (hasattr(cls, '__self__') and
isinstance(cls.__self__, non_deterministic)):
raise TypeError('`functional_datapipe` can only decorate IterDataPipe')
IterDataPipe.register_datapipe_as_function(self.name, cls)
IterDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)
elif issubclass(cls, MapDataPipe):
MapDataPipe.register_datapipe_as_function(self.name, cls)
MapDataPipe.register_datapipe_as_function(self.name, cls, enable_df_api_tracing=self.enable_df_api_tracing)

return cls

Expand Down
Loading

0 comments on commit ab5e1c6

Please sign in to comment.