Skip to content

Commit

Permalink
fix: 为 intoGraph 补充拓扑序判断
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Oct 16, 2023
1 parent 50a02aa commit 9fffe3f
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions src/01graph_topo/include/graph_topo/linked_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ namespace refactor::graph_topo {
while (known.size() < _nodes.size()) {
auto before = known.size();
for (auto &n : _nodes) {
// n was moved
if (!n) { continue; }
// ∀e ∈ n.inputs, e.source ∈ known
if (std::all_of(n->_inputs.begin(), n->_inputs.end(),
[&known](auto const &e) { return !e || !e->_source || known.find(e->_source.get()) != known.end(); })) {
known.insert(n.get());
Expand Down Expand Up @@ -313,40 +315,52 @@ namespace refactor::graph_topo {
nodes.reserve(_nodes.size());
edges.reserve(_inputs.size());

std::unordered_set<void *> mappedNodes;
std::unordered_map<void *, GraphTopo::OutputEdge> edgeIndices;
for (auto &e : _inputs) {
edgeIndices.try_emplace(e.get(), edges.size());
edges.emplace_back(std::move(e->_info));
}

for (auto &n : _nodes) {
nodes.emplace_back(std::move(n->_info));

idx_t newLocalCount = 0;
topology._connections.reserve(topology._connections.size() + n->_inputs.size());
for (auto &e : n->_inputs) {
ASSERT(e, "Input edge is not connected");
auto [it, ok] = edgeIndices.try_emplace(e.get(), edges.size());
if (ok) {
ASSERT(!e->_source, "Local edge should not have source node");
++newLocalCount;
while (mappedNodes.size() < _nodes.size()) {
auto before = mappedNodes.size();
for (auto &n : _nodes) {
// ∃e ∈ n.inputs, e.source ∉ mapped
if (std::any_of(n->_inputs.begin(), n->_inputs.end(),
[&mappedNodes](auto const &e) {
ASSERT(e, "Input edge is not connected");
return e->_source && mappedNodes.find(e->_source.get()) == mappedNodes.end();
})) {
continue;
}
mappedNodes.insert(n.get());
nodes.emplace_back(std::move(n->_info));

idx_t newLocalCount = 0;
topology._connections.reserve(topology._connections.size() + n->_inputs.size());
for (auto &e : n->_inputs) {
auto [it, ok] = edgeIndices.try_emplace(e.get(), edges.size());
if (ok) {
ASSERT(!e->_source, "Local edge should not have source node");
++newLocalCount;
edges.emplace_back(std::move(e->_info));
}
topology._connections.push_back(it->second);
}
for (auto &e : n->_outputs) {
edgeIndices[e.get()] = edges.size();
edges.emplace_back(std::move(e->_info));
}
topology._connections.push_back(it->second);
}

for (auto &e : n->_outputs) {
edgeIndices[e.get()] = edges.size();
edges.emplace_back(std::move(e->_info));
topology._nodes.push_back({
newLocalCount,
static_cast<idx_t>(n->_inputs.size()),
static_cast<idx_t>(n->_outputs.size()),
});
}
if (before == mappedNodes.size()) {
RUNTIME_ERROR("Graph is not topo-sortable.");
}

topology._nodes.push_back({
newLocalCount,
static_cast<idx_t>(n->_inputs.size()),
static_cast<idx_t>(n->_outputs.size()),
});
}

std::transform(_outputs.begin(), _outputs.end(),
topology._connections.begin(),
[&](auto &e) { return edgeIndices.at(e.get()); });
Expand Down

0 comments on commit 9fffe3f

Please sign in to comment.