Skip to content

Commit

Permalink
fix(trainer): fix data path visitor (#1045)
Browse files Browse the repository at this point in the history
* fix(trainer): fix data path visitor

* fix(trainer): fix data path visitor

* fix(trainer): fix data path visitor

Co-authored-by: 杭卫强 <[email protected]>
  • Loading branch information
hangweiqiang-uestc and 杭卫强 authored Sep 22, 2022
1 parent 9318944 commit 1cf0b69
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
7 changes: 4 additions & 3 deletions deploy/scripts/trainer/run_trainer_master.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export_model=$(normalize_env_to_args "--export-model" $EXPORT_MODEL)
shuffle=$(normalize_env_to_args "--shuffle" $SUFFLE_DATA_BLOCK)
shuffle_in_day=$(normalize_env_to_args "--shuffle-in-day" $SHUFFLE_IN_DAY)
local_data_source=$(normalize_env_to_args "--local-data-source" $LOCAL_DATA_SOURCE)
local_data_path=$(normalize_env_to_args "--local-data-path" $LOCAL_DATA_PATH)
local_start_date=$(normalize_env_to_args "--local-start-date" $LOCAL_START_DATE)
local_end_date=$(normalize_env_to_args "--local-end-date" $LOCAL_END_DATE)

Expand Down Expand Up @@ -109,6 +110,6 @@ python main.py --master \
$mode $sparse_estimator \
$save_checkpoint_steps $save_checkpoint_secs \
$summary_save_steps $summary_save_secs \
$local_data_source $local_start_date $local_end_date \
$epoch_num $start_date $end_date $shuffle $shuffle_in_day \
$extra_params $export_model
$local_data_source $local_data_path $local_start_date \
$local_end_date $epoch_num $start_date $end_date \
$shuffle $shuffle_in_day $extra_params $export_model
18 changes: 8 additions & 10 deletions fedlearner/trainer/data_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,11 @@ def __init__(self,
continue
subdirname = os.path.relpath(dirname, data_path)
block_id = os.path.join(subdirname, filename)
datablock = _RawDataBlock(block_id,
os.path.join(dirname, filename),
None, None,
tm_pb.JOINED)
datablock = _RawDataBlock(
id=block_id, data_path=os.path.join(dirname, filename),
start_time=None, end_time=None, type=tm_pb.JOINED)
datablocks.append(datablock)
datablocks.sort(key=lambda x: x.end_time)
datablocks.sort(key=lambda x: x.id)

fl_logging.info("create DataVisitor by local_data_path: %s",
local_data_path)
Expand All @@ -366,12 +365,11 @@ def __init__(self,
continue
subdirname = os.path.relpath(dirname, local_data_path)
block_id = os.path.join(subdirname, filename)
datablock = _RawDataBlock(block_id,
os.path.join(dirname, filename),
None, None,
tm_pb.LOCAL)
datablock = _RawDataBlock(
id=block_id, data_path=os.path.join(dirname, filename),
start_time=None, end_time=None, type=tm_pb.LOCAL)
local_datablocks.append(datablock)
local_datablocks.sort(key=lambda x: x.end_time)
local_datablocks.sort(key=lambda x: x.id)

super(DataPathVisitor, self).__init__(datablocks, local_datablocks,
epoch_num, shuffle_type)
Expand Down

0 comments on commit 1cf0b69

Please sign in to comment.