Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 95 additions & 89 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1255,15 +1255,22 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue,

void exec_graph_impl::duplicateNodes() {
// Map of original modifiable nodes (keys) to new duplicated nodes (values)
std::map<node_impl *, node_impl *> NodesMap;

std::unordered_map<node_impl *, node_impl *> NodesMap;
nodes_range ModifiableNodes{MGraphImpl->MNodeStorage};
std::deque<std::shared_ptr<node_impl>> NewNodes;
std::vector<std::shared_ptr<node_impl>> NewNodes;

const size_t NodeCount = ModifiableNodes.size();
NodesMap.reserve(NodeCount);
NewNodes.reserve(NodeCount);

bool foundSubgraph = false;

for (node_impl &OriginalNode : ModifiableNodes) {
NewNodes.push_back(std::make_shared<node_impl>(OriginalNode));
node_impl &NodeCopy = *NewNodes.back();

foundSubgraph |= (NodeCopy.MNodeType == node_type::subgraph);

// Associate the ID of the original node with the node copy for later quick
// access
MIDCache.insert(std::make_pair(OriginalNode.MID, &NodeCopy));
Expand Down Expand Up @@ -1292,110 +1299,109 @@ void exec_graph_impl::duplicateNodes() {

// Subgraph nodes need special handling, we extract all subgraph nodes and
// merge them into the main node list
if (foundSubgraph) {
for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend();
++NewNodeIt) {
auto NewNode = *NewNodeIt;
if (NewNode->MNodeType != node_type::subgraph) {
continue;
}
nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage};
std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};

// Map of original subgraph nodes (keys) to new duplicated nodes (values)
std::map<node_impl *, node_impl *> SubgraphNodesMap;

// Copy subgraph nodes
for (node_impl &SubgraphNode : SubgraphNodes) {
NewSubgraphNodes.push_back(std::make_shared<node_impl>(SubgraphNode));
node_impl &NodeCopy = *NewSubgraphNodes.back();
// Associate the ID of the original subgraph node with all extracted
// node copies for future quick access.
MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy));

SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy});
NodeCopy.MSuccessors.clear();
NodeCopy.MPredecessors.clear();
}

for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend();
++NewNodeIt) {
auto NewNode = *NewNodeIt;
if (NewNode->MNodeType != node_type::subgraph) {
continue;
}
nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage};
std::deque<std::shared_ptr<node_impl>> NewSubgraphNodes{};

// Map of original subgraph nodes (keys) to new duplicated nodes (values)
std::map<node_impl *, node_impl *> SubgraphNodesMap;

// Copy subgraph nodes
for (node_impl &SubgraphNode : SubgraphNodes) {
NewSubgraphNodes.push_back(std::make_shared<node_impl>(SubgraphNode));
node_impl &NodeCopy = *NewSubgraphNodes.back();
// Associate the ID of the original subgraph node with all extracted node
// copies for future quick access.
MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy));

SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy});
NodeCopy.MSuccessors.clear();
NodeCopy.MPredecessors.clear();
}

// Rebuild edges for new subgraph nodes
auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end();
for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd;
++OrigIt, ++NewIt) {
node_impl &SubgraphNode = *OrigIt;
node_impl &NodeCopy = **NewIt;
// Rebuild edges for new subgraph nodes
auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end();
for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd;
++OrigIt, ++NewIt) {
node_impl &SubgraphNode = *OrigIt;
node_impl &NodeCopy = **NewIt;

for (node_impl &NextNode : SubgraphNode.successors()) {
node_impl &Successor = *SubgraphNodesMap.at(&NextNode);
NodeCopy.registerSuccessor(Successor);
for (node_impl &NextNode : SubgraphNode.successors()) {
node_impl &Successor = *SubgraphNodesMap.at(&NextNode);
NodeCopy.registerSuccessor(Successor);
}
}
}

