Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LabelScorer base class #80

Open
wants to merge 7 commits into
base: collapsed-vector
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/Flow/Vector.hh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public:
: Timestamp(type()), std::vector<T>(n, t) {}
Vector(const std::vector<T>& v)
: Timestamp(type()), std::vector<T>(v) {}
Vector(const std::vector<T>& v, Time start, Time end)
: Timestamp(start, end), std::vector<T>(v) {}
template<class InputIterator>
Vector(InputIterator begin, InputIterator end)
: Timestamp(type()), std::vector<T>(begin, end) {}
Expand Down
52 changes: 52 additions & 0 deletions src/Nn/LabelScorer/LabelScorer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/** Copyright 2024 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "LabelScorer.hh"
#include <Flow/Timestamp.hh>

namespace Nn {

/*
* =============================
* === LabelScorer =============
* =============================
*/

LabelScorer::LabelScorer(const Core::Configuration& config)
: Core::Component(config) {}

void LabelScorer::addInput(Core::Ref<const Speech::Feature> input) {
addInput(Flow::dataPtr(new FeatureVector(*input->mainStream(), input->timestamp().startTime(), input->timestamp().endTime())));
}

std::optional<LabelScorer::ScoresWithTimes> LabelScorer::getScoresWithTimes(const std::vector<LabelScorer::Request>& requests) {
// By default, just loop over the non-batched `getScoreWithTime` and collect the results
ScoresWithTimes result;

result.scores.reserve(requests.size());
result.timesteps.reserve(requests.size());
for (auto& request : requests) {
auto singleResult = getScoreWithTime(request);
if (not singleResult.has_value()) {
return {};
}
result.scores.push_back(singleResult->score);
result.timesteps.push_back(singleResult->timeframe);
}

return result;
}

} // namespace Nn
144 changes: 144 additions & 0 deletions src/Nn/LabelScorer/LabelScorer.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/** Copyright 2024 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LABEL_SCORER_HH
#define LABEL_SCORER_HH

#include <Core/CollapsedVector.hh>
#include <Core/Component.hh>
#include <Core/Configuration.hh>
#include <Core/Parameter.hh>
#include <Core/ReferenceCounting.hh>
#include <Core/Types.hh>
#include <Flow/Timestamp.hh>
#include <Flow/Vector.hh>
#include <Mm/FeatureScorer.hh>
#include <Nn/Types.hh>
#include <Search/Types.hh>
#include <Speech/Feature.hh>
#include <Speech/Types.hh>
#include <optional>
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
#include "ScoringContext.hh"

namespace Nn {

typedef Search::Score Score;
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
typedef Flow::Vector<f32> FeatureVector;
typedef Flow::DataPtr<FeatureVector> FeatureVectorRef;

/*
* Abstract base class for scoring tokens within an ASR search algorithm.
*
* This class provides an interface for different types of label scorers in an ASR system.
* Label Scorers compute the scores of tokens based on input features and a scoring context.
* Children of this base class should represent various ASR model architectures and cover a
* wide range of possibilities such as CTC, transducer, AED or other models.
*
* The usage is intended as follows:
* - Before or during the search, features can be added
* - At the beginning of search, `getInitialScoringContext` should be called
* and used for the first hypotheses
* - For a given hypothesis in search, its search context together with a successor token and
* transition type are packed into a request and scored via `getScoreWithTime`. This also returns
* the timestamp of the successor.
* Note: The scoring function may return no value, in this case it is not ready yet
* and needs more input features.
* Note: There is also the function `getScoresWithTimes` which can handle an entire batch of
* requests at once and might be implemented more efficiently (e.g. using batched model forwarding).
* - For all hypotheses that survive pruning, the LabelScorer can compute a new scoring context
* that extends the previous scoring context of that hypothesis with a given successor token. This new
* scoring context can then be used as context in subsequent search steps.
* - After all features have been passed, the `signalNoMoreFeatures` function is called to inform
* the label scorer that it doesn't need to wait for more features and can score as much as possible.
* This is especially important when the label scorer internally uses an encoder or window with right
* context.
* - When all necessary scores for the current segment have been computed, the `reset` function is called
* to clean up any internal data (e.g. feature buffer) or reset flags of the LabelScorer. Afterwards
* it is ready to receive features for the next segment.
*
* Each concrete subclass internally implements a concrete type of scoring context which the outside
* search algorithm is agnostic to. Depending on the model, this scoring context can consist of things like
* the current timestep, a label history, a hidden state or other values.
*/
class LabelScorer : public virtual Core::Component,
public Core::ReferenceCounted {
public:
// Transition type as part of scoring or context extension requests
enum TransitionType {
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
LABEL_TO_LABEL,
LABEL_LOOP,
LABEL_TO_BLANK,
BLANK_TO_LABEL,
BLANK_LOOP,
};

// Request for scoring or context extension
struct Request {
ScoringContextRef context;
LabelIndex nextToken;
TransitionType transitionType;
};

// Return value of scoring function
struct ScoreWithTime {
Score score;
Speech::TimeframeIndex timeframe;
};

// Return value of batched scoring function
struct ScoresWithTimes {
std::vector<Score> scores;
Core::CollapsedVector<Speech::TimeframeIndex> timesteps; // Timesteps vector is internally collapsed if all timesteps are the same (e.g. time-sync decoding)
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
};
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved

LabelScorer(const Core::Configuration& config);
virtual ~LabelScorer() = default;

// Prepares the LabelScorer to receive new inputs
// e.g. by resetting input buffers and segmentEnd flags
virtual void reset() = 0;

// Tells the LabelScorer that there will be no more input features coming in the current segment
virtual void signalNoMoreFeatures() = 0;

// Gets initial scoring context to use for the hypotheses in the first search step
virtual ScoringContextRef getInitialScoringContext() = 0;

// Creates a copy of the context in the request that is extended using the given token and transition type
virtual ScoringContextRef extendedScoringContext(Request request) = 0;

// Function that returns the mapping of each timeframe index (returned in the scoring functions)
// to actual flow timestamps with start-/ and end-time in seconds.
virtual const std::vector<Flow::Timestamp>& getTimestamps() const = 0;

// Add a single input feature
virtual void addInput(FeatureVectorRef input) = 0;
virtual void addInput(Core::Ref<const Speech::Feature> input);
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved

// Perform scoring computation for a single request
// Return score and timeframe index of the corresponding output
// May not return a value if the LabelScorer is not ready to score the request yet
// (e.g. not enough features received)
virtual std::optional<ScoreWithTime> getScoreWithTime(const Request request) = 0;
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved

// Perform scoring computation for a batch of requests
// May be implemented more efficiently than iterated calls of `getScoreWithTime`
// Return two vectors: one vector with scores and one vector with times
virtual std::optional<ScoresWithTimes> getScoresWithTimes(const std::vector<Request>& requests);
};

} // namespace Nn

