Merge pull request #754 from vdp/fix-compute-epochs2

nnet3: Undo the damage done by 2310a19ca and fix the original problem properly
This commit is contained in:
Daniel Povey 2016-05-04 15:40:23 -04:00
Родитель a8ca3bf4de 048c01d048
Коммит 27dc00c53e
2 изменённых файлов: 45 добавлений и 20 удалений

Просмотреть файл

@ -135,14 +135,14 @@ void BuildTestTopSortOrder(std::vector<int32> *node_to_order) {
// The topological sorting order of the above SCC graph is as follows (from
// our particular algorithm):
// 0 --> 0
// 1 --> 1
// 2 --> 2
// 3 --> 3
(*node_to_order)[0] = 0;
(*node_to_order)[1] = 1;
(*node_to_order)[2] = 2;
(*node_to_order)[3] = 3;
// 0 --> 3
// 1 --> 2
// 2 --> 1
// 3 --> 0
(*node_to_order)[0] = 3;
(*node_to_order)[1] = 2;
(*node_to_order)[2] = 1;
(*node_to_order)[3] = 0;
}
void UnitTestComputeGraphTranspose() {
@ -196,6 +196,35 @@ void UnitTestComputeTopSortOrder() {
KALDI_ASSERT(AssertVectorEqual(node_to_order, ref_node_to_order));
}
void UnitTestComputeTopSortOrder2() {
// The outer vector is indexed by node ID, and each nested vector contains
// the node IDs for its successors in the graph. For example, if there are
// arcs from node 0 to nodes 1 and 2, then the vector at graph[0] will be (1, 2)
std::vector<std::vector<int32> > graph;
// Build a test graph:
// 0 ---> 1 ---> 2 ---> 4
// `--> 3 -----^
graph.resize(5);
graph[0].push_back(1); graph[0].push_back(3);
graph[1].push_back(2);
graph[2].push_back(4);
graph[3].push_back(2);
// graph[4] is empty(has no successors)
// fill in the desired(topological) mapping node->order
std::vector<int32> ref_node_to_order;
ref_node_to_order.push_back(0); // node 0 comes first
ref_node_to_order.push_back(2); // node 1 comes third
ref_node_to_order.push_back(3); // node 2 comes fourth
ref_node_to_order.push_back(1); // node 3 comes second
ref_node_to_order.push_back(4); // node 4 comes last
std::vector<int32> computed_node_to_order;
ComputeTopSortOrder(graph, &computed_node_to_order);
KALDI_ASSERT(AssertVectorEqual(ref_node_to_order, computed_node_to_order));
}
} // namespace nnet3
} // namespace kaldi
@ -207,6 +236,7 @@ int main() {
UnitTestFindSccs();
UnitTestMakeSccGraph();
UnitTestComputeTopSortOrder();
UnitTestComputeTopSortOrder2();
KALDI_LOG << "Nnet graph tests succeeded.";

Просмотреть файл

@ -230,18 +230,18 @@ void ComputeTopSortOrder(const std::vector<std::vector<int32> > &graph,
std::vector<bool> cycle_detector(graph.size(), false);
std::vector<bool> is_visited(graph.size(), false);
std::vector<int32> orders;
std::vector<int32> reversed_orders;
for(int32 i = 0; i < graph.size(); ++i) {
if (!is_visited[i]) {
ComputeTopSortOrderRecursive(i, graph, &cycle_detector,
&is_visited, &orders);
&is_visited, &reversed_orders);
}
}
KALDI_ASSERT(node_to_order->size() == orders.size());
for (int32 i = 0; i < orders.size(); ++i) {
KALDI_ASSERT(orders[i] >= 0 && orders[i] < graph.size());
(*node_to_order)[orders[i]] = i;
KALDI_ASSERT(node_to_order->size() == reversed_orders.size());
for (int32 i = 0; i < reversed_orders.size(); ++i) {
KALDI_ASSERT(reversed_orders[i] >= 0 && reversed_orders[i] < graph.size());
(*node_to_order)[reversed_orders[i]] = graph.size() - i - 1;
}
}
@ -277,13 +277,8 @@ void ComputeNnetComputationEpochs(const Nnet &nnet,
MakeSccGraph(graph, sccs, &scc_graph);
KALDI_VLOG(6) << "scc graph is: " << PrintGraphToString(scc_graph);
std::vector<std::vector<int32> > scc_graph_transpose;
// compute transpose because we actually want the reverse of the topological
// order so inputs come first and then outputs.
ComputeGraphTranspose(scc_graph, &scc_graph_transpose);
std::vector<int32> scc_node_to_epoch;
ComputeTopSortOrder(scc_graph_transpose, &scc_node_to_epoch);
ComputeTopSortOrder(scc_graph, &scc_node_to_epoch);
if (GetVerboseLevel() >= 6) {
std::ostringstream os;
for (int32 i = 0; i < scc_node_to_epoch.size(); i++)