Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored TreeBuilder #85

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/Search/AdvancedTreeSearch/PersistentStateTree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct ConvertTree {
TreeIndex masterTreeIndex;
StateId rootSubTree;
StateId ciRootNode;
std::map<StateTree::Exit, u32> exits; //Maps exits to label-indices @todo Make this a hash_map
std::map<StateTree::Exit, u32> exits; // Maps exits to label-indices @todo Make this a hash_map
std::vector<PersistentStateTree::Exit> exitVector;
Core::HashMap<StateId, StateTree::StateId> statesForNodes;
Core::HashMap<StateTree::StateId, StateId> nodesForStates;
Expand Down Expand Up @@ -73,7 +73,7 @@ struct ConvertTree {
}
}

///Make sure a node is created for every single state, so that also the coarticulated roots are respected
/// Make sure a node is created for every single state, so that also the coarticulated roots are respected

for (std::set<StateTree::StateId>::iterator stateIt = coarticulatedRootStates.begin(); stateIt != coarticulatedRootStates.end(); ++stateIt) {
StateTree::StateId state = *stateIt;
Expand Down Expand Up @@ -121,7 +121,7 @@ struct ConvertTree {
exitIndices.insert(exitEntry->second);
}

//Add connections to the attached outputs/exits
// Add connections to the attached outputs/exits
for (std::set<u32>::iterator it = exitIndices.begin(); it != exitIndices.end(); ++it)
subtrees.addOutputToEdge(subtrees.state(node).successors, *it);
}
Expand Down Expand Up @@ -150,10 +150,10 @@ struct ConvertTree {

subtrees.state(node).stateDesc = state;

//Build successor structure
// Build successor structure
std::pair<StateTree::SuccessorIterator, StateTree::SuccessorIterator> successors = tree->successors(stateId);

StateId current = node; //Just to verify the order
StateId current = node; // Just to verify the order

for (; successors.first != successors.second; ++successors.first) {
std::unordered_map<StateTree::StateId, StateId>::iterator nodeIt = nodesForStates.find(*successors.first);
Expand All @@ -166,14 +166,15 @@ struct ConvertTree {
}
};

PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon)
PersistentStateTree::PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory)
: masterTree(0),
rootState(0),
ciRootState(0),
archive_(paramCacheArchive(Core::Configuration(config, "search-network"))),
acousticModel_(acousticModel),
lexicon_(lexicon),
config_(config) {
config_(config),
treeBuilderFactory_(treeBuilderFactory) {
if (acousticModel_.get() && lexicon_.get()) {
const Am::ClassicAcousticModel* am = required_cast(const Am::ClassicAcousticModel*, acousticModel.get());
Core::DependencySet d;
Expand Down Expand Up @@ -320,7 +321,7 @@ bool PersistentStateTree::read(Core::MappedArchiveReader in) {
in >> masterTree >> dependenciesChecksum;

if (dependenciesChecksum != dependencies_.getChecksum()) {
Core::Application::us()->log() << "dependencies of the network image don't equal the requiered dependencies with checksum " << dependenciesChecksum;
Core::Application::us()->log() << "dependencies of the network image don't equal the required dependencies with checksum " << dependenciesChecksum;
return false;
}

Expand Down Expand Up @@ -436,7 +437,7 @@ HMMStateNetwork::CleanupResult PersistentStateTree::cleanup(bool cleanupExits) {

Core::HashMap<StateId, StateId>::const_iterator targetNodeIt;
if (rootState) {
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); //Root-node must stay unchanged
verify(cleanupResult.nodeMap.find(rootState) != cleanupResult.nodeMap.end()); // Root-node must stay unchanged
verify(cleanupResult.nodeMap.find(rootState)->second == rootState);
targetNodeIt = cleanupResult.nodeMap.find(rootState);
verify(targetNodeIt != cleanupResult.nodeMap.end());
Expand Down Expand Up @@ -512,7 +513,7 @@ void PersistentStateTree::dumpDotGraph(std::string file, const std::vector<int>&
int depth = 0;
if (!nodeDepths.empty())
depth = nodeDepths[node];
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d", node, node, depth, structure.state(node).stateDesc.acousticModel);
os << Core::form("n%d [label=\"%d\\nd=%d\\nm=%d\\nt=%d", node, node, depth, structure.state(node).stateDesc.acousticModel, structure.state(node).stateDesc.transitionModelIndex);

for (HMMStateNetwork::SuccessorIterator target = structure.successors(node); target; ++target)
if (target.isLabel() && exits[target.label()].pronunciation != Bliss::LemmaPronunciation::invalidId)
Expand Down
29 changes: 17 additions & 12 deletions src/Search/AdvancedTreeSearch/PersistentStateTree.hh
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,38 @@ struct MyStandardValueHash {
}
};

class AbstractTreeBuilder;

namespace Search {
class HMMStateNetwork;
class StateTree;

class PersistentStateTree {
public:
using TreeBuilderFactory = std::function<std::unique_ptr<AbstractTreeBuilder>(Core::Configuration, const Bliss::Lexicon&, const Am::AcousticModel&, PersistentStateTree&, bool)>;

///@param lexicon This must be given if the resulting exits are supposed to be functional
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon);
PersistentStateTree(Core::Configuration config, Core::Ref<const Am::AcousticModel> acousticModel, Bliss::LexiconRef lexicon, TreeBuilderFactory treeBuilderFactory);

///Builds this state tree.
/// Builds this state tree.
void build();

///Writes the current state of the state tree into the file,
///Returns whether writing was successful
/// Writes the current state of the state tree into the file,
/// Returns whether writing was successful
bool write(int transformation = 0);

///Reads the state tree from the file.
/// Reads the state tree from the file.
///@return Whether the reading was successful.
bool read(int transformation = 0);

///Cleans up the structure, saving memory and allowing a more efficient iteration.
///Node and tree IDs may be changed.
/// Cleans up the structure, saving memory and allowing a more efficient iteration.
/// Node and tree IDs may be changed.
///@return An object that contains a mapping representing the index changes.
HMMStateNetwork::CleanupResult cleanup(bool cleanupExits = true);

///Removes all outputs from the network
///Also performs a cleanup, so the search network must already be clean
///for indices to stay equal
/// Removes all outputs from the network
/// Also performs a cleanup, so the search network must already be clean
/// for indices to stay equal
void removeOutputs();

u32 getChecksum() const;
Expand Down Expand Up @@ -128,11 +132,12 @@ private:
Core::Ref<const Am::AcousticModel> acousticModel_;
Bliss::LexiconRef lexicon_;
Core::Configuration config_;
TreeBuilderFactory treeBuilderFactory_;

//Writes the whole state network into the given stream
// Writes the whole state network into the given stream
void write(Core::MappedArchiveWriter writer);

//Reads the state network from the given stream.
// Reads the state network from the given stream.
//@return Whether the reading was successful.
bool read(Core::MappedArchiveReader reader);
};
Expand Down
Loading