Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 79 additions & 5 deletions src/launchpad/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from arroyo import Message, Topic, configure_metrics
from arroyo.backends.kafka import KafkaConsumer as ArroyoKafkaConsumer
from arroyo.backends.kafka import KafkaPayload
from arroyo.dlq import InvalidMessage
from arroyo.processing.processor import StreamProcessor
from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory
from arroyo.processing.strategies.commit import CommitOffsets
Expand Down Expand Up @@ -62,19 +63,92 @@ def __init__(
self,
function: Callable[[Message[TStrategyPayload]], Any],
next_step: ProcessingStrategy[FilteredPayload | Any],
max_batch_size: int,
max_batch_time: float,
pool: MultiprocessingPool,
input_block_size: int | None = None,
output_block_size: int | None = None,
batch_timeout: float | None = None,
) -> None:
super().__init__(function, next_step, max_batch_size, max_batch_time, pool, input_block_size, output_block_size)
super().__init__(function, next_step, 1, 1, pool, input_block_size, output_block_size)

self._batch_timeout = batch_timeout
self._batch_submit_times: dict[int, float] = {} # Maps batch id to submission timestamp

# Override SIGCHLD handler - child exits are expected with maxtasksperchild=1
signal.signal(
signal.SIGCHLD,
lambda signum, frame: logger.debug(f"Worker process exited normally (SIGCHLD {signum})"),
)

def poll(self) -> None:
if self._batch_timeout is not None:
try:
self._check_batch_timeouts()
except Exception as e:
logger.error(f"Error checking batch timeouts: {e}", exc_info=True)

super().poll()

def _check_batch_timeouts(self) -> None:
"""
Check if any in-flight batches have exceeded timeout and terminate those workers.

This accesses parent class private members via name mangling. If parent class
changes these member names, this will break.
"""

try:
processes = self._RunTaskWithMultiprocessing__processes
pool = self._RunTaskWithMultiprocessing__pool
invalid_messages = self._RunTaskWithMultiprocessing__invalid_messages
except AttributeError:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is brittle but we'll have alerting if something breaks around this. Timed out messages are pretty rare and should have plenty of time to react in this case.

logger.exception("Failed to access parent class private members - please check if the Arroyo API changed:")
return

# Track batch times to know which batch is timed out
# Make sure to clean up finished/timed out batches so this doesn't grow unbounded
if processes:
current_batch_ids = {id(batch) for batch in processes}
completed_batch_ids = [bid for bid in self._batch_submit_times if bid not in current_batch_ids]
for batch_id in completed_batch_ids:
del self._batch_submit_times[batch_id]
else:
self._batch_submit_times.clear()
return

# Check the first batch, we only support batchsize=1 right now
first_batch = processes[0]
batch_id = id(first_batch)

if batch_id not in self._batch_submit_times:
self._batch_submit_times[batch_id] = time.time()
return

elapsed = time.time() - self._batch_submit_times[batch_id]

if elapsed > self._batch_timeout:
logger.error(f"Batch exceeded timeout of {self._batch_timeout}s (elapsed={elapsed:.2f}s).")

input_batch = first_batch[0]

pool.close() # Terminates all workers
pool.maybe_create_pool() # Recreate fresh pool

# Remove timed-out batch from queue
processes.popleft()

# Convert batch messages to InvalidMessages for DLQ
for idx, message in input_batch:
invalid_msg = InvalidMessage(
message.value.partition,
message.value.offset,
reason=f"Batch processing exceeded {self._batch_timeout}s timeout",
)
invalid_messages.append(invalid_msg)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part I wasn't entirely sure about, but I didn't think I could raise InvalidMessage from this part of the code.


del self._batch_submit_times[batch_id]

logger.info(f"Terminated worker and sent {len(input_batch)} messages to DLQ")


def process_kafka_message_with_service(msg: Message[KafkaPayload]) -> Any:
"""Process a Kafka message using the actual service logic in a worker process."""
Expand Down Expand Up @@ -134,6 +208,7 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer:

topics = [Topic(topic) for topic in config.topics]
topic = topics[0] if topics else Topic("default")

processor = StreamProcessor(
consumer=arroyo_consumer,
topic=topic,
Expand Down Expand Up @@ -262,11 +337,10 @@ def create_with_partitions(
strategy = LaunchpadRunTaskWithMultiprocessing(
process_kafka_message_with_service,
next_step=next_step,
max_batch_size=1, # Process immediately, subject to be re-tuned
max_batch_time=1, # Process after 1 second max, subject to be re-tuned
pool=self._pool,
input_block_size=None,
output_block_size=None,
batch_timeout=60.0 * 12, # 12 minutes
)

return strategy
Expand Down
Loading