diff --git a/sycl/source/detail/graph/graph_impl.cpp b/sycl/source/detail/graph/graph_impl.cpp index ad5058361b6d4..b7d0a02f8f57b 100644 --- a/sycl/source/detail/graph/graph_impl.cpp +++ b/sycl/source/detail/graph/graph_impl.cpp @@ -1255,15 +1255,22 @@ exec_graph_impl::enqueue(sycl::detail::queue_impl &Queue, void exec_graph_impl::duplicateNodes() { // Map of original modifiable nodes (keys) to new duplicated nodes (values) - std::map NodesMap; - + std::unordered_map NodesMap; nodes_range ModifiableNodes{MGraphImpl->MNodeStorage}; - std::deque> NewNodes; + std::vector> NewNodes; + + const size_t NodeCount = ModifiableNodes.size(); + NodesMap.reserve(NodeCount); + NewNodes.reserve(NodeCount); + + bool foundSubgraph = false; for (node_impl &OriginalNode : ModifiableNodes) { NewNodes.push_back(std::make_shared(OriginalNode)); node_impl &NodeCopy = *NewNodes.back(); + foundSubgraph |= (NodeCopy.MNodeType == node_type::subgraph); + // Associate the ID of the original node with the node copy for later quick // access MIDCache.insert(std::make_pair(OriginalNode.MID, &NodeCopy)); @@ -1292,110 +1299,109 @@ void exec_graph_impl::duplicateNodes() { // Subgraph nodes need special handling, we extract all subgraph nodes and // merge them into the main node list + if (foundSubgraph) { + for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend(); + ++NewNodeIt) { + auto NewNode = *NewNodeIt; + if (NewNode->MNodeType != node_type::subgraph) { + continue; + } + nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage}; + std::deque> NewSubgraphNodes{}; + + // Map of original subgraph nodes (keys) to new duplicated nodes (values) + std::map SubgraphNodesMap; + + // Copy subgraph nodes + for (node_impl &SubgraphNode : SubgraphNodes) { + NewSubgraphNodes.push_back(std::make_shared(SubgraphNode)); + node_impl &NodeCopy = *NewSubgraphNodes.back(); + // Associate the ID of the original subgraph node with all extracted + // node copies for future quick access. + MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy)); + + SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy}); + NodeCopy.MSuccessors.clear(); + NodeCopy.MPredecessors.clear(); + } - for (auto NewNodeIt = NewNodes.rbegin(); NewNodeIt != NewNodes.rend(); - ++NewNodeIt) { - auto NewNode = *NewNodeIt; - if (NewNode->MNodeType != node_type::subgraph) { - continue; - } - nodes_range SubgraphNodes{NewNode->MSubGraphImpl->MNodeStorage}; - std::deque> NewSubgraphNodes{}; - - // Map of original subgraph nodes (keys) to new duplicated nodes (values) - std::map SubgraphNodesMap; - - // Copy subgraph nodes - for (node_impl &SubgraphNode : SubgraphNodes) { - NewSubgraphNodes.push_back(std::make_shared(SubgraphNode)); - node_impl &NodeCopy = *NewSubgraphNodes.back(); - // Associate the ID of the original subgraph node with all extracted node - // copies for future quick access. - MIDCache.insert(std::make_pair(SubgraphNode.MID, &NodeCopy)); - - SubgraphNodesMap.insert({&SubgraphNode, &NodeCopy}); - NodeCopy.MSuccessors.clear(); - NodeCopy.MPredecessors.clear(); - } - - // Rebuild edges for new subgraph nodes - auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end(); - for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd; - ++OrigIt, ++NewIt) { - node_impl &SubgraphNode = *OrigIt; - node_impl &NodeCopy = **NewIt; + // Rebuild edges for new subgraph nodes + auto OrigIt = SubgraphNodes.begin(), OrigEnd = SubgraphNodes.end(); + for (auto NewIt = NewSubgraphNodes.begin(); OrigIt != OrigEnd; + ++OrigIt, ++NewIt) { + node_impl &SubgraphNode = *OrigIt; + node_impl &NodeCopy = **NewIt; - for (node_impl &NextNode : SubgraphNode.successors()) { - node_impl &Successor = *SubgraphNodesMap.at(&NextNode); - NodeCopy.registerSuccessor(Successor); + for (node_impl &NextNode : SubgraphNode.successors()) { + node_impl &Successor = *SubgraphNodesMap.at(&NextNode); + NodeCopy.registerSuccessor(Successor); + } } - } - // Collect input and output nodes for the subgraph - std::vector Inputs; - std::vector Outputs; - for (std::shared_ptr &NodeImpl : NewSubgraphNodes) { - if (NodeImpl->MPredecessors.size() == 0) { - Inputs.push_back(&*NodeImpl); - } - if (NodeImpl->MSuccessors.size() == 0) { - Outputs.push_back(&*NodeImpl); + // Collect input and output nodes for the subgraph + std::vector Inputs; + std::vector Outputs; + for (std::shared_ptr &NodeImpl : NewSubgraphNodes) { + if (NodeImpl->MPredecessors.size() == 0) { + Inputs.push_back(&*NodeImpl); + } + if (NodeImpl->MSuccessors.size() == 0) { + Outputs.push_back(&*NodeImpl); + } } - } - // Update the predecessors and successors of the nodes which reference the - // original subgraph node + // Update the predecessors and successors of the nodes which reference the + // original subgraph node - // Predecessors - for (node_impl &PredNode : NewNode->predecessors()) { - auto &Successors = PredNode.MSuccessors; + // Predecessors + for (node_impl &PredNode : NewNode->predecessors()) { + auto &Successors = PredNode.MSuccessors; - // Remove the subgraph node from this nodes successors - Successors.erase( - std::remove(Successors.begin(), Successors.end(), NewNode.get()), - Successors.end()); + // Remove the subgraph node from this nodes successors + Successors.erase( + std::remove(Successors.begin(), Successors.end(), NewNode.get()), + Successors.end()); - // Add all input nodes from the subgraph as successors for this node - // instead - for (node_impl *Input : Inputs) { - PredNode.registerSuccessor(*Input); + // Add all input nodes from the subgraph as successors for this node + // instead + for (node_impl *Input : Inputs) { + PredNode.registerSuccessor(*Input); + } } - } - // Successors - for (node_impl &SuccNode : NewNode->successors()) { - auto &Predecessors = SuccNode.MPredecessors; + // Successors + for (node_impl &SuccNode : NewNode->successors()) { + auto &Predecessors = SuccNode.MPredecessors; - // Remove the subgraph node from this nodes successors - Predecessors.erase( - std::remove(Predecessors.begin(), Predecessors.end(), NewNode.get()), - Predecessors.end()); + // Remove the subgraph node from this nodes successors + Predecessors.erase(std::remove(Predecessors.begin(), Predecessors.end(), + NewNode.get()), + Predecessors.end()); - // Add all Output nodes from the subgraph as predecessors for this node - // instead - for (node_impl *Output : Outputs) { - Output->registerSuccessor(SuccNode); + // Add all Output nodes from the subgraph as predecessors for this node + // instead + for (node_impl *Output : Outputs) { + Output->registerSuccessor(SuccNode); + } } - } - // Remove single subgraph node and add all new individual subgraph nodes - // to the node storage in its place - auto OldPositionIt = - NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode)); - // Also set the iterator to the newly added nodes so we can continue - // iterating over all remaining nodes - auto InsertIt = NewNodes.insert( - OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()), - std::make_move_iterator(NewSubgraphNodes.end())); - // Since the new reverse_iterator will be at i - 1 we need to advance it - // when constructing - NewNodeIt = std::make_reverse_iterator(std::next(InsertIt)); + // Remove single subgraph node and add all new individual subgraph nodes + // to the node storage in its place + auto OldPositionIt = + NewNodes.erase(std::find(NewNodes.begin(), NewNodes.end(), NewNode)); + // Also set the iterator to the newly added nodes so we can continue + // iterating over all remaining nodes + auto InsertIt = NewNodes.insert( + OldPositionIt, std::make_move_iterator(NewSubgraphNodes.begin()), + std::make_move_iterator(NewSubgraphNodes.end())); + // Since the new reverse_iterator will be at i - 1 we need to advance it + // when constructing + NewNodeIt = std::make_reverse_iterator(std::next(InsertIt)); + } } // Store all the new nodes locally - MNodeStorage.insert(MNodeStorage.begin(), - std::make_move_iterator(NewNodes.begin()), - std::make_move_iterator(NewNodes.end())); + MNodeStorage = std::move(NewNodes); } void exec_graph_impl::update(std::shared_ptr GraphImpl) {