Перейти к содержанию

Checkpoint

Система чекпоинтов для возобновления pipeline после сбоев.

Возможности

  • Сохранение состояния на каждом шаге
  • Автоматическое восстановление
  • TTL для устаревших чекпоинтов
  • Поддержка вложенных данных
  • Атомарные операции

Использование

Базовое использование

from kit.pipeline.checkpoint import CheckpointManager

checkpoint = CheckpointManager(storage_path="./checkpoints")

# Сохранение состояния
await checkpoint.save("task_123", {
    "step": "rendering",
    "progress": 45,
    "frames_done": 120,
    "data": {"output_path": "/tmp/video.mp4"}
})

# Загрузка состояния
state = await checkpoint.load("task_123")
if state:
    print(f"Resuming from step: {state['step']}")

Восстановление после сбоя

async def process_video(task_id: str, input_path: str):
    checkpoint = CheckpointManager()

    # Проверяем существующий чекпоинт
    state = await checkpoint.load(task_id)

    if state:
        # Восстановление
        current_step = state["step"]
        data = state["data"]
        print(f"Resuming from step {current_step}")
    else:
        # Новая задача
        current_step = 0
        data = {"input": input_path}

    steps = [extract_audio, generate_script, create_images, render_video]

    for i, step in enumerate(steps[current_step:], start=current_step):
        # Выполняем шаг
        data = await step(data)

        # Сохраняем чекпоинт
        await checkpoint.save(task_id, {
            "step": i + 1,
            "data": data
        })

    # Очищаем чекпоинт после успеха
    await checkpoint.delete(task_id)

    return data

TTL для чекпоинтов

checkpoint = CheckpointManager(
    storage_path="./checkpoints",
    default_ttl=86400  # 24 часа
)

# Чекпоинт автоматически удалится через 24 часа
await checkpoint.save("task_123", {"step": 1})

# Кастомный TTL
await checkpoint.save("important_task", {"step": 1}, ttl=604800)  # 7 дней

Атомарные обновления

# Обновление части состояния
await checkpoint.update("task_123", {
    "progress": 50,
    "last_frame": 250
})

# Атомарный инкремент
await checkpoint.increment("task_123", "frames_done", 10)

Список чекпоинтов

# Все чекпоинты
all_checkpoints = await checkpoint.list_all()

# По паттерну
video_checkpoints = await checkpoint.find("video_*")

# С фильтром
pending = await checkpoint.find(
    "*",
    filter_fn=lambda k, v: v.get("status") == "pending"
)

API Reference

CheckpointManager

class CheckpointManager:
    def __init__(
        self,
        storage_path: str = "./checkpoints",
        default_ttl: int = None
    )

    async def save(self, key: str, state: Dict, ttl: int = None) -> None
    async def load(self, key: str) -> Optional[Dict]
    async def delete(self, key: str) -> bool
    async def exists(self, key: str) -> bool

    async def update(self, key: str, updates: Dict) -> None
    async def increment(self, key: str, field: str, amount: int = 1) -> int

    async def list_all(self) -> List[str]
    async def find(self, pattern: str, filter_fn: Callable = None) -> List[Tuple[str, Dict]]

    async def cleanup_expired(self) -> int  # Returns count of deleted

Паттерны использования

Pipeline с автоматическими чекпоинтами

class CheckpointedPipeline:
    def __init__(self, name: str):
        self.name = name
        self.checkpoint = CheckpointManager()
        self.steps = []

    def add_step(self, name: str, func: Callable):
        self.steps.append((name, func))

    async def run(self, task_id: str, initial_data: dict):
        # Загружаем состояние
        state = await self.checkpoint.load(f"{self.name}:{task_id}")

        if state:
            start_step = state["completed_steps"]
            data = state["data"]
        else:
            start_step = 0
            data = initial_data

        for i, (step_name, func) in enumerate(self.steps[start_step:], start=start_step):
            try:
                data = await func(data)

                # Сохраняем после каждого шага
                await self.checkpoint.save(f"{self.name}:{task_id}", {
                    "completed_steps": i + 1,
                    "current_step": step_name,
                    "data": data
                })

            except Exception as e:
                # Сохраняем ошибку
                await self.checkpoint.update(f"{self.name}:{task_id}", {
                    "error": str(e),
                    "failed_step": step_name
                })
                raise

        # Успех — удаляем чекпоинт
        await self.checkpoint.delete(f"{self.name}:{task_id}")
        return data

# Использование
pipeline = CheckpointedPipeline("video_generation")
pipeline.add_step("script", generate_script)
pipeline.add_step("audio", generate_audio)
pipeline.add_step("images", generate_images)
pipeline.add_step("render", render_video)

result = await pipeline.run("task_123", {"prompt": "..."})

Batch processing с чекпоинтами

class BatchProcessor:
    def __init__(self):
        self.checkpoint = CheckpointManager()

    async def process_batch(self, batch_id: str, items: List[dict]):
        state = await self.checkpoint.load(f"batch:{batch_id}")

        if state:
            processed = set(state["processed_ids"])
            results = state["results"]
        else:
            processed = set()
            results = []

        for item in items:
            if item["id"] in processed:
                continue

            result = await self.process_item(item)
            results.append(result)
            processed.add(item["id"])

            # Чекпоинт каждые 10 items
            if len(processed) % 10 == 0:
                await self.checkpoint.save(f"batch:{batch_id}", {
                    "processed_ids": list(processed),
                    "results": results
                })

        await self.checkpoint.delete(f"batch:{batch_id}")
        return results

Примеры из production

Autoshorts — генерация видео

class VideoGeneratorWithCheckpoints:
    def __init__(self):
        self.checkpoint = CheckpointManager(
            storage_path="./video_checkpoints",
            default_ttl=86400 * 7  # 7 дней
        )

    async def generate(self, task_id: str, params: dict):
        state = await self.checkpoint.load(task_id)

        if state and state.get("status") == "completed":
            return state["result"]

        steps = [
            ("script", self.generate_script),
            ("audio", self.generate_audio),
            ("images", self.generate_images),
            ("video", self.render_video),
            ("upload", self.upload_result)
        ]

        start_idx = 0
        data = params

        if state:
            start_idx = state.get("step_index", 0)
            data = state.get("data", params)
            logger.info(f"Resuming task {task_id} from step {start_idx}")

        for i, (step_name, func) in enumerate(steps[start_idx:], start=start_idx):
            await self.checkpoint.save(task_id, {
                "status": "running",
                "step_index": i,
                "step_name": step_name,
                "data": data,
                "progress": i / len(steps) * 100
            })

            data = await func(data)

        # Финальное состояние
        await self.checkpoint.save(task_id, {
            "status": "completed",
            "result": data
        })

        return data