3 files changed
@@ -223,6 +223,7 @@ class RecoveryResult: | |||
| 223 | 223 | recovered_completed: list = field(default_factory=list) | |
| 224 | 224 | still_running: list = field(default_factory=list) | |
| 225 | 225 | recovered_failed: list = field(default_factory=list) | |
| 226 | + retried: list = field(default_factory=list) | ||
| 226 | 227 | ssh_unreachable: bool = False | |
| 227 | 228 | needs_monitor: bool = False | |
| 228 | 229 | progress: dict = field(default_factory=dict) | |
@@ -233,14 +234,57 @@ def get_running_tasks(state: ExperimentState) -> list[str]: | |||
| 233 | 234 | return [tid for tid, info in state.tasks.items() if info.get("status") == "running"] | |
| 234 | 235 | ||
| 235 | 236 | ||
| 237 | + # Keywords indicating a potentially recoverable failure (OOM, SSH, timeout, etc.) | ||
| 238 | + _RECOVERABLE_ERROR_KEYWORDS = ( | ||
| 239 | + "oom", "out of memory", "cuda out of memory", | ||
| 240 | + "ssh", "connection", "timeout", "timed out", | ||
| 241 | + "broken pipe", "reset by peer", "errno", | ||
| 242 | + ) | ||
| 243 | + | ||
| 244 | + _DEFAULT_MAX_RETRIES = 1 | ||
| 245 | + | ||
| 246 | + | ||
| 247 | + def _is_recoverable_failure(error_summary: str) -> bool: | ||
| 248 | + """Check whether a failure reason looks recoverable (OOM, SSH, timeout, etc.). | ||
| 249 | + | ||
| 250 | + If the error summary is empty or cannot be classified, returns True | ||
| 251 | + (default to retrying once rather than giving up immediately). | ||
| 252 | + """ | ||
| 253 | + if not error_summary: | ||
| 254 | + return True # unknown failure → default to retry | ||
| 255 | + lower = error_summary.lower() | ||
| 256 | + for kw in _RECOVERABLE_ERROR_KEYWORDS: | ||
| 257 | + if kw in lower: | ||
| 258 | + return True | ||
| 259 | + # If we have an explicit error summary but it doesn't match known keywords, | ||
| 260 | + # still default to retry since we can't be sure it's permanent | ||
| 261 | + return True | ||
| 262 | + | ||
| 263 | + | ||
| 264 | + def _should_retry_task(task: dict, max_retries: int = _DEFAULT_MAX_RETRIES) -> bool: | ||
| 265 | + """Determine whether a failing task should be retried.""" | ||
| 266 | + retry_count = task.get("retry_count", 0) | ||
| 267 | + if retry_count >= max_retries: | ||
| 268 | + return False | ||
| 269 | + error_summary = task.get("error_summary", "") | ||
| 270 | + return _is_recoverable_failure(error_summary) | ||
| 271 | + | ||
| 272 | + | ||
| 236 | 273 | def recover_from_detection( | |
| 237 | - state: ExperimentState, detection: dict | ||
| 274 | + state: ExperimentState, detection: dict, | ||
| 275 | + *, max_retries: int = _DEFAULT_MAX_RETRIES, | ||
| 238 | 276 | ) -> RecoveryResult: | |
| 239 | 277 | """Apply detection results to experiment state in-place. | |
| 240 | 278 | ||
| 279 | + Failed tasks with ``retry_count < max_retries`` are automatically reset | ||
| 280 | + to ``pending`` for re-scheduling. Only tasks that have exhausted their | ||
| 281 | + retries (or whose failure is deemed non-recoverable) are finally marked | ||
| 282 | + ``failed``. | ||
| 283 | + | ||
| 241 | 284 | Args: | |
| 242 | 285 | state: ExperimentState to update (modified in-place) | |
| 243 | 286 | detection: Output from parse_detection_output() | |
| 287 | + max_retries: Maximum automatic retries per task (default 1) | ||
| 244 | 288 | ||
| 245 | 289 | Returns: | |
| 246 | 290 | RecoveryResult summarizing what happened | |
@@ -251,36 +295,72 @@ def recover_from_detection( | |||
| 251 | 295 | ||
| 252 | 296 | for task_id, info in detection.items(): | |
| 253 | 297 | status = info.get("detected_status", "unknown") | |
| 298 | + task = state.tasks.setdefault(task_id, {}) | ||
| 254 | 299 | ||
| 255 | 300 | if status == "done": | |
| 256 | 301 | done_info = info.get("done_info", {}) | |
| 257 | 302 | exit_code = done_info.get("exit_code", 0) | |
| 258 | 303 | if exit_code == 0: | |
| 259 | - state.tasks[task_id]["status"] = "completed" | ||
| 304 | + task["status"] = "completed" | ||
| 260 | 305 | result.recovered_completed.append(task_id) | |
| 261 | 306 | log_entries.append(f"[{now}] {task_id}: recovered as completed") | |
| 262 | 307 | else: | |
| 263 | - state.tasks[task_id]["status"] = "failed" | ||
| 264 | - result.recovered_failed.append(task_id) | ||
| 265 | - log_entries.append( | ||
| 266 | - f"[{now}] {task_id}: recovered as failed (exit_code={exit_code})" | ||
| 267 | - ) | ||
| 308 | + # Non-zero exit code — check if we should retry | ||
| 309 | + task["error_summary"] = f"exit_code={exit_code}" | ||
| 310 | + if _should_retry_task(task, max_retries): | ||
| 311 | + task["status"] = "pending" | ||
| 312 | + task["retry_count"] = task.get("retry_count", 0) + 1 | ||
| 313 | + result.retried.append(task_id) | ||
| 314 | + log_entries.append( | ||
| 315 | + f"[{now}] {task_id}: exit_code={exit_code}, " | ||
| 316 | + f"retry #{task['retry_count']} (reset to pending)" | ||
| 317 | + ) | ||
| 318 | + else: | ||
| 319 | + task["status"] = "failed" | ||
| 320 | + result.recovered_failed.append(task_id) | ||
| 321 | + log_entries.append( | ||
| 322 | + f"[{now}] {task_id}: recovered as failed " | ||
| 323 | + f"(exit_code={exit_code}, retries exhausted)" | ||
| 324 | + ) | ||
| 268 | 325 | elif status == "running": | |
| 269 | 326 | result.still_running.append(task_id) | |
| 270 | 327 | result.progress[task_id] = info.get("progress", {}) | |
| 271 | 328 | elif status == "dead": | |
| 272 | - state.tasks[task_id]["status"] = "failed" | ||
| 273 | - state.tasks[task_id]["error_summary"] = "process_disappeared" | ||
| 274 | - result.recovered_failed.append(task_id) | ||
| 275 | - dead_pid = info.get("dead_pid", "?") | ||
| 276 | - log_entries.append( | ||
| 277 | - f"[{now}] {task_id}: process dead (pid={dead_pid}), marked failed" | ||
| 278 | - ) | ||
| 329 | + task["error_summary"] = "process_disappeared" | ||
| 330 | + if _should_retry_task(task, max_retries): | ||
| 331 | + task["status"] = "pending" | ||
| 332 | + task["retry_count"] = task.get("retry_count", 0) + 1 | ||
| 333 | + result.retried.append(task_id) | ||
| 334 | + dead_pid = info.get("dead_pid", "?") | ||
| 335 | + log_entries.append( | ||
| 336 | + f"[{now}] {task_id}: process dead (pid={dead_pid}), " | ||
| 337 | + f"retry #{task['retry_count']} (reset to pending)" | ||
| 338 | + ) | ||
| 339 | + else: | ||
| 340 | + task["status"] = "failed" | ||
| 341 | + result.recovered_failed.append(task_id) | ||
| 342 | + dead_pid = info.get("dead_pid", "?") | ||
| 343 | + log_entries.append( | ||
| 344 | + f"[{now}] {task_id}: process dead (pid={dead_pid}), " | ||
| 345 | + f"marked failed (retries exhausted)" | ||
| 346 | + ) | ||
| 279 | 347 | else: # unknown | |
| 280 | - state.tasks[task_id]["status"] = "failed" | ||
| 281 | - state.tasks[task_id]["error_summary"] = "unknown_status" | ||
| 282 | - result.recovered_failed.append(task_id) | ||
| 283 | - log_entries.append(f"[{now}] {task_id}: unknown status, marked failed") | ||
| 348 | + task["error_summary"] = "unknown_status" | ||
| 349 | + if _should_retry_task(task, max_retries): | ||
| 350 | + task["status"] = "pending" | ||
| 351 | + task["retry_count"] = task.get("retry_count", 0) + 1 | ||
| 352 | + result.retried.append(task_id) | ||
| 353 | + log_entries.append( | ||
| 354 | + f"[{now}] {task_id}: unknown status, " | ||
| 355 | + f"retry #{task['retry_count']} (reset to pending)" | ||
| 356 | + ) | ||
| 357 | + else: | ||
| 358 | + task["status"] = "failed" | ||
| 359 | + result.recovered_failed.append(task_id) | ||
| 360 | + log_entries.append( | ||
| 361 | + f"[{now}] {task_id}: unknown status, " | ||
| 362 | + f"marked failed (retries exhausted)" | ||
| 363 | + ) | ||
| 284 | 364 | ||
| 285 | 365 | result.needs_monitor = len(result.still_running) > 0 | |
| 286 | 366 | ||
0 commit comments