Skip to content

Commit

Permalink
feat(tree): add multiprocessing tree data loader (#1041)
Browse files Browse the repository at this point in the history
* feat(tree): add multiprocessing dataloader

* fix(tree): fix lint

* fix(tree): fix multiprocessing start method (#1040)

* fix(tree): fix multiprocessing start method

* chore(tree): multiprocessing reusing parameter `num_parallel`

* style(tree): rename num_data_loaders to num_parallel

Co-authored-by: qinminhao <qinminhao@bytedance.com>

* test(tree): multiprocessing tree model test.sh

* style(tree): code clean

* style(tree): fix code lint(logging-not-lazy)

* test(tree): move unittest to rest_trainer.py

* docs(tree): more comments

* fix(tree): fix code lint

* style(tree): make parameter num_paralle optimal

Co-authored-by: qinminhao <qinminhao@bytedance.com>
Lemon-412 and qinminhao authored Aug 29, 2022
1 parent 049ec1c commit 7463a12
Showing 3 changed files with 191 additions and 51 deletions.
8 changes: 8 additions & 0 deletions example/tree_model/test.sh
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@ python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--num-parallel=4 \
--verify-example-ids=true \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
@@ -26,6 +27,7 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=1 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--num-parallel=4 \
--verify-example-ids=true \
--file-ext=.tfrecord \
--file-wildcard=*tfrecord \
@@ -42,6 +44,7 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=1 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--num-parallel=4 \
--mode=test \
--verify-example-ids=true \
--file-type=tfrecord \
@@ -56,6 +59,7 @@ python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--num-parallel=4 \
--mode=test \
--verify-example-ids=true \
--file-type=tfrecord \
@@ -77,6 +81,7 @@ python -m fedlearner.model.tree.trainer follower \
--verbosity=1 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--num-parallel=4 \
--file-ext=.csv \
--file-type=csv \
--file-wildcard=*csv \
@@ -89,6 +94,7 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=1 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--num-parallel=4 \
--file-ext=.csv \
--file-type=csv \
--file-wildcard=*csv \
@@ -103,6 +109,7 @@ python -m fedlearner.model.tree.trainer follower \
--verbosity=2 \
--local-addr=localhost:50052 \
--peer-addr=localhost:50051 \
--num-parallel=4 \
--mode=test \
--file-ext=.csv \
--file-wildcard=*csv \
@@ -116,6 +123,7 @@ python -m fedlearner.model.tree.trainer leader \
--verbosity=2 \
--local-addr=localhost:50051 \
--peer-addr=localhost:50052 \
--num-parallel=4 \
--mode=test \
--file-ext=.csv \
--file-wildcard=*csv \
140 changes: 90 additions & 50 deletions fedlearner/model/tree/trainer.py
Original file line number Diff line number Diff line change
@@ -16,12 +16,15 @@

import os
import csv
import time
import queue
import logging
import argparse
import traceback
import itertools
from typing import Optional
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import numpy as np

import tensorflow.compat.v1 as tf
@@ -33,7 +36,6 @@
from fedlearner.model.tree.trainer_master_client import DataBlockInfo
from fedlearner.model.tree.utils import filter_files


def create_argument_parser():
parser = argparse.ArgumentParser(
description='FedLearner Tree Model Trainer.')
@@ -172,16 +174,16 @@ def parse_tfrecord(record):
for key, value in example.features.feature.items():
kind = value.WhichOneof('kind')
if kind == 'float_list':
assert len(value.float_list.value) == 1, "Invalid tfrecord format"
assert len(value.float_list.value) == 1, 'Invalid tfrecord format'
parsed[key] = value.float_list.value[0]
elif kind == 'int64_list':
assert len(value.int64_list.value) == 1, "Invalid tfrecord format"
assert len(value.int64_list.value) == 1, 'Invalid tfrecord format'
parsed[key] = value.int64_list.value[0]
elif kind == 'bytes_list':
assert len(value.bytes_list.value) == 1, "Invalid tfrecord format"
assert len(value.bytes_list.value) == 1, 'Invalid tfrecord format'
parsed[key] = value.bytes_list.value[0]
else:
raise ValueError("Invalid tfrecord format")
raise ValueError('Invalid tfrecord format')

return parsed

@@ -191,7 +193,7 @@ def extract_field(field_names, field_name, required):
return []

assert not required, \
"Field %s is required but missing in data"%field_name
'Field %s is required but missing in data'%field_name
return None


@@ -220,7 +222,7 @@ def read_data(file_type, filename, require_example_ids, require_labels,
ignore_fields.update(['example_id', 'raw_id', label_field])
cat_fields = set(filter(bool, cat_fields.strip().split(',')))
for name in cat_fields:
assert name in field_names, "cat_field %s missing"%name
assert name in field_names, 'cat_field %s missing'%name

cont_columns = list(filter(
lambda x: x not in ignore_fields and x not in cat_fields, field_names))
@@ -231,6 +233,7 @@ def read_data(file_type, filename, require_example_ids, require_labels,

features = []
cat_features = []

def to_float(x):
return float(x if x not in ['', None] else 'nan')
for line in reader:
@@ -256,47 +259,78 @@ def to_float(x):

def read_data_dir(file_ext: str, file_wildcard: str, file_type: str, path: str,
require_example_ids: bool, require_labels: bool,
ignore_fields: str, cat_fields: str, label_field: str):
ignore_fields: str, cat_fields: str, label_field: str,
num_parallel: Optional[int]):

if not tf.io.gfile.isdir(path):
return read_data(
file_type, path, require_example_ids,
require_labels, ignore_fields, cat_fields, label_field)

files = filter_files(path, file_ext, file_wildcard)
files.sort()
assert len(files) > 0, f'No file exsists in directory(path={path} ' \
f'extension={file_ext} wildcard={file_wildcard})'

if num_parallel:
assert num_parallel >= 1, 'Invalid num_parallel'
else:
num_parallel = 1

if num_parallel > len(files):
logging.info('Number of files(%s) is less than num_parallel(%s), '
'switch num_parallel to %s',
len(files), num_parallel, len(files))
num_parallel = len(files)

features = None
for fullname in files:
ifeatures, icat_features, icont_columns, icat_columns, \
ilabels, iexample_ids, iraw_ids = read_data(
file_type, fullname, require_example_ids, require_labels,
ignore_fields, cat_fields, label_field
)
if features is None:
features = ifeatures
cat_features = icat_features
cont_columns = icont_columns
cat_columns = icat_columns
labels = ilabels
example_ids = iexample_ids
raw_ids = iraw_ids
else:
assert cont_columns == icont_columns, \
"columns mismatch between files %s vs %s"%(
cont_columns, icont_columns)
assert cat_columns == icat_columns, \
"columns mismatch between files %s vs %s"%(
cat_columns, icat_columns)
features = np.concatenate((features, ifeatures), axis=0)
cat_features = np.concatenate(
(cat_features, icat_features), axis=0)
if labels is not None:
labels = np.concatenate((labels, ilabels), axis=0)
if example_ids is not None:
example_ids.extend(iexample_ids)
if raw_ids is not None:
raw_ids.extend(iraw_ids)

assert features is not None, "No data found in %s"%path

start_time = time.time()
logging.info('taskes start time: %s', str(start_time))
logging.info('Data loader count = %s', str(num_parallel))

with ProcessPoolExecutor(max_workers=num_parallel) as pool:
futures = []
for fullname in files:
future = pool.submit(
read_data, file_type, fullname,
require_example_ids, require_labels,
ignore_fields, cat_fields, label_field)
futures.append(future)
for future in futures:
ifeatures, icat_features, icont_columns, icat_columns, \
ilabels, iexample_ids, iraw_ids = future.result()
if features is None:
features = ifeatures
cat_features = icat_features
cont_columns = icont_columns
cat_columns = icat_columns
labels = ilabels
example_ids = iexample_ids
raw_ids = iraw_ids
else:
assert cont_columns == icont_columns, \
'columns mismatch between files %s vs %s'%(
cont_columns, icont_columns)
assert cat_columns == icat_columns, \
'columns mismatch between files %s vs %s'%(
cat_columns, icat_columns)
features = np.concatenate((features, ifeatures), axis=0)
cat_features = np.concatenate(
(cat_features, icat_features), axis=0)
if labels is not None:
labels = np.concatenate((labels, ilabels), axis=0)
if example_ids is not None:
example_ids.extend(iexample_ids)
if raw_ids is not None:
raw_ids.extend(iraw_ids)

end_time = time.time()
elapsed_time = end_time - start_time
logging.info('taskes end time: %s', str(end_time))
logging.info('elapsed time for reading data: %ss', str(elapsed_time))

assert features is not None, 'No data found in %s'%path

return features, cat_features, cont_columns, cat_columns, \
labels, example_ids, raw_ids
@@ -306,7 +340,7 @@ def train(args, booster):
X, cat_X, X_names, cat_X_names, y, example_ids, _ = read_data_dir(
args.file_ext, args.file_wildcard, args.file_type, args.data_path,
args.verify_example_ids, args.role != 'follower', args.ignore_fields,
args.cat_fields, args.label_field)
args.cat_fields, args.label_field, args.num_parallel)

if args.validation_data_path:
val_X, val_cat_X, val_X_names, val_cat_X_names, val_y, \
@@ -315,11 +349,11 @@ def train(args, booster):
args.file_ext, args.file_wildcard, args.file_type,
args.validation_data_path, args.verify_example_ids,
args.role != 'follower', args.ignore_fields,
args.cat_fields, args.label_field)
args.cat_fields, args.label_field, args.num_parallel)
assert X_names == val_X_names, \
"Train data and validation data must have same features"
'Train data and validation data must have same features'
assert cat_X_names == val_cat_X_names, \
"Train data and validation data must have same features"
'Train data and validation data must have same features'
else:
val_X = val_cat_X = val_y = val_example_ids = None

