Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def get_schedule_status(self):
# todo rebalance status
status = (
NODE_STATUS_AVAILABLE
if self.scheduler.layer_allocator.has_full_active_pipeline()
if self.scheduler.layer_allocator.has_full_pipeline(active_only=True)
else NODE_STATUS_WAITING
)
logger.debug(f"SchedulerManage status queried: {status}")
Expand Down
55 changes: 31 additions & 24 deletions src/scheduling/layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def deallocate(self, node: Node) -> None:
if layer_id in self.layer_to_load:
self.layer_to_load[layer_id].remove_node(node)
node.clear_layer_allocation()
node.is_active = False
self._update_layer_loads_heap()

def declare(self, node: Node) -> None:
Expand All @@ -216,6 +217,7 @@ def join(self, node: Node) -> None:
logger.debug("Joining node dynamically: %s", node.node_id)
self.declare(node)
lightest_layer = self.get_lightest_layer()
logger.debug("Lightest layer: %s", lightest_layer)
if lightest_layer is None:
raise ValueError("No layers to assign")

Expand Down Expand Up @@ -529,39 +531,44 @@ def _adjust_end_layer_for_tail(self, node: Node, proposed_start_layer: int) -> i

return end_layer

def has_full_pipeline(self) -> bool:
def has_full_pipeline(self, active_only: bool = False) -> bool:
"""Return True if there exists at least one pipeline covering [0, num_total_layers).

Checks whether we can chain contiguous node allocations starting at 0 to reach L.
This requires that there exists at least one node starting at layer 0 and a chain
of contiguous node ranges that reaches num_total_layers.
"""
total_layers = self.num_total_layers
layer_count: Dict[int, int] = {}
for _, (s, e) in self.node_allocation.items():
if s is None or e is None:
continue
for layer in range(s, e):
layer_count[layer] = layer_count.get(layer, 0) + 1

for layer in range(total_layers):
if layer not in layer_count or layer_count[layer] == 0:
return False
return True

def has_full_active_pipeline(self) -> bool:
"""Return True if there exists at least one active pipeline covering [0, num_total_layers)."""
total_layers = self.num_total_layers
layer_count: Dict[int, int] = {}
# Build index of nodes by start_layer
start_to_nodes: Dict[int, List[Node]] = {}
for node_id, (s, e) in self.node_allocation.items():
if self.node_id_to_node[node_id].is_active is False:
continue
if s is None or e is None:
continue
for layer in range(s, e):
layer_count[layer] = layer_count.get(layer, 0) + 1
for layer in range(total_layers):
if layer not in layer_count or layer_count[layer] == 0:
return False
return True
node = self.node_id_to_node.get(node_id)
if node is None or (active_only and not node.is_active):
continue
start_to_nodes.setdefault(s, []).append(node)

# Must have at least one node starting at layer 0
if not start_to_nodes.get(0):
return False

# DFS to check if we can reach total_layers from any head node
def can_reach_target(current_end: int) -> bool:
if current_end >= total_layers:
return current_end == total_layers

for nxt in start_to_nodes.get(current_end, []):
if nxt.end_layer and nxt.end_layer > current_end:
if can_reach_target(nxt.end_layer):
return True
return False

return any(
head.end_layer and can_reach_target(head.end_layer)
for head in start_to_nodes.get(0, [])
)

def layer_replication_stats(self) -> Tuple[int, int, float]:
"""Return (min, max, avg) number of nodes hosting each layer.
Expand Down
6 changes: 0 additions & 6 deletions src/scheduling/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,6 @@ def decoder_layer_io_bytes(
ffn_params *= self.num_local_experts
kv_cache_size = 0

logger.debug(
"Model Info ffn_params=%d, kv_cache_size=%d, attention_params=%d",
ffn_params,
kv_cache_size,
attention_params,
)
return round(ffn_params + kv_cache_size + attention_params)

def lm_head_flops(self, target_seq_len: int = 1) -> int:
Expand Down
5 changes: 0 additions & 5 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,6 @@ def get_decoder_layer_capacity(
if not (include_input_embed and self.model_info.tie_embedding):
available_memory_bytes -= self.model_info.embedding_io_bytes

logger.debug(
"Node available_memory_bytes=%d, decoder_layer_io_bytes=%d",
available_memory_bytes,
self.model_info.decoder_layer_io_bytes(roofline=False),
)
if self.hardware.device == "mlx":
# For mlx, consider mlx bit factor
return floor(
Expand Down
8 changes: 7 additions & 1 deletion src/scheduling/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,13 @@ def leave(self, node_id: str) -> None:
for n in self.nodes:
if n.start_layer is not None and n.end_layer is not None:
self.layer_allocator.deallocate(n)
self.layer_allocator.global_allocation()
success = self.layer_allocator.global_allocation()
if not success:
logger.warning("Global rebalance failed to produce a full pipeline")
else:
logger.debug("Global rebalance completed successfully")
self._bootstrapped = True
self._bootstrapped_event.set()

with self._node_count_cv:
self._node_count_cv.notify_all()
Expand Down