#endif // LABEL_SCORER_HH
26 changes: 26 additions & 0 deletions src/Nn/LabelScorer/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!gmake

TOPDIR = ../../..

include $(TOPDIR)/Makefile.cfg

# -----------------------------------------------------------------------------

SUBDIRS =
TARGETS = libSprintLabelScorer.$(a)

LIBSPRINTLABELSCORER_O = \
$(OBJDIR)/LabelScorer.o \
$(OBJDIR)/ScoringContext.o

# -----------------------------------------------------------------------------

all: $(TARGETS)

libSprintLabelScorer.$(a): $(LIBSPRINTLABELSCORER_O)
$(MAKELIB) $@ $^

include $(TOPDIR)/Rules.make

sinclude $(LIBSPRINTLABELSCORER_O:.o=.d)
include $(patsubst %.o,%.d,$(filter %.o,$(CHECK_O)))
34 changes: 34 additions & 0 deletions src/Nn/LabelScorer/ScoringContext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/** Copyright 2024 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ScoringContext.hh"

namespace Nn {

typedef Mm::EmissionIndex LabelIndex;

/*
* =============================
* === ScoringContext ============
SimBe195 marked this conversation as resolved.
Show resolved Hide resolved
* =============================
*/
size_t ScoringContextHash::operator()(ScoringContextRef history) const {
return 0ul;
}

bool ScoringContextEq::operator()(ScoringContextRef lhs, ScoringContextRef rhs) const {
return true;
}

} // namespace Nn
45 changes: 45 additions & 0 deletions src/Nn/LabelScorer/ScoringContext.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/** Copyright 2024 RWTH Aachen University. All rights reserved.
*
* Licensed under the RWTH ASR License (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.hltpr.rwth-aachen.de/rwth-asr/rwth-asr-license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef SCORING_CONTEXT_HH
#define SCORING_CONTEXT_HH

#include <Core/ReferenceCounting.hh>
#include <Mm/Types.hh>

namespace Nn {

typedef Mm::EmissionIndex LabelIndex;

/*
* Empty scoring context base class
*/
struct ScoringContext : public Core::ReferenceCounted {
virtual ~ScoringContext() = default;
};

typedef Core::Ref<const ScoringContext> ScoringContextRef;

struct ScoringContextHash {
size_t operator()(ScoringContextRef history) const;
};

struct ScoringContextEq {
bool operator()(ScoringContextRef lhs, ScoringContextRef rhs) const;
};

} // namespace Nn

#endif // SCORING_CONTEXT_HH
5 changes: 5 additions & 0 deletions src/Nn/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ ifdef MODULE_GENERIC_SEQ2SEQ_TREE_SEARCH
endif
endif

SUBDIRS += LabelScorer

# -----------------------------------------------------------------------------
all: $(TARGETS)

Expand All @@ -99,6 +101,9 @@ interpol:
libSprintNn.$(a): $(SUBDIRS) $(LIBSPRINTNN_O)
$(MAKELIB) $@ $(LIBSPRINTNN_O) $(patsubst %,%/$(OBJDIR)/*.o,$(SUBDIRS))

LabelScorer:
$(MAKE) -C $@ libSprintLabelScorer.$(a)

check$(exe): $(CHECK_O)
$(LD) $(LD_START_GROUP) $(CHECK_O) $(LD_END_GROUP) -o $@ $(LDFLAGS)

Expand Down