Skip to content

Commit

Permalink
fix load and warmup for kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
strint committed Sep 22, 2023
1 parent b93d13f commit 4e60628
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions python/oneflow/nn/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,18 +1074,20 @@ def _fill_sub_destination(dest_dict, name_list, tensor_tuple):
# This is original outputs is needed to build output buffer.
tuple_idx = -1

def gen_index_in_tuple(eager_out):
def gen_index_in_tuple(item):
nonlocal tuple_idx
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
if isinstance(item, Tensor):
tuple_idx += 1
return "_OFTPI" + str(tuple_idx) # oneflow tuple index
else:
return item

inputs_sub_destination = OrderedDict()
_fill_sub_destination(
inputs_sub_destination, self._input_op_names, self._inputs_tensor_tuple
)

_eager_inputs_args, _eager_inputs_kwargs = self.__map_io(
"input",
_eager_inputs_args, _eager_inputs_kwargs = self.__map_io_lite(
gen_index_in_tuple,
*self.inputs_original[0],
**self.inputs_original[1],
Expand All @@ -1094,8 +1096,8 @@ def gen_index_in_tuple(eager_out):
destination["inputs_original"] = (_eager_inputs_args, _eager_inputs_kwargs)

tuple_idx = -1
_eager_outputs, _ = self.__map_io(
"output", gen_index_in_tuple, *self._eager_outputs
_eager_outputs, _ = self.__map_io_lite(
gen_index_in_tuple, *self._eager_outputs
)
destination["outputs_original"] = _eager_outputs
assert len(self._outputs_tensor_tuple) == tuple_idx + 1
Expand Down Expand Up @@ -1146,7 +1148,7 @@ def load_runtime_state_dict(
Dict[str, Dict[str, Union[Dict[str, Tensor], str]]],
],
*,
warmup_with_run: bool = False,
warmup_with_run: bool = True,
) -> None:
if self._run_with_cache == True:
return self._dynamic_input_graph_cache.load_runtime_state_dict(
Expand Down Expand Up @@ -1293,6 +1295,7 @@ def get_tensor_in_tuple(tensor_tuple, map_item):
self.__run(
*_eager_inputs_args, **_eager_inputs_kwargs
) # pre-run to warm up
oneflow._oneflow_internal.eager.Sync()
build_graph_end = time.perf_counter()
self.__print(
0,
Expand Down

0 comments on commit 4e60628

Please sign in to comment.