From 7463a121044041ab3d7780bc9b1e468544f9ae6d Mon Sep 17 00:00:00 2001 From: Lemon Qin <57213526+Lemon-412@users.noreply.github.com> Date: Mon, 29 Aug 2022 14:34:59 +0800 Subject: [PATCH] feat(tree): add multiprocessing tree data loader (#1041) * feat(tree): add multiprocessing dataloader * fix(tree): fix lint * fix(tree): fix multiprocessing start method (#1040) * fix(tree): fix multiprocessing start method * chore(tree): multiprocessing reusing parameter `num_parallel` * style(tree): rename num_data_loaders to num_parallel Co-authored-by: qinminhao * test(tree): multiprocessing tree model test.sh * style(tree): code clean * style(tree): fix code lint(logging-not-lazy) * test(tree): move unittest to rest_trainer.py * docs(tree): more comments * fix(tree): fix code lint * style(tree): make parameter num_paralle optimal Co-authored-by: qinminhao --- example/tree_model/test.sh | 8 ++ fedlearner/model/tree/trainer.py | 140 ++++++++++++++++++++----------- test/tree_model/test_trainer.py | 94 ++++++++++++++++++++- 3 files changed, 191 insertions(+), 51 deletions(-) diff --git a/example/tree_model/test.sh b/example/tree_model/test.sh index 71315446a..923327899 100755 --- a/example/tree_model/test.sh +++ b/example/tree_model/test.sh @@ -12,6 +12,7 @@ python -m fedlearner.model.tree.trainer follower \ --verbosity=1 \ --local-addr=localhost:50052 \ --peer-addr=localhost:50051 \ + --num-parallel=4 \ --verify-example-ids=true \ --file-ext=.tfrecord \ --file-wildcard=*tfrecord \ @@ -26,6 +27,7 @@ python -m fedlearner.model.tree.trainer leader \ --verbosity=1 \ --local-addr=localhost:50051 \ --peer-addr=localhost:50052 \ + --num-parallel=4 \ --verify-example-ids=true \ --file-ext=.tfrecord \ --file-wildcard=*tfrecord \ @@ -42,6 +44,7 @@ python -m fedlearner.model.tree.trainer leader \ --verbosity=1 \ --local-addr=localhost:50051 \ --peer-addr=localhost:50052 \ + --num-parallel=4 \ --mode=test \ --verify-example-ids=true \ --file-type=tfrecord \ @@ -56,6 +59,7 @@ python -m fedlearner.model.tree.trainer follower \ --verbosity=1 \ --local-addr=localhost:50052 \ --peer-addr=localhost:50051 \ + --num-parallel=4 \ --mode=test \ --verify-example-ids=true \ --file-type=tfrecord \ @@ -77,6 +81,7 @@ python -m fedlearner.model.tree.trainer follower \ --verbosity=1 \ --local-addr=localhost:50052 \ --peer-addr=localhost:50051 \ + --num-parallel=4 \ --file-ext=.csv \ --file-type=csv \ --file-wildcard=*csv \ @@ -89,6 +94,7 @@ python -m fedlearner.model.tree.trainer leader \ --verbosity=1 \ --local-addr=localhost:50051 \ --peer-addr=localhost:50052 \ + --num-parallel=4 \ --file-ext=.csv \ --file-type=csv \ --file-wildcard=*csv \ @@ -103,6 +109,7 @@ python -m fedlearner.model.tree.trainer follower \ --verbosity=2 \ --local-addr=localhost:50052 \ --peer-addr=localhost:50051 \ + --num-parallel=4 \ --mode=test \ --file-ext=.csv \ --file-wildcard=*csv \ @@ -116,6 +123,7 @@ python -m fedlearner.model.tree.trainer leader \ --verbosity=2 \ --local-addr=localhost:50051 \ --peer-addr=localhost:50052 \ + --num-parallel=4 \ --mode=test \ --file-ext=.csv \ --file-wildcard=*csv \ diff --git a/fedlearner/model/tree/trainer.py b/fedlearner/model/tree/trainer.py index 85f1b1c58..77a3eb7b5 100644 --- a/fedlearner/model/tree/trainer.py +++ b/fedlearner/model/tree/trainer.py @@ -16,12 +16,15 @@ import os import csv +import time import queue import logging import argparse import traceback import itertools from typing import Optional +from concurrent.futures import ProcessPoolExecutor +import multiprocessing import numpy as np import tensorflow.compat.v1 as tf @@ -33,7 +36,6 @@ from fedlearner.model.tree.trainer_master_client import DataBlockInfo from fedlearner.model.tree.utils import filter_files - def create_argument_parser(): parser = argparse.ArgumentParser( description='FedLearner Tree Model Trainer.') @@ -172,16 +174,16 @@ def parse_tfrecord(record): for key, value in example.features.feature.items(): kind = value.WhichOneof('kind') if kind == 'float_list': - assert len(value.float_list.value) == 1, "Invalid tfrecord format" + assert len(value.float_list.value) == 1, 'Invalid tfrecord format' parsed[key] = value.float_list.value[0] elif kind == 'int64_list': - assert len(value.int64_list.value) == 1, "Invalid tfrecord format" + assert len(value.int64_list.value) == 1, 'Invalid tfrecord format' parsed[key] = value.int64_list.value[0] elif kind == 'bytes_list': - assert len(value.bytes_list.value) == 1, "Invalid tfrecord format" + assert len(value.bytes_list.value) == 1, 'Invalid tfrecord format' parsed[key] = value.bytes_list.value[0] else: - raise ValueError("Invalid tfrecord format") + raise ValueError('Invalid tfrecord format') return parsed @@ -191,7 +193,7 @@ def extract_field(field_names, field_name, required): return [] assert not required, \ - "Field %s is required but missing in data"%field_name + 'Field %s is required but missing in data'%field_name return None @@ -220,7 +222,7 @@ def read_data(file_type, filename, require_example_ids, require_labels, ignore_fields.update(['example_id', 'raw_id', label_field]) cat_fields = set(filter(bool, cat_fields.strip().split(','))) for name in cat_fields: - assert name in field_names, "cat_field %s missing"%name + assert name in field_names, 'cat_field %s missing'%name cont_columns = list(filter( lambda x: x not in ignore_fields and x not in cat_fields, field_names)) @@ -231,6 +233,7 @@ def read_data(file_type, filename, require_example_ids, require_labels, features = [] cat_features = [] + def to_float(x): return float(x if x not in ['', None] else 'nan') for line in reader: @@ -256,7 +259,9 @@ def to_float(x): def read_data_dir(file_ext: str, file_wildcard: str, file_type: str, path: str, require_example_ids: bool, require_labels: bool, - ignore_fields: str, cat_fields: str, label_field: str): + ignore_fields: str, cat_fields: str, label_field: str, + num_parallel: Optional[int]): + if not tf.io.gfile.isdir(path): return read_data( file_type, path, require_example_ids, @@ -264,39 +269,68 @@ def read_data_dir(file_ext: str, file_wildcard: str, file_type: str, path: str, files = filter_files(path, file_ext, file_wildcard) files.sort() + assert len(files) > 0, f'No file exsists in directory(path={path} ' \ + f'extension={file_ext} wildcard={file_wildcard})' + + if num_parallel: + assert num_parallel >= 1, 'Invalid num_parallel' + else: + num_parallel = 1 + + if num_parallel > len(files): + logging.info('Number of files(%s) is less than num_parallel(%s), ' + 'switch num_parallel to %s', + len(files), num_parallel, len(files)) + num_parallel = len(files) + features = None - for fullname in files: - ifeatures, icat_features, icont_columns, icat_columns, \ - ilabels, iexample_ids, iraw_ids = read_data( - file_type, fullname, require_example_ids, require_labels, - ignore_fields, cat_fields, label_field - ) - if features is None: - features = ifeatures - cat_features = icat_features - cont_columns = icont_columns - cat_columns = icat_columns - labels = ilabels - example_ids = iexample_ids - raw_ids = iraw_ids - else: - assert cont_columns == icont_columns, \ - "columns mismatch between files %s vs %s"%( - cont_columns, icont_columns) - assert cat_columns == icat_columns, \ - "columns mismatch between files %s vs %s"%( - cat_columns, icat_columns) - features = np.concatenate((features, ifeatures), axis=0) - cat_features = np.concatenate( - (cat_features, icat_features), axis=0) - if labels is not None: - labels = np.concatenate((labels, ilabels), axis=0) - if example_ids is not None: - example_ids.extend(iexample_ids) - if raw_ids is not None: - raw_ids.extend(iraw_ids) - - assert features is not None, "No data found in %s"%path + + start_time = time.time() + logging.info('taskes start time: %s', str(start_time)) + logging.info('Data loader count = %s', str(num_parallel)) + + with ProcessPoolExecutor(max_workers=num_parallel) as pool: + futures = [] + for fullname in files: + future = pool.submit( + read_data, file_type, fullname, + require_example_ids, require_labels, + ignore_fields, cat_fields, label_field) + futures.append(future) + for future in futures: + ifeatures, icat_features, icont_columns, icat_columns, \ + ilabels, iexample_ids, iraw_ids = future.result() + if features is None: + features = ifeatures + cat_features = icat_features + cont_columns = icont_columns + cat_columns = icat_columns + labels = ilabels + example_ids = iexample_ids + raw_ids = iraw_ids + else: + assert cont_columns == icont_columns, \ + 'columns mismatch between files %s vs %s'%( + cont_columns, icont_columns) + assert cat_columns == icat_columns, \ + 'columns mismatch between files %s vs %s'%( + cat_columns, icat_columns) + features = np.concatenate((features, ifeatures), axis=0) + cat_features = np.concatenate( + (cat_features, icat_features), axis=0) + if labels is not None: + labels = np.concatenate((labels, ilabels), axis=0) + if example_ids is not None: + example_ids.extend(iexample_ids) + if raw_ids is not None: + raw_ids.extend(iraw_ids) + + end_time = time.time() + elapsed_time = end_time - start_time + logging.info('taskes end time: %s', str(end_time)) + logging.info('elapsed time for reading data: %ss', str(elapsed_time)) + + assert features is not None, 'No data found in %s'%path return features, cat_features, cont_columns, cat_columns, \ labels, example_ids, raw_ids @@ -306,7 +340,7 @@ def train(args, booster): X, cat_X, X_names, cat_X_names, y, example_ids, _ = read_data_dir( args.file_ext, args.file_wildcard, args.file_type, args.data_path, args.verify_example_ids, args.role != 'follower', args.ignore_fields, - args.cat_fields, args.label_field) + args.cat_fields, args.label_field, args.num_parallel) if args.validation_data_path: val_X, val_cat_X, val_X_names, val_cat_X_names, val_y, \ @@ -315,11 +349,11 @@ def train(args, booster): args.file_ext, args.file_wildcard, args.file_type, args.validation_data_path, args.verify_example_ids, args.role != 'follower', args.ignore_fields, - args.cat_fields, args.label_field) + args.cat_fields, args.label_field, args.num_parallel) assert X_names == val_X_names, \ - "Train data and validation data must have same features" + 'Train data and validation data must have same features' assert cat_X_names == val_cat_X_names, \ - "Train data and validation data must have same features" + 'Train data and validation data must have same features' else: val_X = val_cat_X = val_y = val_example_ids = None @@ -343,7 +377,7 @@ def train(args, booster): def write_predictions(filename, pred, example_ids=None, raw_ids=None): - logging.debug("Writing predictions to %s.tmp", filename) + logging.debug('Writing predictions to %s.tmp', filename) headers = [] lines = [] if example_ids is not None: @@ -362,7 +396,7 @@ def write_predictions(filename, pred, example_ids=None, raw_ids=None): fout.write(','.join([str(i) for i in line]) + '\n') fout.close() - logging.debug("Renaming %s.tmp to %s", filename, filename) + logging.debug('Renaming %s.tmp to %s', filename, filename) tf.io.gfile.rename(filename+'.tmp', filename, overwrite=True) def test_one_file(args, bridge, booster, data_file, output_file): @@ -387,7 +421,7 @@ def test_one_file(args, bridge, booster, data_file, output_file): booster.iter_metrics_handler(metrics, 'eval') else: metrics = {} - logging.info("Test metrics: %s", metrics) + logging.info('Test metrics: %s', metrics) if args.role == 'follower': bridge.start() @@ -492,7 +526,7 @@ def get_next_block(self): def test(args, bridge, booster): if not args.no_data: - assert args.data_path, "Data path must not be empty" + assert args.data_path, 'Data path must not be empty' else: assert not args.data_path and args.role == 'leader' @@ -524,9 +558,9 @@ def run(args): logging.basicConfig(level=logging.DEBUG) assert args.role in ['leader', 'follower', 'local'], \ - "role must be leader, follower, or local" + 'role must be leader, follower, or local' assert args.mode in ['train', 'test', 'eval'], \ - "mode must be train, test, or eval" + 'mode must be train, test, or eval' if args.role != 'local': bridge = Bridge(args.role, int(args.local_addr.split(':')[1]), @@ -569,4 +603,10 @@ def run(args): if __name__ == '__main__': + # Experiments show `spawn` method is essential for ProcessPoolExecutor + # to get stable performance in multiprocessing HDFS data read. + # Otherwise, forked processes may lead to deadlock problems. + # Similar cases reported: https://github.com/crs4/pydoop/issues/311 + # Reason discussed: https://github.com/dask/hdfs3/issues/100 + multiprocessing.set_start_method('spawn') run(create_argument_parser().parse_args()) diff --git a/test/tree_model/test_trainer.py b/test/tree_model/test_trainer.py index 1fbbe66ed..86f9752f5 100644 --- a/test/tree_model/test_trainer.py +++ b/test/tree_model/test_trainer.py @@ -2,12 +2,17 @@ import tempfile import unittest import shutil +import random from pathlib import Path from threading import Thread from fedlearner.trainer.bridge import Bridge +from tensorflow.train import Feature, FloatList, Features, Example, Int64List +from fedlearner.model.tree.trainer import read_data_dir from fedlearner.model.tree.trainer import DataBlockLoader, DataBlockInfo - +import tensorflow as tf +import numpy as np +import csv class DataBlockLoaderTest(unittest.TestCase): @@ -60,5 +65,92 @@ def test_get_next_block(self): self.assertIsNone(data_block) +class MultiprocessingDataReadTest(unittest.TestCase): + + def _make_data(self): + root = tempfile.mkdtemp() + record_root = os.path.join(root, 'tfrecord_test') + csv_root = os.path.join(root, 'csv_test') + os.mkdir(record_root) + os.mkdir(csv_root) + + for i in range(5): + file_name = 'part-' + str(i).zfill(4) + '.tfrecord' + file_path = os.path.join(record_root, file_name) + writer = tf.io.TFRecordWriter(file_path) + for _ in range(5): + features = { + 'f_' + str(j): Feature( + float_list=FloatList(value=[random.random()]) + ) for j in range(3) + } + features['i_0'] = Feature( + int64_list=Int64List(value=[random.randint(0, 100)]) + ) + features['label'] = Feature( + int64_list=Int64List(value=[random.randint(0, 1)]) + ) + writer.write(Example( + features=Features(feature=features) + ).SerializeToString() + ) + writer.close() + + for i in range(5): + file_name = 'part-' + str(i).zfill(4) + '.csv' + file_path = os.path.join(csv_root, file_name) + with open(file_path, 'w') as file: + csv_writer = csv.writer(file) + csv_writer.writerow(['f_0', 'f_1', 'f_2', 'i_0', 'label']) + for _ in range(5): + csv_writer.writerow([ + random.random(), random.random(), random.random(), + random.randint(0, 100), random.randint(0, 1) + ]) + return record_root, csv_root + + def test_multiprocessing_data_read(self): + record_root, csv_root = self._make_data() + + record_1 = read_data_dir( + '.tfrecord', '*tfrecord', 'tfrecord', + record_root, False, True, '', 'i_0', 'label', 1) + record_4 = read_data_dir( + '.tfrecord', '*tfrecord', 'tfrecord', + record_root, False, True, '', 'i_0', 'label', 4) + # result shape: + # features, cat_features, cont_columns, + # cat_columns, labels, example_ids, raw_ids + np.testing.assert_almost_equal(record_1[0], record_4[0]) + np.testing.assert_almost_equal(record_1[1], record_4[1]) + np.testing.assert_almost_equal(record_1[4], record_4[4]) + + test_file = os.path.join(record_root, 'part-0000.tfrecord') + read_data_dir('.tfrecord', '*tfrecord', 'tfrecord', test_file, + False, True, '', 'i_0', 'label', 1) + read_data_dir('.tfrecord', '*tfrecord', 'tfrecord', test_file, + False, True, '', 'i_0', 'label', 4) + np.testing.assert_almost_equal(record_1[0], record_4[0]) + np.testing.assert_almost_equal(record_1[1], record_4[1]) + np.testing.assert_almost_equal(record_1[4], record_4[4]) + + csv_1 = read_data_dir('.csv', '*csv', 'csv', csv_root, + False, True, '', 'i_0', 'label', 1) + csv_4 = read_data_dir('.csv', '*csv', 'csv', csv_root, + False, True, '', 'i_0', 'label', 4) + np.testing.assert_almost_equal(csv_1[0], csv_4[0]) + np.testing.assert_almost_equal(csv_1[1], csv_4[1]) + np.testing.assert_almost_equal(csv_1[4], csv_4[4]) + + test_file = os.path.join(csv_root, 'part-0000.csv') + csv_1 = read_data_dir('.csv', '*csv', 'csv', test_file, + False, True, '', 'i_0', 'label', 1) + csv_4 = read_data_dir('.csv', '*csv', 'csv', test_file, + False, True, '', 'i_0', 'label', 4) + np.testing.assert_almost_equal(csv_1[0], csv_4[0]) + np.testing.assert_almost_equal(csv_1[1], csv_4[1]) + np.testing.assert_almost_equal(csv_1[4], csv_4[4]) + + if __name__ == '__main__': unittest.main()