Skip to content

Commit

Permalink
fix(tree): fix filter files error (#1046)
Browse files Browse the repository at this point in the history
  • Loading branch information
gejielun authored Sep 8, 2022
1 parent e68b975 commit 9318944
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion fedlearner/model/tree/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def extract_field(field_names, field_name, required):

def read_data(file_type, filename, require_example_ids, require_labels,
ignore_fields, cat_fields, label_field):
logging.debug('Reading data file from %s', filename)
logging.info('Reading data file from %s', filename)

if file_type == 'tfrecord':
reader = tf.io.tf_record_iterator(filename)
Expand Down
10 changes: 10 additions & 0 deletions fedlearner/model/tree/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import logging
from fnmatch import fnmatch
from typing import List, Optional

Expand All @@ -8,6 +9,7 @@
def filter_files(path: str, file_ext: Optional[str],
file_wildcard: Optional[str]) -> List[str]:
files = []
depth = 0
for dirname, _, filenames in tf.io.gfile.walk(path):
for filename in filenames:
_, ext = os.path.splitext(filename)
Expand All @@ -18,4 +20,12 @@ def filter_files(path: str, file_ext: Optional[str],
if file_wildcard and not fnmatch(fpath, file_wildcard):
continue
files.append(fpath)
depth += 1
# Not retrieving recursively since there might be
# some unrecognized files.
if depth > 1:
break
logging.info("file wildcard is %s, file ext is %s, "
"filtered files num: %d", file_wildcard,
file_ext, len(files))
return files
35 changes: 20 additions & 15 deletions test/tree_model/test_filter_files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os.path
import tempfile
import logging
import unittest
from pathlib import Path
from fedlearner.model.tree.trainer import filter_files
Expand All @@ -9,28 +11,31 @@ def test_filter_files(self):
path = tempfile.mkdtemp()
path = Path(path, 'test').resolve()
path.mkdir()
path.joinpath('test1').mkdir()
path.joinpath('test2').mkdir()
path.joinpath('sub_test').mkdir()
path.joinpath('1.csv').touch()
path.joinpath('2.csv').touch()
path.joinpath('3.csv').touch()
path.joinpath('1.tfrecord').touch()
path.joinpath('2.tfrecord').touch()
path.joinpath('3.tfrecord').touch()
path.joinpath('test1').joinpath('1.csv').touch()
path.joinpath('test1').joinpath('2.tfrecord').touch()
path.joinpath('test2').joinpath('2.csv').touch()
path.joinpath('test2').joinpath('1.tfrecord').touch()
path.joinpath('test1/test').mkdir()
path.joinpath('test1/test').joinpath('4.csv').touch()
path.joinpath('test2/test').mkdir()
path.joinpath('test2/test').joinpath('4.tfrecord').touch()

files = filter_files(path, '.csv', '')
path.joinpath('sub_test').joinpath('sub_sub_test').mkdir()
path.joinpath('sub_test').joinpath('4.csv').touch()
path.joinpath('sub_test').joinpath('4.tfrecord').touch()
path.joinpath('sub_test').joinpath('sub_sub_test').joinpath('5.csv').touch()
path.joinpath('sub_test').joinpath('sub_sub_test').joinpath('5.tfrecord').touch()

files = filter_files(str(path), '.csv', '')
self.assertEqual(len(files), 4)
files = filter_files(path, '', '*tfr*')
files = filter_files(str(path), '', '*tfr*')
self.assertEqual(len(files), 4)
files = filter_files(path, '', '')
files = filter_files(str(path), '', '')
self.assertEqual(len(files), 8)
files = filter_files(path, '.csv', '*1.*')
files = filter_files(str(path), '.csv', '*1.*')
self.assertEqual(len(files), 1)
files = filter_files((str(os.path.join(path, 'sub_test'))), '', '*csv')
self.assertEqual(len(files), 2)


if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
unittest.main()

0 comments on commit 9318944

Please sign in to comment.