diff --git a/src/launchpad/kafka.py b/src/launchpad/kafka.py index 5eb39c92..720fa79c 100644 --- a/src/launchpad/kafka.py +++ b/src/launchpad/kafka.py @@ -17,6 +17,7 @@ from arroyo import Message, Topic, configure_metrics from arroyo.backends.kafka import KafkaConsumer as ArroyoKafkaConsumer from arroyo.backends.kafka import KafkaPayload +from arroyo.dlq import InvalidMessage from arroyo.processing.processor import StreamProcessor from arroyo.processing.strategies import ProcessingStrategy, ProcessingStrategyFactory from arroyo.processing.strategies.commit import CommitOffsets @@ -62,19 +63,92 @@ def __init__( self, function: Callable[[Message[TStrategyPayload]], Any], next_step: ProcessingStrategy[FilteredPayload | Any], - max_batch_size: int, - max_batch_time: float, pool: MultiprocessingPool, input_block_size: int | None = None, output_block_size: int | None = None, + batch_timeout: float | None = None, ) -> None: - super().__init__(function, next_step, max_batch_size, max_batch_time, pool, input_block_size, output_block_size) + super().__init__(function, next_step, 1, 1, pool, input_block_size, output_block_size) + + self._batch_timeout = batch_timeout + self._batch_submit_times: dict[int, float] = {} # Maps batch id to submission timestamp + # Override SIGCHLD handler - child exits are expected with maxtasksperchild=1 signal.signal( signal.SIGCHLD, lambda signum, frame: logger.debug(f"Worker process exited normally (SIGCHLD {signum})"), ) + def poll(self) -> None: + if self._batch_timeout is not None: + try: + self._check_batch_timeouts() + except Exception as e: + logger.error(f"Error checking batch timeouts: {e}", exc_info=True) + + super().poll() + + def _check_batch_timeouts(self) -> None: + """ + Check if any in-flight batches have exceeded timeout and terminate those workers. + + This accesses parent class private members via name mangling. If parent class + changes these member names, this will break. + """ + + try: + processes = self._RunTaskWithMultiprocessing__processes + pool = self._RunTaskWithMultiprocessing__pool + invalid_messages = self._RunTaskWithMultiprocessing__invalid_messages + except AttributeError: + logger.exception("Failed to access parent class private members - please check if the Arroyo API changed:") + return + + # Track batch times to know which batch is timed out + # Make sure to clean up finished/timed out batches so this doesn't grow unbounded + if processes: + current_batch_ids = {id(batch) for batch in processes} + completed_batch_ids = [bid for bid in self._batch_submit_times if bid not in current_batch_ids] + for batch_id in completed_batch_ids: + del self._batch_submit_times[batch_id] + else: + self._batch_submit_times.clear() + return + + # Check the first batch, we only support batchsize=1 right now + first_batch = processes[0] + batch_id = id(first_batch) + + if batch_id not in self._batch_submit_times: + self._batch_submit_times[batch_id] = time.time() + return + + elapsed = time.time() - self._batch_submit_times[batch_id] + + if elapsed > self._batch_timeout: + logger.error(f"Batch exceeded timeout of {self._batch_timeout}s (elapsed={elapsed:.2f}s).") + + input_batch = first_batch[0] + + pool.close() # Terminates all workers + pool.maybe_create_pool() # Recreate fresh pool + + # Remove timed-out batch from queue + processes.popleft() + + # Convert batch messages to InvalidMessages for DLQ + for idx, message in input_batch: + invalid_msg = InvalidMessage( + message.value.partition, + message.value.offset, + reason=f"Batch processing exceeded {self._batch_timeout}s timeout", + ) + invalid_messages.append(invalid_msg) + + del self._batch_submit_times[batch_id] + + logger.info(f"Terminated worker and sent {len(input_batch)} messages to DLQ") + def process_kafka_message_with_service(msg: Message[KafkaPayload]) -> Any: """Process a Kafka message using the actual service logic in a worker process.""" @@ -134,6 +208,7 @@ def create_kafka_consumer() -> LaunchpadKafkaConsumer: topics = [Topic(topic) for topic in config.topics] topic = topics[0] if topics else Topic("default") + processor = StreamProcessor( consumer=arroyo_consumer, topic=topic, @@ -262,11 +337,10 @@ def create_with_partitions( strategy = LaunchpadRunTaskWithMultiprocessing( process_kafka_message_with_service, next_step=next_step, - max_batch_size=1, # Process immediately, subject to be re-tuned - max_batch_time=1, # Process after 1 second max, subject to be re-tuned pool=self._pool, input_block_size=None, output_block_size=None, + batch_timeout=60.0 * 12, # 12 minutes ) return strategy