@@ -343,7 +377,7 @@ def train(args, booster):


def write_predictions(filename, pred, example_ids=None, raw_ids=None):
logging.debug("Writing predictions to %s.tmp", filename)
logging.debug('Writing predictions to %s.tmp', filename)
headers = []
lines = []
if example_ids is not None:
@@ -362,7 +396,7 @@ def write_predictions(filename, pred, example_ids=None, raw_ids=None):
fout.write(','.join([str(i) for i in line]) + '\n')
fout.close()

logging.debug("Renaming %s.tmp to %s", filename, filename)
logging.debug('Renaming %s.tmp to %s', filename, filename)
tf.io.gfile.rename(filename+'.tmp', filename, overwrite=True)

def test_one_file(args, bridge, booster, data_file, output_file):
@@ -387,7 +421,7 @@ def test_one_file(args, bridge, booster, data_file, output_file):
booster.iter_metrics_handler(metrics, 'eval')
else:
metrics = {}
logging.info("Test metrics: %s", metrics)
logging.info('Test metrics: %s', metrics)

if args.role == 'follower':
bridge.start()
@@ -492,7 +526,7 @@ def get_next_block(self):

def test(args, bridge, booster):
if not args.no_data:
assert args.data_path, "Data path must not be empty"
assert args.data_path, 'Data path must not be empty'
else:
assert not args.data_path and args.role == 'leader'

