diff --git a/.claude/skills/git-spice-merge-my-stack/SKILL.md b/.claude/skills/git-spice-merge-my-stack/SKILL.md new file mode 100644 index 0000000000..b11c4b78c8 --- /dev/null +++ b/.claude/skills/git-spice-merge-my-stack/SKILL.md @@ -0,0 +1,60 @@ +--- +name: git-spice-merge-my-stack +description: Merge a stacked git-spice branch chain with gh by retargeting each PR to main and merging bottom to top, including conflict recovery via rebase. +license: MIT +compatibility: Requires GitHub CLI (gh), git, and push access. +metadata: + author: rivet + version: "1.0" +--- + +Merge a stacked PR chain. + +**Input**: A target branch in the stack (usually the top branch to merge through). + +**Goal**: Merge all PRs from the bottom of that stack up to the target branch. + +## Steps + +1. **Resolve the target PR** + - Find PR for the provided branch: + - `gh pr list --state open --head "" --json number,headRefName,baseRefName,url` + - If no open PR exists, stop and report. + +2. **Build the stack chain down to main** + - Start at target PR. + - Repeatedly find the PR whose `headRefName` equals the current PR `baseRefName`. + - Continue until base is `main` or no parent PR exists. + - If chain is ambiguous, stop and ask the user which branch to follow. + +3. **Determine merge order** + - Merge from **bottom to top**. + - Example: `[bottom, ..., target]`. + +4. **For each PR in order** + - Retarget to `main` before merge: + - `gh pr edit --base main` + - Merge with repository-compatible strategy: + - Try `gh pr merge --squash --delete-branch=false` + - If merge fails due conflicts: + - `gh pr checkout ` + - `git fetch origin main` + - `git rebase origin/main` + - Resolve conflicts. If replaying already-upstream commits from lower stack layers, prefer `git rebase --skip`. + - Continue with `GIT_EDITOR=true git rebase --continue` when needed. + - `git push --force-with-lease origin ` + - Retry `gh pr merge ... --squash`. + +5. **Verify completion** + - Confirm each PR in chain is merged: + - `gh pr view --json state,mergedAt,url` + - Report final ordered merge list with PR numbers and timestamps. + +## Guardrails + +- Always merge in bottom-to-top order. +- Do not use merge commits if the repo disallows them. +- Do not delete remote branches unless explicitly requested. +- If a conflict cannot be safely resolved, stop and ask the user. +- If force-push is required, use `--force-with-lease`, never `--force`. +- After finishing, return to the user's original branch unless they asked otherwise. diff --git a/Cargo.lock b/Cargo.lock index 18d5593551..988081ac69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4180,6 +4180,7 @@ dependencies = [ "futures-util", "gasoline", "include_dir", + "indexmap 2.10.0", "namespace", "pegboard", "reqwest", diff --git a/README.md b/README.md index d99925967a..a473acc5aa 100644 --- a/README.md +++ b/README.md @@ -270,4 +270,4 @@ Serverless, containers, or your own servers — Rivet Actors work with your exis ## License -[Apache 2.0](LICENSE) +[Apache 2.0](LICENSE) \ No newline at end of file diff --git a/engine/artifacts/config-schema.json b/engine/artifacts/config-schema.json index 320ab0b625..535f01a143 100644 --- a/engine/artifacts/config-schema.json +++ b/engine/artifacts/config-schema.json @@ -132,16 +132,12 @@ "default": { "allow_version_rollback": null, "force_shutdown_duration": null, - "gasoline": { - "prune_eligibility_duration": null, - "prune_interval_duration": null - }, + "gasoline_prune_eligibility_duration": null, + "gasoline_prune_interval_duration": null, "guard_shutdown_duration": null, - "worker": { - "cpu_max": null, - "load_shedding_curve": null, - "shutdown_duration": null - } + "worker_cpu_max": null, + "worker_load_shedding_curve": null, + "worker_shutdown_duration": null }, "allOf": [ { @@ -352,30 +348,6 @@ }, "additionalProperties": false }, - "Gasoline": { - "type": "object", - "properties": { - "prune_eligibility_duration": { - "description": "Time (in seconds) after completion before considering a workflow eligible for pruning. Defaults to 7 days. Set to 0 to never prune workflow data.", - "type": [ - "integer", - "null" - ], - "format": "uint64", - "minimum": 0.0 - }, - "prune_interval_duration": { - "description": "Time (in seconds) to periodically check for workflows to prune. Defaults to 12 hours.", - "type": [ - "integer", - "null" - ], - "format": "uint64", - "minimum": 0.0 - } - }, - "additionalProperties": false - }, "Guard": { "type": "object", "properties": { @@ -893,16 +865,23 @@ "format": "uint32", "minimum": 0.0 }, - "gasoline": { - "default": { - "prune_eligibility_duration": null, - "prune_interval_duration": null - }, - "allOf": [ - { - "$ref": "#/definitions/Gasoline" - } - ] + "gasoline_prune_eligibility_duration": { + "description": "Time (in seconds) after completion before considering a workflow eligible for pruning. Defaults to 7 days. Set to 0 to never prune workflow data.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0.0 + }, + "gasoline_prune_interval_duration": { + "description": "Time (in seconds) to periodically check for workflows to prune. Defaults to 12 hours.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0.0 }, "guard_shutdown_duration": { "description": "Time (in seconds) to allow for guard to wait for pending requests after receiving SIGTERM. Defaults to 1 hour.", @@ -913,17 +892,49 @@ "format": "uint32", "minimum": 0.0 }, - "worker": { - "default": { - "cpu_max": null, - "load_shedding_curve": null, - "shutdown_duration": null + "worker_cpu_max": { + "description": "Adjusts worker curve around this value (in millicores, i.e. 1000 = 1 core). Is not a hard limit. When unset, uses /sys/fs/cgroup/cpu.max, and if that is unset uses total host cpu.", + "type": [ + "integer", + "null" + ], + "format": "uint", + "minimum": 0.0 + }, + "worker_load_shedding_curve": { + "description": "Determine load shedding ratio based on linear mapping on cpu usage. We will gradually pull less workflows as the cpu usage increases. Units are in (permilli overall cpu usage, permilli) Default: | . . 100% | _____ . | .\\ . % wfs | . \\ . | . \\. 5% | . \\_____ |_____.___.______ 0 70% 90% avg cpu usage", + "type": [ + "array", + "null" + ], + "items": { + "type": "array", + "items": [ + { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + }, + { + "type": "integer", + "format": "uint64", + "minimum": 0.0 + } + ], + "maxItems": 2, + "minItems": 2 }, - "allOf": [ - { - "$ref": "#/definitions/Worker" - } - ] + "maxItems": 2, + "minItems": 2 + }, + "worker_shutdown_duration": { + "description": "Time (in seconds) to allow for the gasoline worker engine to stop gracefully after receiving SIGTERM. Defaults to 30 seconds.", + "type": [ + "integer", + "null" + ], + "format": "uint32", + "minimum": 0.0 } }, "additionalProperties": false @@ -989,56 +1000,6 @@ } }, "additionalProperties": false - }, - "Worker": { - "type": "object", - "properties": { - "cpu_max": { - "description": "Adjusts worker curve around this value (in millicores, i.e. 1000 = 1 core). Is not a hard limit. When unset, uses /sys/fs/cgroup/cpu.max, and if that is unset uses total host cpu.", - "type": [ - "integer", - "null" - ], - "format": "uint", - "minimum": 0.0 - }, - "load_shedding_curve": { - "description": "Determine load shedding ratio based on linear mapping on cpu usage. We will gradually pull less workflows as the cpu usage increases. Units are in (permilli overall cpu usage, permilli) Default: | . . 100% | _____ . | .\\ . % wfs | . \\ . | . \\. 5% | . \\_____ |_____.___.______ 0 70% 90% avg cpu usage", - "type": [ - "array", - "null" - ], - "items": { - "type": "array", - "items": [ - { - "type": "integer", - "format": "uint64", - "minimum": 0.0 - }, - { - "type": "integer", - "format": "uint64", - "minimum": 0.0 - } - ], - "maxItems": 2, - "minItems": 2 - }, - "maxItems": 2, - "minItems": 2 - }, - "shutdown_duration": { - "description": "Time (in seconds) to allow for the gasoline worker engine to stop gracefully after receiving SIGTERM. Defaults to 30 seconds.", - "type": [ - "integer", - "null" - ], - "format": "uint32", - "minimum": 0.0 - } - }, - "additionalProperties": false } } } \ No newline at end of file diff --git a/engine/packages/api-peer/src/actors/delete.rs b/engine/packages/api-peer/src/actors/delete.rs index b0dda1619e..135e4f0825 100644 --- a/engine/packages/api-peer/src/actors/delete.rs +++ b/engine/packages/api-peer/src/actors/delete.rs @@ -18,13 +18,21 @@ use rivet_util::Id; )] #[tracing::instrument(skip_all)] pub async fn delete(ctx: ApiCtx, path: DeletePath, query: DeleteQuery) -> Result { - // Get the actor first to verify it exists - let actors_res = ctx - .op(pegboard::ops::actor::get::Input { + // Subscribe before fetching actor data + let mut destroy_sub = ctx + .subscribe::(("actor_id", path.actor_id)) + .await?; + + let (actors_res, namespace_res) = tokio::try_join!( + // Get the actor to verify it exists + ctx.op(pegboard::ops::actor::get::Input { actor_ids: vec![path.actor_id], fetch_error: false, - }) - .await?; + }), + ctx.op(namespace::ops::resolve_for_name_global::Input { + name: query.namespace, + }), + )?; let actor = actors_res .actors @@ -32,14 +40,14 @@ pub async fn delete(ctx: ApiCtx, path: DeletePath, query: DeleteQuery) -> Result .next() .ok_or_else(|| pegboard::errors::Actor::NotFound.build())?; - // Verify the actor belongs to the specified namespace - let namespace = ctx - .op(namespace::ops::resolve_for_name_global::Input { - name: query.namespace, - }) - .await? - .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; + // Already destroyed + if actor.destroy_ts.is_some() { + return Err(pegboard::errors::Actor::NotFound.build()); + } + let namespace = namespace_res.ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; + + // Verify the actor belongs to the specified namespace if actor.namespace_id != namespace.namespace_id { return Err(pegboard::errors::Actor::NotFound.build()); } @@ -56,6 +64,8 @@ pub async fn delete(ctx: ApiCtx, path: DeletePath, query: DeleteQuery) -> Result actor_id=?path.actor_id, "actor workflow not found, likely already stopped" ); + } else { + destroy_sub.next().await?; } Ok(DeleteResponse {}) diff --git a/engine/packages/api-public/Cargo.toml b/engine/packages/api-public/Cargo.toml index 18d04cb394..1d84a529b2 100644 --- a/engine/packages/api-public/Cargo.toml +++ b/engine/packages/api-public/Cargo.toml @@ -12,6 +12,7 @@ epoxy.workspace = true futures-util.workspace = true gas.workspace = true include_dir.workspace = true +indexmap.workspace = true namespace.workspace = true pegboard.workspace = true reqwest.workspace = true diff --git a/engine/packages/api-public/src/actors/list.rs b/engine/packages/api-public/src/actors/list.rs index 72e2d808ea..36b80523ec 100644 --- a/engine/packages/api-public/src/actors/list.rs +++ b/engine/packages/api-public/src/actors/list.rs @@ -107,24 +107,18 @@ async fn list_inner(ctx: ApiCtx, query: ListQuery) -> Result { .await? .ok_or_else(|| namespace::errors::Namespace::NotFound.build())?; + let limit = query.limit.unwrap_or(100); + // Fetch actors let mut actors = fetch_actors_by_ids( &ctx, actor_ids, query.namespace.clone(), query.include_destroyed, - None, // Don't apply limit in fetch, we'll apply it after cursor filtering + Some(limit), ) .await?; - // Apply cursor filtering if provided - if let Some(cursor_str) = &query.cursor { - let cursor_ts: i64 = cursor_str.parse().context("invalid cursor format")?; - actors.retain(|actor| actor.create_ts < cursor_ts); - } - - // Apply limit after cursor filtering - let limit = query.limit.unwrap_or(100); actors.truncate(limit); let cursor = actors.last().map(|x| x.create_ts.to_string()); @@ -208,8 +202,6 @@ async fn list_inner(ctx: ApiCtx, query: ListQuery) -> Result { // Sort by create ts desc actors.sort_by_cached_key(|x| std::cmp::Reverse(x.create_ts)); - // Shorten array since returning all actors from all regions could end up returning `regions * - // limit` results, which is a lot. actors.truncate(limit); let cursor = actors.last().map(|x| x.create_ts.to_string()); diff --git a/engine/packages/api-public/src/actors/list_names.rs b/engine/packages/api-public/src/actors/list_names.rs index d68240816a..9815d39a63 100644 --- a/engine/packages/api-public/src/actors/list_names.rs +++ b/engine/packages/api-public/src/actors/list_names.rs @@ -1,5 +1,6 @@ use anyhow::Result; use axum::response::{IntoResponse, Response}; +use indexmap::IndexMap; use rivet_api_builder::{ ApiError, extract::{Extension, Json, Query}, @@ -49,7 +50,7 @@ async fn list_names_inner(ctx: ApiCtx, query: ListNamesQuery) -> Result>( + fanout_to_datacenters::>( ctx.into(), "/actors/names", peer_query, @@ -58,18 +59,18 @@ async fn list_names_inner(ctx: ApiCtx, query: ListNamesQuery) -> Result>(); - // Sort by name for consistency - all_names.sort_by(|a, b| a.0.cmp(&b.0)); - - // Truncate to the requested limit - all_names.truncate(query.limit.unwrap_or(100)); + // Sort for consistency + all_names.sort_keys(); let cursor = all_names.last().map(|(name, _)| name.to_string()); Ok(ListNamesResponse { - // TODO: Implement ComposeSchema for FakeMap so we don't have to reallocate names: all_names.into_iter().collect(), pagination: Pagination { cursor }, }) diff --git a/engine/packages/api-public/src/runners.rs b/engine/packages/api-public/src/runners.rs index 7b2b8667cd..d876f03623 100644 --- a/engine/packages/api-public/src/runners.rs +++ b/engine/packages/api-public/src/runners.rs @@ -1,5 +1,6 @@ use anyhow::Result; use axum::response::{IntoResponse, Response}; +use indexmap::IndexSet; use rivet_api_builder::{ ApiError, extract::{Extension, Json, Query}, @@ -90,25 +91,26 @@ async fn list_names_inner(ctx: ApiCtx, query: ListNamesQuery) -> Result>( + let mut all_names = fanout_to_datacenters::>( ctx.into(), "/runners/names", query, |ctx, query| async move { rivet_api_peer::runners::list_names(ctx, (), query).await }, |_, res, agg| agg.extend(res.names), ) - .await?; + .await? + .into_iter() + // Apply limit + .take(limit) + .collect::>(); // Sort by name for consistency all_names.sort(); - // Truncate to the requested limit - all_names.truncate(limit); - let cursor = all_names.last().map(|x: &String| x.to_string()); Ok(ListNamesResponse { - names: all_names, + names: all_names.into_iter().collect(), pagination: Pagination { cursor }, }) } diff --git a/engine/packages/config/src/config/mod.rs b/engine/packages/config/src/config/mod.rs index 88c60d782d..63eb3630fa 100644 --- a/engine/packages/config/src/config/mod.rs +++ b/engine/packages/config/src/config/mod.rs @@ -204,7 +204,7 @@ impl Root { } // Validate force_shutdown_duration is greater than worker and guard shutdown durations - let worker = self.runtime.worker.shutdown_duration(); + let worker = self.runtime.worker_shutdown_duration(); let guard = self.runtime.guard_shutdown_duration(); let force = self.runtime.force_shutdown_duration(); let max_graceful = worker.max(guard); diff --git a/engine/packages/config/src/config/runtime.rs b/engine/packages/config/src/config/runtime.rs index 31b9e92ac5..60919abfea 100644 --- a/engine/packages/config/src/config/runtime.rs +++ b/engine/packages/config/src/config/runtime.rs @@ -6,10 +6,25 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] #[serde(deny_unknown_fields)] pub struct Runtime { - #[serde(default)] - pub worker: Worker, - #[serde(default)] - pub gasoline: Gasoline, + /// Adjusts worker curve around this value (in millicores, i.e. 1000 = 1 core). Is not a hard limit. When + /// unset, uses /sys/fs/cgroup/cpu.max, and if that is unset uses total host cpu. + pub worker_cpu_max: Option, + /// Determine load shedding ratio based on linear mapping on cpu usage. We will gradually + /// pull less workflows as the cpu usage increases. Units are in (permilli overall cpu usage, permilli) + /// Default: + /// | . . + /// 100% | _____ . + /// | .\ . + /// % wfs | . \ . + /// | . \. + /// 5% | . \_____ + /// |_____.___.______ + /// 0 70% 90% + /// avg cpu usage + worker_load_shedding_curve: Option<[(u64, u64); 2]>, + /// Time (in seconds) to allow for the gasoline worker engine to stop gracefully after receiving SIGTERM. + /// Defaults to 30 seconds. + worker_shutdown_duration: Option, /// Time (in seconds) to allow for guard to wait for pending requests after receiving SIGTERM. Defaults /// to 1 hour. guard_shutdown_duration: Option, @@ -20,9 +35,23 @@ pub struct Runtime { /// Whether or not to allow running the engine when the previous version that was run is higher than /// the current version. allow_version_rollback: Option, + /// Time (in seconds) after completion before considering a workflow eligible for pruning. Defaults to 7 + /// days. Set to 0 to never prune workflow data. + gasoline_prune_eligibility_duration: Option, + /// Time (in seconds) to periodically check for workflows to prune. Defaults to 12 hours. + gasoline_prune_interval_duration: Option, } impl Runtime { + pub fn worker_load_shedding_curve(&self) -> [(u64, u64); 2] { + self.worker_load_shedding_curve + .unwrap_or([(700, 1000), (900, 50)]) + } + + pub fn worker_shutdown_duration(&self) -> Duration { + Duration::from_secs(self.worker_shutdown_duration.unwrap_or(30) as u64) + } + pub fn guard_shutdown_duration(&self) -> Duration { Duration::from_secs(self.guard_shutdown_duration.unwrap_or(60 * 60) as u64) } @@ -38,55 +67,9 @@ impl Runtime { pub fn allow_version_rollback(&self) -> bool { self.allow_version_rollback.unwrap_or_default() } -} -#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] -#[serde(deny_unknown_fields)] -pub struct Worker { - /// Adjusts worker curve around this value (in millicores, i.e. 1000 = 1 core). Is not a hard limit. When - /// unset, uses /sys/fs/cgroup/cpu.max, and if that is unset uses total host cpu. - pub cpu_max: Option, - /// Determine load shedding ratio based on linear mapping on cpu usage. We will gradually - /// pull less workflows as the cpu usage increases. Units are in (permilli overall cpu usage, permilli) - /// Default: - /// | . . - /// 100% | _____ . - /// | .\ . - /// % wfs | . \ . - /// | . \. - /// 5% | . \_____ - /// |_____.___.______ - /// 0 70% 90% - /// avg cpu usage - load_shedding_curve: Option<[(u64, u64); 2]>, - /// Time (in seconds) to allow for the gasoline worker engine to stop gracefully after receiving SIGTERM. - /// Defaults to 30 seconds. - shutdown_duration: Option, -} - -impl Worker { - pub fn load_shedding_curve(&self) -> [(u64, u64); 2] { - self.load_shedding_curve.unwrap_or([(700, 1000), (900, 50)]) - } - - pub fn shutdown_duration(&self) -> Duration { - Duration::from_secs(self.shutdown_duration.unwrap_or(30) as u64) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] -#[serde(deny_unknown_fields)] -pub struct Gasoline { - /// Time (in seconds) after completion before considering a workflow eligible for pruning. Defaults to 7 - /// days. Set to 0 to never prune workflow data. - prune_eligibility_duration: Option, - /// Time (in seconds) to periodically check for workflows to prune. Defaults to 12 hours. - prune_interval_duration: Option, -} - -impl Gasoline { - pub fn prune_eligibility_duration(&self) -> Option { - if let Some(prune_eligibility_duration) = self.prune_eligibility_duration { + pub fn gasoline_prune_eligibility_duration(&self) -> Option { + if let Some(prune_eligibility_duration) = self.gasoline_prune_eligibility_duration { if prune_eligibility_duration == 0 { None } else { @@ -97,7 +80,10 @@ impl Gasoline { } } - pub fn prune_interval_duration(&self) -> Duration { - Duration::from_secs(self.prune_interval_duration.unwrap_or(60 * 60 * 12)) + pub fn gasoline_prune_interval_duration(&self) -> Duration { + Duration::from_secs( + self.gasoline_prune_interval_duration + .unwrap_or(60 * 60 * 12), + ) } } diff --git a/engine/packages/epoxy/src/ops/propose.rs b/engine/packages/epoxy/src/ops/propose.rs index 87e05f278f..08397940b5 100644 --- a/engine/packages/epoxy/src/ops/propose.rs +++ b/engine/packages/epoxy/src/ops/propose.rs @@ -356,6 +356,7 @@ async fn purge_optimistic_cache(ctx: OperationCtx, keys: Vec) -> Result< .workflow(crate::workflows::purger::Input { replica_id: dc.datacenter_label as u64, }) + .bypass_signal_from_workflow_I_KNOW_WHAT_IM_DOING() .tag("replica_id", dc.datacenter_label as u64) .unique() .dispatch() diff --git a/engine/packages/gasoline-runtime/src/workflows/pruner.rs b/engine/packages/gasoline-runtime/src/workflows/pruner.rs index a0419a1cc6..68ada307d5 100644 --- a/engine/packages/gasoline-runtime/src/workflows/pruner.rs +++ b/engine/packages/gasoline-runtime/src/workflows/pruner.rs @@ -13,7 +13,7 @@ pub async fn gasoline_pruner(ctx: &mut WorkflowCtx, input: &Input) -> Result<()> async move { ctx.activity(PruneInput {}).await?; - ctx.sleep(ctx.config().runtime.gasoline.prune_interval_duration()) + ctx.sleep(ctx.config().runtime.gasoline_prune_interval_duration()) .await?; Ok(Loop::<()>::Continue) @@ -41,7 +41,7 @@ async fn prune(ctx: &ActivityCtx, _input: &PruneInput) -> Result { // Check if pruning is enabled let Some(prune_eligibility_duration) = - ctx.config().runtime.gasoline.prune_eligibility_duration() + ctx.config().runtime.gasoline_prune_eligibility_duration() else { return Ok(PruneOutput { prune_count: 0 }); }; diff --git a/engine/packages/gasoline/src/builder/common/workflow.rs b/engine/packages/gasoline/src/builder/common/workflow.rs index f8404df9ac..6699e9c0e4 100644 --- a/engine/packages/gasoline/src/builder/common/workflow.rs +++ b/engine/packages/gasoline/src/builder/common/workflow.rs @@ -50,6 +50,17 @@ where } } + // TODO: Get rid of this (RVT-5281) + // NOTE: This is a bad implementation because it disregards other errors that may have happened earlier + #[allow(non_snake_case)] + pub fn bypass_signal_from_workflow_I_KNOW_WHAT_IM_DOING(mut self) -> Self { + if let Some(BuilderError::CannotDispatchFromOpInWorkflow) = &self.error { + self.error = None; + } + + self + } + pub fn tags(mut self, tags: serde_json::Value) -> Self { if self.error.is_some() { return self; diff --git a/engine/packages/gasoline/src/ctx/message.rs b/engine/packages/gasoline/src/ctx/message.rs index f61fbbb723..c32a779061 100644 --- a/engine/packages/gasoline/src/ctx/message.rs +++ b/engine/packages/gasoline/src/ctx/message.rs @@ -1,5 +1,6 @@ use std::{ - fmt::{self, Debug}, + borrow::Cow, + fmt::{self, Debug, Display}, marker::PhantomData, sync::Arc, }; @@ -8,7 +9,7 @@ use rivet_pools::UpsPool; use rivet_util::Id; use tokio_util::sync::{CancellationToken, DropGuard}; use tracing::Instrument; -use universalpubsub::{NextOutput, Subscriber}; +use universalpubsub::{NextOutput, Subject, Subscriber}; use crate::{ error::{WorkflowError, WorkflowResult}, @@ -61,7 +62,7 @@ impl MessageCtx { let client = self.clone(); let topic = topic.to_string(); - let spawn_res = tokio::task::Builder::new() + tokio::task::Builder::new() .name("gasoline::message_async") .spawn( async move { @@ -73,13 +74,9 @@ impl MessageCtx { } } .instrument(tracing::info_span!("message_bg")), - ); - - if let Err(err) = spawn_res { - tracing::error!(?err, "failed to spawn message_async task"); - } - - Ok(()) + ) + .map_err(|err| WorkflowError::PublishMessage(err.into())) + .map(|_| ()) } /// Same as `message` but waits for the message to successfully publish. @@ -92,7 +89,10 @@ impl MessageCtx { where M: Message, { - let subject = format!("{}:{topic}", M::subject()); + let subject = MsgSubject { + topic, + msg_marker: PhantomData::, + }; let duration_since_epoch = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_else(|err| unreachable!("time is broken: {}", err)); @@ -126,15 +126,14 @@ impl MessageCtx { // It's important to write to the stream as fast as possible in order to // ensure messages are handled quickly. let message_buf = Arc::new(message_buf); - self.message_publish_pubsub::(&subject, message_buf) - .await; + self.message_publish_pubsub::(subject, message_buf).await; Ok(()) } /// Publishes the message to pubsub. #[tracing::instrument(level = "debug", skip_all)] - async fn message_publish_pubsub(&self, subject: &str, message_buf: Arc>) + async fn message_publish_pubsub(&self, subject: MsgSubject<'_, M>, message_buf: Arc>) where M: Message, { @@ -144,8 +143,6 @@ impl MessageCtx { // Ignore for infinite backoff backoff.tick().await; - let subject = subject.to_owned(); - tracing::trace!( %subject, message_len = message_buf.len(), @@ -154,7 +151,7 @@ impl MessageCtx { if let Err(err) = self .pubsub .publish( - &subject, + subject.clone(), &(*message_buf), universalpubsub::PublishOpts::broadcast(), ) @@ -332,3 +329,30 @@ where }) } } + +// Helper struct +struct MsgSubject<'a, M: Message> { + topic: &'a str, + msg_marker: PhantomData, +} + +impl Clone for MsgSubject<'_, M> { + fn clone(&self) -> Self { + MsgSubject { + topic: self.topic, + msg_marker: PhantomData::, + } + } +} + +impl Display for MsgSubject<'_, M> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}:{}", M::subject(), self.topic) + } +} + +impl Subject for MsgSubject<'_, M> { + fn root<'a>() -> Option> { + Some(Cow::Owned(M::subject())) + } +} diff --git a/engine/packages/gasoline/src/db/kv/mod.rs b/engine/packages/gasoline/src/db/kv/mod.rs index 3be187cd45..0346510cbc 100644 --- a/engine/packages/gasoline/src/db/kv/mod.rs +++ b/engine/packages/gasoline/src/db/kv/mod.rs @@ -1068,9 +1068,9 @@ impl Database for DatabaseKv { self.system .lock() .await - .cpu_usage_ratio(self.config.runtime.worker.cpu_max) + .cpu_usage_ratio(self.config.runtime.worker_cpu_max) }; - let load_shed_curve = self.config.runtime.worker.load_shedding_curve(); + let load_shed_curve = self.config.runtime.worker_load_shedding_curve(); let load_shed_ratio_x1000 = calc_pull_ratio( (cpu_usage_ratio * 1000.0) as u64, load_shed_curve[0].0, diff --git a/engine/packages/gasoline/src/error.rs b/engine/packages/gasoline/src/error.rs index 69fa449b30..30edb2f7f7 100644 --- a/engine/packages/gasoline/src/error.rs +++ b/engine/packages/gasoline/src/error.rs @@ -112,6 +112,9 @@ pub enum WorkflowError { #[error("failed to flush pubsub: {0}")] FlushPubsub(#[source] anyhow::Error), + #[error("failed to publish message: {0}")] + PublishMessage(#[source] anyhow::Error), + #[error("subscription unsubscribed")] SubscriptionUnsubscribed, diff --git a/engine/packages/gasoline/src/worker.rs b/engine/packages/gasoline/src/worker.rs index 6631b401bb..5d196eed60 100644 --- a/engine/packages/gasoline/src/worker.rs +++ b/engine/packages/gasoline/src/worker.rs @@ -289,7 +289,7 @@ impl Worker { #[tracing::instrument(skip_all)] async fn shutdown(mut self, mut term_signal: TermSignal) { - let shutdown_duration = self.config.runtime.worker.shutdown_duration(); + let shutdown_duration = self.config.runtime.worker_shutdown_duration(); tracing::info!( duration=?shutdown_duration, diff --git a/engine/packages/guard-core/src/server.rs b/engine/packages/guard-core/src/server.rs index b9b591a7fc..59c16b1d3a 100644 --- a/engine/packages/guard-core/src/server.rs +++ b/engine/packages/guard-core/src/server.rs @@ -112,7 +112,7 @@ pub async fn run_server( tokio::spawn( async move { if let Err(err) = conn.await { - tracing::error!("{} connection error: {}", port_type_str, err); + tracing::warn!("{} connection error: {}", port_type_str, err); } tracing::debug!("{} connection dropped: {}", port_type_str, remote_addr); diff --git a/engine/packages/pegboard/src/workflows/actor/runtime.rs b/engine/packages/pegboard/src/workflows/actor/runtime.rs index 3655fc0079..d95d379696 100644 --- a/engine/packages/pegboard/src/workflows/actor/runtime.rs +++ b/engine/packages/pegboard/src/workflows/actor/runtime.rs @@ -478,6 +478,7 @@ async fn allocate_actor_v2( state.for_serverless = res.serverless; state.allocated_serverless_slot = res.serverless; + state.reschedule_ts = None; match &res.status { AllocateActorStatus::Allocated { diff --git a/engine/packages/pegboard/src/workflows/serverless/receiver.rs b/engine/packages/pegboard/src/workflows/serverless/receiver.rs index e3c3dcbec8..6f0af7bc16 100644 --- a/engine/packages/pegboard/src/workflows/serverless/receiver.rs +++ b/engine/packages/pegboard/src/workflows/serverless/receiver.rs @@ -33,6 +33,8 @@ pub async fn pegboard_serverless_receiver(ctx: &mut WorkflowCtx, input: &Input) namespace_id: input.namespace_id, runner_name: input.runner_name.clone(), }) + .tag("namespace_id", input.namespace_id) + .tag("runner_name", input.runner_name.clone()) .dispatch() .await?; diff --git a/engine/packages/pools/src/db/ups.rs b/engine/packages/pools/src/db/ups.rs index 9d098c2bf9..b10a4faa75 100644 --- a/engine/packages/pools/src/db/ups.rs +++ b/engine/packages/pools/src/db/ups.rs @@ -53,7 +53,8 @@ pub async fn setup(config: &Config, client_name: &str) -> Result { tracing::warn!(?server_addrs, "nats draining"); } async_nats::Event::Closed => { - tracing::error!(?server_addrs, "nats closed"); + // Engine is shutting down, not an error + tracing::info!(?server_addrs, "nats closed"); } async_nats::Event::SlowConsumer(sid) => { tracing::warn!(?server_addrs, ?sid, "nats slow consumer"); diff --git a/engine/packages/runtime/src/traces.rs b/engine/packages/runtime/src/traces.rs index d07c4155cb..2f0103696a 100644 --- a/engine/packages/runtime/src/traces.rs +++ b/engine/packages/runtime/src/traces.rs @@ -118,6 +118,7 @@ fn build_filter_from_spec(filter_spec: &str) -> anyhow::Result { .add_directive("info".parse()?) // Disable verbose logs .add_directive("tokio_cron_scheduler=warn".parse()?) + .add_directive("async_nats=warn".parse()?) .add_directive("tokio=warn".parse()?) .add_directive("hyper=warn".parse()?) .add_directive("h2=warn".parse()?); diff --git a/engine/packages/universaldb/src/driver/mod.rs b/engine/packages/universaldb/src/driver/mod.rs index c00b9ef12d..d261be246d 100644 --- a/engine/packages/universaldb/src/driver/mod.rs +++ b/engine/packages/universaldb/src/driver/mod.rs @@ -78,6 +78,11 @@ pub trait TransactionDriver: Send + Sync { end: &'a [u8], ) -> Pin> + Send + 'a>>; + fn tag(&self, _tag: &str) -> Result<()> { + // No-op unless implemented + Ok(()) + } + // Helper for committing without consuming self (for database drivers that need it) fn commit_ref(&self) -> Pin> + Send + '_>> { Box::pin(async move { diff --git a/engine/packages/universaldb/src/transaction.rs b/engine/packages/universaldb/src/transaction.rs index fdc1353500..abde98e5b2 100644 --- a/engine/packages/universaldb/src/transaction.rs +++ b/engine/packages/universaldb/src/transaction.rs @@ -263,6 +263,11 @@ impl Transaction { ) -> Pin> + Send + 'a>> { self.driver.get_estimated_range_size_bytes(begin, end) } + + /// Adds a tag to the current transaction + pub fn tag(&self, tag: &str) -> Result<()> { + self.driver.tag(tag) + } } pub struct InformalTransaction<'t> { diff --git a/engine/packages/universalpubsub/src/lib.rs b/engine/packages/universalpubsub/src/lib.rs index 2e61a700f0..2685e86607 100644 --- a/engine/packages/universalpubsub/src/lib.rs +++ b/engine/packages/universalpubsub/src/lib.rs @@ -3,6 +3,8 @@ pub mod driver; pub mod errors; pub mod metrics; pub mod pubsub; +pub mod subject; pub use driver::*; pub use pubsub::{Message, NextOutput, PubSub, Response, Subscriber}; +pub use subject::Subject; diff --git a/engine/packages/universalpubsub/src/pubsub.rs b/engine/packages/universalpubsub/src/pubsub.rs index 46c379ae34..b6328945ca 100644 --- a/engine/packages/universalpubsub/src/pubsub.rs +++ b/engine/packages/universalpubsub/src/pubsub.rs @@ -13,6 +13,7 @@ use rivet_util::backoff::Backoff; use crate::chunking::{ChunkTracker, encode_chunk, split_payload_into_chunks}; use crate::driver::{PubSubDriverHandle, PublishOpts, SubscriberDriverHandle}; use crate::metrics; +use crate::subject::Subject; const GC_INTERVAL: Duration = Duration::from_secs(60); @@ -82,9 +83,11 @@ impl PubSub { } #[tracing::instrument(skip_all, fields(%subject))] - pub async fn subscribe(&self, subject: &str) -> Result { + pub async fn subscribe(&self, subject: impl Subject) -> Result { + let subject = subject.as_cow(); + // Underlying driver subscription - let driver = self.driver.subscribe(subject).await?; + let driver = self.driver.subscribe(&subject).await?; if !self.memory_optimization { return Ok(Subscriber::new(driver, self.clone(), None)); @@ -117,46 +120,33 @@ impl PubSub { } #[tracing::instrument(skip_all, fields(%subject))] - pub async fn publish(&self, subject: &str, payload: &[u8], opts: PublishOpts) -> Result<()> { - let message_id = *Uuid::new_v4().as_bytes(); - let chunks = - split_payload_into_chunks(payload, self.driver.max_message_size(), message_id, None)?; - let chunk_count = chunks.len() as u32; - - let use_local = self - .should_use_local_subscriber(subject, opts.behavior) - .await; - - for (chunk_idx, chunk_payload) in chunks.into_iter().enumerate() { - let encoded = encode_chunk( - chunk_payload, - chunk_idx as u32, - chunk_count, - message_id, - None, - )?; - - if use_local { - if let Some(sender) = self.local_subscribers.get_async(subject).await { - let _ = sender.send(encoded); - } else { - tracing::warn!(%subject, "local subscriber disappeared"); - break; - } - } else { - // Use backoff when publishing through the driver - self.publish_with_backoff(subject, &encoded).await?; - } - } - Ok(()) + pub async fn publish( + &self, + subject: impl Subject, + payload: &[u8], + opts: PublishOpts, + ) -> Result<()> { + self.publish_inner(subject, payload, None::<&str>, opts) + .await } #[tracing::instrument(skip_all, fields(%subject, %reply_subject))] pub async fn publish_with_reply( &self, - subject: &str, + subject: impl Subject, + payload: &[u8], + reply_subject: impl Subject, + opts: PublishOpts, + ) -> Result<()> { + self.publish_inner(subject, payload, Some(reply_subject), opts) + .await + } + + async fn publish_inner( + &self, + subject: impl Subject, payload: &[u8], - reply_subject: &str, + reply_subject: Option, opts: PublishOpts, ) -> Result<()> { let message_id = *Uuid::new_v4().as_bytes(); @@ -164,25 +154,27 @@ impl PubSub { payload, self.driver.max_message_size(), message_id, - Some(reply_subject), + reply_subject.as_ref().map(|x| x.as_cow()).as_deref(), )?; let chunk_count = chunks.len() as u32; let use_local = self - .should_use_local_subscriber(subject, opts.behavior) + .should_use_local_subscriber(&subject, opts.behavior) .await; + let subject_cow = subject.as_cow(); + for (chunk_idx, chunk_payload) in chunks.into_iter().enumerate() { let encoded = encode_chunk( chunk_payload, chunk_idx as u32, chunk_count, message_id, - Some(reply_subject.to_string()), + reply_subject.as_ref().map(|x| x.to_string()), )?; if use_local { - if let Some(sender) = self.local_subscribers.get_async(subject).await { + if let Some(sender) = self.local_subscribers.get_async(&*subject_cow).await { let _ = sender.send(encoded); } else { tracing::warn!(%subject, "local subscriber disappeared"); @@ -190,17 +182,19 @@ impl PubSub { } } else { // Use backoff when publishing through the driver - self.publish_with_backoff(subject, &encoded).await?; + self.publish_with_backoff(&subject, &encoded).await?; } } Ok(()) } #[tracing::instrument(skip_all, fields(%subject))] - async fn publish_with_backoff(&self, subject: &str, encoded: &[u8]) -> Result<()> { + async fn publish_with_backoff(&self, subject: &impl Subject, encoded: &[u8]) -> Result<()> { + let subject = subject.as_cow(); + let mut backoff = Backoff::default(); loop { - match self.driver.publish(subject, encoded).await { + match self.driver.publish(&subject, encoded).await { Result::Ok(_) => break, Err(err) if !backoff.tick().await => { tracing::warn!(?err, "error publishing, cannot retry again"); @@ -221,7 +215,7 @@ impl PubSub { } #[tracing::instrument(skip_all, fields(%subject))] - pub async fn request(&self, subject: &str, payload: &[u8]) -> Result { + pub async fn request(&self, subject: impl Subject, payload: &[u8]) -> Result { self.request_with_timeout(subject, payload, Duration::from_secs(30)) .await } @@ -229,7 +223,7 @@ impl PubSub { #[tracing::instrument(skip_all, fields(%subject))] pub async fn request_with_timeout( &self, - subject: &str, + subject: impl Subject, payload: &[u8], timeout: Duration, ) -> Result { @@ -299,7 +293,7 @@ impl PubSub { #[tracing::instrument(skip_all, fields(%subject))] async fn should_use_local_subscriber( &self, - subject: &str, + subject: &impl Subject, behavior: crate::driver::PublishBehavior, ) -> bool { // Local fast-path for one-subscriber behavior: @@ -317,7 +311,7 @@ impl PubSub { if !matches!(behavior, crate::driver::PublishBehavior::OneSubscriber) { return false; } - if let Some(sender) = self.local_subscribers.get_async(subject).await { + if let Some(sender) = self.local_subscribers.get_async(&*subject.as_cow()).await { sender.receiver_count() > 0 } else { false diff --git a/engine/packages/universalpubsub/src/subject.rs b/engine/packages/universalpubsub/src/subject.rs new file mode 100644 index 0000000000..7cda65d28d --- /dev/null +++ b/engine/packages/universalpubsub/src/subject.rs @@ -0,0 +1,32 @@ +use std::{borrow::Cow, fmt::Display}; + +pub trait Subject: Display { + /// Used for metrics. + fn root<'a>() -> Option> { + None + } + + fn as_str(&self) -> Option<&str> { + None + } + + fn as_cow<'a>(&'a self) -> Cow<'a, str> { + if let Some(subject) = self.as_str() { + Cow::Borrowed(subject) + } else { + Cow::Owned(self.to_string()) + } + } +} + +impl Subject for &str { + fn as_str(&self) -> Option<&str> { + Some(self) + } +} + +impl Subject for &String { + fn as_str(&self) -> Option<&str> { + Some(self) + } +} diff --git a/engine/sdks/typescript/test-runner/src/index.ts b/engine/sdks/typescript/test-runner/src/index.ts index 938b14f1fe..9955822cc5 100644 --- a/engine/sdks/typescript/test-runner/src/index.ts +++ b/engine/sdks/typescript/test-runner/src/index.ts @@ -152,7 +152,7 @@ async function autoConfigureServerless() { datacenters: { default: { serverless: { - url: `http://localhost:${INTERNAL_SERVER_PORT}`, + url: `http://localhost:${INTERNAL_SERVER_PORT}/api/rivet`, max_runners: 10000, slots_per_runner: 1, request_lifespan: 300, diff --git a/frontend/src/app/getting-started.tsx b/frontend/src/app/getting-started.tsx index bd101e7da5..f2a1315229 100644 --- a/frontend/src/app/getting-started.tsx +++ b/frontend/src/app/getting-started.tsx @@ -941,4 +941,4 @@ function PackageManagerCode(props: { )} ); -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-kv.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-kv.ts index 0366ed41b5..586083e16c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-kv.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-kv.ts @@ -108,4 +108,4 @@ export function runActorKvTests(driverTestConfig: DriverTestConfig) { expect(values).toEqual([4, 8, 15, 16, 23, 42]); }); }); -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 69ba2446af..fb68c85f2e 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -168,7 +168,7 @@ export class EngineActorDriver implements ActorDriver { onConnected: () => { this.#runnerStarted.resolve(undefined); }, - onDisconnected: (_code, _reason) => {}, + onDisconnected: (_code, _reason) => { }, onShutdown: () => { this.#runnerStopped.resolve(undefined); this.#isRunnerStopped = true; @@ -444,7 +444,7 @@ export class EngineActorDriver implements ActorDriver { async serverlessHandleStart(c: HonoContext): Promise { return streamSSE(c, async (stream) => { // NOTE: onAbort does not work reliably - stream.onAbort(() => {}); + stream.onAbort(() => { }); c.req.raw.signal.addEventListener("abort", () => { logger().debug("SSE aborted, shutting down runner"); @@ -572,7 +572,7 @@ export class EngineActorDriver implements ActorDriver { if (protocolMetadata.serverlessDrainGracePeriod) { const drainMax = Math.max( Number(protocolMetadata.serverlessDrainGracePeriod) - - 1000, + 1000, 0, ); handler.actor.overrides.runStopTimeout = drainMax; @@ -595,12 +595,12 @@ export class EngineActorDriver implements ActorDriver { const error = innerError instanceof Error ? new Error( - `Failed to start actor ${actorId}: ${innerError.message}`, - { cause: innerError }, - ) + `Failed to start actor ${actorId}: ${innerError.message}`, + { cause: innerError }, + ) : new Error( - `Failed to start actor ${actorId}: ${String(innerError)}`, - ); + `Failed to start actor ${actorId}: ${String(innerError)}`, + ); handler.actor = undefined; handler.actorStartError = error; handler.actorStartPromise?.reject(error); @@ -1127,4 +1127,4 @@ export class EngineActorDriver implements ActorDriver { entry.bufferedMessageSize = 0; } } -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts index 518418316a..a97fc11c2d 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts @@ -659,7 +659,7 @@ export class FileSystemGlobalState { await actor.actor.onStop("sleep"); } finally { // Ensure any pending KV writes finish before removing the entry. - await this.#withActorWrite(actorId, async () => {}); + await this.#withActorWrite(actorId, async () => { }); this.#closeActorKvDatabase(actorId); actor.stopPromise?.resolve(); actor.stopPromise = undefined; @@ -718,7 +718,7 @@ export class FileSystemGlobalState { } // Ensure any pending KV writes finish before deleting files. - await this.#withActorWrite(actorId, async () => {}); + await this.#withActorWrite(actorId, async () => { }); this.#closeActorKvDatabase(actorId); // Clear alarm timeout if exists @@ -778,7 +778,7 @@ export class FileSystemGlobalState { } } finally { // Ensure any pending KV writes finish before clearing the entry. - await this.#withActorWrite(actorId, async () => {}); + await this.#withActorWrite(actorId, async () => { }); actor.stopPromise?.resolve(); actor.stopPromise = undefined; @@ -947,7 +947,7 @@ export class FileSystemGlobalState { try { const fs = getNodeFs(); await fs.unlink(tempPath); - } catch {} + } catch { } logger().error({ msg: "failed to write alarm", actorId, @@ -1495,19 +1495,19 @@ export class FileSystemGlobalState { const limit = options?.limit ?? DEFAULT_LIST_LIMIT; const rows = upperBound ? db.all<{ - key: Uint8Array | ArrayBuffer; - value: Uint8Array | ArrayBuffer; - }>( - `SELECT key, value FROM kv WHERE key >= ? AND key < ? ORDER BY key ${direction} LIMIT ?`, - [prefix, upperBound, limit], - ) + key: Uint8Array | ArrayBuffer; + value: Uint8Array | ArrayBuffer; + }>( + `SELECT key, value FROM kv WHERE key >= ? AND key < ? ORDER BY key ${direction} LIMIT ?`, + [prefix, upperBound, limit], + ) : db.all<{ - key: Uint8Array | ArrayBuffer; - value: Uint8Array | ArrayBuffer; - }>( - `SELECT key, value FROM kv WHERE key >= ? ORDER BY key ${direction} LIMIT ?`, - [prefix, limit], - ); + key: Uint8Array | ArrayBuffer; + value: Uint8Array | ArrayBuffer; + }>( + `SELECT key, value FROM kv WHERE key >= ? ORDER BY key ${direction} LIMIT ?`, + [prefix, limit], + ); return rows.map((row) => [ ensureUint8Array(row.key, "key"), @@ -1558,4 +1558,4 @@ export class FileSystemGlobalState { ensureUint8Array(row.value, "value"), ]); } -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts index cbd9c10578..1af2ec7f55 100644 --- a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts @@ -67,15 +67,15 @@ class ActorWorkflowMessageDriver implements WorkflowMessageDriver { sentAt: message.createdAt, ...(opts.completable ? { - complete: async (response?: unknown) => { - await this.#runCtx.keepAwake( - this.#actor.queueManager.completeMessage( - message, - response, - ), - ); - }, - } + complete: async (response?: unknown) => { + await this.#runCtx.keepAwake( + this.#actor.queueManager.completeMessage( + message, + response, + ), + ); + }, + } : {}), })); } @@ -225,4 +225,4 @@ export class ActorWorkflowDriver implements EngineDriver { abortSignal, ); } -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/workflow-engine/src/context.ts b/rivetkit-typescript/packages/workflow-engine/src/context.ts index 29fff52ccb..e049ebbb68 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/context.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/context.ts @@ -1,1961 +1,604 @@ -import type { Logger } from "pino"; -import type { EngineDriver } from "./driver.js"; -import { - CancelledError, - CriticalError, - EntryInProgressError, - EvictedError, - HistoryDivergedError, - JoinError, - MessageWaitError, - RaceError, - RollbackCheckpointError, - RollbackError, - RollbackStopError, - SleepError, - StepExhaustedError, - StepFailedError, -} from "./errors.js"; -import { buildLoopIterationRange, buildEntryMetadataKey } from "./keys.js"; -import { - appendLoopIteration, - appendName, - emptyLocation, - isLocationPrefix, - locationToKey, - registerName, -} from "./location.js"; -import { - createEntry, - deleteEntriesWithPrefix, - flush, - getOrCreateMetadata, - loadMetadata, - setEntry, - type PendingDeletions, -} from "./storage.js"; +import type { RunContext } from "@/actor/contexts/run"; +import type { Client } from "@/client/client"; +import type { Registry } from "@/registry"; +import type { ActorDefinition, AnyActorDefinition } from "@/actor/definition"; +import type { + AnyDatabaseProvider, + InferDatabaseClient, +} from "@/actor/database"; +import type { + QueueFilterName, + QueueNextBatchOptions, + QueueNextOptions, + QueueResultMessageForName, +} from "@/actor/instance/queue"; +import type { + EventSchemaConfig, + InferEventArgs, + InferSchemaMap, + QueueSchemaConfig, +} from "@/actor/schema"; +import type { WorkflowContextInterface } from "@rivetkit/workflow-engine"; import type { BranchConfig, BranchOutput, - BranchStatus, - Entry, EntryKindType, - EntryMetadata, - Location, LoopConfig, - LoopIterationResult, LoopResult, - Message, - RollbackContextInterface, StepConfig, - Storage, - WorkflowContextInterface, - WorkflowQueue, WorkflowQueueMessage, - WorkflowQueueNextBatchOptions, - WorkflowQueueNextOptions, - WorkflowMessageDriver, -} from "./types.js"; -import { sleep } from "./utils.js"; - -/** - * Default values for step configuration. - * These are exported so users can reference them when overriding. - */ -export const DEFAULT_MAX_RETRIES = 3; -export const DEFAULT_RETRY_BACKOFF_BASE = 100; -export const DEFAULT_RETRY_BACKOFF_MAX = 30000; -export const DEFAULT_LOOP_HISTORY_PRUNE_INTERVAL = 20; -export const DEFAULT_STEP_TIMEOUT = 30000; // 30 seconds +} from "@rivetkit/workflow-engine"; +import { WORKFLOW_GUARD_KV_KEY } from "./constants"; + +type WorkflowActorQueueNextOptions< + TName extends string, + TCompletable extends boolean, +> = Omit, "signal">; + +type WorkflowActorQueueNextOptionsFallback = Omit< + QueueNextOptions, + "signal" +>; + +type WorkflowActorQueueNextBatchOptions< + TName extends string, + TCompletable extends boolean, +> = Omit, "signal">; + +type WorkflowActorQueueNextBatchOptionsFallback = + Omit, "signal">; + +type ActorWorkflowLoopConfig< + S, + T, + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, + TEvents extends EventSchemaConfig, + TQueues extends QueueSchemaConfig, +> = Omit, "run"> & { + run: ( + ctx: ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + state: S, + ) => Promise | (S extends undefined ? void : never)>; +}; + +type ActorWorkflowBranchConfig< + TOutput, + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, + TEvents extends EventSchemaConfig, + TQueues extends QueueSchemaConfig, +> = { + run: ( + ctx: ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + ) => Promise; +}; + +export class ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase extends AnyDatabaseProvider, + TEvents extends EventSchemaConfig = Record, + TQueues extends QueueSchemaConfig = Record, +> implements WorkflowContextInterface { + #inner: WorkflowContextInterface; + #runCtx: RunContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >; + #actorAccessDepth = 0; + #allowActorAccess = false; + #guardViolation = false; -const QUEUE_HISTORY_MESSAGE_MARKER = "__rivetWorkflowQueueMessage"; - -/** - * Calculate backoff delay with exponential backoff. - * Uses deterministic calculation (no jitter) for replay consistency. - */ -function calculateBackoff(attempts: number, base: number, max: number): number { - // Exponential backoff without jitter for determinism - return Math.min(max, base * 2 ** attempts); -} - -/** - * Error thrown when a step times out. - */ -export class StepTimeoutError extends Error { constructor( - public readonly stepName: string, - public readonly timeoutMs: number, + inner: WorkflowContextInterface, + runCtx: RunContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, ) { - super(`Step "${stepName}" timed out after ${timeoutMs}ms`); - this.name = "StepTimeoutError"; + this.#inner = inner; + this.#runCtx = runCtx; } -} -/** - * Internal representation of a rollback handler. - */ -export interface RollbackAction { - entryId: string; - name: string; - output: T; - rollback: (ctx: RollbackContextInterface, output: T) => Promise; -} - -/** - * Internal implementation of WorkflowContext. - */ -export class WorkflowContextImpl implements WorkflowContextInterface { - private entryInProgress = false; - private abortController: AbortController; - private currentLocation: Location; - private visitedKeys = new Set(); - private mode: "forward" | "rollback"; - private rollbackActions?: RollbackAction[]; - private rollbackCheckpointSet: boolean; - /** Track names used in current execution to detect duplicates */ - private usedNamesInExecution = new Set(); - private pendingCompletableMessageIds = new Set(); - private historyNotifier?: () => void; - private logger?: Logger; - - constructor( - public readonly workflowId: string, - private storage: Storage, - private driver: EngineDriver, - private messageDriver: WorkflowMessageDriver, - location: Location = emptyLocation(), - abortController?: AbortController, - mode: "forward" | "rollback" = "forward", - rollbackActions?: RollbackAction[], - rollbackCheckpointSet = false, - historyNotifier?: () => void, - logger?: Logger, - ) { - this.currentLocation = location; - this.abortController = abortController ?? new AbortController(); - this.mode = mode; - this.rollbackActions = rollbackActions; - this.rollbackCheckpointSet = rollbackCheckpointSet; - this.historyNotifier = historyNotifier; - this.logger = logger; + get workflowId(): string { + return this.#inner.workflowId; } get abortSignal(): AbortSignal { - return this.abortController.signal; - } + return this.#inner.abortSignal; + } + + get queue() { + const self = this; + function next< + const TName extends QueueFilterName, + const TCompletable extends boolean = false, + >( + name: string, + opts?: WorkflowActorQueueNextOptions, + ): Promise>; + function next( + name: string, + opts?: WorkflowActorQueueNextOptionsFallback, + ): Promise< + QueueResultMessageForName< + TQueues, + QueueFilterName, + TCompletable + > + >; + async function next( + name: string, + opts?: WorkflowActorQueueNextOptions, + ): Promise> { + const message = await self.#inner.queue.next(name, opts); + return self.#toActorQueueMessage(message); + } + + function nextBatch< + const TName extends QueueFilterName, + const TCompletable extends boolean = false, + >( + name: string, + opts?: WorkflowActorQueueNextBatchOptions, + ): Promise< + Array> + >; + function nextBatch( + name: string, + opts?: WorkflowActorQueueNextBatchOptionsFallback, + ): Promise< + Array< + QueueResultMessageForName< + TQueues, + QueueFilterName, + TCompletable + > + > + >; + async function nextBatch( + name: string, + opts?: WorkflowActorQueueNextBatchOptions, + ): Promise>> { + const messages = await self.#inner.queue.nextBatch(name, opts); + return messages.map((message) => + self.#toActorQueueMessage(message), + ); + } + + function send( + name: K, + body: InferSchemaMap[K], + ): Promise; + function send( + name: keyof TQueues extends never ? string : never, + body: unknown, + ): Promise; + async function send(name: string, body: unknown): Promise { + self.#ensureActorAccess("queue.send"); + await self.#runCtx.queue.send(name as never, body as never); + } - get queue(): WorkflowQueue { return { - next: async (name, opts) => await this.queueNext(name, opts), - nextBatch: async (name, opts) => - await this.queueNextBatch(name, opts), - send: async (name, body) => await this.queueSend(name, body), + next, + nextBatch, + send, }; } - isEvicted(): boolean { - return this.abortSignal.aborted; - } - - private assertNotInProgress(): void { - if (this.entryInProgress) { - throw new EntryInProgressError(); + async step( + nameOrConfig: string | Parameters[0], + run?: () => Promise, + ): Promise { + if (typeof nameOrConfig === "string") { + if (!run) { + throw new Error("Step run function missing"); + } + return await this.#wrapActive(() => + this.#inner.step(nameOrConfig, () => + this.#withActorAccess(run), + ), + ); } + const stepConfig = nameOrConfig as StepConfig; + const config: StepConfig = { + ...stepConfig, + run: () => this.#withActorAccess(stepConfig.run), + }; + return await this.#wrapActive(() => this.#inner.step(config)); } - private checkEvicted(): void { - if (this.abortSignal.aborted) { - throw new EvictedError(); + async loop( + name: string, + run: ( + ctx: ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + ) => Promise | void>, + ): Promise; + async loop( + name: string, + run: ( + ctx: WorkflowContextInterface, + ) => Promise | void>, + ): Promise; + async loop( + config: ActorWorkflowLoopConfig< + S, + T, + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + ): Promise; + async loop(config: LoopConfig): Promise; + async loop( + nameOrConfig: + | string + | LoopConfig + | ActorWorkflowLoopConfig< + any, + any, + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + run?: ( + ctx: ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + ) => Promise | void>, + ): Promise { + if (typeof nameOrConfig === "string") { + if (!run) { + throw new Error("Loop run function missing"); + } + return await this.#wrapActive(() => + this.#inner.loop(nameOrConfig, async (ctx) => + run(this.#createChildContext(ctx)), + ), + ); } + const wrapped: LoopConfig = { + ...nameOrConfig, + run: async (ctx, state) => + nameOrConfig.run(this.#createChildContext(ctx), state), + }; + return await this.#wrapActive(() => this.#inner.loop(wrapped)); } - private async flushStorage(): Promise { - await flush(this.storage, this.driver, this.historyNotifier); + sleep(name: string, durationMs: number): Promise { + return this.#inner.sleep(name, durationMs); } - /** - * Create a new branch context for parallel/nested execution. - */ - createBranch( - location: Location, - abortController?: AbortController, - ): WorkflowContextImpl { - return new WorkflowContextImpl( - this.workflowId, - this.storage, - this.driver, - this.messageDriver, - location, - abortController ?? this.abortController, - this.mode, - this.rollbackActions, - this.rollbackCheckpointSet, - this.historyNotifier, - this.logger, - ); + sleepUntil(name: string, timestampMs: number): Promise { + return this.#inner.sleepUntil(name, timestampMs); } - /** - * Log a debug message using the configured logger. - */ - private log( - level: "debug" | "info" | "warn" | "error", - data: Record, - ): void { - if (!this.logger) return; - this.logger[level](data); + async rollbackCheckpoint(name: string): Promise { + await this.#wrapActive(() => this.#inner.rollbackCheckpoint(name)); + } + + async join< + T extends Record< + string, + ActorWorkflowBranchConfig< + unknown, + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > + >, + >( + name: string, + branches: T, + ): Promise<{ [K in keyof T]: Awaited> }>; + async join>>( + name: string, + branches: T, + ): Promise<{ [K in keyof T]: BranchOutput }>; + async join(name: string, branches: Record>) { + const wrappedBranches = Object.fromEntries( + Object.entries(branches).map(([key, branch]) => [ + key, + { + run: async (ctx: WorkflowContextInterface) => + branch.run(this.#createChildContext(ctx)), + }, + ]), + ) as Record>; + return await this.#wrapActive(() => + this.#inner.join(name, wrappedBranches), + ); } - /** - * Mark a key as visited. - */ - private markVisited(key: string): void { - this.visitedKeys.add(key); + async race( + name: string, + branches: Array<{ + name: string; + run: ( + ctx: ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + >, + ) => Promise; + }>, + ): Promise<{ winner: string; value: T }>; + async race( + name: string, + branches: Array<{ + name: string; + run: (ctx: WorkflowContextInterface) => Promise; + }>, + ): Promise<{ winner: string; value: T }> { + const wrappedBranches = branches.map((branch) => ({ + name: branch.name, + run: (ctx: WorkflowContextInterface) => + branch.run(this.#createChildContext(ctx)), + })); + return (await this.#wrapActive(() => + this.#inner.race(name, wrappedBranches), + )) as { winner: string; value: T }; } - /** - * Check if a name has already been used at the current location in this execution. - * Throws HistoryDivergedError if duplicate detected. - */ - private checkDuplicateName(name: string): void { - const fullKey = - locationToKey(this.storage, this.currentLocation) + "/" + name; - if (this.usedNamesInExecution.has(fullKey)) { - throw new HistoryDivergedError( - `Duplicate entry name "${name}" at location "${locationToKey(this.storage, this.currentLocation)}". ` + - `Each step/loop/sleep/queue.next/join/race must have a unique name within its scope.`, - ); - } - this.usedNamesInExecution.add(fullKey); + async removed(name: string, originalType: EntryKindType): Promise { + await this.#wrapActive(() => this.#inner.removed(name, originalType)); } - private stopRollback(): never { - throw new RollbackStopError(); + isEvicted(): boolean { + return this.#inner.isEvicted(); } - private stopRollbackIfMissing(entry: Entry | undefined): void { - if (this.mode === "rollback" && !entry) { - this.stopRollback(); - } + get state(): TState extends never ? never : TState { + this.#ensureActorAccess("state"); + return this.#runCtx.state as TState extends never ? never : TState; } - private stopRollbackIfIncomplete(condition: boolean): void { - if (this.mode === "rollback" && condition) { - this.stopRollback(); - } + get vars(): TVars extends never ? never : TVars { + this.#ensureActorAccess("vars"); + return this.#runCtx.vars as TVars extends never ? never : TVars; } - private registerRollbackAction( - config: StepConfig, - entryId: string, - output: T, - metadata: EntryMetadata, - ): void { - if (!config.rollback) { - return; - } - if (metadata.rollbackCompletedAt !== undefined) { - return; - } - this.rollbackActions?.push({ - entryId, - name: config.name, - output: output as unknown, - rollback: config.rollback as ( - ctx: RollbackContextInterface, - output: unknown, - ) => Promise, - }); + client = Registry>(): Client { + this.#ensureActorAccess("client"); + return this.#runCtx.client(); } - /** - * Ensure a rollback checkpoint exists before registering rollback handlers. - */ - private ensureRollbackCheckpoint(config: StepConfig): void { - if (!config.rollback) { - return; - } - if (!this.rollbackCheckpointSet) { - throw new RollbackCheckpointError(); - } + get db(): TDatabase extends never ? never : InferDatabaseClient { + this.#ensureActorAccess("db"); + return this.#runCtx.db as TDatabase extends never + ? never + : InferDatabaseClient; } - /** - * Validate that all expected entries in the branch were visited. - * Throws HistoryDivergedError if there are unvisited entries. - */ - validateComplete(): void { - const prefix = locationToKey(this.storage, this.currentLocation); - - for (const key of this.storage.history.entries.keys()) { - // Check if this key is under our current location prefix - // Handle root prefix (empty string) specially - all keys are under root - const isUnderPrefix = - prefix === "" - ? true // Root: all keys are children - : key.startsWith(prefix + "/") || key === prefix; - - if (isUnderPrefix) { - if (!this.visitedKeys.has(key)) { - // Entry exists in history but wasn't visited - // This means workflow code may have changed - throw new HistoryDivergedError( - `Entry "${key}" exists in history but was not visited. ` + - `Workflow code may have changed. Use ctx.removed() to handle migrations.`, - ); - } - } - } + get log() { + return this.#runCtx.log; } - /** - * Evict the workflow. - */ - evict(): void { - this.abortController.abort(new EvictedError()); + keepAwake(promise: Promise): Promise { + this.#ensureActorAccess("keepAwake"); + return this.#runCtx.keepAwake(promise); } - /** - * Wait for eviction message. - * - * The event listener uses { once: true } to auto-remove after firing, - * preventing memory leaks if this method is called multiple times. - */ - waitForEviction(): Promise { - return new Promise((_, reject) => { - if (this.abortSignal.aborted) { - reject(new EvictedError()); - return; - } - this.abortSignal.addEventListener( - "abort", - () => { - reject(new EvictedError()); - }, - { once: true }, - ); - }); + waitUntil(promise: Promise): void { + this.#ensureActorAccess("waitUntil"); + this.#runCtx.waitUntil(promise); } - // === Step === - - async step( - nameOrConfig: string | StepConfig, - run?: () => Promise, - ): Promise { - this.assertNotInProgress(); - this.checkEvicted(); - - const config: StepConfig = - typeof nameOrConfig === "string" - ? { name: nameOrConfig, run: run! } - : nameOrConfig; - - this.entryInProgress = true; - try { - return await this.executeStep(config); - } finally { - this.entryInProgress = false; - } + get actorId(): string { + return this.#runCtx.actorId; } - private async executeStep(config: StepConfig): Promise { - this.ensureRollbackCheckpoint(config); - if (this.mode === "rollback") { - return await this.executeStepRollback(config); - } - - // Check for duplicate name in current execution - this.checkDuplicateName(config.name); - - const location = appendName( - this.storage, - this.currentLocation, - config.name, + broadcast( + name: K, + ...args: InferEventArgs[K]> + ): void; + broadcast( + name: keyof TEvents extends never ? string : never, + ...args: Array + ): void; + broadcast(name: string, ...args: Array): void { + this.#ensureActorAccess("broadcast"); + this.#runCtx.broadcast( + name as never, + ...(args as unknown[] as never[]), ); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - if (existing) { - if (existing.kind.type !== "step") { - throw new HistoryDivergedError( - `Expected step "${config.name}" at ${key}, found ${existing.kind.type}`, - ); - } - - const stepData = existing.kind.data; - - const metadata = await loadMetadata( - this.storage, - this.driver, - existing.id, - ); - - // Replay successful result (including void steps). - if ( - metadata.status === "completed" || - stepData.output !== undefined - ) { - return stepData.output as T; - } - - // Check if we should retry - const maxRetries = config.maxRetries ?? DEFAULT_MAX_RETRIES; - - if (metadata.attempts >= maxRetries) { - // Prefer step history error, but fall back to metadata since - // driver implementations may persist metadata without the history - // entry error (e.g. partial writes/crashes between attempts). - const lastError = stepData.error ?? metadata.error; - throw new StepExhaustedError(config.name, lastError); - } - - // Calculate backoff and yield to scheduler - // This allows the workflow to be evicted during backoff - const backoffDelay = calculateBackoff( - metadata.attempts, - config.retryBackoffBase ?? DEFAULT_RETRY_BACKOFF_BASE, - config.retryBackoffMax ?? DEFAULT_RETRY_BACKOFF_MAX, - ); - const retryAt = metadata.lastAttemptAt + backoffDelay; - const now = Date.now(); - - if (now < retryAt) { - // Yield to scheduler - will be woken up at retryAt - throw new SleepError(retryAt); - } - } - - // Execute the step - const entry = - existing ?? createEntry(location, { type: "step", data: {} }); - if (!existing) { - // New entry - register name - this.log("debug", { - msg: "executing new step", - step: config.name, - key, - }); - const nameIndex = registerName(this.storage, config.name); - entry.location = [...location]; - entry.location[entry.location.length - 1] = nameIndex; - setEntry(this.storage, location, entry); - } else { - this.log("debug", { msg: "retrying step", step: config.name, key }); - } - - const metadata = getOrCreateMetadata(this.storage, entry.id); - metadata.status = "running"; - metadata.attempts++; - metadata.lastAttemptAt = Date.now(); - metadata.dirty = true; - - // Get timeout configuration - const timeout = config.timeout ?? DEFAULT_STEP_TIMEOUT; + } + #toActorQueueMessage( + message: WorkflowQueueMessage, + ): WorkflowQueueMessage & { id: bigint } { + let id: bigint; try { - // Execute with timeout - const output = await this.executeWithTimeout( - config.run(), - timeout, - config.name, - ); - - if (entry.kind.type === "step") { - entry.kind.data.output = output; - } - entry.dirty = true; - metadata.status = "completed"; - metadata.error = undefined; - metadata.completedAt = Date.now(); - - // Ephemeral steps don't trigger an immediate flush. This avoids the - // synchronous write overhead for transient operations. Note that the - // step's entry is still marked dirty and WILL be persisted on the - // next flush from a non-ephemeral operation. The purpose of ephemeral - // is to batch writes, not to avoid persistence entirely. - if (!config.ephemeral) { - this.log("debug", { - msg: "flushing step", - step: config.name, - key, - }); - await this.flushStorage(); - } - - this.log("debug", { - msg: "step completed", - step: config.name, - key, - }); - return output; - } catch (error) { - // Timeout errors are treated as critical (no retry) - if (error instanceof StepTimeoutError) { - if (entry.kind.type === "step") { - entry.kind.data.error = String(error); - } - entry.dirty = true; - metadata.status = "exhausted"; - metadata.error = String(error); - await this.flushStorage(); - throw new CriticalError(error.message); - } - - if ( - error instanceof CriticalError || - error instanceof RollbackError - ) { - if (entry.kind.type === "step") { - entry.kind.data.error = String(error); - } - entry.dirty = true; - metadata.status = "exhausted"; - metadata.error = String(error); - await this.flushStorage(); - throw error; - } - - if (entry.kind.type === "step") { - entry.kind.data.error = String(error); - } - entry.dirty = true; - metadata.status = "failed"; - metadata.error = String(error); - - await this.flushStorage(); - - throw new StepFailedError(config.name, error, metadata.attempts); + id = BigInt(message.id); + } catch { + throw new Error(`Invalid queue message id "${message.id}"`); } + return { + id, + name: message.name, + body: message.body, + createdAt: message.createdAt, + ...(message.complete ? { complete: message.complete } : {}), + }; } - /** - * Execute a promise with timeout. - * - * Note: This does NOT cancel the underlying operation. JavaScript Promises - * cannot be cancelled once started. When a timeout occurs: - * - The step is marked as failed with StepTimeoutError - * - The underlying async operation continues running in the background - * - Any side effects from the operation may still occur - * - * For cancellable operations, pass ctx.abortSignal to APIs that support AbortSignal: - * - * return fetch(url, { signal: ctx.abortSignal }); - - * }); - * - * Or check ctx.isEvicted() periodically in long-running loops. - */ - private async executeStepRollback(config: StepConfig): Promise { - this.checkDuplicateName(config.name); - this.ensureRollbackCheckpoint(config); - - const location = appendName( - this.storage, - this.currentLocation, - config.name, - ); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - this.markVisited(key); - - if (!existing || existing.kind.type !== "step") { - this.stopRollback(); - } - - const metadata = await loadMetadata( - this.storage, - this.driver, - existing.id, - ); - if (metadata.status !== "completed") { - this.stopRollback(); - } - - const output = existing.kind.data.output as T; - this.registerRollbackAction(config, existing.id, output, metadata); - - return output; + async #wrapActive(run: () => Promise): Promise { + return await this.#runCtx.keepAwake(run()); } - private async executeWithTimeout( - promise: Promise, - timeoutMs: number, - stepName: string, - ): Promise { - if (timeoutMs <= 0) { - return promise; + async #withActorAccess(run: () => Promise): Promise { + this.#actorAccessDepth++; + if (this.#actorAccessDepth === 1) { + this.#allowActorAccess = true; } - - let timeoutId: ReturnType | undefined; - const timeoutPromise = new Promise((_, reject) => { - timeoutId = setTimeout(() => { - reject(new StepTimeoutError(stepName, timeoutMs)); - }, timeoutMs); - }); - try { - return await Promise.race([promise, timeoutPromise]); + return await run(); } finally { - if (timeoutId !== undefined) { - clearTimeout(timeoutId); + this.#actorAccessDepth--; + if (this.#actorAccessDepth === 0) { + this.#allowActorAccess = false; } } } - // === Loop === - - async loop( - nameOrConfig: string | LoopConfig, - run?: ( - ctx: WorkflowContextInterface, - ) => LoopIterationResult, - ): Promise { - this.assertNotInProgress(); - this.checkEvicted(); - - const config: LoopConfig = - typeof nameOrConfig === "string" - ? { name: nameOrConfig, run: run as LoopConfig["run"] } - : nameOrConfig; - - this.entryInProgress = true; - try { - return await this.executeLoop(config); - } finally { - this.entryInProgress = false; - } - } - - private async executeLoop(config: LoopConfig): Promise { - // Check for duplicate name in current execution - this.checkDuplicateName(config.name); - - const location = appendName( - this.storage, - this.currentLocation, - config.name, - ); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - let entry: Entry; - let metadata: EntryMetadata | undefined; - let state: S; - let iteration: number; - let rollbackSingleIteration = false; - let rollbackIterationRan = false; - let rollbackOutput: T | undefined; - const rollbackMode = this.mode === "rollback"; - - if (existing) { - if (existing.kind.type !== "loop") { - throw new HistoryDivergedError( - `Expected loop "${config.name}" at ${key}, found ${existing.kind.type}`, - ); - } - - const loopData = existing.kind.data; - metadata = await loadMetadata( - this.storage, - this.driver, - existing.id, - ); - - if (rollbackMode) { - if (loopData.output !== undefined) { - return loopData.output as T; - } - rollbackSingleIteration = true; - rollbackIterationRan = false; - rollbackOutput = undefined; - } - - if (metadata.status === "completed") { - return loopData.output as T; - } - - // Loop already completed - if (loopData.output !== undefined) { - return loopData.output as T; - } - - // Resume from saved state - entry = existing; - state = loopData.state as S; - iteration = loopData.iteration; - if (rollbackMode) { - rollbackOutput = loopData.output as T | undefined; - rollbackIterationRan = rollbackOutput !== undefined; - } - } else { - this.stopRollbackIfIncomplete(true); - - // New loop - state = config.state as S; - iteration = 0; - entry = createEntry(location, { - type: "loop", - data: { state, iteration }, - }); - setEntry(this.storage, location, entry); - metadata = getOrCreateMetadata(this.storage, entry.id); - } - - if (metadata) { - metadata.status = "running"; - metadata.error = undefined; - metadata.dirty = true; - } - - const historyPruneInterval = - config.historyPruneInterval ?? DEFAULT_LOOP_HISTORY_PRUNE_INTERVAL; - const historySize = config.historySize ?? historyPruneInterval; - - // Track the last iteration we pruned up to so we only delete - // newly-expired iterations instead of re-scanning from 0. - let lastPrunedUpTo = 0; - - // Deferred flush promise from the previous prune cycle. Awaited at the - // start of the next iteration so the flush runs in parallel with user code. - let deferredFlush: Promise | null = null; - - // Execute loop iterations - while (true) { - // Await any deferred flush from the previous prune cycle - if (deferredFlush) { - await deferredFlush; - deferredFlush = null; - } - - if (rollbackMode && rollbackSingleIteration) { - if (rollbackIterationRan) { - return rollbackOutput as T; - } - this.stopRollbackIfIncomplete(true); - } - this.checkEvicted(); - - // Create branch for this iteration - const iterationLocation = appendLoopIteration( - this.storage, - location, - config.name, - iteration, + #ensureActorAccess(feature: string): void { + if (!this.#allowActorAccess) { + this.#guardViolation = true; + this.#markGuardTriggered(); + throw new Error( + `${feature} is only available inside workflow steps`, ); - const branchCtx = this.createBranch(iterationLocation); - - // Execute iteration - const iterationResult = await config.run(branchCtx, state); - if (iterationResult === undefined && state !== undefined) { - throw new Error( - `Loop "${config.name}" returned undefined for a stateful iteration. Return Loop.continue(state) or Loop.break(value).`, - ); - } - const result = - iterationResult === undefined - ? ({ continue: true, state } as LoopResult) - : iterationResult; - - // Validate branch completed cleanly - branchCtx.validateComplete(); - - if ("break" in result && result.break) { - // Loop complete - if (entry.kind.type === "loop") { - entry.kind.data.output = result.value; - entry.kind.data.state = state; - entry.kind.data.iteration = iteration; - } - entry.dirty = true; - if (metadata) { - metadata.status = "completed"; - metadata.completedAt = Date.now(); - metadata.dirty = true; - } - - // Collect pruning deletions and flush - const deletions = this.collectLoopPruning( - location, - iteration + 1, - historySize, - lastPrunedUpTo, - ); - await this.flushStorageWithDeletions(deletions); - - if (rollbackMode && rollbackSingleIteration) { - rollbackOutput = result.value; - rollbackIterationRan = true; - continue; - } - - return result.value; - } - - // Continue with new state - if ("continue" in result && result.continue) { - state = result.state; - } - iteration++; - - if (!rollbackMode) { - if (entry.kind.type === "loop") { - entry.kind.data.state = state; - entry.kind.data.iteration = iteration; - } - entry.dirty = true; - } - - // Periodically defer the flush so the next iteration can overlap - // with loop pruning and any pending dirty state writes. - if (iteration % historyPruneInterval === 0) { - const deletions = this.collectLoopPruning( - location, - iteration, - historySize, - lastPrunedUpTo, - ); - lastPrunedUpTo = Math.max(0, iteration - historySize); - - // Defer the flush to run in parallel with the next iteration - deferredFlush = this.flushStorageWithDeletions(deletions); - } } } - /** - * Collect pending deletions for loop history pruning. - * - * Only deletes iterations in the range [fromIteration, keepFrom) where - * keepFrom = currentIteration - historySize. This avoids re-scanning - * already-deleted iterations. - */ - private collectLoopPruning( - loopLocation: Location, - currentIteration: number, - historySize: number, - fromIteration: number, - ): PendingDeletions | undefined { - if (currentIteration <= historySize) { - return undefined; - } - - const keepFrom = Math.max(0, currentIteration - historySize); - if (fromIteration >= keepFrom) { - return undefined; - } - - const loopSegment = loopLocation[loopLocation.length - 1]; - if (typeof loopSegment !== "number") { - throw new Error("Expected loop location to end with a name index"); - } - - const range = buildLoopIterationRange( - loopLocation, - loopSegment, - fromIteration, - keepFrom, - ); - const metadataKeys: Uint8Array[] = []; - - for (const [key, entry] of this.storage.history.entries) { - if (!isLocationPrefix(loopLocation, entry.location)) { - continue; - } + consumeGuardViolation(): boolean { + const violated = this.#guardViolation; + this.#guardViolation = false; + return violated; + } - const iterationSegment = entry.location[loopLocation.length]; + #markGuardTriggered(): void { + try { + const state = this.#runCtx.state as Record; if ( - !iterationSegment || - typeof iterationSegment === "number" || - iterationSegment.loop !== loopSegment || - iterationSegment.iteration < fromIteration || - iterationSegment.iteration >= keepFrom + state && + typeof state === "object" && + "guardTriggered" in state ) { - continue; + (state as Record).guardTriggered = true; } - - metadataKeys.push(buildEntryMetadataKey(entry.id)); - this.storage.entryMetadata.delete(entry.id); - this.storage.history.entries.delete(key); + } catch { + // Ignore if state is unavailable } - return { - prefixes: [], - keys: metadataKeys, - ranges: [range], - }; - } - - /** - * Flush storage with optional pending deletions so pruning - * happens alongside the state write. - */ - private async flushStorageWithDeletions( - deletions?: PendingDeletions, - ): Promise { - await flush(this.storage, this.driver, this.historyNotifier, deletions); - } - - // === Sleep === - - async sleep(name: string, durationMs: number): Promise { - const deadline = Date.now() + durationMs; - return this.sleepUntil(name, deadline); + this.#runCtx.waitUntil( + (async () => { + try { + await this.#runCtx.kv.put(WORKFLOW_GUARD_KV_KEY, "true"); + } catch (error) { + this.#runCtx.log.error({ + msg: "failed to persist workflow guard flag", + error, + }); + } + })(), + ); } - async sleepUntil(name: string, timestampMs: number): Promise { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - await this.executeSleep(name, timestampMs); - } finally { - this.entryInProgress = false; - } + #createChildContext( + ctx: WorkflowContextInterface, + ): ActorWorkflowContext< + TState, + TConnParams, + TConnState, + TVars, + TInput, + TDatabase, + TEvents, + TQueues + > { + return new ActorWorkflowContext(ctx, this.#runCtx); } +} - private async executeSleep(name: string, deadline: number): Promise { - // Check for duplicate name in current execution - this.checkDuplicateName(name); - - const location = appendName(this.storage, this.currentLocation, name); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - let entry: Entry; - - if (existing) { - if (existing.kind.type !== "sleep") { - throw new HistoryDivergedError( - `Expected sleep "${name}" at ${key}, found ${existing.kind.type}`, - ); - } - - const sleepData = existing.kind.data; - - if (this.mode === "rollback") { - this.stopRollbackIfIncomplete(sleepData.state === "pending"); - return; - } - - // Already completed or interrupted - if (sleepData.state !== "pending") { - return; - } - - // Use stored deadline - deadline = sleepData.deadline; - entry = existing; - } else { - this.stopRollbackIfIncomplete(true); - - entry = createEntry(location, { - type: "sleep", - data: { deadline, state: "pending" }, - }); - setEntry(this.storage, location, entry); - entry.dirty = true; - await this.flushStorage(); - } - - const now = Date.now(); - const remaining = deadline - now; - - if (remaining <= 0) { - // Deadline passed - if (entry.kind.type === "sleep") { - entry.kind.data.state = "completed"; - } - entry.dirty = true; - await this.flushStorage(); - return; - } - - // Short sleep: wait in memory - if (remaining < this.driver.workerPollInterval) { - await Promise.race([sleep(remaining), this.waitForEviction()]); - - this.checkEvicted(); - - if (entry.kind.type === "sleep") { - entry.kind.data.state = "completed"; - } - entry.dirty = true; - await this.flushStorage(); - return; - } - - // Long sleep: yield to scheduler - throw new SleepError(deadline); - } - - // === Rollback Checkpoint === - - async rollbackCheckpoint(name: string): Promise { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - await this.executeRollbackCheckpoint(name); - } finally { - this.entryInProgress = false; - } - } - - private async executeRollbackCheckpoint(name: string): Promise { - this.checkDuplicateName(name); - - const location = appendName(this.storage, this.currentLocation, name); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - this.markVisited(key); - - if (existing) { - if (existing.kind.type !== "rollback_checkpoint") { - throw new HistoryDivergedError( - `Expected rollback checkpoint "${name}" at ${key}, found ${existing.kind.type}`, - ); - } - this.rollbackCheckpointSet = true; - return; - } - - if (this.mode === "rollback") { - throw new HistoryDivergedError( - `Missing rollback checkpoint "${name}" at ${key}`, - ); - } - - const entry = createEntry(location, { - type: "rollback_checkpoint", - data: { name }, - }); - setEntry(this.storage, location, entry); - entry.dirty = true; - await this.flushStorage(); - - this.rollbackCheckpointSet = true; - } - - // === Queue === - - private async queueSend(name: string, body: unknown): Promise { - const message: Message = { - id: crypto.randomUUID(), - name, - data: body, - sentAt: Date.now(), - }; - await this.messageDriver.addMessage(message); - } - - private async queueNext( - name: string, - opts?: WorkflowQueueNextOptions, - ): Promise> { - const messages = await this.queueNextBatch(name, { - ...(opts ?? {}), - count: 1, - }); - const message = messages[0]; - if (!message) { - throw new Error( - `queue.next("${name}") timed out before receiving a message. Use queue.nextBatch(...) for optional/time-limited reads.`, - ); - } - return message; - } - - private async queueNextBatch( - name: string, - opts?: WorkflowQueueNextBatchOptions, - ): Promise>> { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - return await this.executeQueueNextBatch(name, opts); - } finally { - this.entryInProgress = false; - } - } - - private async executeQueueNextBatch( - name: string, - opts?: WorkflowQueueNextBatchOptions, - ): Promise>> { - if (this.pendingCompletableMessageIds.size > 0) { - throw new Error( - "Previous completable queue message is not completed. Call `message.complete(...)` before receiving the next message.", - ); - } - - const resolvedOpts = opts ?? {}; - const messageNames = this.normalizeQueueNames(resolvedOpts.names); - const messageNameLabel = this.messageNamesLabel(messageNames); - const count = Math.max(1, resolvedOpts.count ?? 1); - const completable = resolvedOpts.completable === true; - - this.checkDuplicateName(name); - - const countLocation = appendName( - this.storage, - this.currentLocation, - `${name}:count`, - ); - const countKey = locationToKey(this.storage, countLocation); - const existingCount = this.storage.history.entries.get(countKey); - this.markVisited(countKey); - this.stopRollbackIfMissing(existingCount); - - let deadline: number | undefined; - let deadlineEntry: Entry | undefined; - if (resolvedOpts.timeout !== undefined) { - const deadlineLocation = appendName( - this.storage, - this.currentLocation, - `${name}:deadline`, - ); - const deadlineKey = locationToKey(this.storage, deadlineLocation); - deadlineEntry = this.storage.history.entries.get(deadlineKey); - this.markVisited(deadlineKey); - this.stopRollbackIfMissing(deadlineEntry); - - if (deadlineEntry && deadlineEntry.kind.type === "sleep") { - deadline = deadlineEntry.kind.data.deadline; - } else { - deadline = Date.now() + resolvedOpts.timeout; - const created = createEntry(deadlineLocation, { - type: "sleep", - data: { deadline, state: "pending" }, - }); - setEntry(this.storage, deadlineLocation, created); - created.dirty = true; - await this.flushStorage(); - deadlineEntry = created; - } - } - - if (existingCount && existingCount.kind.type === "message") { - const replayCount = existingCount.kind.data.data as number; - return await this.readReplayQueueMessages( - name, - replayCount, - completable, - ); - } - - const now = Date.now(); - if (deadline !== undefined && now >= deadline) { - if (deadlineEntry && deadlineEntry.kind.type === "sleep") { - deadlineEntry.kind.data.state = "completed"; - deadlineEntry.dirty = true; - } - await this.recordQueueCountEntry( - countLocation, - `${messageNameLabel}:count`, - 0, - ); - return []; - } - - const received = await this.receiveMessagesNow( - messageNames, - count, - completable, - ); - if (received.length > 0) { - const historyMessages = received.map((message) => - this.toWorkflowQueueMessage(message), - ); - if (deadlineEntry && deadlineEntry.kind.type === "sleep") { - deadlineEntry.kind.data.state = "interrupted"; - deadlineEntry.dirty = true; - } - await this.recordQueueMessages( - name, - countLocation, - messageNames, - historyMessages, - ); - const queueMessages = received.map((message, index) => - this.createQueueMessage(message, completable, { - historyLocation: appendName( - this.storage, - this.currentLocation, - `${name}:${index}`, - ), - }), - ); - return queueMessages; - } - - if (deadline === undefined) { - throw new MessageWaitError(messageNames); - } - throw new SleepError(deadline, messageNames); - } - - private normalizeQueueNames(names?: readonly string[]): string[] { - if (!names || names.length === 0) { - return []; - } - const deduped: string[] = []; - const seen = new Set(); - for (const name of names) { - if (seen.has(name)) { - continue; - } - seen.add(name); - deduped.push(name); - } - return deduped; - } - - private messageNamesLabel(messageNames: string[]): string { - if (messageNames.length === 0) { - return "*"; - } - return messageNames.length === 1 - ? messageNames[0] - : messageNames.join("|"); - } - - private async receiveMessagesNow( - messageNames: string[], - count: number, - completable: boolean, - ): Promise { - return await this.messageDriver.receiveMessages({ - names: messageNames.length > 0 ? messageNames : undefined, - count, - completable, - }); - } - - private async recordQueueMessages( - name: string, - countLocation: Location, - messageNames: string[], - messages: Array>, - ): Promise { - for (let i = 0; i < messages.length; i++) { - const messageLocation = appendName( - this.storage, - this.currentLocation, - `${name}:${i}`, - ); - const messageEntry = createEntry(messageLocation, { - type: "message", - data: { - name: messages[i].name, - data: this.toHistoryQueueMessage(messages[i]), - }, - }); - setEntry(this.storage, messageLocation, messageEntry); - this.markVisited(locationToKey(this.storage, messageLocation)); - } - await this.recordQueueCountEntry( - countLocation, - `${this.messageNamesLabel(messageNames)}:count`, - messages.length, - ); - } - - private async recordQueueCountEntry( - countLocation: Location, - countLabel: string, - count: number, - ): Promise { - const countEntry = createEntry(countLocation, { - type: "message", - data: { - name: countLabel, - data: count, - }, - }); - setEntry(this.storage, countLocation, countEntry); - await this.flushStorage(); - } - - private async readReplayQueueMessages( - name: string, - count: number, - completable: boolean, - ): Promise>> { - const results: Array> = []; - for (let i = 0; i < count; i++) { - const messageLocation = appendName( - this.storage, - this.currentLocation, - `${name}:${i}`, - ); - const messageKey = locationToKey(this.storage, messageLocation); - this.markVisited(messageKey); - const existingMessage = - this.storage.history.entries.get(messageKey); - if (!existingMessage || existingMessage.kind.type !== "message") { - throw new HistoryDivergedError( - `Expected queue message "${name}:${i}" in history`, - ); - } - const parsed = this.fromHistoryQueueMessage( - existingMessage.kind.data.name, - existingMessage.kind.data.data, - ); - results.push( - this.createQueueMessage(parsed.message, completable, { - historyLocation: messageLocation, - completed: parsed.completed, - replay: true, - }), - ); - } - return results; - } - - private toWorkflowQueueMessage( - message: Message, - ): WorkflowQueueMessage { - return { - id: message.id, - name: message.name, - body: message.data as T, - createdAt: message.sentAt, - }; - } - - private createQueueMessage( - message: Message, - completable: boolean, - opts?: { - historyLocation?: Location; - completed?: boolean; - replay?: boolean; - }, - ): WorkflowQueueMessage { - const queueMessage = this.toWorkflowQueueMessage(message); - if (!completable) { - return queueMessage; - } - - if (opts?.replay && opts.completed) { - return { - ...queueMessage, - complete: async () => { - // No-op: this message was already completed in a prior run. - }, - }; - } - - const messageId = message.id; - this.pendingCompletableMessageIds.add(messageId); - let completed = false; - - return { - ...queueMessage, - complete: async (response?: unknown) => { - if (completed) { - throw new Error("Queue message already completed"); - } - completed = true; - try { - await this.completeMessage(message, response); - await this.markQueueMessageCompleted(opts?.historyLocation); - this.pendingCompletableMessageIds.delete(messageId); - } catch (error) { - completed = false; - throw error; - } - }, - }; - } - - private async markQueueMessageCompleted( - historyLocation: Location | undefined, - ): Promise { - if (!historyLocation) { - return; - } - const key = locationToKey(this.storage, historyLocation); - const entry = this.storage.history.entries.get(key); - if (!entry || entry.kind.type !== "message") { - return; - } - const parsed = this.fromHistoryQueueMessage( - entry.kind.data.name, - entry.kind.data.data, - ); - entry.kind.data.data = this.toHistoryQueueMessage( - this.toWorkflowQueueMessage(parsed.message), - true, - ); - entry.dirty = true; - await this.flushStorage(); - } - - private async completeMessage( - message: Message, - response?: unknown, - ): Promise { - if (message.complete) { - await message.complete(response); - return; - } - await this.messageDriver.completeMessage(message.id, response); - } - - private toHistoryQueueMessage( - message: WorkflowQueueMessage, - completed = false, - ): unknown { - return { - [QUEUE_HISTORY_MESSAGE_MARKER]: 1, - id: message.id, - name: message.name, - body: message.body, - createdAt: message.createdAt, - completed, - }; - } - - private fromHistoryQueueMessage( - name: string, - value: unknown, - ): { message: Message; completed: boolean } { - if ( - typeof value === "object" && - value !== null && - (value as Record)[QUEUE_HISTORY_MESSAGE_MARKER] === - 1 - ) { - const serialized = value as Record; - const id = typeof serialized.id === "string" ? serialized.id : ""; - const serializedName = - typeof serialized.name === "string" ? serialized.name : name; - const createdAt = - typeof serialized.createdAt === "number" - ? serialized.createdAt - : 0; - const completed = - typeof serialized.completed === "boolean" - ? serialized.completed - : false; - return { - message: { - id, - name: serializedName, - data: serialized.body, - sentAt: createdAt, - }, - completed, - }; - } - return { - message: { - id: "", - name, - data: value, - sentAt: 0, - }, - completed: false, - }; - } - - // === Join === - - async join>>( - name: string, - branches: T, - ): Promise<{ [K in keyof T]: BranchOutput }> { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - return await this.executeJoin(name, branches); - } finally { - this.entryInProgress = false; - } - } - - private async executeJoin>>( - name: string, - branches: T, - ): Promise<{ [K in keyof T]: BranchOutput }> { - // Check for duplicate name in current execution - this.checkDuplicateName(name); - - const location = appendName(this.storage, this.currentLocation, name); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - this.stopRollbackIfMissing(existing); - - let entry: Entry; - - if (existing) { - if (existing.kind.type !== "join") { - throw new HistoryDivergedError( - `Expected join "${name}" at ${key}, found ${existing.kind.type}`, - ); - } - entry = existing; - } else { - entry = createEntry(location, { - type: "join", - data: { - branches: Object.fromEntries( - Object.keys(branches).map((k) => [ - k, - { status: "pending" as const }, - ]), - ), - }, - }); - setEntry(this.storage, location, entry); - entry.dirty = true; - // Flush immediately to persist entry before branches execute - await this.flushStorage(); - } - - if (entry.kind.type !== "join") { - throw new HistoryDivergedError("Entry type mismatch"); - } - - this.stopRollbackIfIncomplete( - Object.values(entry.kind.data.branches).some( - (branch) => branch.status !== "completed", - ), - ); - - const joinData = entry.kind.data; - const results: Record = {}; - const errors: Record = {}; - - // Execute all branches in parallel - const branchPromises = Object.entries(branches).map( - async ([branchName, config]) => { - const branchStatus = joinData.branches[branchName]; - - // Already completed - if (branchStatus.status === "completed") { - results[branchName] = branchStatus.output; - return; - } - - // Already failed - if (branchStatus.status === "failed") { - errors[branchName] = new Error(branchStatus.error); - return; - } - - // Execute branch - const branchLocation = appendName( - this.storage, - location, - branchName, - ); - const branchCtx = this.createBranch(branchLocation); - - branchStatus.status = "running"; - entry.dirty = true; - - try { - const output = await config.run(branchCtx); - branchCtx.validateComplete(); - - branchStatus.status = "completed"; - branchStatus.output = output; - results[branchName] = output; - } catch (error) { - branchStatus.status = "failed"; - branchStatus.error = String(error); - errors[branchName] = error as Error; - } - - entry.dirty = true; - }, - ); - - // Wait for ALL branches (no short-circuit on error) - await Promise.allSettled(branchPromises); - await this.flushStorage(); - - // Throw if any branches failed - if (Object.keys(errors).length > 0) { - throw new JoinError(errors); - } - - return results as { [K in keyof T]: BranchOutput }; - } - - // === Race === - - async race( - name: string, - branches: Array<{ - name: string; - run: (ctx: WorkflowContextInterface) => Promise; - }>, - ): Promise<{ winner: string; value: T }> { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - return await this.executeRace(name, branches); - } finally { - this.entryInProgress = false; - } - } - - private async executeRace( - name: string, - branches: Array<{ - name: string; - run: (ctx: WorkflowContextInterface) => Promise; - }>, - ): Promise<{ winner: string; value: T }> { - // Check for duplicate name in current execution - this.checkDuplicateName(name); - - const location = appendName(this.storage, this.currentLocation, name); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - this.stopRollbackIfMissing(existing); - - let entry: Entry; - - if (existing) { - if (existing.kind.type !== "race") { - throw new HistoryDivergedError( - `Expected race "${name}" at ${key}, found ${existing.kind.type}`, - ); - } - entry = existing; - - // Check if we already have a winner - const raceKind = existing.kind; - if (raceKind.data.winner !== null) { - const winnerStatus = - raceKind.data.branches[raceKind.data.winner]; - return { - winner: raceKind.data.winner, - value: winnerStatus.output as T, - }; - } - - this.stopRollbackIfIncomplete(true); - } else { - entry = createEntry(location, { - type: "race", - data: { - winner: null, - branches: Object.fromEntries( - branches.map((b) => [ - b.name, - { status: "pending" as const }, - ]), - ), - }, - }); - setEntry(this.storage, location, entry); - entry.dirty = true; - // Flush immediately to persist entry before branches execute - await this.flushStorage(); - } - - if (entry.kind.type !== "race") { - throw new HistoryDivergedError("Entry type mismatch"); - } - - const raceData = entry.kind.data; - - // Create abort controller for cancellation - const raceAbortController = new AbortController(); - - // Track all branch promises to wait for cleanup - const branchPromises: Promise[] = []; - - // Track winner info - let winnerName: string | null = null; - let winnerValue: T | null = null; - let settled = false; - let pendingCount = branches.length; - const errors: Record = {}; - const lateErrors: Array<{ name: string; error: string }> = []; - // Track scheduler yield errors - we need to propagate these after allSettled - let yieldError: SleepError | MessageWaitError | null = null; - - // Check for replay winners first - for (const branch of branches) { - const branchStatus = raceData.branches[branch.name]; - if ( - branchStatus.status !== "pending" && - branchStatus.status !== "running" - ) { - pendingCount--; - if (branchStatus.status === "completed" && !settled) { - settled = true; - winnerName = branch.name; - winnerValue = branchStatus.output as T; - } - } - } - - // If we found a replay winner, return immediately - if (settled && winnerName !== null && winnerValue !== null) { - return { winner: winnerName, value: winnerValue }; - } - - // Execute branches that need to run - for (const branch of branches) { - const branchStatus = raceData.branches[branch.name]; - - // Skip already completed/cancelled - if ( - branchStatus.status !== "pending" && - branchStatus.status !== "running" - ) { - continue; - } - - const branchLocation = appendName( - this.storage, - location, - branch.name, - ); - const branchCtx = this.createBranch( - branchLocation, - raceAbortController, - ); - - branchStatus.status = "running"; - entry.dirty = true; - - const branchPromise = branch.run(branchCtx).then( - async (output) => { - if (settled) { - // This branch completed after a winner was determined - // Still record the completion for observability - branchStatus.status = "completed"; - branchStatus.output = output; - entry.dirty = true; - return; - } - settled = true; - winnerName = branch.name; - winnerValue = output; - - branchCtx.validateComplete(); - - branchStatus.status = "completed"; - branchStatus.output = output; - raceData.winner = branch.name; - entry.dirty = true; - - // Cancel other branches - raceAbortController.abort(); - }, - (error) => { - pendingCount--; - - // Track sleep/message errors - they need to bubble up to the scheduler - // We'll re-throw after allSettled to allow cleanup - if (error instanceof SleepError) { - // Track the earliest deadline - if ( - !yieldError || - !(yieldError instanceof SleepError) || - error.deadline < yieldError.deadline - ) { - yieldError = error; - } - branchStatus.status = "running"; // Keep as running since we'll resume - entry.dirty = true; - return; - } - if (error instanceof MessageWaitError) { - // Track message wait errors, prefer sleep errors with deadlines - if ( - !yieldError || - !(yieldError instanceof SleepError) - ) { - if (!yieldError) { - yieldError = error; - } else if (yieldError instanceof MessageWaitError) { - // Merge message names - yieldError = new MessageWaitError([ - ...yieldError.messageNames, - ...error.messageNames, - ]); - } - } - branchStatus.status = "running"; // Keep as running since we'll resume - entry.dirty = true; - return; - } - - if ( - error instanceof CancelledError || - error instanceof EvictedError - ) { - branchStatus.status = "cancelled"; - } else { - branchStatus.status = "failed"; - branchStatus.error = String(error); - - if (settled) { - // Track late errors for observability - lateErrors.push({ - name: branch.name, - error: String(error), - }); - } else { - errors[branch.name] = error; - } - } - entry.dirty = true; - - // All branches failed (only if no winner yet) - if (pendingCount === 0 && !settled) { - settled = true; - } - }, - ); - - branchPromises.push(branchPromise); - } - - // Wait for all branches to complete or be cancelled - await Promise.allSettled(branchPromises); - - // If any branch needs to yield to the scheduler (sleep/message wait), - // save state and re-throw the error to exit the workflow execution - if (yieldError && !settled) { - await this.flushStorage(); - throw yieldError; - } - - // Clean up entries from non-winning branches - if (winnerName !== null) { - for (const branch of branches) { - if (branch.name !== winnerName) { - const branchLocation = appendName( - this.storage, - location, - branch.name, - ); - await deleteEntriesWithPrefix( - this.storage, - this.driver, - branchLocation, - this.historyNotifier, - ); - } - } - } - - // Flush final state - await this.flushStorage(); - - // Log late errors if any (these occurred after a winner was determined) - if (lateErrors.length > 0) { - console.warn( - `Race "${name}" had ${lateErrors.length} branch(es) fail after winner was determined:`, - lateErrors, - ); - } - - // Return result or throw error - if (winnerName !== null && winnerValue !== null) { - return { winner: winnerName, value: winnerValue }; - } - - // All branches failed - throw new RaceError( - "All branches failed", - Object.entries(errors).map(([name, error]) => ({ - name, - error: String(error), - })), - ); - } - - // === Removed === - - async removed(name: string, originalType: EntryKindType): Promise { - this.assertNotInProgress(); - this.checkEvicted(); - - this.entryInProgress = true; - try { - await this.executeRemoved(name, originalType); - } finally { - this.entryInProgress = false; - } - } - - private async executeRemoved( - name: string, - originalType: EntryKindType, - ): Promise { - // Check for duplicate name in current execution - this.checkDuplicateName(name); - - const location = appendName(this.storage, this.currentLocation, name); - const key = locationToKey(this.storage, location); - const existing = this.storage.history.entries.get(key); - - // Mark this entry as visited for validateComplete - this.markVisited(key); - - this.stopRollbackIfMissing(existing); - - if (existing) { - // Validate the existing entry matches what we expect - if ( - existing.kind.type !== "removed" && - existing.kind.type !== originalType - ) { - throw new HistoryDivergedError( - `Expected ${originalType} or removed at ${key}, found ${existing.kind.type}`, - ); - } - - // If it's not already marked as removed, we just skip it - return; - } - - // Create a removed entry placeholder - const entry = createEntry(location, { - type: "removed", - data: { originalType, originalName: name }, - }); - setEntry(this.storage, location, entry); - await this.flushStorage(); - } -} +export type WorkflowContextOf = + AD extends ActorDefinition< + infer S, + infer CP, + infer CS, + infer V, + infer I, + infer DB extends AnyDatabaseProvider, + infer E extends EventSchemaConfig, + infer Q extends QueueSchemaConfig, + any + > + ? ActorWorkflowContext + : never; + +export type WorkflowLoopContextOf = + WorkflowContextOf; + +export type WorkflowBranchContextOf = + WorkflowContextOf; + +export type WorkflowStepContextOf = + WorkflowContextOf; \ No newline at end of file diff --git a/rivetkit-typescript/packages/workflow-engine/src/storage.ts b/rivetkit-typescript/packages/workflow-engine/src/storage.ts index cdde5d371f..c10865a3f6 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/storage.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/storage.ts @@ -402,4 +402,4 @@ export function setEntry( ): void { const key = locationToKey(storage, location); storage.history.entries.set(key, entry); -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/workflow-engine/src/types.ts b/rivetkit-typescript/packages/workflow-engine/src/types.ts index 80cc7ef5d4..d3bdefa140 100644 --- a/rivetkit-typescript/packages/workflow-engine/src/types.ts +++ b/rivetkit-typescript/packages/workflow-engine/src/types.ts @@ -525,4 +525,4 @@ export interface WorkflowHandle { * Get the current workflow state. */ getState(): Promise; -} +} \ No newline at end of file diff --git a/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts b/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts index 123909c637..858fc00566 100644 --- a/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts +++ b/rivetkit-typescript/packages/workflow-engine/tests/messages.test.ts @@ -670,4 +670,4 @@ for (const mode of modes) { }); }); }); -} +} \ No newline at end of file diff --git a/website/src/components/marketing/sections/ProblemSection.tsx b/website/src/components/marketing/sections/ProblemSection.tsx index d692b18180..f5fc7e7231 100644 --- a/website/src/components/marketing/sections/ProblemSection.tsx +++ b/website/src/components/marketing/sections/ProblemSection.tsx @@ -443,4 +443,4 @@ export const ProblemSection = () => { ); -}; +}; \ No newline at end of file