Skip to content

Commit ca9ead6

Browse files
Add check for context run (#601)
1 parent 6f437f4 commit ca9ead6

File tree

10 files changed

+257
-65
lines changed

10 files changed

+257
-65
lines changed

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt

Lines changed: 81 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import dev.restate.sdk.common.HandlerRequest
1616
import dev.restate.sdk.common.StateKey
1717
import dev.restate.sdk.common.TerminalException
1818
import dev.restate.sdk.endpoint.definition.HandlerContext
19+
import dev.restate.sdk.kotlin.internal.InsideRunElement
20+
import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun
1921
import dev.restate.serde.Serde
2022
import dev.restate.serde.SerdeFactory
2123
import dev.restate.serde.TypeTag
@@ -31,6 +33,7 @@ internal constructor(
3133
internal val handlerContext: HandlerContext,
3234
internal val contextSerdeFactory: SerdeFactory,
3335
) : WorkflowContext {
36+
3437
override fun key(): String {
3538
return this.handlerContext.objectKey()
3639
}
@@ -39,75 +42,89 @@ internal constructor(
3942
return this.handlerContext.request()
4043
}
4144

42-
override suspend fun <T : Any> get(key: StateKey<T>): T? =
43-
resolveSerde<T?>(key.serdeInfo())
44-
.let { serde ->
45-
SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap {
46-
it.getOrNull()?.let { serde.deserialize(it) }
47-
}
45+
override suspend fun <T : Any> get(key: StateKey<T>): T? {
46+
checkNotInsideRun()
47+
return resolveSerde<T?>(key.serdeInfo())
48+
.let { serde ->
49+
SingleDurableFutureImpl(handlerContext.get(key.name()).await()).simpleMap {
50+
it.getOrNull()?.let { serde.deserialize(it) }
4851
}
49-
.await()
52+
}
53+
.await()
54+
}
5055

51-
override suspend fun stateKeys(): Collection<String> =
52-
SingleDurableFutureImpl(handlerContext.getKeys().await()).await()
56+
override suspend fun stateKeys(): Collection<String> {
57+
checkNotInsideRun()
58+
return SingleDurableFutureImpl(handlerContext.getKeys().await()).await()
59+
}
5360

5461
override suspend fun <T : Any> set(key: StateKey<T>, value: T) {
62+
checkNotInsideRun()
5563
handlerContext.set(key.name(), resolveAndSerialize(key.serdeInfo(), value)).await()
5664
}
5765

5866
override suspend fun clear(key: StateKey<*>) {
67+
checkNotInsideRun()
5968
handlerContext.clear(key.name()).await()
6069
}
6170

6271
override suspend fun clearAll() {
72+
checkNotInsideRun()
6373
handlerContext.clearAll().await()
6474
}
6575

66-
override suspend fun timer(duration: Duration, name: String?): DurableFuture<Unit> =
67-
SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await()).map {}
76+
override suspend fun timer(duration: Duration, name: String?): DurableFuture<Unit> {
77+
checkNotInsideRun()
78+
return SingleDurableFutureImpl(handlerContext.timer(duration.toJavaDuration(), name).await())
79+
.map {}
80+
}
6881

6982
override suspend fun <Req : Any?, Res : Any?> call(
7083
request: Request<Req, Res>
71-
): CallDurableFuture<Res> =
72-
resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
73-
val callHandle =
74-
handlerContext
75-
.call(
76-
request.getTarget(),
77-
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
78-
request.getIdempotencyKey(),
79-
request.getHeaders()?.entries,
80-
)
81-
.await()
82-
83-
val callAsyncResult =
84-
callHandle.callAsyncResult.map {
85-
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
86-
}
84+
): CallDurableFuture<Res> {
85+
checkNotInsideRun()
86+
return resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
87+
val callHandle =
88+
handlerContext
89+
.call(
90+
request.getTarget(),
91+
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
92+
request.getIdempotencyKey(),
93+
request.getHeaders()?.entries,
94+
)
95+
.await()
8796

88-
return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult)
89-
}
97+
val callAsyncResult =
98+
callHandle.callAsyncResult.map {
99+
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
100+
}
101+
102+
return@let CallDurableFutureImpl(callAsyncResult, callHandle.invocationIdAsyncResult)
103+
}
104+
}
90105

