Перейти к содержанию

Decision Engine

Адаптивный выбор и скоринг кандидатов для принятия решений агентом.

Возможности

  • Адаптивный порог на основе распределения scores
  • Обнаружение и слияние перекрывающихся кандидатов
  • Ранжирование по score
  • Настраиваемые лимиты выбора

Использование

Базовый выбор

from kit.agent.decision import DecisionEngine, Candidate

engine = DecisionEngine(min_quality_threshold=0.3)

candidates = [
    Candidate(id="option_a", score=0.9),
    Candidate(id="option_b", score=0.7),
    Candidate(id="option_c", score=0.5),
    Candidate(id="option_d", score=0.2),  # Ниже порога
]

selected = engine.select(candidates)

for selection in selected:
    print(f"#{selection.rank}: {selection.candidate.id} ({selection.score:.2f})")
# #1: option_a (0.90)
# #2: option_b (0.70)
# #3: option_c (0.50)

Адаптивный порог

engine = DecisionEngine(
    min_quality_threshold=0.3,  # Минимальный порог
    target_percentile=70        # Целевой percentile
)

# Порог вычисляется автоматически на основе распределения scores
# Если все scores высокие, порог тоже будет высоким

Кандидаты с временными интервалами

# Для видео-сегментов, аудио-фрагментов и т.д.
candidates = [
    Candidate(id="seg1", score=0.8, start=0.0, end=5.0),
    Candidate(id="seg2", score=0.7, start=4.0, end=9.0),   # Перекрывается с seg1
    Candidate(id="seg3", score=0.9, start=10.0, end=15.0),
]

# Engine автоматически сольёт перекрывающиеся сегменты
selected = engine.select(candidates)

Кастомный scorer

def custom_scorer(candidate: Candidate) -> float:
    # Комбинируем несколько факторов
    base_score = candidate.score
    duration_bonus = min(candidate.duration / 10, 0.2)
    return base_score + duration_bonus

selected = engine.select(candidates, scorer=custom_scorer)

Фильтрация

selected = engine.select(
    candidates,
    filter_fn=lambda c: c.data.get("type") == "highlight"
)

Лимиты выбора

engine = DecisionEngine(
    max_selections=5,        # Максимум 5 кандидатов
    selection_rate=0.5       # Или 0.5 на единицу source
)

# Rate-based лимит
selected = engine.select(candidates, source_size=20)  # max 10

API Reference

Candidate

@dataclass
class Candidate:
    id: str
    score: float = 0.0
    data: Dict[str, Any] = field(default_factory=dict)

    # Для временных интервалов
    start: Optional[float] = None
    end: Optional[float] = None

    @property
    def has_range(self) -> bool
    @property
    def duration(self) -> float

Selection

@dataclass
class Selection:
    candidate: Candidate
    rank: int
    score: float
    reason: str = ""

DecisionEngine

class DecisionEngine:
    def __init__(
        self,
        min_quality_threshold: float = 0.3,
        target_percentile: float = 70,
        max_selections: Optional[int] = None,
        selection_rate: float = 0.5
    )

    def compute_adaptive_threshold(self, scores: List[float]) -> float
    def merge_overlapping(self, candidates: List[Candidate], overlap_threshold: float = 0.5) -> List[Candidate]
    def select(
        self,
        candidates: List[Candidate],
        source_size: Optional[float] = None,
        scorer: Optional[Callable] = None,
        filter_fn: Optional[Callable] = None
    ) -> List[Selection]
    def decide_count(self, source_size: float, quality_scores: List[float]) -> int
    def rank(self, candidates: List[Candidate], weights: Optional[Dict[str, float]] = None) -> List[Candidate]

Примеры из production

Autoshorts — выбор лучших кадров

class FrameSelector:
    def __init__(self):
        self.engine = DecisionEngine(
            min_quality_threshold=0.4,
            target_percentile=80,
            selection_rate=0.2  # 1 кадр на 5 секунд
        )

    async def select_keyframes(self, video_path: str) -> List[float]:
        # Анализ всех кадров
        frames = await analyze_video(video_path)

        candidates = [
            Candidate(
                id=f"frame_{i}",
                score=frame.quality_score,
                start=frame.timestamp,
                end=frame.timestamp + 0.5,
                data={"sharpness": frame.sharpness, "faces": frame.face_count}
            )
            for i, frame in enumerate(frames)
        ]

        # Выбор лучших
        selected = self.engine.select(
            candidates,
            source_size=video_duration,
            filter_fn=lambda c: c.data.get("faces", 0) > 0  # Только с лицами
        )

        return [s.candidate.start for s in selected]

Music Video Generator — выбор музыкальных моментов

class BeatSelector:
    def __init__(self):
        self.engine = DecisionEngine(
            min_quality_threshold=0.5,
            max_selections=20
        )

    def select_beats(self, audio_analysis: dict) -> List[float]:
        candidates = []

        for beat in audio_analysis['beats']:
            # Score на основе силы бита и позиции
            score = beat['strength'] * 0.7 + beat['novelty'] * 0.3

            candidates.append(Candidate(
                id=f"beat_{beat['time']}",
                score=score,
                start=beat['time'],
                end=beat['time'] + 0.5,
                data={"strength": beat['strength']}
            ))

        selected = self.engine.select(candidates)
        return [s.candidate.start for s in selected]

Sentinel — приоритизация алертов

class AlertPrioritizer:
    def __init__(self):
        self.engine = DecisionEngine(
            min_quality_threshold=0.0,
            max_selections=10
        )

    def prioritize(self, alerts: List[dict]) -> List[dict]:
        candidates = [
            Candidate(
                id=alert['id'],
                score=self._calculate_priority(alert),
                data=alert
            )
            for alert in alerts
        ]

        # Взвешенное ранжирование
        ranked = self.engine.rank(candidates, weights={
            "severity": 0.4,
            "frequency": 0.3,
            "recency": 0.3
        })

        return [c.data for c in ranked[:10]]

    def _calculate_priority(self, alert: dict) -> float:
        severity_map = {"critical": 1.0, "warning": 0.6, "info": 0.3}
        return severity_map.get(alert['severity'], 0.5)