Skip to content

Commit 6c4560f

Browse files
JasonOEjason
andauthored
feat(scheduler): scheduler change after node leave (#194)
Co-authored-by: jason <jl@gradient.network>
1 parent fd50821 commit 6c4560f

File tree

5 files changed

+39
-37
lines changed

5 files changed

+39
-37
lines changed

src/backend/server/scheduler_manage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def get_schedule_status(self):
247247
# todo rebalance status
248248
status = (
249249
NODE_STATUS_AVAILABLE
250-
if self.scheduler.layer_allocator.has_full_active_pipeline()
250+
if self.scheduler.layer_allocator.has_full_pipeline(active_only=True)
251251
else NODE_STATUS_WAITING
252252
)
253253
logger.debug(f"SchedulerManage status queried: {status}")

src/scheduling/layer_allocation.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def deallocate(self, node: Node) -> None:
200200
if layer_id in self.layer_to_load:
201201
self.layer_to_load[layer_id].remove_node(node)
202202
node.clear_layer_allocation()
203+
node.is_active = False
203204
self._update_layer_loads_heap()
204205

205206
def declare(self, node: Node) -> None:
@@ -216,6 +217,7 @@ def join(self, node: Node) -> None:
216217
logger.debug("Joining node dynamically: %s", node.node_id)
217218
self.declare(node)
218219
lightest_layer = self.get_lightest_layer()
220+
logger.debug("Lightest layer: %s", lightest_layer)
219221
if lightest_layer is None:
220222
raise ValueError("No layers to assign")
221223

@@ -529,39 +531,44 @@ def _adjust_end_layer_for_tail(self, node: Node, proposed_start_layer: int) -> i
529531

530532
return end_layer
531533

532-
def has_full_pipeline(self) -> bool:
534+
def has_full_pipeline(self, active_only: bool = False) -> bool:
533535
"""Return True if there exists at least one pipeline covering [0, num_total_layers).
534536
535537
Checks whether we can chain contiguous node allocations starting at 0 to reach L.
538+
This requires that there exists at least one node starting at layer 0 and a chain
539+
of contiguous node ranges that reaches num_total_layers.
536540
"""
537541
total_layers = self.num_total_layers
538-
layer_count: Dict[int, int] = {}
539-
for _, (s, e) in self.node_allocation.items():
540-
if s is None or e is None:
541-
continue
542-
for layer in range(s, e):
543-
layer_count[layer] = layer_count.get(layer, 0) + 1
544-
545-
for layer in range(total_layers):
546-
if layer not in layer_count or layer_count[layer] == 0:
547-
return False
548-
return True
549542

550-
def has_full_active_pipeline(self) -> bool:
551-
"""Return True if there exists at least one active pipeline covering [0, num_total_layers)."""
552-
total_layers = self.num_total_layers
553-
layer_count: Dict[int, int] = {}
543+
# Build index of nodes by start_layer
544+
start_to_nodes: Dict[int, List[Node]] = {}
554545
for node_id, (s, e) in self.node_allocation.items():
555-
if self.node_id_to_node[node_id].is_active is False:
556-
continue
557546
if s is None or e is None:
558547
continue
559-
for layer in range(s, e):
560-
layer_count[layer] = layer_count.get(layer, 0) + 1
561-
for layer in range(total_layers):
562-
if layer not in layer_count or layer_count[layer] == 0:
563-
return False
564-
return True
548+
node = self.node_id_to_node.get(node_id)
549+
if node is None or (active_only and not node.is_active):
550+
continue
551+
start_to_nodes.setdefault(s, []).append(node)
552+
553+
# Must have at least one node starting at layer 0
554+
if not start_to_nodes.get(0):
555+
return False
556+
557+
# DFS to check if we can reach total_layers from any head node
558+
def can_reach_target(current_end: int) -> bool:
559+
if current_end >= total_layers:
560+
return current_end == total_layers
561+
562+
for nxt in start_to_nodes.get(current_end, []):
563+
if nxt.end_layer and nxt.end_layer > current_end:
564+
if can_reach_target(nxt.end_layer):
565+
return True
566+
return False
567+
568+
return any(
569+
head.end_layer and can_reach_target(head.end_layer)
570+
for head in start_to_nodes.get(0, [])
571+
)
565572

566573
def layer_replication_stats(self) -> Tuple[int, int, float]:
567574
"""Return (min, max, avg) number of nodes hosting each layer.

src/scheduling/model_info.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,6 @@ def decoder_layer_io_bytes(
184184
ffn_params *= self.num_local_experts
185185
kv_cache_size = 0
186186

187-
logger.debug(
188-
"Model Info ffn_params=%d, kv_cache_size=%d, attention_params=%d",
189-
ffn_params,
190-
kv_cache_size,
191-
attention_params,
192-
)
193187
return round(ffn_params + kv_cache_size + attention_params)
194188

195189
def lm_head_flops(self, target_seq_len: int = 1) -> int:

src/scheduling/node.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,6 @@ def get_decoder_layer_capacity(
280280
if not (include_input_embed and self.model_info.tie_embedding):
281281
available_memory_bytes -= self.model_info.embedding_io_bytes
282282

283-
logger.debug(
284-
"Node available_memory_bytes=%d, decoder_layer_io_bytes=%d",
285-
available_memory_bytes,
286-
self.model_info.decoder_layer_io_bytes(roofline=False),
287-
)
288283
if self.hardware.device == "mlx":
289284
# For mlx, consider mlx bit factor
290285
return floor(

src/scheduling/scheduler.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,13 @@ def leave(self, node_id: str) -> None:
322322
for n in self.nodes:
323323
if n.start_layer is not None and n.end_layer is not None:
324324
self.layer_allocator.deallocate(n)
325-
self.layer_allocator.global_allocation()
325+
success = self.layer_allocator.global_allocation()
326+
if not success:
327+
logger.warning("Global rebalance failed to produce a full pipeline")
328+
else:
329+
logger.debug("Global rebalance completed successfully")
330+
self._bootstrapped = True
331+
self._bootstrapped_event.set()
326332

327333
with self._node_count_cv:
328334
self._node_count_cv.notify_all()

0 commit comments

Comments
 (0)