diff --git a/src/main/java/com/cohere/api/core/Stream.java b/src/main/java/com/cohere/api/core/Stream.java index 3ee1c8a..b950383 100644 --- a/src/main/java/com/cohere/api/core/Stream.java +++ b/src/main/java/com/cohere/api/core/Stream.java @@ -224,74 +224,33 @@ private boolean readNextMessage() { try { while (sseScanner.hasNextLine()) { String line = sseScanner.nextLine(); - - if (line.trim().isEmpty()) { - if (eventDataBuffer.length() > 0) { - try { - nextItem = ObjectMappers.JSON_MAPPER.readValue(eventDataBuffer.toString(), valueType); - hasNextItem = true; - eventDataBuffer.setLength(0); - currentEventType = null; - return true; - } catch (Exception parseEx) { - System.err.println("Failed to parse SSE event: " + parseEx.getMessage()); - eventDataBuffer.setLength(0); - currentEventType = null; - continue; - } - } + if (line == null) continue; + line = line.trim(); + if (line.isEmpty()) { continue; } - if (line.startsWith(DATA_PREFIX)) { - String dataContent = line.substring(DATA_PREFIX.length()); - if (dataContent.startsWith(" ")) { - dataContent = dataContent.substring(1); - } - - if (eventDataBuffer.length() == 0 - && streamTerminator != null - && dataContent.trim().equals(streamTerminator)) { + String dataContent = line.substring(DATA_PREFIX.length()).trim(); + // Check for stream terminator + if (streamTerminator != null && dataContent.equals(streamTerminator)) { endOfStream = true; return false; } - - if (eventDataBuffer.length() > 0) { - eventDataBuffer.append('\n'); - } - eventDataBuffer.append(dataContent); - } else if (line.startsWith("event:")) { - String eventValue = line.length() > 6 ? line.substring(6) : ""; - if (eventValue.startsWith(" ")) { - eventValue = eventValue.substring(1); + if (!dataContent.isEmpty()) { + try { + nextItem = ObjectMappers.JSON_MAPPER.readValue(dataContent, valueType); + hasNextItem = true; + return true; + } catch (Exception parseEx) { + System.err.println("Failed to parse SSE data line: " + parseEx.getMessage()); + continue; + } } - currentEventType = eventValue; - } else if (line.startsWith("id:")) { - // Event ID field (ignored) - } else if (line.startsWith("retry:")) { - // Retry field (ignored) - } else if (line.startsWith(":")) { - // Comment line (ignored) - } - } - - if (eventDataBuffer.length() > 0) { - try { - nextItem = ObjectMappers.JSON_MAPPER.readValue(eventDataBuffer.toString(), valueType); - hasNextItem = true; - eventDataBuffer.setLength(0); - currentEventType = null; - return true; - } catch (Exception parseEx) { - System.err.println("Failed to parse final SSE event: " + parseEx.getMessage()); - eventDataBuffer.setLength(0); - currentEventType = null; } + // Ignore event:, id:, retry:, and comment lines } - endOfStream = true; return false; - } catch (Exception e) { System.err.println("Failed to parse SSE stream: " + e.getMessage()); endOfStream = true; diff --git a/src/test/java/com/cohere/api/StreamTest.java b/src/test/java/com/cohere/api/StreamTest.java index 4de3538..a342902 100644 --- a/src/test/java/com/cohere/api/StreamTest.java +++ b/src/test/java/com/cohere/api/StreamTest.java @@ -72,6 +72,44 @@ public void testStreamResourceManagement() throws IOException { assertFalse(testStream.iterator().hasNext()); } + @Test + public void testSseStreamIgnoresEventLinesAndParsesData() { + // Simulate Cohere SSE stream with event: and data: lines + String sse = String.join("\n", + "event: message-start", + "data: {\"id\":\"id1\",\"type\":\"message-start\",\"delta\":{\"message\":{\"role\":\"assistant\",\"content\":[]}}}", + "", + "event: content-delta", + "data: {\"type\":\"content-delta\",\"index\":0,\"delta\":{\"message\":{\"content\":{\"text\":\"hello\"}}}}", + "", + "event: content-delta", + "data: {\"type\":\"content-delta\",\"index\":1,\"delta\":{\"message\":{\"content\":{\"text\":\" world\"}}}}", + "" + ); + StringReader sseInput = new StringReader(sse); + Stream sseStream = Stream.fromSse(Map.class, sseInput); + int count = 0; + StringBuilder content = new StringBuilder(); + for (Map event : sseStream) { + assertNotNull(event); + assertTrue(event.containsKey("type")); + if ("content-delta".equals(event.get("type"))) { + // Drill down to message.content.text + Map delta = (Map) event.get("delta"); + assertNotNull(delta); + Map message = (Map) delta.get("message"); + assertNotNull(message); + Map contentMap = (Map) message.get("content"); + assertNotNull(contentMap); + String text = (String) contentMap.get("text"); + if (text != null) content.append(text); + } + count++; + } + assertEquals(3, count); // 1 message-start + 2 content-delta + assertEquals("hello world", content.toString()); + } + private static String mapToJson(Map map) { try { return ObjectMappers.JSON_MAPPER.writeValueAsString(map);