зеркало из https://github.com/mozilla/kaldi.git
Merge pull request #754 from vdp/fix-compute-epochs2
nnet3: Undo the damage done by 2310a19ca
and fix the original problem properly
This commit is contained in:
Коммит
27dc00c53e
|
@ -135,14 +135,14 @@ void BuildTestTopSortOrder(std::vector<int32> *node_to_order) {
|
|||
|
||||
// The topological sorting order of the above SCC graph is as follows (from
|
||||
// our particular algorithm):
|
||||
// 0 --> 0
|
||||
// 1 --> 1
|
||||
// 2 --> 2
|
||||
// 3 --> 3
|
||||
(*node_to_order)[0] = 0;
|
||||
(*node_to_order)[1] = 1;
|
||||
(*node_to_order)[2] = 2;
|
||||
(*node_to_order)[3] = 3;
|
||||
// 0 --> 3
|
||||
// 1 --> 2
|
||||
// 2 --> 1
|
||||
// 3 --> 0
|
||||
(*node_to_order)[0] = 3;
|
||||
(*node_to_order)[1] = 2;
|
||||
(*node_to_order)[2] = 1;
|
||||
(*node_to_order)[3] = 0;
|
||||
}
|
||||
|
||||
void UnitTestComputeGraphTranspose() {
|
||||
|
@ -196,6 +196,35 @@ void UnitTestComputeTopSortOrder() {
|
|||
KALDI_ASSERT(AssertVectorEqual(node_to_order, ref_node_to_order));
|
||||
}
|
||||
|
||||
void UnitTestComputeTopSortOrder2() {
|
||||
// The outer vector is indexed by node ID, and each nested vector contains
|
||||
// the node IDs for its successors in the graph. For example, if there are
|
||||
// arcs from node 0 to nodes 1 and 2, then the vector at graph[0] will be (1, 2)
|
||||
std::vector<std::vector<int32> > graph;
|
||||
|
||||
// Build a test graph:
|
||||
// 0 ---> 1 ---> 2 ---> 4
|
||||
// `--> 3 -----^
|
||||
graph.resize(5);
|
||||
graph[0].push_back(1); graph[0].push_back(3);
|
||||
graph[1].push_back(2);
|
||||
graph[2].push_back(4);
|
||||
graph[3].push_back(2);
|
||||
// graph[4] is empty(has no successors)
|
||||
|
||||
// fill in the desired(topological) mapping node->order
|
||||
std::vector<int32> ref_node_to_order;
|
||||
ref_node_to_order.push_back(0); // node 0 comes first
|
||||
ref_node_to_order.push_back(2); // node 1 comes third
|
||||
ref_node_to_order.push_back(3); // node 2 comes fourth
|
||||
ref_node_to_order.push_back(1); // node 3 comes second
|
||||
ref_node_to_order.push_back(4); // node 4 comes last
|
||||
|
||||
std::vector<int32> computed_node_to_order;
|
||||
ComputeTopSortOrder(graph, &computed_node_to_order);
|
||||
KALDI_ASSERT(AssertVectorEqual(ref_node_to_order, computed_node_to_order));
|
||||
}
|
||||
|
||||
} // namespace nnet3
|
||||
} // namespace kaldi
|
||||
|
||||
|
@ -207,6 +236,7 @@ int main() {
|
|||
UnitTestFindSccs();
|
||||
UnitTestMakeSccGraph();
|
||||
UnitTestComputeTopSortOrder();
|
||||
UnitTestComputeTopSortOrder2();
|
||||
|
||||
KALDI_LOG << "Nnet graph tests succeeded.";
|
||||
|
||||
|
|
|
@ -230,18 +230,18 @@ void ComputeTopSortOrder(const std::vector<std::vector<int32> > &graph,
|
|||
std::vector<bool> cycle_detector(graph.size(), false);
|
||||
std::vector<bool> is_visited(graph.size(), false);
|
||||
|
||||
std::vector<int32> orders;
|
||||
std::vector<int32> reversed_orders;
|
||||
for(int32 i = 0; i < graph.size(); ++i) {
|
||||
if (!is_visited[i]) {
|
||||
ComputeTopSortOrderRecursive(i, graph, &cycle_detector,
|
||||
&is_visited, &orders);
|
||||
&is_visited, &reversed_orders);
|
||||
}
|
||||
}
|
||||
|
||||
KALDI_ASSERT(node_to_order->size() == orders.size());
|
||||
for (int32 i = 0; i < orders.size(); ++i) {
|
||||
KALDI_ASSERT(orders[i] >= 0 && orders[i] < graph.size());
|
||||
(*node_to_order)[orders[i]] = i;
|
||||
KALDI_ASSERT(node_to_order->size() == reversed_orders.size());
|
||||
for (int32 i = 0; i < reversed_orders.size(); ++i) {
|
||||
KALDI_ASSERT(reversed_orders[i] >= 0 && reversed_orders[i] < graph.size());
|
||||
(*node_to_order)[reversed_orders[i]] = graph.size() - i - 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -277,13 +277,8 @@ void ComputeNnetComputationEpochs(const Nnet &nnet,
|
|||
MakeSccGraph(graph, sccs, &scc_graph);
|
||||
KALDI_VLOG(6) << "scc graph is: " << PrintGraphToString(scc_graph);
|
||||
|
||||
std::vector<std::vector<int32> > scc_graph_transpose;
|
||||
// compute transpose because we actually want the reverse of the topological
|
||||
// order so inputs come first and then outputs.
|
||||
ComputeGraphTranspose(scc_graph, &scc_graph_transpose);
|
||||
|
||||
std::vector<int32> scc_node_to_epoch;
|
||||
ComputeTopSortOrder(scc_graph_transpose, &scc_node_to_epoch);
|
||||
ComputeTopSortOrder(scc_graph, &scc_node_to_epoch);
|
||||
if (GetVerboseLevel() >= 6) {
|
||||
std::ostringstream os;
|
||||
for (int32 i = 0; i < scc_node_to_epoch.size(); i++)
|
||||
|
|
Загрузка…
Ссылка в новой задаче