@@ -16,6 +16,8 @@ import dev.restate.sdk.common.HandlerRequest
1616import dev.restate.sdk.common.StateKey
1717import dev.restate.sdk.common.TerminalException
1818import dev.restate.sdk.endpoint.definition.HandlerContext
19+ import dev.restate.sdk.kotlin.internal.InsideRunElement
20+ import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun
1921import dev.restate.serde.Serde
2022import dev.restate.serde.SerdeFactory
2123import dev.restate.serde.TypeTag
@@ -31,6 +33,7 @@ internal constructor(
3133 internal val handlerContext: HandlerContext ,
3234 internal val contextSerdeFactory: SerdeFactory ,
3335) : WorkflowContext {
36+
3437 override fun key (): String {
3538 return this .handlerContext.objectKey()
3639 }
@@ -39,75 +42,89 @@ internal constructor(
3942 return this .handlerContext.request()
4043 }
4144
42- override suspend fun <T : Any > get (key : StateKey <T >): T ? =
43- resolveSerde< T ?>(key.serdeInfo() )
44- . let { serde ->
45- SingleDurableFutureImpl (handlerContext.get(key.name()).await()).simpleMap {
46- it.getOrNull()?. let { serde.deserialize(it) }
47- }
45+ override suspend fun <T : Any > get (key : StateKey <T >): T ? {
46+ checkNotInsideRun( )
47+ return resolveSerde< T ?>(key.serdeInfo())
48+ . let { serde ->
49+ SingleDurableFutureImpl (handlerContext.get(key.name()).await()).simpleMap {
50+ it.getOrNull()?. let { serde.deserialize(it) }
4851 }
49- .await()
52+ }
53+ .await()
54+ }
5055
51- override suspend fun stateKeys (): Collection <String > =
52- SingleDurableFutureImpl (handlerContext.getKeys().await()).await()
56+ override suspend fun stateKeys (): Collection <String > {
57+ checkNotInsideRun()
58+ return SingleDurableFutureImpl (handlerContext.getKeys().await()).await()
59+ }
5360
5461 override suspend fun <T : Any > set (key : StateKey <T >, value : T ) {
62+ checkNotInsideRun()
5563 handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await()
5664 }
5765
5866 override suspend fun clear (key : StateKey <* >) {
67+ checkNotInsideRun()
5968 handlerContext.clear(key.name()).await()
6069 }
6170
6271 override suspend fun clearAll () {
72+ checkNotInsideRun()
6373 handlerContext.clearAll().await()
6474 }
6575
66- override suspend fun timer (duration : Duration , name : String? ): DurableFuture <Unit > =
67- SingleDurableFutureImpl (handlerContext.timer(duration.toJavaDuration(), name).await()).map {}
76+ override suspend fun timer (duration : Duration , name : String? ): DurableFuture <Unit > {
77+ checkNotInsideRun()
78+ return SingleDurableFutureImpl (handlerContext.timer(duration.toJavaDuration(), name).await())
79+ .map {}
80+ }
6881
6982 override suspend fun <Req : Any ?, Res : Any ?> call (
7083 request : Request <Req , Res >
71- ): CallDurableFuture <Res > =
72- resolveSerde<Res >(request.getResponseTypeTag()).let { responseSerde ->
73- val callHandle =
74- handlerContext
75- .call(
76- request.getTarget(),
77- resolveAndSerialize<Req >(request.getRequestTypeTag(), request.getRequest()),
78- request.getIdempotencyKey(),
79- request.getHeaders()?.entries,
80- )
81- .await()
82-
83- val callAsyncResult =
84- callHandle.callAsyncResult.map {
85- CompletableFuture .completedFuture<Res >(responseSerde.deserialize(it))
86- }
84+ ): CallDurableFuture <Res > {
85+ checkNotInsideRun()
86+ return resolveSerde<Res >(request.getResponseTypeTag()).let { responseSerde ->
87+ val callHandle =
88+ handlerContext
89+ .call(
90+ request.getTarget(),
91+ resolveAndSerialize<Req >(request.getRequestTypeTag(), request.getRequest()),
92+ request.getIdempotencyKey(),
93+ request.getHeaders()?.entries,
94+ )
95+ .await()
8796
88- return @let CallDurableFutureImpl (callAsyncResult, callHandle.invocationIdAsyncResult)
89- }
97+ val callAsyncResult =
98+ callHandle.callAsyncResult.map {
99+ CompletableFuture .completedFuture<Res >(responseSerde.deserialize(it))
100+ }
101+
102+ return @let CallDurableFutureImpl (callAsyncResult, callHandle.invocationIdAsyncResult)
103+ }
104+ }
90105
91106 override suspend fun <Req : Any ?, Res : Any ?> send (
92107 request : Request <Req , Res >,
93108 delay : Duration ? ,
94- ): InvocationHandle <Res > =
95- resolveSerde<Res >(request.getResponseTypeTag()).let { responseSerde ->
96- val invocationIdAsyncResult =
97- handlerContext
98- .send(
99- request.getTarget(),
100- resolveAndSerialize<Req >(request.getRequestTypeTag(), request.getRequest()),
101- request.getIdempotencyKey(),
102- request.getHeaders()?.entries,
103- delay?.toJavaDuration(),
104- )
105- .await()
109+ ): InvocationHandle <Res > {
110+ checkNotInsideRun()
111+ return resolveSerde<Res >(request.getResponseTypeTag()).let { responseSerde ->
112+ val invocationIdAsyncResult =
113+ handlerContext
114+ .send(
115+ request.getTarget(),
116+ resolveAndSerialize<Req >(request.getRequestTypeTag(), request.getRequest()),
117+ request.getIdempotencyKey(),
118+ request.getHeaders()?.entries,
119+ delay?.toJavaDuration(),
120+ )
121+ .await()
106122
107- object : BaseInvocationHandle <Res >(handlerContext, responseSerde) {
108- override suspend fun invocationId (): String = invocationIdAsyncResult.poll().await()
109- }
123+ object : BaseInvocationHandle <Res >(handlerContext, responseSerde) {
124+ override suspend fun invocationId (): String = invocationIdAsyncResult.poll().await()
110125 }
126+ }
127+ }
111128
112129 override fun <Res > invocationHandle (
113130 invocationId : String ,
@@ -125,6 +142,7 @@ internal constructor(
125142 retryPolicy : RetryPolicy ? ,
126143 block : suspend () -> T ,
127144 ): DurableFuture <T > {
145+ checkNotInsideRun()
128146 val serde: Serde <T > = resolveSerde(typeTag)
129147 val coroutineCtx = currentCoroutineContext()
130148 val javaRetryPolicy =
@@ -138,7 +156,10 @@ internal constructor(
138156 .setMaxDuration(it.maxDuration?.toJavaDuration())
139157 }
140158
141- val scope = CoroutineScope (coroutineCtx + CoroutineName (" restate-run-$name " ))
159+ val scope =
160+ CoroutineScope (
161+ coroutineCtx + CoroutineName (" restate-run-$name " ) + InsideRunElement .INSTANCE
162+ )
142163
143164 val asyncResult =
144165 handlerContext
@@ -159,6 +180,7 @@ internal constructor(
159180 }
160181
161182 override suspend fun <T : Any > awakeable (typeTag : TypeTag <T >): Awakeable <T > {
183+ checkNotInsideRun()
162184 val serde: Serde <T > = resolveSerde(typeTag)
163185 val awk = handlerContext.awakeable().await()
164186 return AwakeableImpl (awk.asyncResult, serde, awk.id)
@@ -184,22 +206,27 @@ internal constructor(
184206 DurablePromise <T > {
185207 val serde: Serde <T > = resolveSerde(key.serdeInfo())
186208
187- override suspend fun future (): DurableFuture <T > =
188- SingleDurableFutureImpl (handlerContext.promise(key.name()).await()).simpleMap {
189- serde.deserialize(it)
190- }
209+ override suspend fun future (): DurableFuture <T > {
210+ checkNotInsideRun()
211+ return SingleDurableFutureImpl (handlerContext.promise(key.name()).await()).simpleMap {
212+ serde.deserialize(it)
213+ }
214+ }
191215
192- override suspend fun peek (): Output <T > =
193- SingleDurableFutureImpl (handlerContext.peekPromise(key.name()).await())
194- .simpleMap { it.map { serde.deserialize(it) } }
195- .await()
216+ override suspend fun peek (): Output <T > {
217+ checkNotInsideRun()
218+ return SingleDurableFutureImpl (handlerContext.peekPromise(key.name()).await())
219+ .simpleMap { it.map { serde.deserialize(it) } }
220+ .await()
221+ }
196222 }
197223
198224 inner class DurablePromiseHandleImpl <T : Any >(private val key : DurablePromiseKey <T >) :
199225 DurablePromiseHandle <T > {
200226 val serde: Serde <T > = resolveSerde(key.serdeInfo())
201227
202228 override suspend fun resolve (payload : T ) {
229+ checkNotInsideRun()
203230 SingleDurableFutureImpl (
204231 handlerContext
205232 .resolvePromise(
@@ -212,6 +239,7 @@ internal constructor(
212239 }
213240
214241 override suspend fun reject (reason : String ) {
242+ checkNotInsideRun()
215243 SingleDurableFutureImpl (
216244 handlerContext.rejectPromise(key.name(), TerminalException (reason)).await()
217245 )
0 commit comments