91106
override suspend fun <Req : Any?, Res : Any?> send(
92107
request: Request<Req, Res>,
93108
delay: Duration?,
94-
): InvocationHandle<Res> =
95-
resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
96-
val invocationIdAsyncResult =
97-
handlerContext
98-
.send(
99-
request.getTarget(),
100-
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
101-
request.getIdempotencyKey(),
102-
request.getHeaders()?.entries,
103-
delay?.toJavaDuration(),
104-
)
105-
.await()
109+
): InvocationHandle<Res> {
110+
checkNotInsideRun()
111+
return resolveSerde<Res>(request.getResponseTypeTag()).let { responseSerde ->
112+
val invocationIdAsyncResult =
113+
handlerContext
114+
.send(
115+
request.getTarget(),
116+
resolveAndSerialize<Req>(request.getRequestTypeTag(), request.getRequest()),
117+
request.getIdempotencyKey(),
118+
request.getHeaders()?.entries,
119+
delay?.toJavaDuration(),
120+
)
121+
.await()
106122

107-
object : BaseInvocationHandle<Res>(handlerContext, responseSerde) {
108-
override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await()
109-
}
123+
object : BaseInvocationHandle<Res>(handlerContext, responseSerde) {
124+
override suspend fun invocationId(): String = invocationIdAsyncResult.poll().await()
110125
}
126+
}
127+
}
111128

