"""Async scheduler for cron-like recurring tasks.
This module provides a Scheduler class for executing tasks on a
cron-like schedule with timezone support, and utilities for converting
human-readable schedule parameters into cron expressions.
"""
from __future__ import annotations
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Union, final
from zoneinfo import ZoneInfo
from croniter import croniter
from kelvin.logs import logger
if TYPE_CHECKING:
from kelvin.application.clock import ClockInterface
# Mapping from human-readable `every` values to base cron expressions.
# The minute and hour fields (first two) are placeholders replaced by `at`.
_EVERY_TO_CRON: dict[str, str] = {
"day": "0 0 * * *",
"monday": "0 0 * * MON",
"tuesday": "0 0 * * TUE",
"wednesday": "0 0 * * WED",
"thursday": "0 0 * * THU",
"friday": "0 0 * * FRI",
"saturday": "0 0 * * SAT",
"sunday": "0 0 * * SUN",
"weekday": "0 0 * * MON-FRI",
"weekend": "0 0 * * SAT,SUN",
}
[docs]
def parse_at(at: Union[str, list[str]]) -> list[tuple[int, int]]:
"""Parse and validate `at` time values.
Args:
at: A single `"HH:MM"` string or a list of them.
Returns:
A sorted list of `(hour, minute)` tuples, deduplicated and validated.
Raises:
ValueError: If any value has an invalid format or is out of range,
or if there are duplicate entries.
"""
raw = [at] if isinstance(at, str) else list(at)
if not raw:
raise ValueError("Parameter 'at' must not be empty.")
times: list[tuple[int, int]] = []
for value in raw:
parts = value.split(":")
if len(parts) != 2:
raise ValueError(f"Invalid 'at' format: {value!r}. Expected 'HH:MM'.")
try:
hour, minute = int(parts[0]), int(parts[1])
except ValueError:
raise ValueError(f"Invalid 'at' format: {value!r}. Expected 'HH:MM' with numeric values.") from None
if not (0 <= hour <= 23 and 0 <= minute <= 59):
raise ValueError(f"Invalid 'at' value: {value!r}. Hour must be 0-23, minute must be 0-59.")
times.append((hour, minute))
if len(times) != len(set(times)):
raise ValueError("Parameter 'at' contains duplicate entries.")
return sorted(times)
[docs]
def build_cron_expressions(
*,
every: Optional[str] = None,
at: Optional[Union[str, list[str]]] = None,
cron: Optional[str] = None,
) -> list[str]:
"""Convert schedule parameters into one or more cron expression strings.
When ``at`` contains times with different minutes (e.g. ``["09:30", "17:45"]``),
multiple cron expressions are returned since a single expression cannot
represent different minute values for different hours.
Args:
every: Human-readable recurrence value.
at: Time(s) of day in ``"HH:MM"`` format.
cron: A 5-field cron expression.
Returns:
A list of valid cron expression strings.
Raises:
ValueError: On invalid or conflicting parameters.
"""
if cron and every:
raise ValueError("Parameters 'cron' and 'every' are mutually exclusive.")
if cron and at:
raise ValueError("Parameter 'at' cannot be used with 'cron'.")
if not cron and not every:
raise ValueError("Either 'cron' or 'every' must be provided.")
if cron:
if not croniter.is_valid(cron):
raise ValueError(f"Invalid cron expression: {cron!r}")
return [cron]
every_lower = every.lower() # type: ignore[union-attr]
if every_lower not in _EVERY_TO_CRON:
raise ValueError(f"Invalid value for 'every': {every!r}. Must be one of: {', '.join(sorted(_EVERY_TO_CRON))}.")
if at is None:
raise ValueError(f"Parameter 'at' is required when every={every!r}.")
times = parse_at(at)
base_cron = _EVERY_TO_CRON[every_lower]
# day-of-month, month, and day-of-week fields from the base
tail = " ".join(base_cron.split()[2:])
# Group times by minute so we can combine hours sharing the same minute
# into a single cron expression.
by_minute: dict[int, list[int]] = {}
for hour, minute in times:
by_minute.setdefault(minute, []).append(hour)
expressions: list[str] = []
for minute, hours in sorted(by_minute.items()):
hour_field = ",".join(str(h) for h in sorted(hours))
expressions.append(f"{minute} {hour_field} {tail}")
return expressions
[docs]
@final
class Scheduler:
"""An async iterator that yields at each scheduled fire time.
Computes the next fire time using one or more cron expressions with
timezone support. Supports optional ``start_time`` for deterministic
alignment and ``interval`` for Nth-occurrence scheduling.
On startup the scheduler uses ``start_time`` as the scheduling base when
provided. If ``start_time`` is in the past, it fast-forwards through all
past fire times (counting for interval alignment) without yielding, so a
far-past ``start_time`` never triggers past executions. If ``start_time``
is in the future, scheduling is aligned from that future time and no fire
time will be yielded before it.
Attributes:
cron_expressions: The cron expressions defining the schedule.
tz: The timezone for interpreting the schedule.
name: A descriptive name for the scheduler (used in logging).
interval: Fire every Nth cron match (``None`` means every match).
iteration: Number of times the scheduler has actually fired.
"""
def __init__(
self,
cron_expressions: list[str],
name: str,
tz: ZoneInfo,
start_time: Optional[datetime] = None,
interval: Optional[int] = None,
clock: Optional[ClockInterface] = None,
) -> None:
from kelvin.application.clock import RealClock
if not cron_expressions:
raise ValueError("cron_expressions must contain at least one cron expression")
self.cron_expressions = cron_expressions
self.tz = tz
self.name = name
self.interval = interval
self.iteration = 0
self._match_count = 0
self._start_time = start_time
self._clock = clock or RealClock()
def _now(self) -> datetime:
return self._clock.now(tz=self.tz)
def _next_fire_time(self, base: datetime) -> datetime:
"""Compute the earliest next fire time across all cron expressions."""
candidates: list[datetime] = []
for expr in self.cron_expressions:
cron = croniter(expr, base)
nxt = cron.get_next(datetime)
candidates.append(nxt)
return min(candidates)
def _seconds_until(self, target: datetime) -> float:
"""Compute seconds from now until *target*."""
now = self._clock.now(tz=timezone.utc)
target_utc = target.astimezone(timezone.utc)
delta = (target_utc - now).total_seconds()
return max(0.0, delta)
def _should_fire(self) -> bool:
"""Check if the current match count should trigger a fire."""
if self.interval is None:
return True
return self._match_count % self.interval == 0
def _advance_to_present(self) -> datetime:
"""Fast-forward from start_time to the present, counting matches.
Returns the next future fire time. If ``interval`` is set, the
returned time corresponds to the next match where
``count % interval == 0``.
"""
now = self._now()
base = self._start_time.astimezone(self.tz) if self._start_time else now
# Count start_time itself if it falls on a past cron match.
# A future start_time aligns scheduling from that future point; no
# fire time will be yielded before it.
if self._start_time is not None and self._seconds_until(base) == 0.0:
for expr in self.cron_expressions:
if croniter.match(expr, base):
self._match_count += 1
break
# Fast-forward past all historical fire times
while True:
next_fire = self._next_fire_time(base)
if self._seconds_until(next_fire) > 0:
# This fire time is in the future
if self._should_fire():
if self._match_count > 0:
logger.info(
"Scheduler advanced to present",
name=self.name,
skipped_matches=self._match_count,
)
return next_fire
# Skip this match for interval counting
self._match_count += 1
base = next_fire
continue
# Past fire time — count it and advance
self._match_count += 1
base = next_fire
[docs]
async def __aiter__(self) -> AsyncIterator[float]:
"""Async iterator that yields the sleep duration before each fire."""
next_fire = self._advance_to_present()
while True:
sleep_seconds = self._seconds_until(next_fire)
logger.debug(
"Scheduler waiting",
name=self.name,
next_fire=next_fire.isoformat(),
sleep_seconds=round(sleep_seconds, 2),
iteration=self.iteration,
)
await self._clock.sleep(sleep_seconds)
self.iteration += 1
self._match_count += 1
logger.debug(
"Scheduler fired",
name=self.name,
iteration=self.iteration,
)
yield sleep_seconds
# Compute the next fire time, skipping any already in the past
base = next_fire
while True:
candidate = self._next_fire_time(base)
if self._should_fire():
if self._seconds_until(candidate) > 0:
next_fire = candidate
break
# Missed tick — log and skip
logger.warning(
"Scheduler skipping missed tick",
name=self.name,
missed_fire_time=candidate.isoformat(),
)
self._match_count += 1
base = candidate