Refactor rollout topology binding#1933
Conversation
53c3a19 to
71a0a3a
Compare
|
@claude review |
| @dataclass(frozen=True) | ||
| class RolloutTopology: | ||
| """Immutable rollout engine layout after dist-init addresses are resolved. | ||
|
|
||
| Actor handles, server URLs, session URLs, and lifecycle state belong to RolloutWorkerRegistry. | ||
| """ | ||
|
|
||
| engines: tuple[_Engine, ...] | ||
| training_engine_mesh: tuple[tuple[int, ...], ...] |
There was a problem hiding this comment.
Claude: [Warning] The old validate_engine_launch_specs() (~90 lines) enforced a comprehensive set of invariants: no empty engines, no duplicate ranks within/across engines, server ranks ⊆ engine ranks, no duplicate bundle assignments, nnodes >= 1, node_rank < nnodes, at least one request entrypoint per engine, etc.
None of these checks exist on the new RolloutTopology. Invalid topologies will produce opaque downstream failures — e.g., engine.engine_ranks.index(server.worker_rank) in server_launch_specs() raises a bare ValueError if a server rank is not in engine_ranks.
Consider adding a __post_init__ with at least the key structural invariants (no duplicate ranks across engines, server ranks ⊆ engine ranks, at least one server per engine). Even a subset of the old checks would catch misconfiguration early with clear error messages instead of cryptic runtime failures.
|
Claude: ## Review Summary Clean refactor that converges scattered topology fields into ProduceBatchResult impact: not affected IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES — the removed session_url validation is a silent safety regression that should be restored. |
|
@claude review |
| lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) | ||
| for server in engine.server_processes: | ||
| if server.worker_rank in server_process_by_rank: | ||
| raise ValueError(f"Duplicate rollout server process worker_rank={server.worker_rank}.") | ||
| server_process_by_rank[server.worker_rank] = server | ||
| lifecycle_group_by_rank[server.worker_rank] = lifecycle_group | ||
|
|
||
| object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) | ||
| object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) | ||
|
|
||
| @staticmethod | ||
| def engine( | ||
| *, | ||
| engine_ranks: tuple[int, ...], | ||
| dist_init_addr: str, | ||
| server_processes: tuple[_ServerProcess, ...], |
There was a problem hiding this comment.
Claude: The __post_init__ validates duplicate server process ranks, which is good — but the old validate_engine_launch_specs also checked these invariants that are now unguarded:
- Server process
worker_rankmust be part of its engine'sengine_ranks - Each engine must have at least one request entrypoint (
accepts_rollout_requests=True) - No duplicate engine ranks across engines
- Valid
node_rankbounds (0 <= node_rank < nnodes)
Consider adding at least the first two checks — they catch backend construction bugs early with clear error messages, rather than surfacing as confusing downstream failures (e.g., is_request_entrypoint_rank returning False for all ranks in an engine).
| lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) | |
| for server in engine.server_processes: | |
| if server.worker_rank in server_process_by_rank: | |
| raise ValueError(f"Duplicate rollout server process worker_rank={server.worker_rank}.") | |
| server_process_by_rank[server.worker_rank] = server | |
| lifecycle_group_by_rank[server.worker_rank] = lifecycle_group | |
| object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) | |
| object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) | |
| @staticmethod | |
| def engine( | |
| *, | |
| engine_ranks: tuple[int, ...], | |
| dist_init_addr: str, | |
| server_processes: tuple[_ServerProcess, ...], | |
| def __post_init__(self) -> None: | |
| server_process_by_rank: dict[int, _ServerProcess] = {} | |
| lifecycle_group_by_rank: dict[int, tuple[int, ...]] = {} | |
| for engine in self.engines: | |
| lifecycle_group = tuple(server.worker_rank for server in engine.server_processes) | |
| engine_rank_set = set(engine.engine_ranks) | |
| has_entrypoint = False | |
| for server in engine.server_processes: | |
| if server.worker_rank in server_process_by_rank: | |
| raise ValueError(f"Duplicate rollout server process worker_rank={server.worker_rank}.") | |
| if server.worker_rank not in engine_rank_set: | |
| raise ValueError( | |
| f"Server worker_rank={server.worker_rank} is not part of " | |
| f"engine_ranks={engine.engine_ranks}." | |
| ) | |
| if server.accepts_rollout_requests: | |
| has_entrypoint = True | |
| server_process_by_rank[server.worker_rank] = server | |
| lifecycle_group_by_rank[server.worker_rank] = lifecycle_group | |
| if not has_entrypoint: | |
| raise ValueError( | |
| f"Engine with ranks={engine.engine_ranks} must have at least one " | |
| f"request entrypoint (accepts_rollout_requests=True)." | |
| ) | |
| object.__setattr__(self, "_server_process_by_rank", server_process_by_rank) | |
| object.__setattr__(self, "_lifecycle_group_by_rank", lifecycle_group_by_rank) |
| with self._lock: | ||
| return tuple(sorted(self._rollout_topology.lifecycle_groups())) | ||
|
|
||
| def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: |
There was a problem hiding this comment.
Claude: The _build_worker_groups method calls lifecycle_group_for_server_rank(worker.rank) which raises KeyError if the rank isn't in the topology. While registered ranks should always match the topology's server_launch_specs in normal operation, a defensive guard would make debugging easier if registrations and topology drift:
| def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: | |
| def _build_worker_groups(self) -> dict[tuple[int, ...], WorkerGroup]: | |
| grouped_ranks: set[tuple[int, ...]] = set() | |
| for worker in self._workers.values(): | |
| try: | |
| grouped_ranks.add(self._rollout_topology.lifecycle_group_for_server_rank(worker.rank)) | |
| except KeyError: | |
| grouped_ranks.add((worker.rank,)) |
Alternatively, leave as-is if you prefer a hard failure — the current KeyError message from lifecycle_group_for_server_rank is reasonably descriptive.
| self.config, | ||
| rank_bundle_idx_list, | ||
| rank_to_dist_init_addr, | ||
| ) |
There was a problem hiding this comment.
Claude: Nit — rank_to_worker is built by iterating zip(workers, dist_init_results, strict=True). This correctly pairs each worker actor with the (rank, dist_init_addr) tuple it returned. Compared to the old code that relied on index-aligned ordering between workers and rank_bundle_idx_list, the new approach is safer since it uses the rank returned by the worker itself.
However, the variable name _dist_init_addr in the comprehension could be clearer — it looks like a bug at first glance (why discard?). A brief comment would help:
| ) | |
| rank_to_worker = { | |
| rank: worker | |
| for worker, (rank, _dist_init_addr) in zip(workers, dist_init_results, strict=True) # rank from worker | |
| } |
|
Claude: ## Review Summary The follow-up commit (patch 2/2) addresses all critical and major findings from the prior review:
The refactor is well-structured. ProduceBatchResult impact: not affected Remaining IssuesWarning
Nit
VerdictAPPROVE — the core design is sound and all prior critical findings are resolved. The remaining items are defensive suggestions, not blockers. |
Summary
这个 PR 重构了 rollout worker 启动时的拓扑表达方式,新增
RolloutTopology,并用ServerLaunchSpec作为 worker 启动 server 的输入。核心目标是把 rollout 启动拓扑相关逻辑从 controller / worker 的散装字段里收敛出来:
RolloutTopology表达 rollout 启动拓扑_Engine表达一个 logical inference engine_ServerProcess表达一个实际会启动 server 的 worker processServerLaunchSpec表达传给 worker 的启动参数另外,额外修改进行以下修改:
RolloutHealthManager中将start等函数命名为start_background_checks,明确为后台检查RolloutWorkerEndpointMetadata来统一rank, server_url, session_url 及 lifecycle_state,用于替换rank_to_serverl_url, rank_to_session_url及其status本PR暂未修改:
RolloutWorkerMetadata中增加to_legacy方法,适配原来的权重更新相关接口的输入,等权重更新重构的PR合入后,再进行修改关键结构
RolloutTopology: rollout worker 启动时使用的整体拓扑。它包含:
_Engine(内部结构):表示一个 logical inference engine。它包含:
engine_ranks:这个 engine 覆盖哪些 worker rankdist_init_addr:这个 engine 内部通信使用的 rendezvous 地址server_processes:这个 engine 里哪些 worker 会启动 rollout server_ServerProcess(内部结构):表示一个 worker 上实际要启动的 server process。它包含:
worker_rankplacement_group_bundle_idxsaccepts_rollout_requestsnode_ranknnodes例如:
RolloutTopology中ServerLaunchSpec:从RolloutTopology投影出来、真正传给 worker 启动 infer server的变量-worker 不再需要理解完整 topology,只需要保存自己的
ServerLaunchSpec。init_worker流程
RolloutController 先通过 placement group 创建所有 RolloutWorker actors。
每个 worker 调用 init_dist_port(),申请本地端口,并返回 (rank, dist_init_addr)。
controller 汇总得到 rank_to_dist_init_addr。
backend worker class 根据:
构造 RolloutTopology。
RolloutTopology 根据内部 _Engine / _ServerProcess 生成 ServerLaunchSpec 列表。
controller 对每个 ServerLaunchSpec 找到对应 worker,调用:
worker.bind_server_launch_spec.remote(launch_spec)