@@ -200,6 +200,7 @@ def deallocate(self, node: Node) -> None:
200200 if layer_id in self .layer_to_load :
201201 self .layer_to_load [layer_id ].remove_node (node )
202202 node .clear_layer_allocation ()
203+ node .is_active = False
203204 self ._update_layer_loads_heap ()
204205
205206 def declare (self , node : Node ) -> None :
@@ -216,6 +217,7 @@ def join(self, node: Node) -> None:
216217 logger .debug ("Joining node dynamically: %s" , node .node_id )
217218 self .declare (node )
218219 lightest_layer = self .get_lightest_layer ()
220+ logger .debug ("Lightest layer: %s" , lightest_layer )
219221 if lightest_layer is None :
220222 raise ValueError ("No layers to assign" )
221223
@@ -529,39 +531,44 @@ def _adjust_end_layer_for_tail(self, node: Node, proposed_start_layer: int) -> i
529531
530532 return end_layer
531533
532- def has_full_pipeline (self ) -> bool :
534+ def has_full_pipeline (self , active_only : bool = False ) -> bool :
533535 """Return True if there exists at least one pipeline covering [0, num_total_layers).
534536
535537 Checks whether we can chain contiguous node allocations starting at 0 to reach L.
538+ This requires that there exists at least one node starting at layer 0 and a chain
539+ of contiguous node ranges that reaches num_total_layers.
536540 """
537541 total_layers = self .num_total_layers
538- layer_count : Dict [int , int ] = {}
539- for _ , (s , e ) in self .node_allocation .items ():
540- if s is None or e is None :
541- continue
542- for layer in range (s , e ):
543- layer_count [layer ] = layer_count .get (layer , 0 ) + 1
544-
545- for layer in range (total_layers ):
546- if layer not in layer_count or layer_count [layer ] == 0 :
547- return False
548- return True
549542
550- def has_full_active_pipeline (self ) -> bool :
551- """Return True if there exists at least one active pipeline covering [0, num_total_layers)."""
552- total_layers = self .num_total_layers
553- layer_count : Dict [int , int ] = {}
543+ # Build index of nodes by start_layer
544+ start_to_nodes : Dict [int , List [Node ]] = {}
554545 for node_id , (s , e ) in self .node_allocation .items ():
555- if self .node_id_to_node [node_id ].is_active is False :
556- continue
557546 if s is None or e is None :
558547 continue
559- for layer in range (s , e ):
560- layer_count [layer ] = layer_count .get (layer , 0 ) + 1
561- for layer in range (total_layers ):
562- if layer not in layer_count or layer_count [layer ] == 0 :
563- return False
564- return True
548+ node = self .node_id_to_node .get (node_id )
549+ if node is None or (active_only and not node .is_active ):
550+ continue
551+ start_to_nodes .setdefault (s , []).append (node )
552+
553+ # Must have at least one node starting at layer 0
554+ if not start_to_nodes .get (0 ):
555+ return False
556+
557+ # DFS to check if we can reach total_layers from any head node
558+ def can_reach_target (current_end : int ) -> bool :
559+ if current_end >= total_layers :
560+ return current_end == total_layers
561+
562+ for nxt in start_to_nodes .get (current_end , []):
563+ if nxt .end_layer and nxt .end_layer > current_end :
564+ if can_reach_target (nxt .end_layer ):
565+ return True
566+ return False
567+
568+ return any (
569+ head .end_layer and can_reach_target (head .end_layer )
570+ for head in start_to_nodes .get (0 , [])
571+ )
565572
566573 def layer_replication_stats (self ) -> Tuple [int , int , float ]:
567574 """Return (min, max, avg) number of nodes hosting each layer.
0 commit comments