← 返回首页
refactor: optimize GPU scheduler — critical path priority, batch SSH,… · Sibyl-Research-Team/AutoResearch-SibylSystem@5624332 · GitHub
Skip to content

Navigation Menu

Toggle navigation
Sign in
Appearance settings
Search or jump to...

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Include my email address so I can be contacted

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Resetting focus

Commit 5624332

Browse files
refactor: optimize GPU scheduler — critical path priority, batch SSH, auto-retry
- Add compute_downstream_counts() for critical path priority in assign_gpus() - Merge stuck detection into single batched SSH call (was per-task SSH) - Add auto-retry for recoverable failures (OOM, SSH, timeout) with retry_count tracking - Make _load_progress() prefer experiment_state.json as authoritative source - Deprecate get_next_batch() as thin wrapper around get_batch_info()
1 parent 3c73b52 commit 5624332

3 files changed

Lines changed: 341 additions & 85 deletions

File tree

‎sibyl/experiment_recovery.py‎

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class RecoveryResult:
223223
recovered_completed: list = field(default_factory=list)
224224
still_running: list = field(default_factory=list)
225225
recovered_failed: list = field(default_factory=list)
226+
retried: list = field(default_factory=list)
226227
ssh_unreachable: bool = False
227228
needs_monitor: bool = False
228229
progress: dict = field(default_factory=dict)
@@ -233,14 +234,57 @@ def get_running_tasks(state: ExperimentState) -> list[str]:
233234
return [tid for tid, info in state.tasks.items() if info.get("status") == "running"]
234235

235236