@@ -524,9 +558,9 @@ def run(args):
logging.basicConfig(level=logging.DEBUG)

assert args.role in ['leader', 'follower', 'local'], \
"role must be leader, follower, or local"
'role must be leader, follower, or local'
assert args.mode in ['train', 'test', 'eval'], \
"mode must be train, test, or eval"
'mode must be train, test, or eval'

if args.role != 'local':
bridge = Bridge(args.role, int(args.local_addr.split(':')[1]),
@@ -569,4 +603,10 @@ def run(args):


if __name__ == '__main__':
# Experiments show `spawn` method is essential for ProcessPoolExecutor
# to get stable performance in multiprocessing HDFS data read.
# Otherwise, forked processes may lead to deadlock problems.
# Similar cases reported: https://github.com/crs4/pydoop/issues/311
# Reason discussed: https://github.com/dask/hdfs3/issues/100
multiprocessing.set_start_method('spawn')
run(create_argument_parser().parse_args())
94 changes: 93 additions & 1 deletion test/tree_model/test_trainer.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,17 @@
import tempfile
import unittest
import shutil
import random
from pathlib import Path
from threading import Thread

from fedlearner.trainer.bridge import Bridge
from tensorflow.train import Feature, FloatList, Features, Example, Int64List
from fedlearner.model.tree.trainer import read_data_dir
from fedlearner.model.tree.trainer import DataBlockLoader, DataBlockInfo

import tensorflow as tf
import numpy as np
import csv

class DataBlockLoaderTest(unittest.TestCase):

@@ -60,5 +65,92 @@ def test_get_next_block(self):
self.assertIsNone(data_block)


class MultiprocessingDataReadTest(unittest.TestCase):

def _make_data(self):
root = tempfile.mkdtemp()
record_root = os.path.join(root, 'tfrecord_test')
csv_root = os.path.join(root, 'csv_test')
os.mkdir(record_root)
os.mkdir(csv_root)

for i in range(5):
file_name = 'part-' + str(i).zfill(4) + '.tfrecord'
file_path = os.path.join(record_root, file_name)
writer = tf.io.TFRecordWriter(file_path)
for _ in range(5):
features = {
'f_' + str(j): Feature(
float_list=FloatList(value=[random.random()])
) for j in range(3)
}
features['i_0'] = Feature(
int64_list=Int64List(value=[random.randint(0, 100)])
)
features['label'] = Feature(
int64_list=Int64List(value=[random.randint(0, 1)])
)
writer.write(Example(
features=Features(feature=features)
).SerializeToString()
)
writer.close()

for i in range(5):
file_name = 'part-' + str(i).zfill(4) + '.csv'
file_path = os.path.join(csv_root, file_name)
with open(file_path, 'w') as file:
csv_writer = csv.writer(file)
csv_writer.writerow(['f_0', 'f_1', 'f_2', 'i_0', 'label'])
for _ in range(5):
csv_writer.writerow([
random.random(), random.random(), random.random(),
random.randint(0, 100), random.randint(0, 1)
])
return record_root, csv_root

def test_multiprocessing_data_read(self):
record_root, csv_root = self._make_data()

record_1 = read_data_dir(
'.tfrecord', '*tfrecord', 'tfrecord',
record_root, False, True, '', 'i_0', 'label', 1)
record_4 = read_data_dir(
'.tfrecord', '*tfrecord', 'tfrecord',
record_root, False, True, '', 'i_0', 'label', 4)
# result shape:
# features, cat_features, cont_columns,
# cat_columns, labels, example_ids, raw_ids
np.testing.assert_almost_equal(record_1[0], record_4[0])
np.testing.assert_almost_equal(record_1[1], record_4[1])
np.testing.assert_almost_equal(record_1[4], record_4[4])

test_file = os.path.join(record_root, 'part-0000.tfrecord')
read_data_dir('.tfrecord', '*tfrecord', 'tfrecord', test_file,
False, True, '', 'i_0', 'label', 1)
read_data_dir('.tfrecord', '*tfrecord', 'tfrecord', test_file,
False, True, '', 'i_0', 'label', 4)
np.testing.assert_almost_equal(record_1[0], record_4[0])
np.testing.assert_almost_equal(record_1[1], record_4[1])
np.testing.assert_almost_equal(record_1[4], record_4[4])

csv_1 = read_data_dir('.csv', '*csv', 'csv', csv_root,
False, True, '', 'i_0', 'label', 1)
csv_4 = read_data_dir('.csv', '*csv', 'csv', csv_root,
False, True, '', 'i_0', 'label', 4)
np.testing.assert_almost_equal(csv_1[0], csv_4[0])
np.testing.assert_almost_equal(csv_1[1], csv_4[1])
np.testing.assert_almost_equal(csv_1[4], csv_4[4])

test_file = os.path.join(csv_root, 'part-0000.csv')
csv_1 = read_data_dir('.csv', '*csv', 'csv', test_file,
False, True, '', 'i_0', 'label', 1)
csv_4 = read_data_dir('.csv', '*csv', 'csv', test_file,
False, True, '', 'i_0', 'label', 4)
np.testing.assert_almost_equal(csv_1[0], csv_4[0])
np.testing.assert_almost_equal(csv_1[1], csv_4[1])
np.testing.assert_almost_equal(csv_1[4], csv_4[4])


if __name__ == '__main__':
unittest.main()

0 comments on commit 7463a12

Please sign in to comment.