feat: introduce trigger functionality (#27644)
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: Stream <Stream_2@qq.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do> Co-authored-by: Harry <xh001x@hotmail.com> Co-authored-by: lyzno1 <yuanyouhuilyz@gmail.com> Co-authored-by: yessenia <yessenia.contact@gmail.com> Co-authored-by: hjlarry <hjlarry@163.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: WTW0313 <twwu@dify.ai> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
104
api/schedule/trigger_provider_refresh_task.py
Normal file
104
api/schedule/trigger_provider_refresh_task.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from sqlalchemy import ColumnElement, and_, func, or_, select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from core.trigger.utils.locks import build_trigger_refresh_lock_keys
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.trigger import TriggerSubscription
|
||||
from tasks.trigger_subscription_refresh_tasks import trigger_subscription_refresh
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _build_due_filter(now_ts: int):
|
||||
"""Build SQLAlchemy filter for due credential or subscription refresh."""
|
||||
credential_due: ColumnElement[bool] = and_(
|
||||
TriggerSubscription.credential_expires_at != -1,
|
||||
TriggerSubscription.credential_expires_at
|
||||
<= now_ts + int(dify_config.TRIGGER_PROVIDER_CREDENTIAL_THRESHOLD_SECONDS),
|
||||
)
|
||||
subscription_due: ColumnElement[bool] = and_(
|
||||
TriggerSubscription.expires_at != -1,
|
||||
TriggerSubscription.expires_at <= now_ts + int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS),
|
||||
)
|
||||
return or_(credential_due, subscription_due)
|
||||
|
||||
|
||||
def _acquire_locks(keys: Iterable[str], ttl_seconds: int) -> list[bool]:
|
||||
"""Attempt to acquire locks in a single pipelined round-trip.
|
||||
|
||||
Returns a list of booleans indicating which locks were acquired.
|
||||
"""
|
||||
pipe = redis_client.pipeline(transaction=False)
|
||||
for key in keys:
|
||||
pipe.set(key, b"1", ex=ttl_seconds, nx=True)
|
||||
results = pipe.execute()
|
||||
return [bool(r) for r in results]
|
||||
|
||||
|
||||
@app.celery.task(queue="trigger_refresh_publisher")
|
||||
def trigger_provider_refresh() -> None:
|
||||
"""
|
||||
Scan due trigger subscriptions and enqueue refresh tasks with in-flight locks.
|
||||
"""
|
||||
now: int = _now_ts()
|
||||
|
||||
batch_size: int = int(dify_config.TRIGGER_PROVIDER_REFRESH_BATCH_SIZE)
|
||||
lock_ttl: int = max(300, int(dify_config.TRIGGER_PROVIDER_SUBSCRIPTION_THRESHOLD_SECONDS))
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
filter: ColumnElement[bool] = _build_due_filter(now_ts=now)
|
||||
total_due: int = int(session.scalar(statement=select(func.count()).where(filter)) or 0)
|
||||
logger.info("Trigger refresh scan start: due=%d", total_due)
|
||||
if total_due == 0:
|
||||
return
|
||||
|
||||
pages: int = math.ceil(total_due / batch_size)
|
||||
for page in range(pages):
|
||||
offset: int = page * batch_size
|
||||
subscription_rows: Sequence[Row[tuple[str, str]]] = session.execute(
|
||||
select(TriggerSubscription.tenant_id, TriggerSubscription.id)
|
||||
.where(filter)
|
||||
.order_by(TriggerSubscription.updated_at.asc())
|
||||
.offset(offset)
|
||||
.limit(batch_size)
|
||||
).all()
|
||||
if not subscription_rows:
|
||||
logger.debug("Trigger refresh page %d/%d empty", page + 1, pages)
|
||||
continue
|
||||
|
||||
subscriptions: list[tuple[str, str]] = [
|
||||
(str(tenant_id), str(subscription_id)) for tenant_id, subscription_id in subscription_rows
|
||||
]
|
||||
lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions)
|
||||
acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl)
|
||||
|
||||
enqueued: int = 0
|
||||
for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired):
|
||||
if not is_locked:
|
||||
continue
|
||||
trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id)
|
||||
enqueued += 1
|
||||
|
||||
logger.info(
|
||||
"Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d",
|
||||
page + 1,
|
||||
pages,
|
||||
len(subscriptions),
|
||||
sum(1 for x in acquired if x),
|
||||
enqueued,
|
||||
)
|
||||
|
||||
logger.info("Trigger refresh scan done: due=%d", total_due)
|
||||
127
api/schedule/workflow_schedule_task.py
Normal file
127
api/schedule/workflow_schedule_task.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import logging
|
||||
|
||||
from celery import group, shared_task
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.schedule_utils import calculate_next_run_at
|
||||
from models.trigger import AppTrigger, AppTriggerStatus, AppTriggerType, WorkflowSchedulePlan
|
||||
from services.workflow.queue_dispatcher import QueueDispatcherManager
|
||||
from tasks.workflow_schedule_tasks import run_schedule_trigger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="schedule_poller")
|
||||
def poll_workflow_schedules() -> None:
|
||||
"""
|
||||
Poll and process due workflow schedules.
|
||||
|
||||
Streaming flow:
|
||||
1. Fetch due schedules in batches
|
||||
2. Process each batch until all due schedules are handled
|
||||
3. Optional: Limit total dispatches per tick as a circuit breaker
|
||||
"""
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
|
||||
with session_factory() as session:
|
||||
total_dispatched = 0
|
||||
total_rate_limited = 0
|
||||
|
||||
# Process in batches until we've handled all due schedules or hit the limit
|
||||
while True:
|
||||
due_schedules = _fetch_due_schedules(session)
|
||||
|
||||
if not due_schedules:
|
||||
break
|
||||
|
||||
dispatched_count, rate_limited_count = _process_schedules(session, due_schedules)
|
||||
total_dispatched += dispatched_count
|
||||
total_rate_limited += rate_limited_count
|
||||
|
||||
logger.debug("Batch processed: %d dispatched, %d rate limited", dispatched_count, rate_limited_count)
|
||||
|
||||
# Circuit breaker: check if we've hit the per-tick limit (if enabled)
|
||||
if (
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0
|
||||
and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK
|
||||
):
|
||||
logger.warning(
|
||||
"Circuit breaker activated: reached dispatch limit (%d), will continue next tick",
|
||||
dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK,
|
||||
)
|
||||
break
|
||||
|
||||
if total_dispatched > 0 or total_rate_limited > 0:
|
||||
logger.info("Total processed: %d dispatched, %d rate limited", total_dispatched, total_rate_limited)
|
||||
|
||||
|
||||
def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]:
|
||||
"""
|
||||
Fetch a batch of due schedules, sorted by most overdue first.
|
||||
|
||||
Returns up to WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE schedules per call.
|
||||
Used in a loop to progressively process all due schedules.
|
||||
"""
|
||||
now = naive_utc_now()
|
||||
|
||||
due_schedules = session.scalars(
|
||||
(
|
||||
select(WorkflowSchedulePlan)
|
||||
.join(
|
||||
AppTrigger,
|
||||
and_(
|
||||
AppTrigger.app_id == WorkflowSchedulePlan.app_id,
|
||||
AppTrigger.node_id == WorkflowSchedulePlan.node_id,
|
||||
AppTrigger.trigger_type == AppTriggerType.TRIGGER_SCHEDULE,
|
||||
),
|
||||
)
|
||||
.where(
|
||||
WorkflowSchedulePlan.next_run_at <= now,
|
||||
WorkflowSchedulePlan.next_run_at.isnot(None),
|
||||
AppTrigger.status == AppTriggerStatus.ENABLED,
|
||||
)
|
||||
)
|
||||
.order_by(WorkflowSchedulePlan.next_run_at.asc())
|
||||
.with_for_update(skip_locked=True)
|
||||
.limit(dify_config.WORKFLOW_SCHEDULE_POLLER_BATCH_SIZE)
|
||||
)
|
||||
|
||||
return list(due_schedules)
|
||||
|
||||
|
||||
def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> tuple[int, int]:
|
||||
"""Process schedules: check quota, update next run time and dispatch to Celery in parallel."""
|
||||
if not schedules:
|
||||
return 0, 0
|
||||
|
||||
dispatcher_manager = QueueDispatcherManager()
|
||||
tasks_to_dispatch: list[str] = []
|
||||
rate_limited_count = 0
|
||||
|
||||
for schedule in schedules:
|
||||
next_run_at = calculate_next_run_at(
|
||||
schedule.cron_expression,
|
||||
schedule.timezone,
|
||||
)
|
||||
schedule.next_run_at = next_run_at
|
||||
|
||||
dispatcher = dispatcher_manager.get_dispatcher(schedule.tenant_id)
|
||||
if not dispatcher.check_daily_quota(schedule.tenant_id):
|
||||
logger.info("Tenant %s rate limited, skipping schedule_plan %s", schedule.tenant_id, schedule.id)
|
||||
rate_limited_count += 1
|
||||
else:
|
||||
tasks_to_dispatch.append(schedule.id)
|
||||
|
||||
if tasks_to_dispatch:
|
||||
job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch)
|
||||
job.apply_async()
|
||||
|
||||
logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch))
|
||||
|
||||
session.commit()
|
||||
|
||||
return len(tasks_to_dispatch), rate_limited_count
|
||||
Reference in New Issue
Block a user