237+
# Keywords indicating a potentially recoverable failure (OOM, SSH, timeout, etc.)
238+
_RECOVERABLE_ERROR_KEYWORDS = (
239+
"oom", "out of memory", "cuda out of memory",
240+
"ssh", "connection", "timeout", "timed out",
241+
"broken pipe", "reset by peer", "errno",
242+
)
243+
244+
_DEFAULT_MAX_RETRIES = 1
245+
246+
247+
def _is_recoverable_failure(error_summary: str) -> bool:
248+
"""Check whether a failure reason looks recoverable (OOM, SSH, timeout, etc.).
249+
250+
If the error summary is empty or cannot be classified, returns True
251+
(default to retrying once rather than giving up immediately).
252+
"""
253+
if not error_summary:
254+
return True # unknown failure → default to retry
255+
lower = error_summary.lower()
256+
for kw in _RECOVERABLE_ERROR_KEYWORDS:
257+
if kw in lower:
258+
return True
259+
# If we have an explicit error summary but it doesn't match known keywords,
260+
# still default to retry since we can't be sure it's permanent
261+
return True
262+
263+
264+
def _should_retry_task(task: dict, max_retries: int = _DEFAULT_MAX_RETRIES) -> bool:
265+
"""Determine whether a failing task should be retried."""
266+
retry_count = task.get("retry_count", 0)
267+
if retry_count >= max_retries:
268+
return False
269+
error_summary = task.get("error_summary", "")
270+
return _is_recoverable_failure(error_summary)
271+
272+
236273
def recover_from_detection(
237-
state: ExperimentState, detection: dict
274+
state: ExperimentState, detection: dict,
275+
*, max_retries: int = _DEFAULT_MAX_RETRIES,
238276
) -> RecoveryResult:
239277
"""Apply detection results to experiment state in-place.
240278
279+
Failed tasks with ``retry_count < max_retries`` are automatically reset
280+
to ``pending`` for re-scheduling. Only tasks that have exhausted their
281+
retries (or whose failure is deemed non-recoverable) are finally marked
282+
``failed``.
283+
241284
Args:
242285
state: ExperimentState to update (modified in-place)
243286
detection: Output from parse_detection_output()
287+
max_retries: Maximum automatic retries per task (default 1)
244288
245289
Returns:
246290
RecoveryResult summarizing what happened
@@ -251,36 +295,72 @@ def recover_from_detection(
251295

252296
for task_id, info in detection.items():
253297
status = info.get("detected_status", "unknown")
298+
task = state.tasks.setdefault(task_id, {})
254299

255300
if status == "done":
256301
done_info = info.get("done_info", {})
257302
exit_code = done_info.get("exit_code", 0)
258303
if exit_code == 0:
259-
state.tasks[task_id]["status"] = "completed"
304+
task["status"] = "completed"
260305
result.recovered_completed.append(task_id)
261306
log_entries.append(f"[{now}] {task_id}: recovered as completed")
262307
else:
263-
state.tasks[task_id]["status"] = "failed"
264-
result.recovered_failed.append(task_id)
265-
log_entries.append(
266-
f"[{now}] {task_id}: recovered as failed (exit_code={exit_code})"
267-
)
308+
# Non-zero exit code — check if we should retry
309+
task["error_summary"] = f"exit_code={exit_code}"
310+
if _should_retry_task(task, max_retries):
311+
task["status"] = "pending"
312+
task["retry_count"] = task.get("retry_count", 0) + 1
313+
result.retried.append(task_id)
314+
log_entries.append(
315+
f"[{now}] {task_id}: exit_code={exit_code}, "
316+
f"retry #{task['retry_count']} (reset to pending)"
317+
)
318+
else:
319+
task["status"] = "failed"
320+
result.recovered_failed.append(task_id)
321+
log_entries.append(
322+
f"[{now}] {task_id}: recovered as failed "
323+
f"(exit_code={exit_code}, retries exhausted)"
324+
)
268325
elif status == "running":
269326
result.still_running.append(task_id)
270327
result.progress[task_id] = info.get("progress", {})
271328
elif status == "dead":
272-
state.tasks[task_id]["status"] = "failed"
273-
state.tasks[task_id]["error_summary"] = "process_disappeared"
274-
result.recovered_failed.append(task_id)
275-
dead_pid = info.get("dead_pid", "?")
276-
log_entries.append(
277-
f"[{now}] {task_id}: process dead (pid={dead_pid}), marked failed"
278-
)
329+
task["error_summary"] = "process_disappeared"
330+
if _should_retry_task(task, max_retries):
331+
task["status"] = "pending"
332+
task["retry_count"] = task.get("retry_count", 0) + 1
333+
result.retried.append(task_id)
334+
dead_pid = info.get("dead_pid", "?")
335+
log_entries.append(
336+
f"[{now}] {task_id}: process dead (pid={dead_pid}), "
337+
f"retry #{task['retry_count']} (reset to pending)"
338+
)
339+
else:
340+
task["status"] = "failed"
341+
result.recovered_failed.append(task_id)
342+
dead_pid = info.get("dead_pid", "?")
343+
log_entries.append(
344+
f"[{now}] {task_id}: process dead (pid={dead_pid}), "
345+
f"marked failed (retries exhausted)"
346+
)
279347
else: # unknown
280-
state.tasks[task_id]["status"] = "failed"
281-
state.tasks[task_id]["error_summary"] = "unknown_status"
282-
result.recovered_failed.append(task_id)
283-
log_entries.append(f"[{now}] {task_id}: unknown status, marked failed")
348+
task["error_summary"] = "unknown_status"
349+
if _should_retry_task(task, max_retries):
350+
task["status"] = "pending"
351+
task["retry_count"] = task.get("retry_count", 0) + 1
352+
result.retried.append(task_id)
353+
log_entries.append(
354+
f"[{now}] {task_id}: unknown status, "
355+
f"retry #{task['retry_count']} (reset to pending)"
356+
)
357+
else:
358+
task["status"] = "failed"
359+
result.recovered_failed.append(task_id)
360+
log_entries.append(
361+
f"[{now}] {task_id}: unknown status, "
362+
f"marked failed (retries exhausted)"
363+
)
284364

285365
result.needs_monitor = len(result.still_running) > 0
286366

0 commit comments

Comments
 (0)

Footer

© 2026 GitHub, Inc.