// Collect input and output nodes for the subgraph
std::vector<node_impl *> Inputs;
std::vector<node_impl *> Outputs;
for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
if (NodeImpl->MPredecessors.size() == 0) {
Inputs.push_back(&*NodeImpl);
}
if (NodeImpl->MSuccessors.size() == 0) {
Outputs.push_back(&*NodeImpl);
// Collect input and output nodes for the subgraph
std::vector<node_impl *> Inputs;
std::vector<node_impl *> Outputs;
for (std::shared_ptr<node_impl> &NodeImpl : NewSubgraphNodes) {
if (NodeImpl->MPredecessors.size() == 0) {
Inputs.push_back(&*NodeImpl);
}
if (NodeImpl->MSuccessors.size() == 0) {
Outputs.push_back(&*NodeImpl);
}
}
}

// Update the predecessors and successors of the nodes which reference the
// original subgraph node
// Update the predecessors and successors of the nodes which reference the
// original subgraph node

// Predecessors
for (node_impl &PredNode : NewNode->predecessors()) {
auto &Successors = PredNode.MSuccessors;
// Predecessors
for (node_impl &PredNode : NewNode->predecessors()) {
auto &Successors = PredNode.MSuccessors;

// Remove the subgraph node from this nodes successors
Successors.erase(
std::remove(Successors.begin(), Successors.end(), NewNode.get()),
Successors.end());
// Remove the subgraph node from this nodes successors
Successors.erase(
std::remove(Successors.begin(), Successors.end(), NewNode.get()),
Successors.end());

// Add all input nodes from the subgraph as successors for this node
// instead
for (node_impl *Input : Inputs) {
PredNode.registerSuccessor(*Input);
// Add all input nodes from the subgraph as successors for this node
// instead
for (node_impl *Input : Inputs) {
PredNode.registerSuccessor(*Input);
}
}
}

// Successors
for (node_impl &SuccNode : NewNode->successors()) {
auto &Predecessors = SuccNode.MPredecessors;
// Successors
for (node_impl &SuccNode : NewNode->successors()) {
auto &Predecessors = SuccNode.MPredecessors;

// Remove the subgraph node from this nodes successors
Predecessors.erase(
std::remove(Predecessors.begin(), Predecessors.end(), NewNode.get()),
Predecessors.end());
// Remove the subgraph node from this nodes successors
Predecessors.erase(std::remove(Predecessors.begin(), Predecessors.end(),
NewNode.get()),
Predecessors.end());

// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (node_impl *Output : Outputs) {
Output->registerSuccessor(SuccNode);
// Add all Output nodes from the subgraph as predecessors for this node
// instead
for (node_impl *Output : Outputs) {
Output->registerSuccessor(SuccNode);
}
}
}

// Remove single subgraph node and add all new individual subgraph nodes
// to the node storage in its place
auto OldPositionIt =
NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode));
// Also set the iterator to the newly added nodes so we can continue
// iterating over all remaining nodes
auto InsertIt = NewNodes.insert(
OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()),
std::make_move_iterator(NewSubgraphNodes.end()));
// Since the new reverse_iterator will be at i - 1 we need to advance it
// when constructing
NewNodeIt = std::make_reverse_iterator(std::next(InsertIt));
// Remove single subgraph node and add all new individual subgraph nodes
// to the node storage in its place
auto OldPositionIt =
NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode));
// Also set the iterator to the newly added nodes so we can continue
// iterating over all remaining nodes
auto InsertIt = NewNodes.insert(
OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()),
std::make_move_iterator(NewSubgraphNodes.end()));
// Since the new reverse_iterator will be at i - 1 we need to advance it
// when constructing
NewNodeIt = std::make_reverse_iterator(std::next(InsertIt));
}
}

// Store all the new nodes locally
MNodeStorage.insert(MNodeStorage.begin(),
std::make_move_iterator(NewNodes.begin()),
std::make_move_iterator(NewNodes.end()));
MNodeStorage = std::move(NewNodes);
}

void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
Expand Down