Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix decoding #649

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 120 additions & 32 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self,
unroll=False, back_prop=None,
use_global_rec_step_offset=False,
include_eos=False,
padded_data_keys=None,
debug=None,
_time_dim_tag=None,
**kwargs):
Expand All @@ -119,6 +120,9 @@ def __init__(self,
:param bool|None back_prop: for tf.while_loop. the default will use self.network.train_flag
:param bool use_global_rec_step_offset:
:param bool include_eos: for search, whether we should include the frame where "end" is True
:param list[str] padded_data_keys: List of data keys for which the data is allowed to have a different length
than the other data, even though on the same time axis. If accessed in time steps greater than the data length
the value will be zero.
:param bool|None debug:
:param DimensionTag|None _time_dim_tag:
"""
Expand Down Expand Up @@ -149,6 +153,9 @@ def __init__(self,
self._max_seq_len = max_seq_len
self.time_dim_tag = _time_dim_tag
self.include_eos = include_eos
self.padded_data_keys = padded_data_keys or []
if padded_data_keys:
assert isinstance(unit, _SubnetworkRecCell), "'padded_data_keys' only implemented for custom unit."
if optimize_move_layers_out is None:
optimize_move_layers_out = self.network.get_config().bool("optimize_move_layers_out", True)
self._optimize_move_layers_out = optimize_move_layers_out
Expand Down Expand Up @@ -2003,34 +2010,36 @@ def get_output(self):
data_placeholder = data.get_placeholder_as_time_major()
with tf.name_scope("check_data_len"):
data_len = tf.shape(data_placeholder)[0]
if common_data_len is None:
# Check for first key if input length matches data length
if input_seq_len is not None:
with tf.control_dependencies(
[tf_compat.v1.assert_equal(
tf.reduce_max(input_seq_len), data_len,
["RecLayer %r with sources %r:" % (rec_layer.name, rec_layer.sources),
" The length of the sources (", tf.reduce_max(input_seq_len),
") differ from the length of the target ", key, "(", data_len, ")."])]):
data_len = tf.identity(data_len)
if fixed_seq_len is not None:
with tf.control_dependencies(
[tf_compat.v1.assert_equal(
tf.reduce_max(fixed_seq_len), data_len,
["RecLayer %r:" % (rec_layer.get_absolute_name(),),
" The predefined length (", tf.reduce_max(fixed_seq_len),
") differs from the length of the target ", key, "(", data_len, ")."])]):
# padded data keys are allowed to have any length, so don't check
if key not in self.parent_rec_layer.padded_data_keys:
if common_data_len is None:
# Check for first key if input length matches data length
if input_seq_len is not None:
with tf.control_dependencies(
[tf_compat.v1.assert_equal(
tf.reduce_max(input_seq_len), data_len,
["RecLayer %r with sources %r:" % (rec_layer.name, rec_layer.sources),
" The length of the sources (", tf.reduce_max(input_seq_len),
") differ from the length of the target ", key, "(", data_len, ")."])]):
data_len = tf.identity(data_len)
if fixed_seq_len is not None:
with tf.control_dependencies(
[tf_compat.v1.assert_equal(
tf.reduce_max(fixed_seq_len), data_len,
["RecLayer %r:" % (rec_layer.get_absolute_name(),),
" The predefined length (", tf.reduce_max(fixed_seq_len),
") differs from the length of the target ", key, "(", data_len, ")."])]):
data_len = tf.identity(data_len)
common_data_len = data_len
else:
# Check from second key on if data length is equal for all external data
with tf.control_dependencies([
tf_compat.v1.assert_equal(
common_data_len, data_len,
["RecLayer %r:" % rec_layer.name, " The length of all targets (%s) " % ", ".join(used_keys),
" has to be the same. Found length ", data_len, " for %s, which does not match length " % key,
common_data_len, " of the other data."])]):
data_len = tf.identity(data_len)
common_data_len = data_len
else:
# Check from second key on if data length is equal for all external data
with tf.control_dependencies([
tf_compat.v1.assert_equal(
common_data_len, data_len,
["RecLayer %r:" % rec_layer.name, " The length of all targets (%s) " % ", ".join(used_keys),
" has to be the same. Found length ", data_len, " for %s, which does not match length " % key,
common_data_len, " of the other data."])]):
data_len = tf.identity(data_len)
data_ta = tf.TensorArray(
name=key + "_ta",
dtype=data.dtype,
Expand Down Expand Up @@ -2343,8 +2352,21 @@ def body(i, net_vars, acc_tas, seq_len_info=None):
for (k, v) in zip(sorted(self._initial_extra_outputs), prev_extra_flat)}
with tf.name_scope("prev_extra"):
prev_extra = identity_op_nested(prev_extra)
data_ = {
key_: ta.read(i, name="{}_ta_read".format(key_)) for key_, ta in data_tensor_arrays.items()}

data_ = {}
for key_, ta in data_tensor_arrays.items():
if key_ in self.parent_rec_layer.padded_data_keys:
# batch_dim in ta.element_shape is undefined, so replace it
element_shape = tf.concat([tf.expand_dims(batch_dim, axis=0), ta.element_shape[1:]], axis=0)

# when trying to access tensor array beyond sequence length, use zeros as padding value
data_[key_] = tf.cond(
pred=(i < ta.size()),
true_fn=lambda: ta.read(i, name="{}_ta_read".format(key_)),
false_fn=lambda: tf.zeros(shape=element_shape, dtype=ta.dtype, name="{}_ta_padding".format(key_)))
else:
data_[key_] = ta.read(i, name="{}_ta_read".format(key_))

# noinspection PyProtectedMember
with reuse_name_scope(self.parent_rec_layer._rec_scope):
self._construct(
Expand Down Expand Up @@ -4482,7 +4504,7 @@ def __init__(self, beam_size, keep_beams=False,
length_normalization=True,
length_normalization_exponent=1.0,
custom_score_combine=None,
source_beam_sizes=None, scheduled_sampling=False, cheating=False,
source_beam_sizes=None, scheduled_sampling=False, cheating=False, prefix_target=None,
explicit_search_sources=None,
**kwargs):
"""
Expand All @@ -4503,6 +4525,8 @@ def __init__(self, beam_size, keep_beams=False,
:param dict|None scheduled_sampling:
:param bool|str cheating: if True, will always add the true target in the beam.
if "exclusive", enables cheating_exclusive. see :func:`TFUtil.beam_search`.
:param str|None prefix_target: If given, this data stream will be enforced to be the prefix of the layer output,
i.e. for the first n positions, the beam choices will be overwritten by the labels from "prefix_target".
:param list[LayerBase]|None explicit_search_sources: will mark it as an additional dependency.
You might use these also in custom_score_combine.
:param callable|None custom_score_combine:
Expand All @@ -4524,6 +4548,7 @@ def __init__(self, beam_size, keep_beams=False,
self.search_scores_base = None
self.search_scores_combined = None
# We assume log-softmax here, inside the rec layer.
self.prefix_target = prefix_target

if self.search_flag:
if cheating:
Expand Down Expand Up @@ -4677,6 +4702,13 @@ def __init__(self, beam_size, keep_beams=False,
cheating_exclusive=cheating_exclusive)
self.search_choices.set_src_beams(src_beams) # (batch, beam) -> beam_in idx
labels = tf.reshape(labels, [net_batch_dim * beam_size]) # (batch * beam)

if self.prefix_target:
assert len(self.sources) == 1, "Prefix decoding not yet implemented for multiple sources."
labels, scores = self._enforce_prefixes(
top_k_labels=labels, all_scores=scores_comb, top_k_scores=scores, batch_dim=net_batch_dim,
beam_size=beam_size)

labels = tf.cast(labels, self.output.dtype)

if len(self.sources) > 1:
Expand Down Expand Up @@ -4986,6 +5018,59 @@ def _get_cheating_targets_and_src_beam_idxs(self, scores):
src_beams = src_beams[:, 0] # (batch,)
return cheating_gold_targets, src_beams

def _enforce_prefixes(self, top_k_labels, all_scores, top_k_scores, batch_dim, beam_size):
"""
This function replaces the target labels from beam search by the ones predefined by the target prefixes as long
as search is still at a position within the prefix. We also replace the scores such that they correspond to a
prediction of the prefixes.

:param tf.Tensor top_k_labels: target labels from beam seach, shape (batch * beam,)
:param tf.Tensor all_scores: scores before beam pruning, used to lookup prefix scores, shape (batch, beam, dim)
:param tf.Tensor top_k_scores: scores after beam pruning, shape (batch, beam)
:param tf.Tensor|int batch_dim: number of sequences in batch (without beam)
:param int beam_size: outgoing beam size of this layer
:return: labels (batch * beam,) and scores (batch, beam) of self.prefix_target as long as within prefix, after
that top_k_labels and top_k_scores from beam search
:rtype: (tf.Tensor, tf.Tensor)
"""
assert self.prefix_target

# Get the labels of the prefixes which should be enforced. They are padded with zeros, therefore we will
# get zeros for those sequences where the current timestep is beyond the length of the prefix.
target_prefix_labels = self._get_target_value(
target=self.prefix_target).get_placeholder_as_batch_major() # (batch * beam,), int32

# Get prefixes that have already ended (i.e. have a smaller length than the current time step).
target_prefix_ended = tf.equal(target_prefix_labels, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has the hardcoded assumption that EOS=0. I would avoid that. Better rely on the target seq len (using rec step info to check the current pos).


# Select between the prefixes and the labels from free decoding, depending on whether the prefix
# has still got to be enforced.
labels = tf.where(target_prefix_ended, top_k_labels, target_prefix_labels)

# Get rid of the redundant beam, all entries are the same, only keep first entry.
target_prefix_labels = tf.reshape(target_prefix_labels, [batch_dim, beam_size])[:, 0] # (batch,)

# Now also get the scores for the prefixes. First, select only the first entry of the incoming beam as all entries
# are the same if we are still within the prefix.
all_scores = all_scores[:, 0, :] # (batch, dim)

# Gather scores for the prefix labels.
from returnn.tf.util.basic import nd_indices
target_prefix_nd_indices = nd_indices(target_prefix_labels)
prefix_scores = tf.expand_dims(tf.gather_nd(all_scores, target_prefix_nd_indices), axis=-1) # (batch, 1)

# Create an artificial beam, where all but the first scores are infinite. Tiling the one entry we have would
# lead to a beam of all equal hypotheses for the rest of the search.
# Conceptually similar to TFUtil.filter_ended_scores().
prefix_scores_padding = tf.fill([batch_dim, beam_size - 1], -1.e30)
prefix_scores = tf.concat([prefix_scores, prefix_scores_padding], axis=1)

# Use prefix scores for sequences where the prefix has not ended yet.
target_prefix_ended = tf.reshape(target_prefix_ended, [batch_dim, beam_size])
scores = tf.where(target_prefix_ended, top_k_scores, prefix_scores) # (batch, beam)

return labels, scores

@classmethod
def transform_config_dict(cls, d, network, get_layer):
"""
Expand Down Expand Up @@ -5041,8 +5126,8 @@ def _create_search_beam(cls, name, beam_size, sources, network):
name="%s%s" % (network.get_absolute_name_prefix(), name))

@classmethod
def get_out_data_from_opts(cls, name, sources, target, network,
beam_size, search=NotSpecified, scheduled_sampling=False, cheating=False, **kwargs):
def get_out_data_from_opts(cls, name, sources, target, network, beam_size, search=NotSpecified,
scheduled_sampling=False, cheating=False, prefix_target=None, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of prefix_target, this could be a layer, like prefix. And you could directly refer to the whole sequence. So you would have prefix="base:data:prefix" or so in your config. This would avoid the whole padding logic.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. You mean ChoiceLayer gets the whole prefix and selects the label of the current timestep itself? Or would a layer somehow handle the padding for me?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, prefix="base:data:prefix" would be the whole prefix.

"""
:param str name:
:param list[LayerBase] sources:
Expand All @@ -5052,6 +5137,7 @@ def get_out_data_from_opts(cls, name, sources, target, network,
:param NotSpecified|bool search:
:param dict|bool scheduled_sampling:
:param bool cheating:
:param str prefix_target:
:rtype: Data
"""
search = NotSpecified.resolve(search, network.search_flag)
Expand All @@ -5077,6 +5163,8 @@ def get_out_data_from_opts(cls, name, sources, target, network,
out_data.batch = out_data.batch.copy_set_beam(out_data.beam)
if cheating or scheduled_sampling or not search:
cls._static_get_target_value(target=target, network=network, mark_data_key_as_used=True) # mark as used
if search and prefix_target:
cls._static_get_target_value(target=prefix_target, network=network, mark_data_key_as_used=True) # mark as used
return out_data

def get_sub_layer(self, layer_name):
Expand Down