112129
override fun <Res> invocationHandle(
113130
invocationId: String,
@@ -125,6 +142,7 @@ internal constructor(
125142
retryPolicy: RetryPolicy?,
126143
block: suspend () -> T,
127144
): DurableFuture<T> {
145+
checkNotInsideRun()
128146
val serde: Serde<T> = resolveSerde(typeTag)
129147
val coroutineCtx = currentCoroutineContext()
130148
val javaRetryPolicy =
@@ -138,7 +156,10 @@ internal constructor(
138156
.setMaxDuration(it.maxDuration?.toJavaDuration())
139157
}
140158

141-
val scope = CoroutineScope(coroutineCtx + CoroutineName("restate-run-$name"))
159+
val scope =
160+
CoroutineScope(
161+
coroutineCtx + CoroutineName("restate-run-$name") + InsideRunElement.INSTANCE
162+
)
142163

143164
val asyncResult =
144165
handlerContext
@@ -159,6 +180,7 @@ internal constructor(
159180
}
160181

161182
override suspend fun <T : Any> awakeable(typeTag: TypeTag<T>): Awakeable<T> {
183+
checkNotInsideRun()
162184
val serde: Serde<T> = resolveSerde(typeTag)
163185
val awk = handlerContext.awakeable().await()
164186
return AwakeableImpl(awk.asyncResult, serde, awk.id)
@@ -184,22 +206,27 @@ internal constructor(
184206
DurablePromise<T> {
185207
val serde: Serde<T> = resolveSerde(key.serdeInfo())
186208

187-
override suspend fun future(): DurableFuture<T> =
188-
SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap {
189-
serde.deserialize(it)
190-
}
209+
override suspend fun future(): DurableFuture<T> {
210+
checkNotInsideRun()
211+
return SingleDurableFutureImpl(handlerContext.promise(key.name()).await()).simpleMap {
212+
serde.deserialize(it)
213+
}
214+
}
191215

192-
override suspend fun peek(): Output<T> =
193-
SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await())
194-
.simpleMap { it.map { serde.deserialize(it) } }
195-
.await()
216+
override suspend fun peek(): Output<T> {
217+
checkNotInsideRun()
218+
return SingleDurableFutureImpl(handlerContext.peekPromise(key.name()).await())
219+
.simpleMap { it.map { serde.deserialize(it) } }
220+
.await()
221+
}
196222
}
197223

198224
inner class DurablePromiseHandleImpl<T : Any>(private val key: DurablePromiseKey<T>) :
199225
DurablePromiseHandle<T> {
200226
val serde: Serde<T> = resolveSerde(key.serdeInfo())
201227

202228
override suspend fun resolve(payload: T) {
229+
checkNotInsideRun()
203230
SingleDurableFutureImpl(
204231
handlerContext
205232
.resolvePromise(
@@ -212,6 +239,7 @@ internal constructor(
212239
}
213240

214241
override suspend fun reject(reason: String) {
242+
checkNotInsideRun()
215243
SingleDurableFutureImpl(
216244
handlerContext.rejectPromise(key.name(), TerminalException(reason)).await()
217245
)

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/futures.kt

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import dev.restate.sdk.common.TerminalException
1414
import dev.restate.sdk.common.TimeoutException
1515
import dev.restate.sdk.endpoint.definition.AsyncResult
1616
import dev.restate.sdk.endpoint.definition.HandlerContext
17+
import dev.restate.sdk.kotlin.internal.InsideRunElement.Key.checkNotInsideRun
1718
import dev.restate.serde.Serde
1819
import dev.restate.serde.TypeTag
1920
import java.util.concurrent.CompletableFuture
@@ -32,6 +33,7 @@ internal abstract class BaseDurableFutureImpl<T : Any?> : DurableFuture<T> {
3233
get() = SelectClauseImpl(this)
3334

3435
override suspend fun await(): T {
36+
checkNotInsideRun()
3537
return asyncResult().poll().await()
3638
}
3739

@@ -193,20 +195,25 @@ internal constructor(
193195
private val responseSerde: Serde<Res>,
194196
) : InvocationHandle<Res> {
195197
override suspend fun cancel() {
198+
checkNotInsideRun()
196199
val ignored = handlerContext.cancelInvocation(invocationId()).await()
197200
}
198201

199-
override suspend fun attach(): DurableFuture<Res> =
200-
SingleDurableFutureImpl(
201-
handlerContext.attachInvocation(invocationId()).await().map {
202-
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
203-
}
204-
)
202+
override suspend fun attach(): DurableFuture<Res> {
203+
checkNotInsideRun()
204+
return SingleDurableFutureImpl(
205+
handlerContext.attachInvocation(invocationId()).await().map {
206+
CompletableFuture.completedFuture<Res>(responseSerde.deserialize(it))
207+
}
208+
)
209+
}
205210

206-
override suspend fun output(): Output<Res> =
207-
SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await())
208-
.simpleMap { it.map { responseSerde.deserialize(it) } }
209-
.await()
211+
override suspend fun output(): Output<Res> {
212+
checkNotInsideRun()
213+
return SingleDurableFutureImpl(handlerContext.getInvocationOutput(invocationId()).await())
214+
.simpleMap { it.map { responseSerde.deserialize(it) } }
215+
.await()
216+
}
210217
}
211218

212219
internal class AwakeableImpl<T : Any?>
@@ -218,13 +225,14 @@ internal constructor(asyncResult: AsyncResult<Slice>, serde: Serde<T>, override
218225

219226
internal class AwakeableHandleImpl(val contextImpl: ContextImpl, val id: String) : AwakeableHandle {
220227
override suspend fun <T : Any> resolve(typeTag: TypeTag<T>, payload: T) {
228+
checkNotInsideRun()
221229
contextImpl.handlerContext
222230
.resolveAwakeable(id, contextImpl.resolveAndSerialize(typeTag, payload))
223231
.await()
224232
}
225233

226234
override suspend fun reject(reason: String) {
227-
return
235+
checkNotInsideRun()
228236
contextImpl.handlerContext.rejectAwakeable(id, TerminalException(reason)).await()
229237
}
230238
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
2+
//
3+
// This file is part of the Restate Java SDK,
4+
// which is released under the MIT license.
5+
//
6+
// You can find a copy of the license in file LICENSE in the root
7+
// directory of this repository or package, or at
8+
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
9+
package dev.restate.sdk.kotlin.internal
10+
11+
import kotlin.coroutines.AbstractCoroutineContextElement
12+
import kotlin.coroutines.CoroutineContext
13+
import kotlinx.coroutines.currentCoroutineContext
14+
15+
/**
16+
* Coroutine context element that marks the current coroutine as executing inside a `ctx.run()`
17+
* block. Context methods check for this element and throw [IllegalStateException] if present.
18+
*/
19+
internal class InsideRunElement private constructor() : AbstractCoroutineContextElement(Key) {
20+
companion object Key : CoroutineContext.Key<InsideRunElement> {
21+
val INSTANCE = InsideRunElement()
22+
23+
suspend fun checkNotInsideRun() {
24+
if (currentCoroutineContext()[Key] != null) {
25+
throw IllegalStateException(
26+
"Cannot invoke context method inside ctx.run(). " +
27+
"The run closure is meant for non-deterministic operations (e.g., HTTP calls, database reads). " +
28+
"You MUST use context methods outside of ctx.run(), check the documentation: https://docs.restate.dev/develop/java/durable-steps#run"
29+
)
30+
}
31+
}
32+
}
33+
}

0 commit comments

Comments
 (0)