-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prefix decoding #649
base: master
Are you sure you want to change the base?
Prefix decoding #649
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,6 +97,7 @@ def __init__(self, | |
unroll=False, back_prop=None, | ||
use_global_rec_step_offset=False, | ||
include_eos=False, | ||
padded_data_keys=None, | ||
debug=None, | ||
_time_dim_tag=None, | ||
**kwargs): | ||
|
@@ -119,6 +120,9 @@ def __init__(self, | |
:param bool|None back_prop: for tf.while_loop. the default will use self.network.train_flag | ||
:param bool use_global_rec_step_offset: | ||
:param bool include_eos: for search, whether we should include the frame where "end" is True | ||
:param list[str] padded_data_keys: List of data keys for which the data is allowed to have a different length | ||
than the other data, even though on the same time axis. If accessed in time steps greater than the data length | ||
the value will be zero. | ||
:param bool|None debug: | ||
:param DimensionTag|None _time_dim_tag: | ||
""" | ||
|
@@ -149,6 +153,9 @@ def __init__(self, | |
self._max_seq_len = max_seq_len | ||
self.time_dim_tag = _time_dim_tag | ||
self.include_eos = include_eos | ||
self.padded_data_keys = padded_data_keys or [] | ||
if padded_data_keys: | ||
assert isinstance(unit, _SubnetworkRecCell), "'padded_data_keys' only implemented for custom unit." | ||
if optimize_move_layers_out is None: | ||
optimize_move_layers_out = self.network.get_config().bool("optimize_move_layers_out", True) | ||
self._optimize_move_layers_out = optimize_move_layers_out | ||
|
@@ -2003,34 +2010,36 @@ def get_output(self): | |
data_placeholder = data.get_placeholder_as_time_major() | ||
with tf.name_scope("check_data_len"): | ||
data_len = tf.shape(data_placeholder)[0] | ||
if common_data_len is None: | ||
# Check for first key if input length matches data length | ||
if input_seq_len is not None: | ||
with tf.control_dependencies( | ||
[tf_compat.v1.assert_equal( | ||
tf.reduce_max(input_seq_len), data_len, | ||
["RecLayer %r with sources %r:" % (rec_layer.name, rec_layer.sources), | ||
" The length of the sources (", tf.reduce_max(input_seq_len), | ||
") differ from the length of the target ", key, "(", data_len, ")."])]): | ||
data_len = tf.identity(data_len) | ||
if fixed_seq_len is not None: | ||
with tf.control_dependencies( | ||
[tf_compat.v1.assert_equal( | ||
tf.reduce_max(fixed_seq_len), data_len, | ||
["RecLayer %r:" % (rec_layer.get_absolute_name(),), | ||
" The predefined length (", tf.reduce_max(fixed_seq_len), | ||
") differs from the length of the target ", key, "(", data_len, ")."])]): | ||
# padded data keys are allowed to have any length, so don't check | ||
if key not in self.parent_rec_layer.padded_data_keys: | ||
if common_data_len is None: | ||
# Check for first key if input length matches data length | ||
if input_seq_len is not None: | ||
with tf.control_dependencies( | ||
[tf_compat.v1.assert_equal( | ||
tf.reduce_max(input_seq_len), data_len, | ||
["RecLayer %r with sources %r:" % (rec_layer.name, rec_layer.sources), | ||
" The length of the sources (", tf.reduce_max(input_seq_len), | ||
") differ from the length of the target ", key, "(", data_len, ")."])]): | ||
data_len = tf.identity(data_len) | ||
if fixed_seq_len is not None: | ||
with tf.control_dependencies( | ||
[tf_compat.v1.assert_equal( | ||
tf.reduce_max(fixed_seq_len), data_len, | ||
["RecLayer %r:" % (rec_layer.get_absolute_name(),), | ||
" The predefined length (", tf.reduce_max(fixed_seq_len), | ||
") differs from the length of the target ", key, "(", data_len, ")."])]): | ||
data_len = tf.identity(data_len) | ||
common_data_len = data_len | ||
else: | ||
# Check from second key on if data length is equal for all external data | ||
with tf.control_dependencies([ | ||
tf_compat.v1.assert_equal( | ||
common_data_len, data_len, | ||
["RecLayer %r:" % rec_layer.name, " The length of all targets (%s) " % ", ".join(used_keys), | ||
" has to be the same. Found length ", data_len, " for %s, which does not match length " % key, | ||
common_data_len, " of the other data."])]): | ||
data_len = tf.identity(data_len) | ||
common_data_len = data_len | ||
else: | ||
# Check from second key on if data length is equal for all external data | ||
with tf.control_dependencies([ | ||
tf_compat.v1.assert_equal( | ||
common_data_len, data_len, | ||
["RecLayer %r:" % rec_layer.name, " The length of all targets (%s) " % ", ".join(used_keys), | ||
" has to be the same. Found length ", data_len, " for %s, which does not match length " % key, | ||
common_data_len, " of the other data."])]): | ||
data_len = tf.identity(data_len) | ||
data_ta = tf.TensorArray( | ||
name=key + "_ta", | ||
dtype=data.dtype, | ||
|
@@ -2343,8 +2352,21 @@ def body(i, net_vars, acc_tas, seq_len_info=None): | |
for (k, v) in zip(sorted(self._initial_extra_outputs), prev_extra_flat)} | ||
with tf.name_scope("prev_extra"): | ||
prev_extra = identity_op_nested(prev_extra) | ||
data_ = { | ||
key_: ta.read(i, name="{}_ta_read".format(key_)) for key_, ta in data_tensor_arrays.items()} | ||
|
||
data_ = {} | ||
for key_, ta in data_tensor_arrays.items(): | ||
if key_ in self.parent_rec_layer.padded_data_keys: | ||
# batch_dim in ta.element_shape is undefined, so replace it | ||
element_shape = tf.concat([tf.expand_dims(batch_dim, axis=0), ta.element_shape[1:]], axis=0) | ||
|
||
# when trying to access tensor array beyond sequence length, use zeros as padding value | ||
data_[key_] = tf.cond( | ||
pred=(i < ta.size()), | ||
true_fn=lambda: ta.read(i, name="{}_ta_read".format(key_)), | ||
false_fn=lambda: tf.zeros(shape=element_shape, dtype=ta.dtype, name="{}_ta_padding".format(key_))) | ||
else: | ||
data_[key_] = ta.read(i, name="{}_ta_read".format(key_)) | ||
|
||
# noinspection PyProtectedMember | ||
with reuse_name_scope(self.parent_rec_layer._rec_scope): | ||
self._construct( | ||
|
@@ -4482,7 +4504,7 @@ def __init__(self, beam_size, keep_beams=False, | |
length_normalization=True, | ||
length_normalization_exponent=1.0, | ||
custom_score_combine=None, | ||
source_beam_sizes=None, scheduled_sampling=False, cheating=False, | ||
source_beam_sizes=None, scheduled_sampling=False, cheating=False, prefix_target=None, | ||
explicit_search_sources=None, | ||
**kwargs): | ||
""" | ||
|
@@ -4503,6 +4525,8 @@ def __init__(self, beam_size, keep_beams=False, | |
:param dict|None scheduled_sampling: | ||
:param bool|str cheating: if True, will always add the true target in the beam. | ||
if "exclusive", enables cheating_exclusive. see :func:`TFUtil.beam_search`. | ||
:param str|None prefix_target: If given, this data stream will be enforced to be the prefix of the layer output, | ||
i.e. for the first n positions, the beam choices will be overwritten by the labels from "prefix_target". | ||
:param list[LayerBase]|None explicit_search_sources: will mark it as an additional dependency. | ||
You might use these also in custom_score_combine. | ||
:param callable|None custom_score_combine: | ||
|
@@ -4524,6 +4548,7 @@ def __init__(self, beam_size, keep_beams=False, | |
self.search_scores_base = None | ||
self.search_scores_combined = None | ||
# We assume log-softmax here, inside the rec layer. | ||
self.prefix_target = prefix_target | ||
|
||
if self.search_flag: | ||
if cheating: | ||
|
@@ -4677,6 +4702,13 @@ def __init__(self, beam_size, keep_beams=False, | |
cheating_exclusive=cheating_exclusive) | ||
self.search_choices.set_src_beams(src_beams) # (batch, beam) -> beam_in idx | ||
labels = tf.reshape(labels, [net_batch_dim * beam_size]) # (batch * beam) | ||
|
||
if self.prefix_target: | ||
assert len(self.sources) == 1, "Prefix decoding not yet implemented for multiple sources." | ||
labels, scores = self._enforce_prefixes( | ||
top_k_labels=labels, all_scores=scores_comb, top_k_scores=scores, batch_dim=net_batch_dim, | ||
beam_size=beam_size) | ||
|
||
labels = tf.cast(labels, self.output.dtype) | ||
|
||
if len(self.sources) > 1: | ||
|
@@ -4986,6 +5018,59 @@ def _get_cheating_targets_and_src_beam_idxs(self, scores): | |
src_beams = src_beams[:, 0] # (batch,) | ||
return cheating_gold_targets, src_beams | ||
|
||
def _enforce_prefixes(self, top_k_labels, all_scores, top_k_scores, batch_dim, beam_size): | ||
""" | ||
This function replaces the target labels from beam search by the ones predefined by the target prefixes as long | ||
as search is still at a position within the prefix. We also replace the scores such that they correspond to a | ||
prediction of the prefixes. | ||
|
||
:param tf.Tensor top_k_labels: target labels from beam seach, shape (batch * beam,) | ||
:param tf.Tensor all_scores: scores before beam pruning, used to lookup prefix scores, shape (batch, beam, dim) | ||
:param tf.Tensor top_k_scores: scores after beam pruning, shape (batch, beam) | ||
:param tf.Tensor|int batch_dim: number of sequences in batch (without beam) | ||
:param int beam_size: outgoing beam size of this layer | ||
:return: labels (batch * beam,) and scores (batch, beam) of self.prefix_target as long as within prefix, after | ||
that top_k_labels and top_k_scores from beam search | ||
:rtype: (tf.Tensor, tf.Tensor) | ||
""" | ||
assert self.prefix_target | ||
|
||
# Get the labels of the prefixes which should be enforced. They are padded with zeros, therefore we will | ||
# get zeros for those sequences where the current timestep is beyond the length of the prefix. | ||
target_prefix_labels = self._get_target_value( | ||
target=self.prefix_target).get_placeholder_as_batch_major() # (batch * beam,), int32 | ||
|
||
# Get prefixes that have already ended (i.e. have a smaller length than the current time step). | ||
target_prefix_ended = tf.equal(target_prefix_labels, 0) | ||
|
||
# Select between the prefixes and the labels from free decoding, depending on whether the prefix | ||
# has still got to be enforced. | ||
labels = tf.where(target_prefix_ended, top_k_labels, target_prefix_labels) | ||
|
||
# Get rid of the redundant beam, all entries are the same, only keep first entry. | ||
target_prefix_labels = tf.reshape(target_prefix_labels, [batch_dim, beam_size])[:, 0] # (batch,) | ||
|
||
# Now also get the scores for the prefixes. First, select only the first entry of the incoming beam as all entries | ||
# are the same if we are still within the prefix. | ||
all_scores = all_scores[:, 0, :] # (batch, dim) | ||
|
||
# Gather scores for the prefix labels. | ||
from returnn.tf.util.basic import nd_indices | ||
target_prefix_nd_indices = nd_indices(target_prefix_labels) | ||
prefix_scores = tf.expand_dims(tf.gather_nd(all_scores, target_prefix_nd_indices), axis=-1) # (batch, 1) | ||
|
||
# Create an artificial beam, where all but the first scores are infinite. Tiling the one entry we have would | ||
# lead to a beam of all equal hypotheses for the rest of the search. | ||
# Conceptually similar to TFUtil.filter_ended_scores(). | ||
prefix_scores_padding = tf.fill([batch_dim, beam_size - 1], -1.e30) | ||
prefix_scores = tf.concat([prefix_scores, prefix_scores_padding], axis=1) | ||
|
||
# Use prefix scores for sequences where the prefix has not ended yet. | ||
target_prefix_ended = tf.reshape(target_prefix_ended, [batch_dim, beam_size]) | ||
scores = tf.where(target_prefix_ended, top_k_scores, prefix_scores) # (batch, beam) | ||
|
||
return labels, scores | ||
|
||
@classmethod | ||
def transform_config_dict(cls, d, network, get_layer): | ||
""" | ||
|
@@ -5041,8 +5126,8 @@ def _create_search_beam(cls, name, beam_size, sources, network): | |
name="%s%s" % (network.get_absolute_name_prefix(), name)) | ||
|
||
@classmethod | ||
def get_out_data_from_opts(cls, name, sources, target, network, | ||
beam_size, search=NotSpecified, scheduled_sampling=False, cheating=False, **kwargs): | ||
def get_out_data_from_opts(cls, name, sources, target, network, beam_size, search=NotSpecified, | ||
scheduled_sampling=False, cheating=False, prefix_target=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand. You mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, |
||
""" | ||
:param str name: | ||
:param list[LayerBase] sources: | ||
|
@@ -5052,6 +5137,7 @@ def get_out_data_from_opts(cls, name, sources, target, network, | |
:param NotSpecified|bool search: | ||
:param dict|bool scheduled_sampling: | ||
:param bool cheating: | ||
:param str prefix_target: | ||
:rtype: Data | ||
""" | ||
search = NotSpecified.resolve(search, network.search_flag) | ||
|
@@ -5077,6 +5163,8 @@ def get_out_data_from_opts(cls, name, sources, target, network, | |
out_data.batch = out_data.batch.copy_set_beam(out_data.beam) | ||
if cheating or scheduled_sampling or not search: | ||
cls._static_get_target_value(target=target, network=network, mark_data_key_as_used=True) # mark as used | ||
if search and prefix_target: | ||
cls._static_get_target_value(target=prefix_target, network=network, mark_data_key_as_used=True) # mark as used | ||
return out_data | ||
|
||
def get_sub_layer(self, layer_name): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has the hardcoded assumption that EOS=0. I would avoid that. Better rely on the target seq len (using rec step info to check the current pos).