diff --git a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/dynamic-registry.ts b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/dynamic-registry.ts index f5772d9368..180f1e5ad1 100644 --- a/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/dynamic-registry.ts +++ b/rivetkit-typescript/packages/rivetkit/fixtures/driver-test-suite/dynamic-registry.ts @@ -130,10 +130,10 @@ const dynamicFromUrl = dynamicActor(async () => { }); const dynamicFromActor = dynamicActor(async (c) => { - const source = await c + const source = (await c .client() .sourceCode.getOrCreate(["dynamic-source"]) - .getCode(); + .getCode()) as string; return { source, nodeProcess: { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts index ae64722afe..1bfe4fcb34 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/config.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/config.ts @@ -175,6 +175,71 @@ export function getRunInspectorConfig( // We don't use Zod generics with `z.custom` because: // (a) there seems to be a weird bug in either Zod, tsup, or TSC that causese external packages to have different types from `z.infer` than from within the same package and // (b) it makes the type definitions incredibly difficult to read as opposed to vanilla TypeScript. +const GlobalActorOptionsBaseSchema = z + .object({ + /** Display name for the actor in the Inspector UI. */ + name: z.string().optional(), + /** Icon for the actor in the Inspector UI. Can be an emoji or FontAwesome icon name. */ + icon: z.string().optional(), + /** + * Can hibernate WebSockets for onWebSocket. + * + * WebSockets using actions/events are hibernatable by default. + * + * @experimental + **/ + canHibernateWebSocket: z + .union([z.boolean(), zFunction<(request: Request) => boolean>()]) + .default(false), + }) + .strict(); + +export const GlobalActorOptionsSchema = GlobalActorOptionsBaseSchema.prefault( + () => ({}), +); + +export type GlobalActorOptions = z.infer; +export type GlobalActorOptionsInput = z.input; + +const InstanceActorOptionsBaseSchema = z + .object({ + createVarsTimeout: z.number().positive().default(5000), + createConnStateTimeout: z.number().positive().default(5000), + onConnectTimeout: z.number().positive().default(5000), + onSleepTimeout: z.number().positive().default(5000), + onDestroyTimeout: z.number().positive().default(5000), + stateSaveInterval: z.number().positive().default(10_000), + actionTimeout: z.number().positive().default(60_000), + // Max time to wait for waitUntil background promises during shutdown + waitUntilTimeout: z.number().positive().default(15_000), + // Max time to wait for run handler to stop during shutdown + runStopTimeout: z.number().positive().default(15_000), + connectionLivenessTimeout: z.number().positive().default(2500), + connectionLivenessInterval: z.number().positive().default(5000), + noSleep: z.boolean().default(false), + sleepTimeout: z.number().positive().default(30_000), + maxQueueSize: z.number().positive().default(1000), + maxQueueMessageSize: z.number().positive().default(64 * 1024), + }) + .strict(); + +export const InstanceActorOptionsSchema = + InstanceActorOptionsBaseSchema.prefault(() => ({})); + +export type InstanceActorOptions = z.infer; +export type InstanceActorOptionsInput = z.input< + typeof InstanceActorOptionsSchema +>; + +export const ActorOptionsSchema = GlobalActorOptionsBaseSchema.extend( + InstanceActorOptionsBaseSchema.shape, +) + .strict() + .prefault(() => ({})); + +export type ActorOptions = z.infer; +export type ActorOptionsInput = z.input; + export const ActorConfigSchema = z .object({ onCreate: zFunction().optional(), @@ -199,48 +264,7 @@ export const ActorConfigSchema = z vars: z.any().optional(), db: z.any().optional(), createVars: zFunction().optional(), - options: z - .object({ - /** Display name for the actor in the Inspector UI. */ - name: z.string().optional(), - /** Icon for the actor in the Inspector UI. Can be an emoji or FontAwesome icon name. */ - icon: z.string().optional(), - createVarsTimeout: z.number().positive().default(5000), - createConnStateTimeout: z.number().positive().default(5000), - onConnectTimeout: z.number().positive().default(5000), - onSleepTimeout: z.number().positive().default(5000), - onDestroyTimeout: z.number().positive().default(5000), - stateSaveInterval: z.number().positive().default(10_000), - actionTimeout: z.number().positive().default(60_000), - // Max time to wait for waitUntil background promises during shutdown - waitUntilTimeout: z.number().positive().default(15_000), - // Max time to wait for run handler to stop during shutdown - runStopTimeout: z.number().positive().default(15_000), - connectionLivenessTimeout: z.number().positive().default(2500), - connectionLivenessInterval: z.number().positive().default(5000), - noSleep: z.boolean().default(false), - sleepTimeout: z.number().positive().default(30_000), - maxQueueSize: z.number().positive().default(1000), - maxQueueMessageSize: z - .number() - .positive() - .default(64 * 1024), - /** - * Can hibernate WebSockets for onWebSocket. - * - * WebSockets using actions/events are hibernatable by default. - * - * @experimental - **/ - canHibernateWebSocket: z - .union([ - z.boolean(), - zFunction<(request: Request) => boolean>(), - ]) - .default(false), - }) - .strict() - .prefault(() => ({})), + options: ActorOptionsSchema, }) .strict() .refine( diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts index 3543b20b40..2bb8dee19f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/driver.ts @@ -1,5 +1,5 @@ import type { AnyConn } from "@/actor/conn/mod"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type { CachedSerializer } from "@/actor/protocol/serde"; export enum DriverReadyState { @@ -24,7 +24,7 @@ export interface ConnDriver { rivetKitProtocol?: { /** Sends a RivetKit client message. */ sendMessage( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, conn: AnyConn, message: CachedSerializer, ): void; @@ -42,20 +42,20 @@ export interface ConnDriver { * This returns a promise since we commonly disconnect at the end of a program, and not waiting will cause the socket to not close cleanly. */ disconnect( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, conn: AnyConn, reason?: string, ): Promise; /** Terminates the connection without graceful handling. */ - terminate?(actor: AnyActorInstance, conn: AnyConn): void; + terminate?(actor: AnyStaticActorInstance, conn: AnyConn): void; /** * Returns the ready state of the connection. * This is used to determine if the connection is ready to send messages, or if the connection is stale. */ getConnectionReadyState( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, conn: AnyConn, ): DriverReadyState | undefined; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts index 34608b6cb2..80aeadacb2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/raw-websocket.ts @@ -1,5 +1,5 @@ import type { AnyConn } from "@/actor/conn/mod"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { loggerWithoutContext } from "../../log"; import { type ConnDriver, DriverReadyState } from "../driver"; @@ -26,7 +26,7 @@ export function createRawWebSocketDriver( // handle messages from the RivetKit protocol disconnect: async ( - _actor: AnyActorInstance, + _actor: AnyStaticActorInstance, _conn: AnyConn, reason?: string, ) => { @@ -49,7 +49,7 @@ export function createRawWebSocketDriver( }, getConnectionReadyState: ( - _actor: AnyActorInstance, + _actor: AnyStaticActorInstance, _conn: AnyConn, ): DriverReadyState | undefined => { return websocket?.readyState ?? DriverReadyState.CONNECTING; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts index fe5652de58..017c325a1c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/drivers/websocket.ts @@ -1,6 +1,6 @@ import type { WSContext } from "hono/ws"; import type { AnyConn } from "@/actor/conn/mod"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type { CachedSerializer, Encoding } from "@/actor/protocol/serde"; import * as errors from "@/actor/errors"; import { loggerWithoutContext } from "../../log"; @@ -27,7 +27,7 @@ export function createWebSocketDriver( hibernatable, rivetKitProtocol: { sendMessage: ( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, conn: AnyConn, message: CachedSerializer, ) => { @@ -105,7 +105,7 @@ export function createWebSocketDriver( }, disconnect: async ( - _actor: AnyActorInstance, + _actor: AnyStaticActorInstance, _conn: AnyConn, reason?: string, ) => { @@ -128,7 +128,7 @@ export function createWebSocketDriver( }, getConnectionReadyState: ( - _actor: AnyActorInstance, + _actor: AnyStaticActorInstance, _conn: AnyConn, ): DriverReadyState | undefined => { return websocket?.readyState ?? DriverReadyState.CONNECTING; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts index a9c3c87915..abb6812753 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/conn/mod.ts @@ -11,7 +11,7 @@ import { import { bufferToArrayBuffer } from "@/utils"; import type { AnyDatabaseProvider } from "../database"; import { EventPayloadInvalid, InternalError } from "../errors"; -import type { ActorInstance } from "../instance/mod"; +import type { StaticActorInstance } from "../instance/mod"; import { CachedSerializer } from "../protocol/serde"; import { type EventSchemaConfig, @@ -52,9 +52,9 @@ export class Conn< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; - get [CONN_ACTOR_SYMBOL](): ActorInstance { + get [CONN_ACTOR_SYMBOL](): StaticActorInstance { return this.#actor; } @@ -139,7 +139,7 @@ export class Conn< * @protected */ constructor( - actor: ActorInstance, + actor: StaticActorInstance, data: ConnDataInput, ) { this.#actor = actor; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts index 4da92c5a33..725ec6d86c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/actor.ts @@ -8,8 +8,8 @@ import type { ActorDefinition, AnyActorDefinition } from "../../definition"; import * as errors from "../../errors"; import { ActorKv } from "../../instance/kv"; import type { - ActorInstance, - AnyActorInstance, + StaticActorInstance, + AnyStaticActorInstance, SaveStateOptions, } from "../../instance/mod"; import { ActorQueue } from "../../instance/queue"; @@ -40,8 +40,8 @@ export class ActorContext< TEvents extends EventSchemaConfig = Record, TQueues extends QueueSchemaConfig = Record, > { - [ACTOR_CONTEXT_INTERNAL_SYMBOL]!: AnyActorInstance; - #actor: ActorInstance< + [ACTOR_CONTEXT_INTERNAL_SYMBOL]!: AnyStaticActorInstance; + #actor: StaticActorInstance< TState, TConnParams, TConnState, @@ -66,7 +66,7 @@ export class ActorContext< | undefined; constructor( - actor: ActorInstance< + actor: StaticActorInstance< TState, TConnParams, TConnState, @@ -78,7 +78,7 @@ export class ActorContext< >, ) { this.#actor = actor; - this[ACTOR_CONTEXT_INTERNAL_SYMBOL] = actor as AnyActorInstance; + this[ACTOR_CONTEXT_INTERNAL_SYMBOL] = actor as AnyStaticActorInstance; } /** diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts index 227ff471c9..d719e4a51f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn-init.ts @@ -1,6 +1,6 @@ import type { AnyDatabaseProvider } from "../../database"; import type { ActorDefinition, AnyActorDefinition } from "../../definition"; -import type { ActorInstance } from "../../instance/mod"; +import type { StaticActorInstance } from "../../instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "../../schema"; import { ActorContext } from "./actor"; @@ -35,7 +35,7 @@ export abstract class ConnInitContext< * @internal */ constructor( - actor: ActorInstance< + actor: StaticActorInstance< TState, any, any, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts index 6c6c5f3506..00c5bbf2a3 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/base/conn.ts @@ -1,7 +1,7 @@ import type { Conn } from "../../conn/mod"; import type { AnyDatabaseProvider } from "../../database"; import type { ActorDefinition, AnyActorDefinition } from "../../definition"; -import type { ActorInstance } from "../../instance/mod"; +import type { StaticActorInstance } from "../../instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "../../schema"; import { ActorContext } from "./actor"; @@ -32,7 +32,7 @@ export abstract class ConnContext< * @internal */ constructor( - actor: ActorInstance< + actor: StaticActorInstance< TState, TConnParams, TConnState, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts index 0b53436f64..756d8d41b8 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/request.ts @@ -1,7 +1,7 @@ import type { Conn } from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; -import type { ActorInstance } from "../instance/mod"; +import type { StaticActorInstance } from "../instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; @@ -37,7 +37,7 @@ export class RequestContext< * @internal */ constructor( - actor: ActorInstance< + actor: StaticActorInstance< TState, TConnParams, TConnState, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts index 17e882d841..ddb023108c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/contexts/websocket.ts @@ -1,7 +1,7 @@ import type { Conn } from "../conn/mod"; import type { AnyDatabaseProvider } from "../database"; import type { ActorDefinition, AnyActorDefinition } from "../definition"; -import type { ActorInstance } from "../instance/mod"; +import type { StaticActorInstance } from "../instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; import { ConnContext } from "./base/conn"; @@ -37,7 +37,7 @@ export class WebSocketContext< * @internal */ constructor( - actor: ActorInstance< + actor: StaticActorInstance< TState, TConnParams, TConnState, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts index 9394bef2dd..409a61aad6 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/definition.ts @@ -1,11 +1,38 @@ import type { RegistryConfig } from "@/registry/config"; -import { DeepMutable } from "@/utils"; import type { Actions, ActorConfig } from "./config"; -import type { ActionContextOf, ActorContext } from "./contexts"; import type { AnyDatabaseProvider } from "./database"; -import { ActorInstance } from "./instance/mod"; +import { + StaticActorInstance, + type ActorInstance, +} from "./instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "./schema"; +export interface ActorDefinition< + S, + CP, + CS, + V, + I, + DB extends AnyDatabaseProvider, + E extends EventSchemaConfig = Record, + Q extends QueueSchemaConfig = Record, + R extends Actions = Actions< + S, + CP, + CS, + V, + I, + DB, + E, + Q + >, +> { + readonly config: ActorConfig; + instantiate(): + | ActorInstance + | Promise>; +} + export type AnyActorDefinition = ActorDefinition< any, any, @@ -18,7 +45,7 @@ export type AnyActorDefinition = ActorDefinition< any >; -export class ActorDefinition< +export class StaticActorDefinition< S, CP, CS, @@ -37,7 +64,7 @@ export class ActorDefinition< E, Q >, -> { +> implements ActorDefinition { #config: ActorConfig; constructor(config: ActorConfig) { @@ -48,8 +75,8 @@ export class ActorDefinition< return this.#config; } - instantiate(): ActorInstance { - return new ActorInstance(this.#config); + instantiate(): StaticActorInstance { + return new StaticActorInstance(this.#config); } } @@ -57,7 +84,6 @@ export function lookupInRegistry( config: RegistryConfig, name: string, ): AnyActorDefinition { - // Build actor const definition = config.use[name]; if (!definition) throw new Error(`no actor in registry for name ${name}`); return definition; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts index fe93bb93cb..05bee9c738 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/driver.ts @@ -2,7 +2,7 @@ import type { Context as HonoContext } from "hono"; import type { AnyClient } from "@/client/client"; import type { ManagerDriver } from "@/manager/driver"; import { type AnyConn } from "./conn/mod"; -import type { AnyActorInstance } from "./instance/mod"; +import type { AnyActorInstance, AnyStaticActorInstance } from "./instance/mod"; import type { RegistryConfig } from "@/registry/config"; import type { RawDatabaseClient, @@ -106,7 +106,7 @@ export interface ActorDriver { /** Extra properties to add to logs for each actor. */ getExtraActorLogParams?(): Record; - onBeforeActorStart?(actor: AnyActorInstance): Promise; + onBeforeActorStart?(actor: AnyStaticActorInstance): Promise; onCreateConn?(conn: AnyConn): void; onDestroyConn?(conn: AnyConn): void; onBeforePersistConn?(conn: AnyConn): void; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts index 9fb15ec7bc..7074f159f1 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/connection-manager.ts @@ -33,7 +33,7 @@ import { CachedSerializer } from "../protocol/serde"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; import { deadline } from "../utils"; import { makeConnKey } from "./keys"; -import type { ActorInstance } from "./mod"; +import type { StaticActorInstance } from "./mod"; /** * Manages all connection-related operations for an actor instance. * Handles connection creation, tracking, hibernation, and cleanup. @@ -48,14 +48,14 @@ export class ConnectionManager< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; #connections = new Map>(); #pendingDisconnectCount = 0; /** Connections that have had their state changed and need to be persisted. */ #connsWithPersistChanged = new Set(); - constructor(actor: ActorInstance) { + constructor(actor: StaticActorInstance) { this.#actor = actor; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts index 7e6cf21bca..3df3854a2c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/event-manager.ts @@ -19,7 +19,7 @@ import type { AnyDatabaseProvider } from "../database"; import * as errors from "../errors"; import { CachedSerializer } from "../protocol/serde"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; -import type { ActorInstance } from "./mod"; +import type { StaticActorInstance } from "./mod"; /** * Manages event subscriptions and broadcasting for actor instances. @@ -35,13 +35,13 @@ export class EventManager< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; #subscriptionIndex = new Map< string, Set> >(); - constructor(actor: ActorInstance) { + constructor(actor: StaticActorInstance) { this.#actor = actor; } diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts index 41d0d048f0..d807e52952 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/mod.ts @@ -87,7 +87,29 @@ enum CanSleep { ActiveRun, } -/** Actor type alias with all `any` types. Used for `extends` in classes referencing this actor. */ +/** + * Minimal lifecycle contract shared by static and dynamic actor instances. + * + * Runtime internals (connections, inspector, queue manager, etc) are exposed + * only on `StaticActorInstance`. + */ +export interface ActorInstance< + S = any, + CP = any, + CS = any, + V = any, + I = any, + DB extends AnyDatabaseProvider = AnyDatabaseProvider, + E extends EventSchemaConfig = Record, + Q extends QueueSchemaConfig = Record, +> { + readonly id: string; + readonly isStopping: boolean; + onStop(mode: "sleep" | "destroy"): Promise; + onAlarm(): Promise; +} + +/** Actor type alias with all `any` types. */ export type AnyActorInstance = ActorInstance< any, any, @@ -99,23 +121,68 @@ export type AnyActorInstance = ActorInstance< any >; +/** Static actor type alias with all `any` types. */ +export type AnyStaticActorInstance = StaticActorInstance< + any, + any, + any, + any, + any, + any, + any, + any +>; + +export function isStaticActorInstance( + actor: AnyActorInstance, +): actor is AnyStaticActorInstance { + return actor instanceof StaticActorInstance; +} + export type ExtractActorState = - A extends ActorInstance + A extends StaticActorInstance< + infer State, + any, + any, + any, + any, + any, + any, + any + > ? State : never; export type ExtractActorConnParams = - A extends ActorInstance + A extends StaticActorInstance< + any, + infer ConnParams, + any, + any, + any, + any, + any, + any + > ? ConnParams : never; export type ExtractActorConnState = - A extends ActorInstance + A extends StaticActorInstance< + any, + any, + infer ConnState, + any, + any, + any, + any, + any + > ? ConnState : never; -// MARK: - Main ActorInstance Class -export class ActorInstance< +// MARK: - Main Static ActorInstance Class +export class StaticActorInstance< S, CP, CS, @@ -124,7 +191,7 @@ export class ActorInstance< DB extends AnyDatabaseProvider, E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, -> { +> implements ActorInstance { // MARK: - Core Properties actorContext: ActorContext; #config: ActorConfig; @@ -1450,7 +1517,7 @@ export class ActorInstance< this.#sqliteVfs = await this.driver.createSqliteVfs(); } - client = await this.#config.db.createClient({ + client = await this.#config.db.createClient({ actorId: this.#actorId, overrideRawDatabaseClient: this.driver.overrideRawDatabaseClient ? () => this.driver.overrideRawDatabaseClient!(this.#actorId) @@ -1458,11 +1525,14 @@ export class ActorInstance< overrideDrizzleDatabaseClient: this.driver.overrideDrizzleDatabaseClient ? () => this.driver.overrideDrizzleDatabaseClient!(this.#actorId) : undefined, - kv: { - batchPut: (entries) => this.driver.kvBatchPut(this.#actorId, entries), - batchGet: (keys) => this.driver.kvBatchGet(this.#actorId, keys), - batchDelete: (keys) => this.driver.kvBatchDelete(this.#actorId, keys), - }, + kv: { + batchPut: (entries: [Uint8Array, Uint8Array][]) => + this.driver.kvBatchPut(this.#actorId, entries), + batchGet: (keys: Uint8Array[]) => + this.driver.kvBatchGet(this.#actorId, keys), + batchDelete: (keys: Uint8Array[]) => + this.driver.kvBatchDelete(this.#actorId, keys), + }, sqliteVfs: this.#sqliteVfs, }); this.#rLog.info({ msg: "database migration starting" }); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts index e5a92e4696..92801cd5ac 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/queue-manager.ts @@ -16,7 +16,7 @@ import { queueMessagesPrefix, queueMetadataKey, } from "./keys"; -import type { ActorInstance } from "./mod"; +import type { StaticActorInstance } from "./mod"; export interface QueueMessage { id: bigint; @@ -74,7 +74,7 @@ export class QueueManager< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; #driver: ActorDriver; #waiters = new Map(); #metadata: QueueMetadata = { ...DEFAULT_METADATA }; @@ -82,7 +82,7 @@ export class QueueManager< #pendingCompletions = new Map(); constructor( - actor: ActorInstance, + actor: StaticActorInstance, driver: ActorDriver, ) { this.#actor = actor; diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts index 666b044e5d..c615bf2a1f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/schedule-manager.ts @@ -7,7 +7,7 @@ import { import type { AnyDatabaseProvider } from "../database"; import type { ActorDriver } from "../driver"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; -import type { ActorInstance } from "./mod"; +import type { StaticActorInstance } from "./mod"; import type { PersistedScheduleEvent } from "./persisted"; /** @@ -24,14 +24,14 @@ export class ScheduleManager< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; #actorDriver: ActorDriver; #alarmWriteQueue = new SinglePromiseQueue(); #config: any; // ActorConfig type #persist: any; // Reference to PersistedActor constructor( - actor: ActorInstance, + actor: StaticActorInstance, actorDriver: ActorDriver, config: any, ) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts index 70e00eeb23..34c4bc9b03 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/instance/state-manager.ts @@ -15,7 +15,7 @@ import * as errors from "../errors"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; import { isConnStatePath, isStatePath } from "../utils"; import { KEYS, makeConnKey } from "./keys"; -import type { ActorInstance } from "./mod"; +import type { StaticActorInstance } from "./mod"; import { convertActorToBarePersisted, type PersistedActor } from "./persisted"; export interface SaveStateOptions { @@ -45,7 +45,7 @@ export class StateManager< E extends EventSchemaConfig = Record, Q extends QueueSchemaConfig = Record, > { - #actor: ActorInstance; + #actor: StaticActorInstance; #actorDriver: ActorDriver; // State tracking @@ -66,7 +66,7 @@ export class StateManager< #stateSaveInterval: number; constructor( - actor: ActorInstance, + actor: StaticActorInstance, actorDriver: ActorDriver, config: any, ) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts index 88ba4c8a27..6f50cf7b36 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/mod.ts @@ -6,7 +6,10 @@ import { ActorTypes, } from "./config"; import type { AnyDatabaseProvider } from "./database"; -import { ActorDefinition } from "./definition"; +import { + type ActorDefinition, + StaticActorDefinition, +} from "./definition"; import { event as schemaEvent, queue as schemaQueue } from "./schema"; import type { EventSchemaConfig, QueueSchemaConfig } from "./schema"; @@ -71,7 +74,7 @@ export function actor< TEvents, TQueues >; - return new ActorDefinition(config); + return new StaticActorDefinition(config); } export type { Encoding } from "@/actor/protocol/serde"; export { @@ -93,13 +96,16 @@ export type { } from "@/common/websocket-interface"; export type { ActorKey } from "@/manager/protocol/query"; export type * from "./config"; +export { CONN_STATE_MANAGER_SYMBOL } from "./conn/mod"; export type { AnyConn, Conn } from "./conn/mod"; export type { ActorDefinition, AnyActorDefinition } from "./definition"; +export { StaticActorDefinition } from "./definition"; export { lookupInRegistry } from "./definition"; export { UserError, type UserErrorOptions } from "./errors"; export { KEYS as KV_KEYS } from "./instance/keys"; export { ActorKv } from "./instance/kv"; export type { AnyActorInstance } from "./instance/mod"; +export { StaticActorInstance } from "./instance/mod"; export { type ActorRouter, createActorRouter, diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts index 3afe595932..853b928e3f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/protocol/old.ts @@ -28,7 +28,7 @@ import { } from "../../utils"; import { CONN_SEND_MESSAGE_SYMBOL, type Conn } from "../conn/mod"; import { ActionContext } from "../contexts"; -import type { ActorInstance } from "../instance/mod"; +import type { StaticActorInstance } from "../instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "../schema"; interface MessageEventOpts { @@ -179,7 +179,7 @@ export async function processMessage< val: { eventName: string; subscribe: boolean }; }; }, - actor: ActorInstance, + actor: StaticActorInstance, conn: Conn, handler: ProcessMessageHandler, ) { diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts index f18632c045..c09668200f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-endpoints.ts @@ -3,7 +3,10 @@ import type { Context as HonoContext, HonoRequest } from "hono"; import type { AnyConn } from "@/actor/conn/mod"; import { ActionContext } from "@/actor/contexts"; import * as errors from "@/actor/errors"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import { + type AnyStaticActorInstance, + isStaticActorInstance, +} from "@/actor/instance/mod"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { hasSchemaConfigKey } from "@/actor/schema"; import { @@ -44,6 +47,19 @@ import { createRawRequestDriver } from "./conn/drivers/raw-request"; import type { ActorDriver } from "./driver"; import { loggerWithoutContext } from "./log"; +async function loadStaticActor( + actorDriver: ActorDriver, + actorId: string, +): Promise { + const actor = await actorDriver.loadActor(actorId); + if (!isStaticActorInstance(actor)) { + throw new errors.InternalError( + "dynamic actor cannot be handled by static actor router endpoints", + ); + } + return actor; +} + export interface ActionOpts { req?: HonoRequest; params: unknown; @@ -116,10 +132,10 @@ export async function handleAction( let outputReady = false; const maxAttempts = 3; for (let attempt = 0; attempt < maxAttempts; attempt++) { - let actor: AnyActorInstance | undefined; + let actor: AnyStaticActorInstance | undefined; let conn: AnyConn | undefined; try { - actor = await actorDriver.loadActor(actorId); + actor = await loadStaticActor(actorDriver, actorId); actor.rLog.debug({ msg: "handling action", actionName, encoding }); @@ -224,7 +240,7 @@ export async function handleQueueSend( throw new errors.InvalidRequest("missing queue name"); } - const actor = await actorDriver.loadActor(actorId); + const actor = await loadStaticActor(actorDriver, actorId); if (!hasSchemaConfigKey(actor.config.queues, name)) { actor.rLog.warn({ msg: "ignoring incoming queue message for undefined queue", @@ -311,7 +327,7 @@ export async function handleRawRequest( actorDriver: ActorDriver, actorId: string, ): Promise { - const actor = await actorDriver.loadActor(actorId); + const actor = await loadStaticActor(actorDriver, actorId); const parameters = getRequestConnParams(c.req); // Track connection outside of scope for cleanup diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts index dbc9dec7dc..31edc3eb4f 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router-websocket-endpoints.ts @@ -1,7 +1,10 @@ import type { WSContext } from "hono/ws"; import invariant from "invariant"; import type { AnyConn } from "@/actor/conn/mod"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import { + type AnyStaticActorInstance, + isStaticActorInstance, +} from "@/actor/instance/mod"; import type { InputData } from "@/actor/protocol/serde"; import { type Encoding, EncodingSchema } from "@/actor/protocol/serde"; import { @@ -33,7 +36,7 @@ import { getRequestExposeInternalError } from "./router-endpoints"; // TODO: Merge with ConnectWebSocketOutput interface export interface UpgradeWebSocketArgs { conn?: AnyConn; - actor?: AnyActorInstance; + actor?: AnyStaticActorInstance; onRestore?: (ws: WSContext) => void; onOpen: (event: any, ws: WSContext) => void; onMessage: (event: any, ws: WSContext) => void; @@ -45,7 +48,7 @@ interface WebSocketHandlerOpts { config: RegistryConfig; request: Request | undefined; encoding: Encoding; - actor: AnyActorInstance; + actor: AnyStaticActorInstance; closePromiseResolvers: ReturnType>; conn: AnyConn; exposeInternalError: boolean; @@ -56,6 +59,19 @@ type WebSocketHandler = ( opts: WebSocketHandlerOpts, ) => Promise; +async function loadStaticActor( + actorDriver: ActorDriver, + actorId: string, +): Promise { + const actor = await actorDriver.loadActor(actorId); + if (!isStaticActorInstance(actor)) { + throw new Error( + "dynamic actor cannot be handled by static websocket router", + ); + } + return actor; +} + export async function routeWebSocket( request: Request | undefined, requestPath: string, @@ -76,7 +92,7 @@ export async function routeWebSocket( let createdConn: AnyConn | undefined; try { - const actor = await actorDriver.loadActor(actorId); + const actor = await loadStaticActor(actorDriver, actorId); actor.rLog.debug({ msg: "new websocket connection", diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts index a8c409dc9b..da24687275 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/router.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/router.ts @@ -26,6 +26,7 @@ import { timingSafeEqual } from "@/utils/crypto"; import { isDev } from "@/utils/env-vars"; import { CONN_DRIVER_SYMBOL } from "./conn/mod"; import type { ActorDriver } from "./driver"; +import { isStaticActorInstance } from "./instance/mod"; import { loggerWithoutContext } from "./log"; import { parseWebSocketProtocols, @@ -45,6 +46,14 @@ export interface MetadataResponse { version: string; } +async function loadStaticActor(actorDriver: ActorDriver, actorId: string) { + const actor = await actorDriver.loadActor(actorId); + if (!isStaticActorInstance(actor)) { + throw new Error("dynamic actor cannot be handled by static actor router"); + } + return actor; +} + /** * Creates a router that runs on the partitioned instance. * @@ -67,7 +76,7 @@ export function createActorRouter( // Track all HTTP requests to prevent actor from sleeping during active requests router.use("*", async (c, next) => { - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); actor.beginHonoHttpRequest(); try { await next(); @@ -102,7 +111,7 @@ export function createActorRouter( return c.text("Missing conn query parameter", 400); } - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const conn = actor.connectionManager.getConnForId(connId); if (!conn) { @@ -196,7 +205,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const isStateEnabled = actor.inspector.isStateEnabled(); const state = isStateEnabled ? actor.inspector.getStateJson() @@ -208,7 +217,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const body = await c.req.json<{ state: unknown }>(); await actor.inspector.setStateJson(body.state); return c.json({ ok: true }); @@ -218,7 +227,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const connections = actor.inspector.getConnectionsJson(); return c.json({ connections }); }); @@ -227,7 +236,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const rpcs = actor.inspector.getRpcs(); return c.json({ rpcs }); }); @@ -236,7 +245,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const name = c.req.param("name"); const body = await c.req.json<{ args: unknown[] }>(); const output = await actor.inspector.executeActionJson( @@ -250,7 +259,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const limit = parseInt(c.req.query("limit") ?? "50", 10); const status = await actor.inspector.getQueueStatusJson(limit); return c.json(status); @@ -260,7 +269,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const startMs = parseInt(c.req.query("startMs") ?? "0", 10); const endMs = parseInt( c.req.query("endMs") ?? String(Date.now()), @@ -281,7 +290,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const result = actor.inspector.getWorkflowHistoryJson(); return c.json(result); }); @@ -290,7 +299,7 @@ export function createActorRouter( const authResponse = await inspectorAuth(c); if (authResponse) return authResponse; - const actor = await actorDriver.loadActor(c.env.actorId); + const actor = await loadStaticActor(actorDriver, c.env.actorId); const isStateEnabled = actor.inspector.isStateEnabled(); const isDatabaseEnabled = actor.inspector.isDatabaseEnabled(); diff --git a/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts b/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts index b512208b88..fc65617872 100644 --- a/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts +++ b/rivetkit-typescript/packages/rivetkit/src/actor/schedule.ts @@ -1,9 +1,9 @@ -import type { AnyActorInstance } from "./instance/mod"; +import type { AnyStaticActorInstance } from "./instance/mod"; export class Schedule { - #actor: AnyActorInstance; + #actor: AnyStaticActorInstance; - constructor(actor: AnyActorInstance) { + constructor(actor: AnyStaticActorInstance) { this.#actor = actor; } diff --git a/rivetkit-typescript/packages/rivetkit/src/common/inline-websocket-adapter.ts b/rivetkit-typescript/packages/rivetkit/src/common/inline-websocket-adapter.ts index 4dfcb545a3..0013693003 100644 --- a/rivetkit-typescript/packages/rivetkit/src/common/inline-websocket-adapter.ts +++ b/rivetkit-typescript/packages/rivetkit/src/common/inline-websocket-adapter.ts @@ -31,17 +31,7 @@ export class InlineWebSocketAdapter { // Actor's send() -> Client's message event this.#clientWs = new VirtualWebSocket({ getReadyState: () => this.#readyState, - onSend: (data) => { - try { - // Call handler.onMessage for protocol-based connections (RPC) - this.#handler.onMessage({ data }, this.#wsContext); - // Also trigger message event on actor's websocket for raw websocket handlers - this.#actorWs.triggerMessage(data); - } catch (err) { - this.#handleError(err); - this.#close(1011, "Internal error processing message"); - } - }, + onSend: (data) => this.dispatchClientMessageWithMetadata(data), onClose: (code, reason) => this.#close(code, reason), }); @@ -81,6 +71,25 @@ export class InlineWebSocketAdapter { return this.#actorWs; } + /** + * Dispatch a client->actor message with optional transport metadata. + * + * This is used by dynamic actor host bridges to preserve + * `rivetMessageIndex` on hibernatable engine websocket paths. + */ + dispatchClientMessageWithMetadata( + data: string | ArrayBufferLike | Blob | ArrayBufferView, + rivetMessageIndex?: number, + ): void { + try { + this.#handler.onMessage({ data, rivetMessageIndex }, this.#wsContext); + (this.#actorWs as any).triggerMessage(data, rivetMessageIndex); + } catch (err) { + this.#handleError(err); + this.#close(1011, "Internal error processing message"); + } + } + async #initialize(): Promise { try { logger().debug({ msg: "websocket initializing" }); diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts index 741069965c..ff95f95b94 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-helpers/mod.ts @@ -1,6 +1,11 @@ export type { ActorDriver } from "@/actor/driver"; export { KEYS, makeConnKey } from "@/actor/instance/keys"; -export type { ActorInstance, AnyActorInstance } from "@/actor/instance/mod"; +export type { + ActorInstance, + AnyActorInstance, + AnyStaticActorInstance, + StaticActorInstance, +} from "@/actor/instance/mod"; export { ALLOWED_PUBLIC_HEADERS, HEADER_ACTOR_ID, diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index 4adb8959fa..f92a6b7cfe 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -12,8 +12,13 @@ import { WSContext, type WSContextInit } from "hono/ws"; import invariant from "invariant"; import { type AnyConn, CONN_STATE_MANAGER_SYMBOL } from "@/actor/conn/mod"; import { lookupInRegistry } from "@/actor/definition"; +import { + isStaticActorInstance, + type AnyStaticActorInstance, +} from "@/actor/instance/mod"; import { KEYS } from "@/actor/instance/keys"; import { deserializeActorKey } from "@/actor/keys"; +import type { Encoding } from "@/actor/protocol/serde"; import { getValueLength } from "@/actor/protocol/old"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { @@ -40,6 +45,9 @@ import { getInitialActorKvState, type ManagerDriver, } from "@/driver-helpers/mod"; +import { DynamicActorInstance } from "@/dynamic/instance"; +import { DynamicActorHostRuntime } from "@/dynamic/host-runtime"; +import { isDynamicActorDefinition } from "@/dynamic/internal"; import { buildActorNames, type RegistryConfig } from "@/registry/config"; import { getEndpoint } from "@/remote-manager-driver/api-utils"; import { @@ -83,6 +91,7 @@ export class EngineActorDriver implements ActorDriver { #inlineClient: Client; #runner: Runner; #actors: Map = new Map(); + #dynamicRuntimes = new Map(); #actorRouter: ActorRouter; #runnerStarted: PromiseWithResolvers = promiseWithResolvers((reason) => logger().warn({ msg: "unhandled runner started promise rejection", reason })); @@ -201,6 +210,18 @@ export class EngineActorDriver implements ActorDriver { return {}; } + #isDynamicActor(actorId: string): boolean { + return this.#dynamicRuntimes.has(actorId); + } + + #requireDynamicRuntime(actorId: string): DynamicActorHostRuntime { + const runtime = this.#dynamicRuntimes.get(actorId); + if (!runtime) { + throw new Error(`dynamic runtime missing for actor ${actorId}`); + } + return runtime; + } + async setAlarm(actor: AnyActorInstance, timestamp: number): Promise { const handler = this.#actors.get(actor.id); if (!handler) { @@ -350,18 +371,20 @@ export class EngineActorDriver implements ActorDriver { }); const stopPromises: Promise[] = []; for (const [_actorId, handler] of this.#actors.entries()) { - if (handler.actor) { - stopPromises.push( - handler.actor.onStop("sleep").catch((err) => { - handler.actor?.rLog.error({ - msg: "onStop errored", - error: stringifyError(err), - }); - }), - ); - } + if (handler.actor) { + stopPromises.push( + handler.actor.onStop("sleep").catch((err) => { + logger().error({ + msg: "onStop errored", + actorId: _actorId, + error: stringifyError(err), + }); + }), + ); + } } await Promise.all(stopPromises); + this.#dynamicRuntimes.clear(); logger().debug({ msg: "all actors stopped" }); try { @@ -497,40 +520,82 @@ export class EngineActorDriver implements ActorDriver { // Create actor instance const definition = lookupInRegistry(this.#config, actorConfig.name); + if (isDynamicActorDefinition(definition)) { + let runtime = this.#dynamicRuntimes.get(actorId); + if (!runtime) { + runtime = new DynamicActorHostRuntime({ + actorId, + actorName: name, + actorKey: key, + input, + region: "unknown", + loader: definition.loader, + actorDriver: this, + inlineClient: this.#inlineClient, + }); + await runtime.start(); + this.#dynamicRuntimes.set(actorId, runtime); + } - handler.actor = await definition.instantiate(); + const dynamicActor = new DynamicActorInstance(actorId, runtime); + handler.actor = dynamicActor; - // Apply protocol limits as per-instance overrides without mutating the shared definition - const protocolMetadata = this.#runner.getProtocolMetadata(); - if (protocolMetadata) { - logger().debug({ - msg: "applying config limits from protocol", - protocolMetadata, - }); + handler.actorStartError = undefined; + handler.actorStartPromise?.resolve(); + handler.actorStartPromise = undefined; - const stopThresholdMax = Math.max(Number(protocolMetadata.actorStopThreshold) - 1000, 0); - handler.actor.overrides.onSleepTimeout = stopThresholdMax; - handler.actor.overrides.onDestroyTimeout = stopThresholdMax; + const metaEntries = await dynamicActor.getHibernatingWebSockets(); + await this.#runner.restoreHibernatingRequests(actorId, metaEntries); + } else { + const staticActor = + (await definition.instantiate()) as AnyStaticActorInstance; + handler.actor = staticActor; + + // Apply protocol limits as per-instance overrides without mutating the shared definition + const protocolMetadata = this.#runner.getProtocolMetadata(); + if (protocolMetadata) { + logger().debug({ + msg: "applying config limits from protocol", + protocolMetadata, + }); - if (protocolMetadata.serverlessDrainGracePeriod) { - const drainMax = Math.max(Number(protocolMetadata.serverlessDrainGracePeriod) - 1000, 0); - handler.actor.overrides.runStopTimeout = drainMax; - handler.actor.overrides.waitUntilTimeout = drainMax; + const stopThresholdMax = Math.max( + Number(protocolMetadata.actorStopThreshold) - 1000, + 0, + ); + staticActor.overrides.onSleepTimeout = stopThresholdMax; + staticActor.overrides.onDestroyTimeout = stopThresholdMax; + + if (protocolMetadata.serverlessDrainGracePeriod) { + const drainMax = Math.max( + Number(protocolMetadata.serverlessDrainGracePeriod) - 1000, + 0, + ); + staticActor.overrides.runStopTimeout = drainMax; + staticActor.overrides.waitUntilTimeout = drainMax; + } } - } - // Start actor - await handler.actor.start( - this, - this.#inlineClient, - actorId, - name, - key, - "unknown", // TODO: Add regions - ); + // Start actor + await staticActor.start( + this, + this.#inlineClient, + actorId, + name, + key, + "unknown", // TODO: Add regions + ); + } logger().debug({ msg: "runner actor started", actorId, name, key }); } catch (innerError) { + const dynamicRuntime = this.#dynamicRuntimes.get(actorId); + if (dynamicRuntime) { + try { + await dynamicRuntime.dispose(); + } catch {} + this.#dynamicRuntimes.delete(actorId); + } const error = innerError instanceof Error ? new Error( @@ -608,6 +673,7 @@ export class EngineActorDriver implements ActorDriver { }); } } + this.#dynamicRuntimes.delete(actorId); this.#actors.delete(actorId); @@ -628,6 +694,9 @@ export class EngineActorDriver implements ActorDriver { url: request.url, method: request.method, }); + if (this.#isDynamicActor(actorId)) { + return await this.#requireDynamicRuntime(actorId).fetch(request); + } return await this.#actorRouter.fetch(request, { actorId }); } @@ -668,6 +737,22 @@ export class EngineActorDriver implements ActorDriver { const protocols = request.headers.get("sec-websocket-protocol"); const { encoding, connParams } = parseWebSocketProtocols(protocols); + if (this.#isDynamicActor(actorId)) { + await this.#runnerDynamicWebSocket( + actorId, + websocket, + gatewayIdBuf, + requestIdBuf, + requestPath, + requestHeaders, + encoding, + connParams, + isHibernatable, + isRestoringHibernatable, + ); + return; + } + // Fetch WS handler // // We store the promise since we need to add WebSocket event listeners immediately that will wait for the promise to resolve @@ -851,6 +936,165 @@ export class EngineActorDriver implements ActorDriver { } } + async #runnerDynamicWebSocket( + actorId: string, + websocket: UniversalWebSocket, + gatewayIdBuf: ArrayBuffer, + requestIdBuf: ArrayBuffer, + requestPath: string, + requestHeaders: Record, + encoding: Encoding, + connParams: unknown, + isHibernatable: boolean, + isRestoringHibernatable: boolean, + ): Promise { + let runtime: DynamicActorHostRuntime; + try { + runtime = this.#requireDynamicRuntime(actorId); + } catch (error) { + logger().error({ + msg: "dynamic runtime missing for websocket", + actorId, + error: stringifyError(error), + }); + websocket.close(1011, "dynamic.runtime_missing"); + return; + } + + let proxyToActorWs: UniversalWebSocket; + try { + proxyToActorWs = await runtime.openWebSocket( + requestPath, + encoding, + connParams, + { + headers: requestHeaders, + gatewayId: gatewayIdBuf, + requestId: requestIdBuf, + isHibernatable, + isRestoringHibernatable, + }, + ); + } catch (error) { + logger().error({ + msg: "failed to open dynamic websocket", + actorId, + error: stringifyError(error), + }); + websocket.close(1011, "dynamic.websocket_open_failed"); + return; + } + + let actorWebSocketReady = proxyToActorWs.readyState === proxyToActorWs.OPEN; + const pendingMessages: Array<{ + data: string | ArrayBufferLike | Blob | ArrayBufferView; + rivetMessageIndex?: number; + }> = []; + + const flushPendingMessages = async (): Promise => { + if (!actorWebSocketReady || pendingMessages.length === 0) { + return; + } + while (pendingMessages.length > 0) { + const next = pendingMessages.shift(); + if (!next) { + continue; + } + await runtime.forwardIncomingWebSocketMessage( + proxyToActorWs, + next.data, + next.rivetMessageIndex, + ); + if ( + isHibernatable && + typeof next.rivetMessageIndex === "number" + ) { + this.#runner.sendHibernatableWebSocketMessageAck( + gatewayIdBuf, + requestIdBuf, + next.rivetMessageIndex, + ); + } + } + }; + + proxyToActorWs.addEventListener("open", () => { + actorWebSocketReady = true; + void flushPendingMessages(); + }); + + proxyToActorWs.addEventListener("message", (event: RivetMessageEvent) => { + if (websocket.readyState !== websocket.OPEN) { + return; + } + websocket.send(event.data as any); + }); + + proxyToActorWs.addEventListener("close", (event) => { + if (websocket.readyState !== websocket.CLOSED) { + websocket.close(event.code, event.reason); + } + }); + + proxyToActorWs.addEventListener("error", (_event) => { + if (websocket.readyState !== websocket.CLOSED) { + websocket.close(1011, "dynamic.websocket_error"); + } + }); + + websocket.addEventListener("message", (event: RivetMessageEvent) => { + const actorHandler = this.#actors.get(actorId); + if (actorHandler?.actor?.isStopping) { + return; + } + if (!actorWebSocketReady) { + pendingMessages.push({ + data: event.data as any, + rivetMessageIndex: event.rivetMessageIndex, + }); + return; + } + void runtime + .forwardIncomingWebSocketMessage( + proxyToActorWs, + event.data as any, + event.rivetMessageIndex, + ) + .then(() => { + if ( + isHibernatable && + typeof event.rivetMessageIndex === "number" + ) { + this.#runner.sendHibernatableWebSocketMessageAck( + gatewayIdBuf, + requestIdBuf, + event.rivetMessageIndex, + ); + } + }) + .catch((error) => { + logger().error({ + msg: "failed forwarding websocket message to dynamic actor", + actorId, + error: stringifyError(error), + }); + websocket.close(1011, "dynamic.websocket_forward_failed"); + }); + }); + + websocket.addEventListener("close", (event) => { + if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { + proxyToActorWs.close(event.code, event.reason); + } + }); + + websocket.addEventListener("error", () => { + if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { + proxyToActorWs.close(1011, "dynamic.gateway_error"); + } + }); + } + // MARK: - Hibernating WebSockets #hwsCanHibernate( actorId: string, @@ -950,6 +1194,21 @@ export class EngineActorDriver implements ActorDriver { actorId: string, ): Promise { const actor = await this.loadActor(actorId); + if (!isStaticActorInstance(actor)) { + const runtime = this.#dynamicRuntimes.get(actorId); + if (!runtime) { + return []; + } + const entries = await runtime.getHibernatingWebSockets(); + return entries.map((entry) => ({ + gatewayId: entry.gatewayId, + requestId: entry.requestId, + serverMessageIndex: entry.serverMessageIndex, + clientMessageIndex: entry.clientMessageIndex, + path: entry.path, + headers: entry.headers, + })); + } return actor.conns .values() .map((conn) => { @@ -969,7 +1228,7 @@ export class EngineActorDriver implements ActorDriver { .toArray(); } - async onBeforeActorStart(actor: AnyActorInstance): Promise { + async onBeforeActorStart(actor: AnyStaticActorInstance): Promise { // Resolve promise if waiting const handler = this.#actors.get(actor.id); invariant(handler, "missing actor handler in onBeforeActorReady"); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts index d5d6f12735..e2080b61eb 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/actor.ts @@ -5,6 +5,7 @@ import { importSqliteVfs, type ActorDriver, type AnyActorInstance, + type AnyStaticActorInstance, type ManagerDriver, } from "@/driver-helpers/mod"; import type { FileSystemGlobalState } from "./global-state"; @@ -96,7 +97,7 @@ export class FileSystemActorDriver implements ActorDriver { await this.#state.destroyActor(actorId); } - async onBeforeActorStart(actor: AnyActorInstance): Promise { + async onBeforeActorStart(actor: AnyStaticActorInstance): Promise { await actor.cleanupPersistedConnections("file-system-driver.start"); } } diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts index 6294e91969..f9c7abe53c 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/file-system/global-state.ts @@ -2,13 +2,17 @@ import invariant from "invariant"; import { lookupInRegistry } from "@/actor/definition"; import { ActorDuplicateKey } from "@/actor/errors"; import type { Encoding } from "@/actor/protocol/serde"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { + AnyActorInstance, + AnyStaticActorInstance, +} from "@/actor/instance/mod"; import type { ActorKey } from "@/actor/mod"; import type { AnyClient } from "@/client/client"; import type { UniversalWebSocket } from "@/common/websocket-interface"; import { type ActorDriver, getInitialActorKvState } from "@/driver-helpers/mod"; -import { getDynamicActorMetadata } from "@/dynamic/internal"; +import { DynamicActorInstance } from "@/dynamic/instance"; import { DynamicActorHostRuntime } from "@/dynamic/host-runtime"; +import { isDynamicActorDefinition } from "@/dynamic/internal"; import type { RegistryConfig } from "@/registry/config"; import type * as schema from "@/schemas/file-system-driver/mod"; import { @@ -87,36 +91,6 @@ interface ActorEntry { generation: string; } -class DynamicActorInstanceAdapter { - #actorId: string; - #runtime: DynamicActorHostRuntime; - #isStopping = false; - - constructor(actorId: string, runtime: DynamicActorHostRuntime) { - this.#actorId = actorId; - this.#runtime = runtime; - } - - get id(): string { - return this.#actorId; - } - - get isStopping(): boolean { - return this.#isStopping; - } - - async onStop(mode: "sleep" | "destroy"): Promise { - if (this.#isStopping) return; - this.#isStopping = true; - await this.#runtime.stop(mode); - await this.#runtime.dispose(); - } - - async onAlarm(): Promise { - await this.#runtime.dispatchAlarm(); - } -} - export interface FileSystemDriverOptions { /** Whether to persist data to disk */ persist?: boolean; @@ -1118,8 +1092,7 @@ export class FileSystemGlobalState { try { // Create actor const definition = lookupInRegistry(config, entry.state.name); - const dynamicMetadata = getDynamicActorMetadata(definition); - if (dynamicMetadata) { + if (isDynamicActorDefinition(definition)) { let runtime = this.#dynamicRuntimes.get(actorId); if (!runtime) { runtime = new DynamicActorHostRuntime({ @@ -1128,24 +1101,26 @@ export class FileSystemGlobalState { actorKey: entry.state.key as string[], input: this.#actorInitialInputs.get(actorId), region: "unknown", - metadata: dynamicMetadata, + loader: definition.loader, actorDriver, inlineClient, }); await runtime.start(); this.#dynamicRuntimes.set(actorId, runtime); } - entry.actor = new DynamicActorInstanceAdapter( + entry.actor = new DynamicActorInstance( actorId, runtime, - ) as unknown as AnyActorInstance; + ); entry.lifecycleState = ActorLifecycleState.AWAKE; } else { - entry.actor = await definition.instantiate(); + const staticActor = + (await definition.instantiate()) as AnyStaticActorInstance; + entry.actor = staticActor; entry.lifecycleState = ActorLifecycleState.AWAKE; // Start actor - await entry.actor.start( + await staticActor.start( actorDriver, inlineClient, actorId, @@ -1213,7 +1188,7 @@ export class FileSystemGlobalState { ): Promise { const state = await this.loadActorStateOrError(actorId); const definition = lookupInRegistry(config, state.name); - return getDynamicActorMetadata(definition) !== undefined; + return isDynamicActorDefinition(definition); } async dynamicFetch(actorId: string, request: Request): Promise { diff --git a/rivetkit-typescript/packages/rivetkit/src/dynamic/host-runtime.ts b/rivetkit-typescript/packages/rivetkit/src/dynamic/host-runtime.ts index 9d3d1718c7..2a2fde51c9 100644 --- a/rivetkit-typescript/packages/rivetkit/src/dynamic/host-runtime.ts +++ b/rivetkit-typescript/packages/rivetkit/src/dynamic/host-runtime.ts @@ -20,8 +20,8 @@ import type { UniversalWebSocket } from "@/common/websocket-interface"; import type { AnyClient } from "@/client/client"; import { createDynamicActorLoaderContext, + type DynamicActorLoader, type DynamicActorLoadResult, - type DynamicActorMetadata, } from "./internal"; const DYNAMIC_RUNTIME_ROOT = path.join( @@ -126,6 +126,11 @@ interface WebSocketOpenEnvelopeInput { path: string; encoding: Encoding; params: unknown; + headers?: Record; + gatewayIdBase64?: string; + requestIdBase64?: string; + isHibernatable?: boolean; + isRestoringHibernatable?: boolean; } interface WebSocketSendEnvelopeInput { @@ -133,6 +138,7 @@ interface WebSocketSendEnvelopeInput { kind: "text" | "binary"; text?: string; dataBase64?: string; + rivetMessageIndex?: number; } interface WebSocketCloseEnvelopeInput { @@ -152,6 +158,7 @@ type IsolateDispatchPayload = kind: "text" | "binary"; text?: string; dataBase64?: string; + rivetMessageIndex?: number; } | { type: "close"; @@ -183,11 +190,28 @@ interface DynamicActorHostRuntimeConfig { actorKey: ActorKey; input: unknown; region: string; - metadata: DynamicActorMetadata; + loader: DynamicActorLoader; actorDriver: ActorDriver; inlineClient: AnyClient; } +export interface DynamicWebSocketOpenOptions { + headers?: Record; + gatewayId?: ArrayBuffer; + requestId?: ArrayBuffer; + isHibernatable?: boolean; + isRestoringHibernatable?: boolean; +} + +export interface DynamicHibernatingWebSocketMetadata { + gatewayId: ArrayBuffer; + requestId: ArrayBuffer; + serverMessageIndex: number; + clientMessageIndex: number; + path: string; + headers: Record; +} + export class DynamicActorHostRuntime { #config: DynamicActorHostRuntimeConfig; #runtimeDir: string; @@ -207,12 +231,18 @@ export class DynamicActorHostRuntime { #closeWebSocketRef: | ReferenceLike<(input: WebSocketCloseEnvelopeInput) => Promise> | undefined; + #getHibernatingWebSocketsRef: + | ReferenceLike< + () => Promise> + > + | undefined; #disposeRef: ReferenceLike<() => Promise> | undefined; #startSleepRef: ReferenceLike<(actorId: string) => void> | undefined; #startDestroyRef: ReferenceLike<(actorId: string) => void> | undefined; #referenceHandles: Array<{ release?: () => void }> = []; #webSocketSessions = new Map(); + #sessionIdsByWebSocket = new WeakMap(); #nextWebSocketSessionId = 1; #started = false; #disposed = false; @@ -238,7 +268,7 @@ export class DynamicActorHostRuntime { actorId: this.#config.actorId, }); - const loadResult = await this.#config.metadata.loader( + const loadResult = await this.#config.loader( createDynamicActorLoaderContext( this.#config.inlineClient, this.#config.actorId, @@ -368,6 +398,7 @@ export class DynamicActorHostRuntime { pathValue: string, encoding: Encoding, params: unknown, + options: DynamicWebSocketOpenOptions = {}, ): Promise { if (!this.#openWebSocketRef || !this.#sendWebSocketRef || !this.#closeWebSocketRef) { throw new Error("dynamic runtime websocket bridge is not started"); @@ -394,18 +425,31 @@ export class DynamicActorHostRuntime { pendingMessages: [], }; this.#webSocketSessions.set(session.id, session); + this.#sessionIdsByWebSocket.set(session.websocket, session.id); try { + const gatewayIdBase64 = options.gatewayId + ? Buffer.from(new Uint8Array(options.gatewayId)).toString("base64") + : undefined; + const requestIdBase64 = options.requestId + ? Buffer.from(new Uint8Array(options.requestId)).toString("base64") + : undefined; + await this.#openWebSocketRef.apply( undefined, [ { sessionId, - path: pathValue, - encoding, - params, - } satisfies WebSocketOpenEnvelopeInput, - ], + path: pathValue, + encoding, + params, + headers: options.headers, + gatewayIdBase64, + requestIdBase64, + isHibernatable: options.isHibernatable, + isRestoringHibernatable: options.isRestoringHibernatable, + } satisfies WebSocketOpenEnvelopeInput, + ], { arguments: { copy: true, @@ -442,6 +486,37 @@ export class DynamicActorHostRuntime { }); } + async getHibernatingWebSockets(): Promise< + Array + > { + if (!this.#getHibernatingWebSocketsRef) { + return []; + } + const entries = await this.#getHibernatingWebSocketsRef.apply( + undefined, + [], + { + result: { + copy: true, + promise: true, + }, + }, + ); + return entries as Array; + } + + async forwardIncomingWebSocketMessage( + websocket: UniversalWebSocket, + data: string | ArrayBufferLike | Blob | ArrayBufferView, + rivetMessageIndex?: number, + ): Promise { + const sessionId = this.#sessionIdsByWebSocket.get(websocket); + if (!sessionId) { + throw new Error("dynamic runtime websocket session not found"); + } + await this.#sendWebSocketMessage(sessionId, data, rivetMessageIndex); + } + async stop(mode: "sleep" | "destroy"): Promise { if (!this.#stopRef) return; try { @@ -514,6 +589,7 @@ export class DynamicActorHostRuntime { this.#openWebSocketRef = undefined; this.#sendWebSocketRef = undefined; this.#closeWebSocketRef = undefined; + this.#getHibernatingWebSocketsRef = undefined; this.#disposeRef = undefined; this.#startSleepRef = undefined; this.#startDestroyRef = undefined; @@ -523,6 +599,7 @@ export class DynamicActorHostRuntime { async #sendWebSocketMessage( sessionId: number, data: string | ArrayBufferLike | Blob | ArrayBufferView, + rivetMessageIndex?: number, ): Promise { if (!this.#sendWebSocketRef) return; @@ -530,12 +607,13 @@ export class DynamicActorHostRuntime { await this.#sendWebSocketRef.apply( undefined, [ - { - sessionId, - kind: "text", - text: data, - } satisfies WebSocketSendEnvelopeInput, - ], + { + sessionId, + kind: "text", + text: data, + rivetMessageIndex, + } satisfies WebSocketSendEnvelopeInput, + ], { arguments: { copy: true }, result: { copy: true, promise: true }, @@ -548,12 +626,13 @@ export class DynamicActorHostRuntime { await this.#sendWebSocketRef.apply( undefined, [ - { - sessionId, - kind: "binary", - dataBase64: Buffer.from(binary).toString("base64"), - } satisfies WebSocketSendEnvelopeInput, - ], + { + sessionId, + kind: "binary", + dataBase64: Buffer.from(binary).toString("base64"), + rivetMessageIndex, + } satisfies WebSocketSendEnvelopeInput, + ], { arguments: { copy: true }, result: { copy: true, promise: true }, @@ -715,11 +794,12 @@ export class DynamicActorHostRuntime { globalThis.__dynamicFetchEnvelope = bootstrap.dynamicFetchEnvelope; globalThis.__dynamicDispatchAlarmEnvelope = bootstrap.dynamicDispatchAlarmEnvelope; globalThis.__dynamicStopEnvelope = bootstrap.dynamicStopEnvelope; - globalThis.__dynamicOpenWebSocketEnvelope = bootstrap.dynamicOpenWebSocketEnvelope; - globalThis.__dynamicWebSocketSendEnvelope = bootstrap.dynamicWebSocketSendEnvelope; - globalThis.__dynamicWebSocketCloseEnvelope = bootstrap.dynamicWebSocketCloseEnvelope; - globalThis.__dynamicDisposeEnvelope = bootstrap.dynamicDisposeEnvelope; - `, + globalThis.__dynamicOpenWebSocketEnvelope = bootstrap.dynamicOpenWebSocketEnvelope; + globalThis.__dynamicWebSocketSendEnvelope = bootstrap.dynamicWebSocketSendEnvelope; + globalThis.__dynamicWebSocketCloseEnvelope = bootstrap.dynamicWebSocketCloseEnvelope; + globalThis.__dynamicGetHibernatingWebSocketsEnvelope = bootstrap.dynamicGetHibernatingWebSocketsEnvelope; + globalThis.__dynamicDisposeEnvelope = bootstrap.dynamicDisposeEnvelope; + `, { filename: path.join(this.#runtimeDir, "dynamic-bootstrap-entry.cjs"), }, @@ -760,6 +840,9 @@ export class DynamicActorHostRuntime { this.#closeWebSocketRef = await getRef( "__dynamicWebSocketCloseEnvelope", ); + this.#getHibernatingWebSocketsRef = await getRef( + "__dynamicGetHibernatingWebSocketsEnvelope", + ); this.#disposeRef = await getRef("__dynamicDisposeEnvelope"); } @@ -854,13 +937,19 @@ export class DynamicActorHostRuntime { payload: Extract, ): void { if (payload.kind === "text") { - session.websocket.triggerMessage(payload.text ?? ""); + (session.websocket as any).triggerMessage( + payload.text ?? "", + payload.rivetMessageIndex, + ); return; } const bytes = payload.dataBase64 ? Buffer.from(payload.dataBase64, "base64") : Buffer.alloc(0); - session.websocket.triggerMessage(bytes); + (session.websocket as any).triggerMessage( + bytes, + payload.rivetMessageIndex, + ); } } @@ -908,12 +997,14 @@ function resolveRivetkitPackageRoot(): string { let current = path.dirname(entryPath); while (true) { - const candidate = path.join(current, "package.json"); - try { - const packageJsonRaw = requireJsonSync(candidate); - if (packageJsonRaw?.name === "rivetkit") { - return current; - } + const candidate = path.join(current, "package.json"); + try { + const packageJsonRaw = requireJsonSync(candidate) as { + name?: string; + }; + if (packageJsonRaw?.name === "rivetkit") { + return current; + } } catch { // Continue walking up until package root is found. } @@ -1507,13 +1598,23 @@ function resolveSecureExecPackageDir(distEntryPath: string): string { } async function nativeDynamicImport(specifier: string): Promise { - // Vite SSR rewrites import() and cannot resolve file:// paths outside the - // project graph. Using Function() forces the runtime's native loader. - const importer = new Function( - "moduleSpecifier", - "return import(moduleSpecifier);", - ) as (moduleSpecifier: string) => Promise; - return await importer(specifier); + // Try direct dynamic import first because VM-backed test runners may reject + // import() from Function() with ERR_VM_DYNAMIC_IMPORT_CALLBACK_MISSING. + try { + return (await import(specifier)) as T; + } catch (directError) { + // Vite SSR can rewrite import() and fail to resolve file:// specifiers + // outside the project graph. Function() forces the runtime native loader. + const importer = new Function( + "moduleSpecifier", + "return import(moduleSpecifier);", + ) as (moduleSpecifier: string) => Promise; + try { + return await importer(specifier); + } catch { + throw directError; + } + } } function buildLockedDownPermissions(rootPath: string): { @@ -1578,8 +1679,15 @@ let setup; let createActorRouter; let routeWebSocket; let InlineWebSocketAdapter; +let CONN_STATE_MANAGER_SYMBOL; try { - ({ setup, createActorRouter, routeWebSocket, InlineWebSocketAdapter } = require("rivetkit")); + ({ + setup, + createActorRouter, + routeWebSocket, + InlineWebSocketAdapter, + CONN_STATE_MANAGER_SYMBOL, + } = require("rivetkit")); } catch (error) { const details = error && error.stack ? error.stack : String(error); throw new Error(\`dynamic runtime failed to require rivetkit: \${details}\`); @@ -1653,6 +1761,14 @@ function fromBase64(value) { return new Uint8Array(Buffer.from(value, "base64")); } +function fromBase64ToArrayBuffer(value) { + const buffer = Buffer.from(value, "base64"); + return buffer.buffer.slice( + buffer.byteOffset, + buffer.byteOffset + buffer.byteLength, + ); +} + function responseHeadersToEntries(headers) { if (!headers) { return []; @@ -1881,13 +1997,19 @@ async function dynamicStopEnvelope(mode) { } async function dynamicOpenWebSocketEnvelope(input) { - const headers = {}; + const headers = input.headers || {}; const requestPath = input.path || "/connect"; const pathOnly = requestPath.split("?")[0]; const request = new Request( requestPath.startsWith("http") ? requestPath : \`http://actor\${requestPath}\`, { method: "GET", headers }, ); + const gatewayId = input.gatewayIdBase64 + ? fromBase64ToArrayBuffer(input.gatewayIdBase64) + : undefined; + const requestId = input.requestIdBase64 + ? fromBase64ToArrayBuffer(input.requestIdBase64) + : undefined; const handler = await routeWebSocket( request, pathOnly, @@ -1897,14 +2019,14 @@ async function dynamicOpenWebSocketEnvelope(input) { actorId, input.encoding, input.params, - undefined, - undefined, - false, - false, + gatewayId, + requestId, + !!input.isHibernatable, + !!input.isRestoringHibernatable, ); const adapter = new InlineWebSocketAdapter(handler); const ws = adapter.clientWebSocket; - webSocketSessions.set(input.sessionId, ws); + webSocketSessions.set(input.sessionId, { ws, adapter }); ws.addEventListener("open", () => { bridgeCallSync(globalThis.__dynamicHostDispatch, [{ @@ -1920,6 +2042,7 @@ async function dynamicOpenWebSocketEnvelope(input) { sessionId: input.sessionId, kind: "text", text: data, + rivetMessageIndex: event.rivetMessageIndex, }]); return; } @@ -1931,6 +2054,7 @@ async function dynamicOpenWebSocketEnvelope(input) { sessionId: input.sessionId, kind: "binary", dataBase64: Buffer.from(new Uint8Array(buffer)).toString("base64"), + rivetMessageIndex: event.rivetMessageIndex, }]); }) .catch((error) => { @@ -1948,6 +2072,7 @@ async function dynamicOpenWebSocketEnvelope(input) { sessionId: input.sessionId, kind: "binary", dataBase64: Buffer.from(new Uint8Array(data.buffer, data.byteOffset, data.byteLength)).toString("base64"), + rivetMessageIndex: event.rivetMessageIndex, }]); return; } @@ -1957,6 +2082,7 @@ async function dynamicOpenWebSocketEnvelope(input) { sessionId: input.sessionId, kind: "binary", dataBase64: Buffer.from(new Uint8Array(data)).toString("base64"), + rivetMessageIndex: event.rivetMessageIndex, }]); } }); @@ -1981,30 +2107,64 @@ async function dynamicOpenWebSocketEnvelope(input) { } async function dynamicWebSocketSendEnvelope(input) { - const ws = webSocketSessions.get(input.sessionId); - if (!ws) return false; - if (input.kind === "text") { - ws.send(input.text || ""); + const session = webSocketSessions.get(input.sessionId); + if (!session) return false; + const payload = + input.kind === "text" + ? input.text || "" + : input.dataBase64 + ? Buffer.from(input.dataBase64, "base64") + : undefined; + if (payload === undefined) { + return false; + } + if ( + typeof session.adapter.dispatchClientMessageWithMetadata === "function" + ) { + session.adapter.dispatchClientMessageWithMetadata( + payload, + input.rivetMessageIndex, + ); return true; } - if (!input.dataBase64) { - return false; + if (input.kind === "text") { + session.ws.send(input.text || ""); + return true; } - ws.send(Buffer.from(input.dataBase64, "base64")); + session.ws.send(Buffer.from(input.dataBase64, "base64")); return true; } async function dynamicWebSocketCloseEnvelope(input) { - const ws = webSocketSessions.get(input.sessionId); - if (!ws) return false; - ws.close(input.code, input.reason); + const session = webSocketSessions.get(input.sessionId); + if (!session) return false; + session.ws.close(input.code, input.reason); return true; } +async function dynamicGetHibernatingWebSocketsEnvelope() { + const actor = await loadActor(actorId); + return Array.from(actor.conns.values()) + .map((conn) => { + const connStateManager = conn[CONN_STATE_MANAGER_SYMBOL]; + const hibernatable = connStateManager?.hibernatableData; + if (!hibernatable) return undefined; + return { + gatewayId: hibernatable.gatewayId, + requestId: hibernatable.requestId, + serverMessageIndex: hibernatable.serverMessageIndex, + clientMessageIndex: hibernatable.clientMessageIndex, + path: hibernatable.requestPath, + headers: hibernatable.requestHeaders, + }; + }) + .filter((entry) => entry !== undefined); +} + async function dynamicDisposeEnvelope() { - for (const ws of webSocketSessions.values()) { + for (const session of webSocketSessions.values()) { try { - ws.close(1001, "dynamic.runtime.disposed"); + session.ws.close(1001, "dynamic.runtime.disposed"); } catch {} } webSocketSessions.clear(); @@ -2018,6 +2178,7 @@ module.exports = { dynamicOpenWebSocketEnvelope, dynamicWebSocketSendEnvelope, dynamicWebSocketCloseEnvelope, + dynamicGetHibernatingWebSocketsEnvelope, dynamicDisposeEnvelope, }; `; diff --git a/rivetkit-typescript/packages/rivetkit/src/dynamic/instance.ts b/rivetkit-typescript/packages/rivetkit/src/dynamic/instance.ts new file mode 100644 index 0000000000..c6b9afb619 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/src/dynamic/instance.ts @@ -0,0 +1,66 @@ +import type { ActorInstance } from "@/actor/instance/mod"; +import type { Encoding } from "@/actor/protocol/serde"; +import type { UniversalWebSocket } from "@/common/websocket-interface"; +import { + DynamicActorHostRuntime, + type DynamicWebSocketOpenOptions, +} from "./host-runtime"; + +export class DynamicActorInstance implements ActorInstance { + #actorId: string; + #runtime: DynamicActorHostRuntime; + #isStopping = false; + + constructor(actorId: string, runtime: DynamicActorHostRuntime) { + this.#actorId = actorId; + this.#runtime = runtime; + } + + get id(): string { + return this.#actorId; + } + + get isStopping(): boolean { + return this.#isStopping; + } + + async onStop(mode: "sleep" | "destroy"): Promise { + if (this.#isStopping) return; + this.#isStopping = true; + await this.#runtime.stop(mode); + await this.#runtime.dispose(); + } + + async onAlarm(): Promise { + await this.#runtime.dispatchAlarm(); + } + + async fetch(request: Request): Promise { + return await this.#runtime.fetch(request); + } + + async openWebSocket( + path: string, + encoding: Encoding, + params: unknown, + options?: DynamicWebSocketOpenOptions, + ): Promise { + return await this.#runtime.openWebSocket(path, encoding, params, options); + } + + async getHibernatingWebSockets() { + return await this.#runtime.getHibernatingWebSockets(); + } + + async forwardIncomingWebSocketMessage( + websocket: UniversalWebSocket, + data: string | ArrayBufferLike | Blob | ArrayBufferView, + rivetMessageIndex?: number, + ): Promise { + await this.#runtime.forwardIncomingWebSocketMessage( + websocket, + data, + rivetMessageIndex, + ); + } +} diff --git a/rivetkit-typescript/packages/rivetkit/src/dynamic/internal.ts b/rivetkit-typescript/packages/rivetkit/src/dynamic/internal.ts index 33a0689683..fc5a4946b2 100644 --- a/rivetkit-typescript/packages/rivetkit/src/dynamic/internal.ts +++ b/rivetkit-typescript/packages/rivetkit/src/dynamic/internal.ts @@ -1,5 +1,13 @@ import type { ActorKey } from "@/actor/mod"; -import type { AnyActorDefinition } from "@/actor/definition"; +import type { + ActorConfig, + GlobalActorOptionsInput, +} from "@/actor/config"; +import { ActorConfigSchema } from "@/actor/config"; +import type { AnyActorDefinition, ActorDefinition } from "@/actor/definition"; +import type { AnyDatabaseProvider } from "@/actor/database"; +import type { EventSchemaConfig, QueueSchemaConfig } from "@/actor/schema"; +import type { ActorInstance } from "@/actor/instance/mod"; import type { AnyClient, Client } from "@/client/client"; import type { Registry } from "@/registry"; @@ -26,39 +34,81 @@ export type DynamicActorLoader = ( context: DynamicActorLoaderContext, ) => Promise | DynamicActorLoadResult; -export interface DynamicActorMetadata { - loader: DynamicActorLoader; +export interface DynamicActorConfigInput { + options?: GlobalActorOptionsInput; } -export const DYNAMIC_ACTOR_METADATA_SYMBOL = Symbol.for( - "rivetkit.dynamic_actor.metadata", -); +export class DynamicActorDefinition + implements + ActorDefinition< + any, + any, + any, + any, + any, + AnyDatabaseProvider, + EventSchemaConfig, + QueueSchemaConfig, + Record unknown> + > +{ + #loader: DynamicActorLoader; + #config: ActorConfig< + any, + any, + any, + any, + any, + AnyDatabaseProvider, + EventSchemaConfig, + QueueSchemaConfig + >; -export function attachDynamicActorMetadata( - definition: AnyActorDefinition, - metadata: DynamicActorMetadata, -): void { - ( - definition as AnyActorDefinition & { - [DYNAMIC_ACTOR_METADATA_SYMBOL]?: DynamicActorMetadata; - } - )[DYNAMIC_ACTOR_METADATA_SYMBOL] = metadata; -} + constructor(loader: DynamicActorLoader, input: DynamicActorConfigInput = {}) { + this.#loader = loader; + this.#config = ActorConfigSchema.parse({ + actions: {}, + options: input.options ?? {}, + }) as ActorConfig< + any, + any, + any, + any, + any, + AnyDatabaseProvider, + EventSchemaConfig, + QueueSchemaConfig + >; + } -export function getDynamicActorMetadata( - definition: AnyActorDefinition, -): DynamicActorMetadata | undefined { - return ( - definition as AnyActorDefinition & { - [DYNAMIC_ACTOR_METADATA_SYMBOL]?: DynamicActorMetadata; - } - )[DYNAMIC_ACTOR_METADATA_SYMBOL]; + get loader(): DynamicActorLoader { + return this.#loader; + } + + get config(): ActorConfig< + any, + any, + any, + any, + any, + AnyDatabaseProvider, + EventSchemaConfig, + QueueSchemaConfig + > { + return this.#config; + } + + instantiate(): ActorInstance { + throw new Error( + "dynamic actor definitions are instantiated by the runtime driver", + ); + } } export function isDynamicActorDefinition( definition: AnyActorDefinition, -): boolean { - return getDynamicActorMetadata(definition) !== undefined; +): definition is DynamicActorDefinition { + return definition instanceof DynamicActorDefinition; } export function createDynamicActorLoaderContext( diff --git a/rivetkit-typescript/packages/rivetkit/src/dynamic/mod.ts b/rivetkit-typescript/packages/rivetkit/src/dynamic/mod.ts index 668324d1c2..8ddc0b2fa7 100644 --- a/rivetkit-typescript/packages/rivetkit/src/dynamic/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/dynamic/mod.ts @@ -1,27 +1,22 @@ -import { actor } from "@/actor/mod"; import type { AnyActorDefinition } from "@/actor/definition"; import { - attachDynamicActorMetadata, + DynamicActorDefinition, + type DynamicActorConfigInput, type DynamicActorLoader, type DynamicActorLoaderContext, type DynamicActorLoadResult, type DynamicNodeProcessConfig, } from "./internal"; -export function dynamicActor(loader: DynamicActorLoader): AnyActorDefinition { - const definition = actor({ - // Keep the host-side placeholder actor awake. Sleep/wake semantics - // are handled by the evaluated actor inside the isolate runtime. - options: { - noSleep: true, - }, - }) as AnyActorDefinition; - - attachDynamicActorMetadata(definition, { loader }); - return definition; +export function dynamicActor( + loader: DynamicActorLoader, + config: DynamicActorConfigInput = {}, +): AnyActorDefinition { + return new DynamicActorDefinition(loader, config); } export type { + DynamicActorConfigInput, DynamicActorLoader, DynamicActorLoaderContext, DynamicActorLoadResult, diff --git a/rivetkit-typescript/packages/rivetkit/src/inspector/actor-inspector.ts b/rivetkit-typescript/packages/rivetkit/src/inspector/actor-inspector.ts index 5806aba71f..cf9ed5ae68 100644 --- a/rivetkit-typescript/packages/rivetkit/src/inspector/actor-inspector.ts +++ b/rivetkit-typescript/packages/rivetkit/src/inspector/actor-inspector.ts @@ -8,7 +8,7 @@ import { import { getRunInspectorConfig } from "@/actor/config"; import { ActionContext } from "@/actor/contexts/action"; import * as actorErrors from "@/actor/errors"; -import type { AnyActorInstance } from "@/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type * as schema from "@/schemas/actor-inspector/mod"; import { bufferToArrayBuffer } from "@/utils"; @@ -35,7 +35,7 @@ export class ActorInspector { ReturnType >["workflow"]; - constructor(private readonly actor: AnyActorInstance) { + constructor(private readonly actor: AnyStaticActorInstance) { this.#lastQueueSize = actor.queueManager?.size ?? 0; const runInspector = getRunInspectorConfig(actor.config.run); this.#workflowInspector = runInspector?.workflow; @@ -349,4 +349,3 @@ export class ActorInspector { function escapeDoubleQuotes(value: string): string { return value.replace(/"/g, '""'); } - diff --git a/rivetkit-typescript/packages/rivetkit/src/inspector/handler.ts b/rivetkit-typescript/packages/rivetkit/src/inspector/handler.ts index e62ea714e7..5359fe2648 100644 --- a/rivetkit-typescript/packages/rivetkit/src/inspector/handler.ts +++ b/rivetkit-typescript/packages/rivetkit/src/inspector/handler.ts @@ -1,7 +1,8 @@ import type { WSContext } from "hono/ws"; import type { Unsubscribe } from "nanoevents"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type { UpgradeWebSocketArgs } from "@/actor/router-websocket-endpoints"; -import type { AnyActorInstance, RivetMessageEvent } from "@/mod"; +import type { RivetMessageEvent } from "@/mod"; import type { ToClient } from "@/schemas/actor-inspector/mod"; import { encodeReadRangeWire } from "@rivetkit/traces/encoding"; import { @@ -15,7 +16,7 @@ import { inspectorLogger } from "./log"; export async function handleWebSocketInspectorConnect({ actor, }: { - actor: AnyActorInstance; + actor: AnyStaticActorInstance; }): Promise { const inspector = actor.inspector; const maxQueueStatusLimit = 200; diff --git a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts index edc77e4b08..ad811a2a28 100644 --- a/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/workflow/driver.ts @@ -1,5 +1,5 @@ import type { RunContext } from "@/actor/contexts/run"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import { makeWorkflowKey, workflowStoragePrefix } from "@/actor/instance/keys"; import type { EngineDriver, @@ -16,11 +16,11 @@ function stripWorkflowKey(prefixed: Uint8Array): Uint8Array { } class ActorWorkflowMessageDriver implements WorkflowMessageDriver { - #actor: AnyActorInstance; + #actor: AnyStaticActorInstance; #runCtx: RunContext; constructor( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, runCtx: RunContext, ) { this.#actor = actor; @@ -84,11 +84,11 @@ class ActorWorkflowMessageDriver implements WorkflowMessageDriver { export class ActorWorkflowDriver implements EngineDriver { readonly workerPollInterval = 100; readonly messageDriver: WorkflowMessageDriver; - #actor: AnyActorInstance; + #actor: AnyStaticActorInstance; #runCtx: RunContext; constructor( - actor: AnyActorInstance, + actor: AnyStaticActorInstance, runCtx: RunContext, ) { this.#actor = actor; diff --git a/rivetkit-typescript/packages/rivetkit/src/workflow/mod.ts b/rivetkit-typescript/packages/rivetkit/src/workflow/mod.ts index 9edc1e422e..4d6eec7ff7 100644 --- a/rivetkit-typescript/packages/rivetkit/src/workflow/mod.ts +++ b/rivetkit-typescript/packages/rivetkit/src/workflow/mod.ts @@ -1,7 +1,7 @@ import { ACTOR_CONTEXT_INTERNAL_SYMBOL } from "@/actor/contexts/base/actor"; import type { RunContext } from "@/actor/contexts/run"; import type { AnyDatabaseProvider } from "@/actor/database"; -import type { AnyActorInstance } from "@/actor/instance/mod"; +import type { AnyStaticActorInstance } from "@/actor/instance/mod"; import type { EventSchemaConfig, QueueSchemaConfig } from "@/actor/schema"; import { RUN_FUNCTION_CONFIG_SYMBOL } from "@/actor/config"; import { stringifyError } from "@/utils"; @@ -70,7 +70,7 @@ export function workflow< ): Promise { const actor = ( runCtx as unknown as { - [ACTOR_CONTEXT_INTERNAL_SYMBOL]?: AnyActorInstance; + [ACTOR_CONTEXT_INTERNAL_SYMBOL]?: AnyStaticActorInstance; } )[ACTOR_CONTEXT_INTERNAL_SYMBOL]; invariant(actor, "workflow() requires an actor instance"); diff --git a/rivetkit-typescript/packages/rivetkit/tests/driver-engine-dynamic.test.ts b/rivetkit-typescript/packages/rivetkit/tests/driver-engine-dynamic.test.ts new file mode 100644 index 0000000000..21e44471a1 --- /dev/null +++ b/rivetkit-typescript/packages/rivetkit/tests/driver-engine-dynamic.test.ts @@ -0,0 +1,374 @@ +import { createServer, type IncomingMessage, type ServerResponse } from "node:http"; +import { existsSync } from "node:fs"; +import { join } from "node:path"; +import { pathToFileURL } from "node:url"; +import { createClient } from "@/client/mod"; +import { createTestRuntime } from "@/driver-test-suite/mod"; +import { createEngineDriver } from "@/drivers/engine/mod"; +import invariant from "invariant"; +import { createClientWithDriver } from "@/client/client"; +import { convertRegistryConfigToClientConfig } from "@/client/config"; +import { afterEach, describe, expect, test } from "vitest"; +import { DYNAMIC_SOURCE } from "../fixtures/driver-test-suite/dynamic-registry"; +import type { registry } from "../fixtures/driver-test-suite/dynamic-registry"; + +const SECURE_EXEC_DIST_PATH = join( + process.env.HOME ?? "", + "secure-exec-rivet/packages/sandboxed-node/dist/index.js", +); +const hasSecureExecDist = existsSync(SECURE_EXEC_DIST_PATH); +const hasEngineEndpointEnv = !!( + process.env.RIVET_ENDPOINT || + process.env.RIVET_NAMESPACE_ENDPOINT || + process.env.RIVET_API_ENDPOINT +); +const initialDynamicSourceUrlEnv = + process.env.RIVETKIT_DYNAMIC_TEST_SOURCE_URL; +const initialSecureExecSpecifierEnv = + process.env.RIVETKIT_DYNAMIC_SECURE_EXEC_SPECIFIER; + +type DynamicHandle = { + increment: (amount?: number) => Promise; + getState: () => Promise<{ + count: number; + wakeCount: number; + sleepCount: number; + alarmCount: number; + }>; + putText: (key: string, value: string) => Promise; + getText: (key: string) => Promise; + listText: (prefix: string) => Promise>; + triggerSleep: () => Promise; + scheduleAlarm: (duration: number) => Promise; + webSocket: (path?: string) => Promise; +}; + +describe.skipIf(!hasSecureExecDist || !hasEngineEndpointEnv)( + "engine dynamic actor runtime", + () => { + let sourceServer: + | { + url: string; + close: () => Promise; + } + | undefined; + + afterEach(async () => { + if (sourceServer) { + await sourceServer.close(); + sourceServer = undefined; + } + if (initialDynamicSourceUrlEnv === undefined) { + delete process.env.RIVETKIT_DYNAMIC_TEST_SOURCE_URL; + } else { + process.env.RIVETKIT_DYNAMIC_TEST_SOURCE_URL = + initialDynamicSourceUrlEnv; + } + if (initialSecureExecSpecifierEnv === undefined) { + delete process.env.RIVETKIT_DYNAMIC_SECURE_EXEC_SPECIFIER; + } else { + process.env.RIVETKIT_DYNAMIC_SECURE_EXEC_SPECIFIER = + initialSecureExecSpecifierEnv; + } + }); + + test("loads dynamic actor source from URL", async () => { + sourceServer = await startSourceServer(DYNAMIC_SOURCE); + process.env.RIVETKIT_DYNAMIC_TEST_SOURCE_URL = sourceServer.url; + process.env.RIVETKIT_DYNAMIC_SECURE_EXEC_SPECIFIER = pathToFileURL( + SECURE_EXEC_DIST_PATH, + ).href; + + const runtime = await createDynamicEngineRuntime(); + const client = createClient({ + endpoint: runtime.endpoint, + namespace: runtime.namespace, + runnerName: runtime.runnerName, + encoding: "json", + disableMetadataLookup: true, + }); + + try { + const actor = client.dynamicFromUrl.getOrCreate([ + "url-loader", + ]) as unknown as DynamicHandle; + expect(await actor.increment(2)).toBe(2); + expect(await actor.increment(3)).toBe(5); + + const state = await actor.getState(); + expect(state.count).toBe(5); + expect(state.wakeCount).toBeGreaterThanOrEqual(1); + } finally { + await client.dispose(); + await runtime.cleanup(); + } + }, 180_000); + + test("supports actions, kv, websockets, alarms, and sleep/wake from actor-loaded source", async () => { + sourceServer = await startSourceServer(DYNAMIC_SOURCE); + process.env.RIVETKIT_DYNAMIC_TEST_SOURCE_URL = sourceServer.url; + process.env.RIVETKIT_DYNAMIC_SECURE_EXEC_SPECIFIER = pathToFileURL( + SECURE_EXEC_DIST_PATH, + ).href; + + const runtime = await createDynamicEngineRuntime(); + const client = createClient({ + endpoint: runtime.endpoint, + namespace: runtime.namespace, + runnerName: runtime.runnerName, + encoding: "json", + disableMetadataLookup: true, + }); + + let ws: WebSocket | undefined; + + try { + const actor = client.dynamicFromActor.getOrCreate([ + "actor-loader", + ]) as unknown as DynamicHandle; + + expect(await actor.increment(1)).toBe(1); + + await actor.putText("prefix-a", "alpha"); + await actor.putText("prefix-b", "beta"); + expect(await actor.getText("prefix-a")).toBe("alpha"); + expect( + (await actor.listText("prefix-")).sort((a, b) => + a.key.localeCompare(b.key), + ), + ).toEqual([ + { key: "prefix-a", value: "alpha" }, + { key: "prefix-b", value: "beta" }, + ]); + + ws = await actor.webSocket(); + const welcome = await readWebSocketJson(ws); + expect(welcome).toMatchObject({ type: "welcome" }); + ws.send(JSON.stringify({ type: "ping" })); + expect(await readWebSocketJson(ws)).toEqual({ type: "pong" }); + ws.close(); + ws = undefined; + + const beforeSleep = await actor.getState(); + await actor.triggerSleep(); + await wait(350); + + const afterSleep = await actor.getState(); + expect(afterSleep.sleepCount).toBeGreaterThanOrEqual( + beforeSleep.sleepCount + 1, + ); + expect(afterSleep.wakeCount).toBeGreaterThanOrEqual( + beforeSleep.wakeCount + 1, + ); + + const beforeAlarm = await actor.getState(); + await actor.scheduleAlarm(500); + await wait(900); + + const afterAlarm = await actor.getState(); + expect(afterAlarm.alarmCount).toBeGreaterThanOrEqual( + beforeAlarm.alarmCount + 1, + ); + expect(afterAlarm.sleepCount).toBeGreaterThanOrEqual( + beforeAlarm.sleepCount + 1, + ); + expect(afterAlarm.wakeCount).toBeGreaterThanOrEqual( + beforeAlarm.wakeCount + 1, + ); + } finally { + ws?.close(); + await client.dispose(); + await runtime.cleanup(); + } + }, 180_000); + }, +); + +async function createDynamicEngineRuntime() { + return await createTestRuntime( + join(__dirname, "../fixtures/driver-test-suite/dynamic-registry.ts"), + async (registry) => { + const endpoint = process.env.RIVET_ENDPOINT || "http://127.0.0.1:6420"; + const namespaceEndpoint = + process.env.RIVET_NAMESPACE_ENDPOINT || + process.env.RIVET_API_ENDPOINT || + endpoint; + const namespace = `test-${crypto.randomUUID().slice(0, 8)}`; + const runnerName = "test-runner"; + const token = "dev"; + + const response = await fetch(`${namespaceEndpoint}/namespaces`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: "Bearer dev", + }, + body: JSON.stringify({ + name: namespace, + display_name: namespace, + }), + }); + if (!response.ok) { + const errorBody = await response.text().catch(() => ""); + throw new Error( + `Create namespace failed at ${namespaceEndpoint}: ${response.status} ${response.statusText} ${errorBody}`, + ); + } + + const driverConfig = createEngineDriver(); + registry.config.driver = driverConfig; + registry.config.endpoint = endpoint; + registry.config.namespace = namespace; + registry.config.token = token; + registry.config.runner = { + ...registry.config.runner, + runnerName, + }; + + const parsedConfig = registry.parseConfig(); + const managerDriver = driverConfig.manager?.(parsedConfig); + invariant(managerDriver, "missing manager driver"); + const inlineClient = createClientWithDriver( + managerDriver, + convertRegistryConfigToClientConfig(parsedConfig), + ); + const actorDriver = driverConfig.actor( + parsedConfig, + managerDriver, + inlineClient, + ); + + const runnersUrl = new URL(`${endpoint.replace(/\/$/, "")}/runners`); + runnersUrl.searchParams.set("namespace", namespace); + runnersUrl.searchParams.set("name", runnerName); + let probeError: unknown; + for (let attempt = 0; attempt < 120; attempt++) { + try { + const runnerResponse = await fetch(runnersUrl, { + method: "GET", + headers: { Authorization: `Bearer ${token}` }, + }); + if (!runnerResponse.ok) { + const errorBody = await runnerResponse.text().catch(() => ""); + probeError = new Error( + `List runners failed: ${runnerResponse.status} ${runnerResponse.statusText} ${errorBody}`, + ); + } else { + const responseJson = (await runnerResponse.json()) as { + runners?: Array<{ name?: string }>; + }; + const hasRunner = !!responseJson.runners?.some( + (runner) => runner.name === runnerName, + ); + if (hasRunner) { + probeError = undefined; + break; + } + probeError = new Error( + `Runner ${runnerName} not registered yet`, + ); + } + } catch (err) { + probeError = err; + } + if (attempt < 119) { + await new Promise((resolve) => setTimeout(resolve, 100)); + } + } + if (probeError) { + throw probeError; + } + + return { + rivetEngine: { + endpoint, + namespace, + runnerName, + token, + }, + driver: driverConfig, + cleanup: async () => { + await actorDriver.shutdownRunner?.(true); + }, + }; + }, + ); +} + +async function startSourceServer(source: string): Promise<{ + url: string; + close: () => Promise; +}> { + const server = createServer((req: IncomingMessage, res: ServerResponse) => { + if (req.url !== "/source.ts") { + res.writeHead(404); + res.end("not found"); + return; + } + + res.writeHead(200, { + "content-type": "text/plain; charset=utf-8", + }); + res.end(source); + }); + + await new Promise((resolve) => server.listen(0, "127.0.0.1", resolve)); + const address = server.address(); + if (!address || typeof address === "string") { + throw new Error("failed to get dynamic source server address"); + } + + return { + url: `http://127.0.0.1:${address.port}/source.ts`, + close: async () => { + await new Promise((resolve, reject) => { + server.close((error) => { + if (error) { + reject(error); + return; + } + resolve(); + }); + }); + }, + }; +} + +async function readWebSocketJson(websocket: WebSocket): Promise { + const message = await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + reject(new Error("timed out waiting for websocket message")); + }, 5_000); + + websocket.addEventListener( + "message", + (event) => { + clearTimeout(timeout); + resolve(String(event.data)); + }, + { once: true }, + ); + websocket.addEventListener( + "error", + (event: Event) => { + clearTimeout(timeout); + reject(event); + }, + { once: true }, + ); + websocket.addEventListener( + "close", + () => { + clearTimeout(timeout); + reject(new Error("websocket closed")); + }, + { once: true }, + ); + }); + + return JSON.parse(message); +} + +async function wait(duration: number): Promise { + await new Promise((resolve) => setTimeout(resolve, duration)); +} diff --git a/shared/typescript/virtual-websocket/src/virtual-websocket.ts b/shared/typescript/virtual-websocket/src/virtual-websocket.ts index 6dc5c616c4..993bd45b26 100644 --- a/shared/typescript/virtual-websocket/src/virtual-websocket.ts +++ b/shared/typescript/virtual-websocket/src/virtual-websocket.ts @@ -194,10 +194,11 @@ export class VirtualWebSocket implements UniversalWebSocket { this.#dispatch("open", event); } - triggerMessage(data: any): void { + triggerMessage(data: any, rivetMessageIndex?: number): void { const event = { type: "message", data, + rivetMessageIndex, target: this, currentTarget: this, } as unknown as RivetMessageEvent;