diff --git a/.fern/metadata.json b/.fern/metadata.json
new file mode 100644
index 0000000..71d1ec6
--- /dev/null
+++ b/.fern/metadata.json
@@ -0,0 +1,8 @@
+{
+ "cliVersion": "0.112.1",
+ "generatorName": "fernapi/fern-java-sdk",
+ "generatorVersion": "3.14.2",
+ "generatorConfig": {
+ "client-class-name": "Cohere"
+ }
+}
\ No newline at end of file
diff --git a/build.gradle b/build.gradle
index 3c5788d..eb378aa 100644
--- a/build.gradle
+++ b/build.gradle
@@ -20,6 +20,7 @@ dependencies {
api 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.17.2'
testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2'
testImplementation 'org.junit.jupiter:junit-jupiter-engine:5.8.2'
+ testImplementation 'org.junit.jupiter:junit-jupiter-params:5.8.2'
}
@@ -46,7 +47,7 @@ java {
group = 'com.cohere'
-version = '1.9.2'
+version = '1.9.3'
jar {
dependsOn(":generatePomFileForMavenPublication")
@@ -77,7 +78,7 @@ publishing {
maven(MavenPublication) {
groupId = 'com.cohere'
artifactId = 'cohere-java'
- version = '1.9.2'
+ version = '1.9.3'
from components.java
pom {
name = 'cohere'
@@ -120,9 +121,10 @@ sonatypeCentralUpload {
}
signing {
- def signingKeyId = "$System.env.MAVEN_SIGNATURE_SECRET_KEY"
+ def signingKeyId = "$System.env.MAVEN_SIGNATURE_KID"
+ def signingKey = "$System.env.MAVEN_SIGNATURE_SECRET_KEY"
def signingPassword = "$System.env.MAVEN_SIGNATURE_PASSWORD"
- useInMemoryPgpKeys(signingKeyId, signingPassword)
+ useInMemoryPgpKeys(signingKeyId, signingKey, signingPassword)
sign publishing.publications.maven
}
diff --git a/reference.md b/reference.md
index 3468658..578cd07 100644
--- a/reference.md
+++ b/reference.md
@@ -290,7 +290,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**p:** `Optional`
+**p:** `Optional`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
@@ -330,7 +330,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**frequencyPenalty:** `Optional`
+**frequencyPenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -344,7 +344,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**presencePenalty:** `Optional`
+**presencePenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -746,7 +746,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**p:** `Optional`
+**p:** `Optional`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
@@ -786,7 +786,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**frequencyPenalty:** `Optional`
+**frequencyPenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -800,7 +800,7 @@ Compatible Deployments: Cohere Platform, Azure, AWS Sagemaker/Bedrock, Private D
-
-**presencePenalty:** `Optional`
+**presencePenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -1037,7 +1037,7 @@ If `NONE` is selected, when the input exceeds the maximum input token length an
-
-**temperature:** `Optional`
+**temperature:** `Optional`
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
@@ -1100,7 +1100,7 @@ Defaults to `0`, min value of `0`, max value of `500`.
-
-**p:** `Optional`
+**p:** `Optional`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
@@ -1111,7 +1111,7 @@ Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
-**frequencyPenalty:** `Optional`
+**frequencyPenalty:** `Optional`
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
@@ -1123,7 +1123,7 @@ Using `frequency_penalty` in combination with `presence_penalty` is not supporte
-
-**presencePenalty:** `Optional`
+**presencePenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -1288,7 +1288,7 @@ If `NONE` is selected, when the input exceeds the maximum input token length an
-
-**temperature:** `Optional`
+**temperature:** `Optional`
A non-negative float that tunes the degree of randomness in generation. Lower temperatures mean less random generations. See [Temperature](/temperature-wiki) for more details.
Defaults to `0.75`, min value of `0.0`, max value of `5.0`.
@@ -1351,7 +1351,7 @@ Defaults to `0`, min value of `0`, max value of `500`.
-
-**p:** `Optional`
+**p:** `Optional`
Ensures that only the most likely tokens, with total probability mass of `p`, are considered for generation at each step. If both `k` and `p` are enabled, `p` acts after `k`.
Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
@@ -1362,7 +1362,7 @@ Defaults to `0.75`. min value of `0.01`, max value of `0.99`.
-
-**frequencyPenalty:** `Optional`
+**frequencyPenalty:** `Optional`
Used to reduce repetitiveness of generated tokens. The higher the value, the stronger a penalty is applied to previously present tokens, proportional to how many times they have already appeared in the prompt or prior generation.
@@ -1374,7 +1374,7 @@ Using `frequency_penalty` in combination with `presence_penalty` is not supporte
-
-**presencePenalty:** `Optional`
+**presencePenalty:** `Optional`
Defaults to `0.0`, min value of `0.0`, max value of `1.0`.
@@ -1449,7 +1449,7 @@ client.embed(
EmbedRequest
.builder()
.texts(
- new ArrayList(
+ Optional.of(
Arrays.asList("hello", "goodbye")
)
)
@@ -1574,14 +1574,12 @@ client.rerank(
RerankRequest
.builder()
.documents(
- new ArrayList(
- Arrays.asList(
- RerankRequestDocumentsItem.of(),
- RerankRequestDocumentsItem.of(),
- RerankRequestDocumentsItem.of(),
- RerankRequestDocumentsItem.of(),
- RerankRequestDocumentsItem.of()
- )
+ Arrays.asList(
+ RerankRequestDocumentsItem.of(),
+ RerankRequestDocumentsItem.of(),
+ RerankRequestDocumentsItem.of(),
+ RerankRequestDocumentsItem.of(),
+ RerankRequestDocumentsItem.of()
)
)
.query("What is the capital of the United States?")
@@ -1705,12 +1703,10 @@ client.classify(
ClassifyRequest
.builder()
.inputs(
- new ArrayList(
- Arrays.asList("Confirm your email address", "hey i need u to send some $")
- )
+ Arrays.asList("Confirm your email address", "hey i need u to send some $")
)
.examples(
- new ArrayList(
+ Optional.of(
Arrays.asList(
ClassifyExample
.builder()
@@ -1929,7 +1925,7 @@ client.summarize(
-
-**temperature:** `Optional` — Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
+**temperature:** `Optional` — Ranges from 0 to 5. Controls the randomness of the output. Lower values tend to generate more “predictable” output, while higher values tend to generate more “creative” output. The sweet spot is typically between 0 and 1.
@@ -2048,9 +2044,7 @@ client.detokenize(
DetokenizeRequest
.builder()
.tokens(
- new ArrayList(
- Arrays.asList(10002, 2261, 2012, 8, 2792, 43)
- )
+ Arrays.asList(10002, 2261, 2012, 8, 2792, 43)
)
.model("command")
.build()
@@ -2163,16 +2157,14 @@ client.v2().chatStream(
.builder()
.model("command-a-03-2025")
.messages(
- new ArrayList(
- Arrays.asList(
- ChatMessageV2.user(
- UserMessageV2
- .builder()
- .content(
- UserMessageV2Content.of("Tell me about LLMs")
- )
- .build()
- )
+ Arrays.asList(
+ ChatMessageV2.user(
+ UserMessageV2
+ .builder()
+ .content(
+ UserMessageV2Content.of("Tell me about LLMs")
+ )
+ .build()
)
)
)
@@ -2460,16 +2452,14 @@ client.v2().chatStream(
.builder()
.model("command-a-03-2025")
.messages(
- new ArrayList(
- Arrays.asList(
- ChatMessageV2.user(
- UserMessageV2
- .builder()
- .content(
- UserMessageV2Content.of("Tell me about LLMs")
- )
- .build()
- )
+ Arrays.asList(
+ ChatMessageV2.user(
+ UserMessageV2
+ .builder()
+ .content(
+ UserMessageV2Content.of("Tell me about LLMs")
+ )
+ .build()
)
)
)
@@ -2760,12 +2750,12 @@ client.v2().embed(
.model("embed-v4.0")
.inputType(EmbedInputType.CLASSIFICATION)
.texts(
- new ArrayList(
+ Optional.of(
Arrays.asList("hello", "goodbye")
)
)
.embeddingTypes(
- new ArrayList(
+ Optional.of(
Arrays.asList(EmbeddingType.FLOAT)
)
)
@@ -2927,9 +2917,7 @@ client.v2().rerank(
V2RerankRequest
.builder()
.documents(
- new ArrayList(
- Arrays.asList("Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.")
- )
+ Arrays.asList("Carson City is the capital city of the American state of Nevada.", "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.", "Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.", "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.", "Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.")
)
.query("What is the capital of the United States?")
.model("rerank-v3.5")
@@ -3617,7 +3605,7 @@ client.datasets().list(
-
-**limit:** `Optional` — optional limit to number of results
+**limit:** `Optional` — optional limit to number of results
@@ -3625,7 +3613,7 @@ client.datasets().list(
-
-**offset:** `Optional` — optional offset to start of results
+**offset:** `Optional` — optional offset to start of results
@@ -3961,7 +3949,7 @@ client.connectors().list(
-
-**limit:** `Optional` — Maximum number of connectors to return [0, 100].
+**limit:** `Optional` — Maximum number of connectors to return [0, 100].
@@ -3969,7 +3957,7 @@ client.connectors().list(
-
-**offset:** `Optional` — Number of connectors to skip before returning results [0, inf].
+**offset:** `Optional` — Number of connectors to skip before returning results [0, inf].
@@ -4493,7 +4481,7 @@ client.models().list(
-
-**pageSize:** `Optional`
+**pageSize:** `Optional`
Maximum number of models to include in a page
Defaults to `20`, min value of `1`, max value of `1000`.
diff --git a/src/main/java/com/cohere/api/AsyncCohereBuilder.java b/src/main/java/com/cohere/api/AsyncCohereBuilder.java
index 231aa21..0aa8d20 100644
--- a/src/main/java/com/cohere/api/AsyncCohereBuilder.java
+++ b/src/main/java/com/cohere/api/AsyncCohereBuilder.java
@@ -5,6 +5,8 @@
import com.cohere.api.core.ClientOptions;
import com.cohere.api.core.Environment;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Optional;
import okhttp3.OkHttpClient;
@@ -13,6 +15,8 @@ public class AsyncCohereBuilder {
private Optional maxRetries = Optional.empty();
+ private final Map customHeaders = new HashMap<>();
+
private String token = System.getenv("CO_API_KEY");
private String clientName = null;
@@ -72,6 +76,19 @@ public AsyncCohereBuilder httpClient(OkHttpClient httpClient) {
return this;
}
+ /**
+ * Add a custom header to be sent with all requests.
+ * For headers that need to be computed dynamically or conditionally, use the setAdditional() method override instead.
+ *
+ * @param name The header name
+ * @param value The header value
+ * @return This builder for method chaining
+ */
+ public AsyncCohereBuilder addHeader(String name, String value) {
+ this.customHeaders.put(name, value);
+ return this;
+ }
+
protected ClientOptions buildClientOptions() {
ClientOptions.Builder builder = ClientOptions.builder();
setEnvironment(builder);
@@ -80,6 +97,9 @@ protected ClientOptions buildClientOptions() {
setHttpClient(builder);
setTimeouts(builder);
setRetries(builder);
+ for (Map.Entry header : this.customHeaders.entrySet()) {
+ builder.addHeader(header.getKey(), header.getValue());
+ }
setAdditional(builder);
return builder.build();
}
diff --git a/src/main/java/com/cohere/api/AsyncRawCohere.java b/src/main/java/com/cohere/api/AsyncRawCohere.java
index 32d586b..4e4ac98 100644
--- a/src/main/java/com/cohere/api/AsyncRawCohere.java
+++ b/src/main/java/com/cohere/api/AsyncRawCohere.java
@@ -88,79 +88,31 @@ public CompletableFuture>> cha
.build();
Map properties = new HashMap<>();
properties.put("message", request.getMessage());
- if (request.getModel().isPresent()) {
- properties.put("model", request.getModel());
- }
+ properties.put("model", request.getModel());
properties.put("stream", request.getStream());
- if (request.getPreamble().isPresent()) {
- properties.put("preamble", request.getPreamble());
- }
- if (request.getChatHistory().isPresent()) {
- properties.put("chat_history", request.getChatHistory());
- }
- if (request.getConversationId().isPresent()) {
- properties.put("conversation_id", request.getConversationId());
- }
- if (request.getPromptTruncation().isPresent()) {
- properties.put("prompt_truncation", request.getPromptTruncation());
- }
- if (request.getConnectors().isPresent()) {
- properties.put("connectors", request.getConnectors());
- }
- if (request.getSearchQueriesOnly().isPresent()) {
- properties.put("search_queries_only", request.getSearchQueriesOnly());
- }
- if (request.getDocuments().isPresent()) {
- properties.put("documents", request.getDocuments());
- }
- if (request.getCitationQuality().isPresent()) {
- properties.put("citation_quality", request.getCitationQuality());
- }
- if (request.getTemperature().isPresent()) {
- properties.put("temperature", request.getTemperature());
- }
- if (request.getMaxTokens().isPresent()) {
- properties.put("max_tokens", request.getMaxTokens());
- }
- if (request.getMaxInputTokens().isPresent()) {
- properties.put("max_input_tokens", request.getMaxInputTokens());
- }
- if (request.getK().isPresent()) {
- properties.put("k", request.getK());
- }
- if (request.getP().isPresent()) {
- properties.put("p", request.getP());
- }
- if (request.getSeed().isPresent()) {
- properties.put("seed", request.getSeed());
- }
- if (request.getStopSequences().isPresent()) {
- properties.put("stop_sequences", request.getStopSequences());
- }
- if (request.getFrequencyPenalty().isPresent()) {
- properties.put("frequency_penalty", request.getFrequencyPenalty());
- }
- if (request.getPresencePenalty().isPresent()) {
- properties.put("presence_penalty", request.getPresencePenalty());
- }
- if (request.getRawPrompting().isPresent()) {
- properties.put("raw_prompting", request.getRawPrompting());
- }
- if (request.getTools().isPresent()) {
- properties.put("tools", request.getTools());
- }
- if (request.getToolResults().isPresent()) {
- properties.put("tool_results", request.getToolResults());
- }
- if (request.getForceSingleStep().isPresent()) {
- properties.put("force_single_step", request.getForceSingleStep());
- }
- if (request.getResponseFormat().isPresent()) {
- properties.put("response_format", request.getResponseFormat());
- }
- if (request.getSafetyMode().isPresent()) {
- properties.put("safety_mode", request.getSafetyMode());
- }
+ properties.put("preamble", request.getPreamble());
+ properties.put("chat_history", request.getChatHistory());
+ properties.put("conversation_id", request.getConversationId());
+ properties.put("prompt_truncation", request.getPromptTruncation());
+ properties.put("connectors", request.getConnectors());
+ properties.put("search_queries_only", request.getSearchQueriesOnly());
+ properties.put("documents", request.getDocuments());
+ properties.put("citation_quality", request.getCitationQuality());
+ properties.put("temperature", request.getTemperature());
+ properties.put("max_tokens", request.getMaxTokens());
+ properties.put("max_input_tokens", request.getMaxInputTokens());
+ properties.put("k", request.getK());
+ properties.put("p", request.getP());
+ properties.put("seed", request.getSeed());
+ properties.put("stop_sequences", request.getStopSequences());
+ properties.put("frequency_penalty", request.getFrequencyPenalty());
+ properties.put("presence_penalty", request.getPresencePenalty());
+ properties.put("raw_prompting", request.getRawPrompting());
+ properties.put("tools", request.getTools());
+ properties.put("tool_results", request.getToolResults());
+ properties.put("force_single_step", request.getForceSingleStep());
+ properties.put("response_format", request.getResponseFormat());
+ properties.put("safety_mode", request.getSafetyMode());
RequestBody body;
try {
body = RequestBody.create(
@@ -261,11 +213,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -300,79 +250,31 @@ public CompletableFuture> chat(
.build();
Map properties = new HashMap<>();
properties.put("message", request.getMessage());
- if (request.getModel().isPresent()) {
- properties.put("model", request.getModel());
- }
+ properties.put("model", request.getModel());
properties.put("stream", request.getStream());
- if (request.getPreamble().isPresent()) {
- properties.put("preamble", request.getPreamble());
- }
- if (request.getChatHistory().isPresent()) {
- properties.put("chat_history", request.getChatHistory());
- }
- if (request.getConversationId().isPresent()) {
- properties.put("conversation_id", request.getConversationId());
- }
- if (request.getPromptTruncation().isPresent()) {
- properties.put("prompt_truncation", request.getPromptTruncation());
- }
- if (request.getConnectors().isPresent()) {
- properties.put("connectors", request.getConnectors());
- }
- if (request.getSearchQueriesOnly().isPresent()) {
- properties.put("search_queries_only", request.getSearchQueriesOnly());
- }
- if (request.getDocuments().isPresent()) {
- properties.put("documents", request.getDocuments());
- }
- if (request.getCitationQuality().isPresent()) {
- properties.put("citation_quality", request.getCitationQuality());
- }
- if (request.getTemperature().isPresent()) {
- properties.put("temperature", request.getTemperature());
- }
- if (request.getMaxTokens().isPresent()) {
- properties.put("max_tokens", request.getMaxTokens());
- }
- if (request.getMaxInputTokens().isPresent()) {
- properties.put("max_input_tokens", request.getMaxInputTokens());
- }
- if (request.getK().isPresent()) {
- properties.put("k", request.getK());
- }
- if (request.getP().isPresent()) {
- properties.put("p", request.getP());
- }
- if (request.getSeed().isPresent()) {
- properties.put("seed", request.getSeed());
- }
- if (request.getStopSequences().isPresent()) {
- properties.put("stop_sequences", request.getStopSequences());
- }
- if (request.getFrequencyPenalty().isPresent()) {
- properties.put("frequency_penalty", request.getFrequencyPenalty());
- }
- if (request.getPresencePenalty().isPresent()) {
- properties.put("presence_penalty", request.getPresencePenalty());
- }
- if (request.getRawPrompting().isPresent()) {
- properties.put("raw_prompting", request.getRawPrompting());
- }
- if (request.getTools().isPresent()) {
- properties.put("tools", request.getTools());
- }
- if (request.getToolResults().isPresent()) {
- properties.put("tool_results", request.getToolResults());
- }
- if (request.getForceSingleStep().isPresent()) {
- properties.put("force_single_step", request.getForceSingleStep());
- }
- if (request.getResponseFormat().isPresent()) {
- properties.put("response_format", request.getResponseFormat());
- }
- if (request.getSafetyMode().isPresent()) {
- properties.put("safety_mode", request.getSafetyMode());
- }
+ properties.put("preamble", request.getPreamble());
+ properties.put("chat_history", request.getChatHistory());
+ properties.put("conversation_id", request.getConversationId());
+ properties.put("prompt_truncation", request.getPromptTruncation());
+ properties.put("connectors", request.getConnectors());
+ properties.put("search_queries_only", request.getSearchQueriesOnly());
+ properties.put("documents", request.getDocuments());
+ properties.put("citation_quality", request.getCitationQuality());
+ properties.put("temperature", request.getTemperature());
+ properties.put("max_tokens", request.getMaxTokens());
+ properties.put("max_input_tokens", request.getMaxInputTokens());
+ properties.put("k", request.getK());
+ properties.put("p", request.getP());
+ properties.put("seed", request.getSeed());
+ properties.put("stop_sequences", request.getStopSequences());
+ properties.put("frequency_penalty", request.getFrequencyPenalty());
+ properties.put("presence_penalty", request.getPresencePenalty());
+ properties.put("raw_prompting", request.getRawPrompting());
+ properties.put("tools", request.getTools());
+ properties.put("tool_results", request.getToolResults());
+ properties.put("force_single_step", request.getForceSingleStep());
+ properties.put("response_format", request.getResponseFormat());
+ properties.put("safety_mode", request.getSafetyMode());
RequestBody body;
try {
body = RequestBody.create(
@@ -399,14 +301,13 @@ public CompletableFuture> chat(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(
- responseBody.string(), NonStreamedChatResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, NonStreamedChatResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -473,11 +374,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -613,11 +512,9 @@ GenerateStreamedResponse.class, new ResponseBodyReader(response), "\\n"),
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -677,13 +574,12 @@ public CompletableFuture> generate(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), Generation.class),
- response));
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Generation.class), response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -750,11 +646,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -821,13 +715,13 @@ public CompletableFuture> embed(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), EmbedResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, EmbedResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -894,11 +788,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -952,13 +844,13 @@ public CompletableFuture> rerank(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), RerankResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, RerankResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1025,11 +917,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -1085,13 +975,13 @@ public CompletableFuture> classify(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), ClassifyResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, ClassifyResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1158,11 +1048,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -1222,13 +1110,13 @@ public CompletableFuture> summarize(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), SummarizeResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, SummarizeResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1295,11 +1183,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -1353,13 +1239,13 @@ public CompletableFuture> tokenize(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), TokenizeResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, TokenizeResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1426,11 +1312,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -1484,13 +1368,13 @@ public CompletableFuture> detokenize(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), DetokenizeResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, DetokenizeResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1557,11 +1441,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -1606,13 +1488,13 @@ public CompletableFuture> checkApiKey(Re
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), CheckApiKeyResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, CheckApiKeyResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1679,11 +1561,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
diff --git a/src/main/java/com/cohere/api/CohereBuilder.java b/src/main/java/com/cohere/api/CohereBuilder.java
index 041d384..b1bbff9 100644
--- a/src/main/java/com/cohere/api/CohereBuilder.java
+++ b/src/main/java/com/cohere/api/CohereBuilder.java
@@ -5,6 +5,8 @@
import com.cohere.api.core.ClientOptions;
import com.cohere.api.core.Environment;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Optional;
import okhttp3.OkHttpClient;
@@ -13,6 +15,8 @@ public class CohereBuilder {
private Optional maxRetries = Optional.empty();
+ private final Map customHeaders = new HashMap<>();
+
private String token = System.getenv("CO_API_KEY");
private String clientName = null;
@@ -72,6 +76,19 @@ public CohereBuilder httpClient(OkHttpClient httpClient) {
return this;
}
+ /**
+ * Add a custom header to be sent with all requests.
+ * For headers that need to be computed dynamically or conditionally, use the setAdditional() method override instead.
+ *
+ * @param name The header name
+ * @param value The header value
+ * @return This builder for method chaining
+ */
+ public CohereBuilder addHeader(String name, String value) {
+ this.customHeaders.put(name, value);
+ return this;
+ }
+
protected ClientOptions buildClientOptions() {
ClientOptions.Builder builder = ClientOptions.builder();
setEnvironment(builder);
@@ -80,6 +97,9 @@ protected ClientOptions buildClientOptions() {
setHttpClient(builder);
setTimeouts(builder);
setRetries(builder);
+ for (Map.Entry header : this.customHeaders.entrySet()) {
+ builder.addHeader(header.getKey(), header.getValue());
+ }
setAdditional(builder);
return builder.build();
}
diff --git a/src/main/java/com/cohere/api/RawCohere.java b/src/main/java/com/cohere/api/RawCohere.java
index 2e43695..7f36061 100644
--- a/src/main/java/com/cohere/api/RawCohere.java
+++ b/src/main/java/com/cohere/api/RawCohere.java
@@ -84,79 +84,31 @@ public CohereHttpResponse> chatStream(
.build();
Map properties = new HashMap<>();
properties.put("message", request.getMessage());
- if (request.getModel().isPresent()) {
- properties.put("model", request.getModel());
- }
+ properties.put("model", request.getModel());
properties.put("stream", request.getStream());
- if (request.getPreamble().isPresent()) {
- properties.put("preamble", request.getPreamble());
- }
- if (request.getChatHistory().isPresent()) {
- properties.put("chat_history", request.getChatHistory());
- }
- if (request.getConversationId().isPresent()) {
- properties.put("conversation_id", request.getConversationId());
- }
- if (request.getPromptTruncation().isPresent()) {
- properties.put("prompt_truncation", request.getPromptTruncation());
- }
- if (request.getConnectors().isPresent()) {
- properties.put("connectors", request.getConnectors());
- }
- if (request.getSearchQueriesOnly().isPresent()) {
- properties.put("search_queries_only", request.getSearchQueriesOnly());
- }
- if (request.getDocuments().isPresent()) {
- properties.put("documents", request.getDocuments());
- }
- if (request.getCitationQuality().isPresent()) {
- properties.put("citation_quality", request.getCitationQuality());
- }
- if (request.getTemperature().isPresent()) {
- properties.put("temperature", request.getTemperature());
- }
- if (request.getMaxTokens().isPresent()) {
- properties.put("max_tokens", request.getMaxTokens());
- }
- if (request.getMaxInputTokens().isPresent()) {
- properties.put("max_input_tokens", request.getMaxInputTokens());
- }
- if (request.getK().isPresent()) {
- properties.put("k", request.getK());
- }
- if (request.getP().isPresent()) {
- properties.put("p", request.getP());
- }
- if (request.getSeed().isPresent()) {
- properties.put("seed", request.getSeed());
- }
- if (request.getStopSequences().isPresent()) {
- properties.put("stop_sequences", request.getStopSequences());
- }
- if (request.getFrequencyPenalty().isPresent()) {
- properties.put("frequency_penalty", request.getFrequencyPenalty());
- }
- if (request.getPresencePenalty().isPresent()) {
- properties.put("presence_penalty", request.getPresencePenalty());
- }
- if (request.getRawPrompting().isPresent()) {
- properties.put("raw_prompting", request.getRawPrompting());
- }
- if (request.getTools().isPresent()) {
- properties.put("tools", request.getTools());
- }
- if (request.getToolResults().isPresent()) {
- properties.put("tool_results", request.getToolResults());
- }
- if (request.getForceSingleStep().isPresent()) {
- properties.put("force_single_step", request.getForceSingleStep());
- }
- if (request.getResponseFormat().isPresent()) {
- properties.put("response_format", request.getResponseFormat());
- }
- if (request.getSafetyMode().isPresent()) {
- properties.put("safety_mode", request.getSafetyMode());
- }
+ properties.put("preamble", request.getPreamble());
+ properties.put("chat_history", request.getChatHistory());
+ properties.put("conversation_id", request.getConversationId());
+ properties.put("prompt_truncation", request.getPromptTruncation());
+ properties.put("connectors", request.getConnectors());
+ properties.put("search_queries_only", request.getSearchQueriesOnly());
+ properties.put("documents", request.getDocuments());
+ properties.put("citation_quality", request.getCitationQuality());
+ properties.put("temperature", request.getTemperature());
+ properties.put("max_tokens", request.getMaxTokens());
+ properties.put("max_input_tokens", request.getMaxInputTokens());
+ properties.put("k", request.getK());
+ properties.put("p", request.getP());
+ properties.put("seed", request.getSeed());
+ properties.put("stop_sequences", request.getStopSequences());
+ properties.put("frequency_penalty", request.getFrequencyPenalty());
+ properties.put("presence_penalty", request.getPresencePenalty());
+ properties.put("raw_prompting", request.getRawPrompting());
+ properties.put("tools", request.getTools());
+ properties.put("tool_results", request.getToolResults());
+ properties.put("force_single_step", request.getForceSingleStep());
+ properties.put("response_format", request.getResponseFormat());
+ properties.put("safety_mode", request.getSafetyMode());
RequestBody body;
try {
body = RequestBody.create(
@@ -228,11 +180,9 @@ public CohereHttpResponse> chatStream(
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -257,79 +207,31 @@ public CohereHttpResponse chat(ChatRequest request, Req
.build();
Map properties = new HashMap<>();
properties.put("message", request.getMessage());
- if (request.getModel().isPresent()) {
- properties.put("model", request.getModel());
- }
+ properties.put("model", request.getModel());
properties.put("stream", request.getStream());
- if (request.getPreamble().isPresent()) {
- properties.put("preamble", request.getPreamble());
- }
- if (request.getChatHistory().isPresent()) {
- properties.put("chat_history", request.getChatHistory());
- }
- if (request.getConversationId().isPresent()) {
- properties.put("conversation_id", request.getConversationId());
- }
- if (request.getPromptTruncation().isPresent()) {
- properties.put("prompt_truncation", request.getPromptTruncation());
- }
- if (request.getConnectors().isPresent()) {
- properties.put("connectors", request.getConnectors());
- }
- if (request.getSearchQueriesOnly().isPresent()) {
- properties.put("search_queries_only", request.getSearchQueriesOnly());
- }
- if (request.getDocuments().isPresent()) {
- properties.put("documents", request.getDocuments());
- }
- if (request.getCitationQuality().isPresent()) {
- properties.put("citation_quality", request.getCitationQuality());
- }
- if (request.getTemperature().isPresent()) {
- properties.put("temperature", request.getTemperature());
- }
- if (request.getMaxTokens().isPresent()) {
- properties.put("max_tokens", request.getMaxTokens());
- }
- if (request.getMaxInputTokens().isPresent()) {
- properties.put("max_input_tokens", request.getMaxInputTokens());
- }
- if (request.getK().isPresent()) {
- properties.put("k", request.getK());
- }
- if (request.getP().isPresent()) {
- properties.put("p", request.getP());
- }
- if (request.getSeed().isPresent()) {
- properties.put("seed", request.getSeed());
- }
- if (request.getStopSequences().isPresent()) {
- properties.put("stop_sequences", request.getStopSequences());
- }
- if (request.getFrequencyPenalty().isPresent()) {
- properties.put("frequency_penalty", request.getFrequencyPenalty());
- }
- if (request.getPresencePenalty().isPresent()) {
- properties.put("presence_penalty", request.getPresencePenalty());
- }
- if (request.getRawPrompting().isPresent()) {
- properties.put("raw_prompting", request.getRawPrompting());
- }
- if (request.getTools().isPresent()) {
- properties.put("tools", request.getTools());
- }
- if (request.getToolResults().isPresent()) {
- properties.put("tool_results", request.getToolResults());
- }
- if (request.getForceSingleStep().isPresent()) {
- properties.put("force_single_step", request.getForceSingleStep());
- }
- if (request.getResponseFormat().isPresent()) {
- properties.put("response_format", request.getResponseFormat());
- }
- if (request.getSafetyMode().isPresent()) {
- properties.put("safety_mode", request.getSafetyMode());
- }
+ properties.put("preamble", request.getPreamble());
+ properties.put("chat_history", request.getChatHistory());
+ properties.put("conversation_id", request.getConversationId());
+ properties.put("prompt_truncation", request.getPromptTruncation());
+ properties.put("connectors", request.getConnectors());
+ properties.put("search_queries_only", request.getSearchQueriesOnly());
+ properties.put("documents", request.getDocuments());
+ properties.put("citation_quality", request.getCitationQuality());
+ properties.put("temperature", request.getTemperature());
+ properties.put("max_tokens", request.getMaxTokens());
+ properties.put("max_input_tokens", request.getMaxInputTokens());
+ properties.put("k", request.getK());
+ properties.put("p", request.getP());
+ properties.put("seed", request.getSeed());
+ properties.put("stop_sequences", request.getStopSequences());
+ properties.put("frequency_penalty", request.getFrequencyPenalty());
+ properties.put("presence_penalty", request.getPresencePenalty());
+ properties.put("raw_prompting", request.getRawPrompting());
+ properties.put("tools", request.getTools());
+ properties.put("tool_results", request.getToolResults());
+ properties.put("force_single_step", request.getForceSingleStep());
+ properties.put("response_format", request.getResponseFormat());
+ properties.put("safety_mode", request.getSafetyMode());
RequestBody body;
try {
body = RequestBody.create(
@@ -353,12 +255,12 @@ public CohereHttpResponse chat(ChatRequest request, Req
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), NonStreamedChatResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, NonStreamedChatResponse.class),
response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -401,11 +303,9 @@ public CohereHttpResponse chat(ChatRequest request, Req
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -502,11 +402,9 @@ public CohereHttpResponse> generateStream(
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -553,11 +451,11 @@ public CohereHttpResponse generate(GenerateRequest request, RequestO
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), Generation.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Generation.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -600,11 +498,9 @@ public CohereHttpResponse generate(GenerateRequest request, RequestO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -658,11 +554,11 @@ public CohereHttpResponse embed(EmbedRequest request, RequestOpti
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), EmbedResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, EmbedResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -705,11 +601,9 @@ public CohereHttpResponse embed(EmbedRequest request, RequestOpti
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -750,11 +644,11 @@ public CohereHttpResponse rerank(RerankRequest request, RequestO
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), RerankResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, RerankResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -797,11 +691,9 @@ public CohereHttpResponse rerank(RerankRequest request, RequestO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -844,11 +736,11 @@ public CohereHttpResponse classify(ClassifyRequest request, Re
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), ClassifyResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, ClassifyResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -891,11 +783,9 @@ public CohereHttpResponse classify(ClassifyRequest request, Re
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -942,11 +832,11 @@ public CohereHttpResponse summarize(SummarizeRequest request,
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), SummarizeResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, SummarizeResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -989,11 +879,9 @@ public CohereHttpResponse summarize(SummarizeRequest request,
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -1034,11 +922,11 @@ public CohereHttpResponse tokenize(TokenizeRequest request, Re
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), TokenizeResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, TokenizeResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1081,11 +969,9 @@ public CohereHttpResponse tokenize(TokenizeRequest request, Re
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -1126,11 +1012,11 @@ public CohereHttpResponse detokenize(DetokenizeRequest reque
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), DetokenizeResponse.class), response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, DetokenizeResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1173,11 +1059,9 @@ public CohereHttpResponse detokenize(DetokenizeRequest reque
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
@@ -1210,12 +1094,11 @@ public CohereHttpResponse checkApiKey(RequestOptions reques
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
return new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), CheckApiKeyResponse.class),
- response);
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, CheckApiKeyResponse.class), response);
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -1258,11 +1141,9 @@ public CohereHttpResponse checkApiKey(RequestOptions reques
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
throw new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response);
+ "Error with status code " + response.code(), response.code(), errorBody, response);
} catch (IOException e) {
throw new CohereException("Network error executing HTTP request", e);
}
diff --git a/src/main/java/com/cohere/api/core/ClientOptions.java b/src/main/java/com/cohere/api/core/ClientOptions.java
index 07e7379..1333deb 100644
--- a/src/main/java/com/cohere/api/core/ClientOptions.java
+++ b/src/main/java/com/cohere/api/core/ClientOptions.java
@@ -21,26 +21,30 @@ public final class ClientOptions {
private final int timeout;
+ private final int maxRetries;
+
private ClientOptions(
Environment environment,
Map headers,
Map> headerSuppliers,
OkHttpClient httpClient,
- int timeout) {
+ int timeout,
+ int maxRetries) {
this.environment = environment;
this.headers = new HashMap<>();
this.headers.putAll(headers);
this.headers.putAll(new HashMap() {
{
- put("User-Agent", "com.cohere:cohere-java/1.8.0");
+ put("User-Agent", "com.cohere:cohere-java/1.9.3");
put("X-Fern-Language", "JAVA");
put("X-Fern-SDK-Name", "com.cohere.fern:api-sdk");
- put("X-Fern-SDK-Version", "1.8.0");
+ put("X-Fern-SDK-Version", "1.9.3");
}
});
this.headerSuppliers = headerSuppliers;
this.httpClient = httpClient;
this.timeout = timeout;
+ this.maxRetries = maxRetries;
}
public Environment environment() {
@@ -82,6 +86,10 @@ public OkHttpClient httpClientWithTimeout(RequestOptions requestOptions) {
.build();
}
+ public int maxRetries() {
+ return this.maxRetries;
+ }
+
public static Builder builder() {
return new Builder();
}
@@ -165,7 +173,8 @@ public ClientOptions build() {
this.httpClient = httpClientBuilder.build();
this.timeout = Optional.of(httpClient.callTimeoutMillis() / 1000);
- return new ClientOptions(environment, headers, headerSuppliers, httpClient, this.timeout.get());
+ return new ClientOptions(
+ environment, headers, headerSuppliers, httpClient, this.timeout.get(), this.maxRetries);
}
/**
diff --git a/src/main/java/com/cohere/api/core/CohereApiException.java b/src/main/java/com/cohere/api/core/CohereApiException.java
index 46c8bb9..b0b6a3b 100644
--- a/src/main/java/com/cohere/api/core/CohereApiException.java
+++ b/src/main/java/com/cohere/api/core/CohereApiException.java
@@ -65,9 +65,9 @@ public Map> headers() {
return this.headers;
}
- @java.lang.Override
+ @Override
public String toString() {
- return "CohereApiException{" + "message: " + getMessage() + ", statusCode: " + statusCode + ", body: " + body
- + "}";
+ return "CohereApiException{" + "message: " + getMessage() + ", statusCode: " + statusCode + ", body: "
+ + ObjectMappers.stringify(body) + "}";
}
}
diff --git a/src/main/java/com/cohere/api/core/NullableNonemptyFilter.java b/src/main/java/com/cohere/api/core/NullableNonemptyFilter.java
index 62644da..6b92248 100644
--- a/src/main/java/com/cohere/api/core/NullableNonemptyFilter.java
+++ b/src/main/java/com/cohere/api/core/NullableNonemptyFilter.java
@@ -14,6 +14,9 @@ public boolean equals(Object o) {
}
private boolean isOptionalEmpty(Object o) {
- return o instanceof Optional && !((Optional>) o).isPresent();
+ if (o instanceof Optional) {
+ return !((Optional>) o).isPresent();
+ }
+ return false;
}
}
diff --git a/src/main/java/com/cohere/api/core/ObjectMappers.java b/src/main/java/com/cohere/api/core/ObjectMappers.java
index e4a1b5c..1790e8e 100644
--- a/src/main/java/com/cohere/api/core/ObjectMappers.java
+++ b/src/main/java/com/cohere/api/core/ObjectMappers.java
@@ -4,6 +4,7 @@
package com.cohere.api.core;
import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
@@ -33,4 +34,12 @@ public static String stringify(Object o) {
return o.getClass().getName() + "@" + Integer.toHexString(o.hashCode());
}
}
+
+ public static Object parseErrorBody(String responseBodyString) {
+ try {
+ return JSON_MAPPER.readValue(responseBodyString, Object.class);
+ } catch (JsonProcessingException ignored) {
+ return responseBodyString;
+ }
+ }
}
diff --git a/src/main/java/com/cohere/api/core/RetryInterceptor.java b/src/main/java/com/cohere/api/core/RetryInterceptor.java
index 5b17329..7fa6843 100644
--- a/src/main/java/com/cohere/api/core/RetryInterceptor.java
+++ b/src/main/java/com/cohere/api/core/RetryInterceptor.java
@@ -5,6 +5,9 @@
import java.io.IOException;
import java.time.Duration;
+import java.time.ZonedDateTime;
+import java.time.format.DateTimeFormatter;
+import java.time.format.DateTimeParseException;
import java.util.Optional;
import java.util.Random;
import okhttp3.Interceptor;
@@ -12,7 +15,10 @@
public class RetryInterceptor implements Interceptor {
- private static final Duration ONE_SECOND = Duration.ofSeconds(1);
+ private static final Duration INITIAL_RETRY_DELAY = Duration.ofMillis(1000);
+ private static final Duration MAX_RETRY_DELAY = Duration.ofMillis(60000);
+ private static final double JITTER_FACTOR = 0.2;
+
private final ExponentialBackoff backoff;
private final Random random = new Random();
@@ -32,7 +38,7 @@ public Response intercept(Chain chain) throws IOException {
}
private Response retryChain(Response response, Chain chain) throws IOException {
- Optional nextBackoff = this.backoff.nextBackoff();
+ Optional nextBackoff = this.backoff.nextBackoff(response);
while (nextBackoff.isPresent()) {
try {
Thread.sleep(nextBackoff.get().toMillis());
@@ -42,7 +48,7 @@ private Response retryChain(Response response, Chain chain) throws IOException {
response.close();
response = chain.proceed(chain.request());
if (shouldRetry(response.code())) {
- nextBackoff = this.backoff.nextBackoff();
+ nextBackoff = this.backoff.nextBackoff(response);
} else {
return response;
}
@@ -51,6 +57,102 @@ private Response retryChain(Response response, Chain chain) throws IOException {
return response;
}
+ /**
+ * Calculates the retry delay from response headers, with fallback to exponential backoff.
+ * Priority: Retry-After > X-RateLimit-Reset > Exponential Backoff
+ */
+ private Duration getRetryDelayFromHeaders(Response response, int retryAttempt) {
+ // Check for Retry-After header first (RFC 7231), with no jitter
+ String retryAfter = response.header("Retry-After");
+ if (retryAfter != null) {
+ // Parse as number of seconds...
+ Optional secondsDelay = tryParseLong(retryAfter)
+ .map(seconds -> seconds * 1000)
+ .filter(delayMs -> delayMs > 0)
+ .map(delayMs -> Math.min(delayMs, MAX_RETRY_DELAY.toMillis()))
+ .map(Duration::ofMillis);
+ if (secondsDelay.isPresent()) {
+ return secondsDelay.get();
+ }
+
+ // ...or as an HTTP date; both are valid
+ Optional dateDelay = tryParseHttpDate(retryAfter)
+ .map(resetTime -> resetTime.toInstant().toEpochMilli() - System.currentTimeMillis())
+ .filter(delayMs -> delayMs > 0)
+ .map(delayMs -> Math.min(delayMs, MAX_RETRY_DELAY.toMillis()))
+ .map(Duration::ofMillis);
+ if (dateDelay.isPresent()) {
+ return dateDelay.get();
+ }
+ }
+
+ // Then check for industry-standard X-RateLimit-Reset header, with positive jitter
+ String rateLimitReset = response.header("X-RateLimit-Reset");
+ if (rateLimitReset != null) {
+ // Assume Unix timestamp in epoch seconds
+ Optional rateLimitDelay = tryParseLong(rateLimitReset)
+ .map(resetTimeSeconds -> (resetTimeSeconds * 1000) - System.currentTimeMillis())
+ .filter(delayMs -> delayMs > 0)
+ .map(delayMs -> Math.min(delayMs, MAX_RETRY_DELAY.toMillis()))
+ .map(this::addPositiveJitter)
+ .map(Duration::ofMillis);
+ if (rateLimitDelay.isPresent()) {
+ return rateLimitDelay.get();
+ }
+ }
+
+ // Fall back to exponential backoff, with symmetric jitter
+ long baseDelay = INITIAL_RETRY_DELAY.toMillis() * (1L << retryAttempt); // 2^retryAttempt
+ long cappedDelay = Math.min(baseDelay, MAX_RETRY_DELAY.toMillis());
+ return Duration.ofMillis(addSymmetricJitter(cappedDelay));
+ }
+
+ /**
+ * Attempts to parse a string as a long, returning empty Optional on failure.
+ */
+ private Optional tryParseLong(String value) {
+ if (value == null) {
+ return Optional.empty();
+ }
+ try {
+ return Optional.of(Long.parseLong(value));
+ } catch (NumberFormatException e) {
+ return Optional.empty();
+ }
+ }
+
+ /**
+ * Attempts to parse a string as an HTTP date (RFC 1123), returning empty Optional on failure.
+ */
+ private Optional tryParseHttpDate(String value) {
+ if (value == null) {
+ return Optional.empty();
+ }
+ try {
+ return Optional.of(ZonedDateTime.parse(value, DateTimeFormatter.RFC_1123_DATE_TIME));
+ } catch (DateTimeParseException e) {
+ return Optional.empty();
+ }
+ }
+
+ /**
+ * Adds positive jitter (100-120% of original value) to prevent thundering herd.
+ * Used for X-RateLimit-Reset header delays.
+ */
+ private long addPositiveJitter(long delayMs) {
+ double jitterMultiplier = 1.0 + (random.nextDouble() * JITTER_FACTOR);
+ return (long) (delayMs * jitterMultiplier);
+ }
+
+ /**
+ * Adds symmetric jitter (90-110% of original value) to prevent thundering herd.
+ * Used for exponential backoff delays.
+ */
+ private long addSymmetricJitter(long delayMs) {
+ double jitterMultiplier = 1.0 + ((random.nextDouble() - 0.5) * JITTER_FACTOR);
+ return (long) (delayMs * jitterMultiplier);
+ }
+
private static boolean shouldRetry(int statusCode) {
return statusCode == 408 || statusCode == 429 || statusCode >= 500;
}
@@ -65,14 +167,14 @@ private final class ExponentialBackoff {
this.maxNumRetries = maxNumRetries;
}
- public Optional nextBackoff() {
- retryNumber += 1;
- if (retryNumber > maxNumRetries) {
+ public Optional nextBackoff(Response response) {
+ if (retryNumber >= maxNumRetries) {
return Optional.empty();
}
- int upperBound = (int) Math.pow(2, retryNumber);
- return Optional.of(ONE_SECOND.multipliedBy(random.nextInt(upperBound)));
+ Duration delay = getRetryDelayFromHeaders(response, retryNumber);
+ retryNumber += 1;
+ return Optional.of(delay);
}
}
}
diff --git a/src/main/java/com/cohere/api/core/Stream.java b/src/main/java/com/cohere/api/core/Stream.java
index 9b1c240..3ee1c8a 100644
--- a/src/main/java/com/cohere/api/core/Stream.java
+++ b/src/main/java/com/cohere/api/core/Stream.java
@@ -174,8 +174,8 @@ private final class SSEIterator implements Iterator {
private T nextItem;
private boolean hasNextItem = false;
private boolean endOfStream = false;
- private StringBuilder buffer = new StringBuilder();
- private boolean prefixSeen = false;
+ private StringBuilder eventDataBuffer = new StringBuilder();
+ private String currentEventType = null;
private SSEIterator() {
if (sseReader != null && !isStreamClosed()) {
@@ -223,39 +223,69 @@ private boolean readNextMessage() {
try {
while (sseScanner.hasNextLine()) {
- String chunk = sseScanner.nextLine();
- buffer.append(chunk).append(NEWLINE);
-
- int terminatorIndex;
- while ((terminatorIndex = buffer.indexOf(messageTerminator)) >= 0) {
- String line = buffer.substring(0, terminatorIndex + messageTerminator.length());
- buffer.delete(0, terminatorIndex + messageTerminator.length());
-
- line = line.trim();
- if (line.isEmpty()) {
- continue;
+ String line = sseScanner.nextLine();
+
+ if (line.trim().isEmpty()) {
+ if (eventDataBuffer.length() > 0) {
+ try {
+ nextItem = ObjectMappers.JSON_MAPPER.readValue(eventDataBuffer.toString(), valueType);
+ hasNextItem = true;
+ eventDataBuffer.setLength(0);
+ currentEventType = null;
+ return true;
+ } catch (Exception parseEx) {
+ System.err.println("Failed to parse SSE event: " + parseEx.getMessage());
+ eventDataBuffer.setLength(0);
+ currentEventType = null;
+ continue;
+ }
}
+ continue;
+ }
- if (!prefixSeen && line.startsWith(DATA_PREFIX)) {
- prefixSeen = true;
- line = line.substring(DATA_PREFIX.length()).trim();
- } else if (!prefixSeen) {
- continue;
+ if (line.startsWith(DATA_PREFIX)) {
+ String dataContent = line.substring(DATA_PREFIX.length());
+ if (dataContent.startsWith(" ")) {
+ dataContent = dataContent.substring(1);
}
- if (streamTerminator != null && line.contains(streamTerminator)) {
+ if (eventDataBuffer.length() == 0
+ && streamTerminator != null
+ && dataContent.trim().equals(streamTerminator)) {
endOfStream = true;
return false;
}
- try {
- nextItem = ObjectMappers.JSON_MAPPER.readValue(line, valueType);
- hasNextItem = true;
- prefixSeen = false;
- return true;
- } catch (Exception parseEx) {
- continue;
+ if (eventDataBuffer.length() > 0) {
+ eventDataBuffer.append('\n');
+ }
+ eventDataBuffer.append(dataContent);
+ } else if (line.startsWith("event:")) {
+ String eventValue = line.length() > 6 ? line.substring(6) : "";
+ if (eventValue.startsWith(" ")) {
+ eventValue = eventValue.substring(1);
}
+ currentEventType = eventValue;
+ } else if (line.startsWith("id:")) {
+ // Event ID field (ignored)
+ } else if (line.startsWith("retry:")) {
+ // Retry field (ignored)
+ } else if (line.startsWith(":")) {
+ // Comment line (ignored)
+ }
+ }
+
+ if (eventDataBuffer.length() > 0) {
+ try {
+ nextItem = ObjectMappers.JSON_MAPPER.readValue(eventDataBuffer.toString(), valueType);
+ hasNextItem = true;
+ eventDataBuffer.setLength(0);
+ currentEventType = null;
+ return true;
+ } catch (Exception parseEx) {
+ System.err.println("Failed to parse final SSE event: " + parseEx.getMessage());
+ eventDataBuffer.setLength(0);
+ currentEventType = null;
}
}
diff --git a/src/main/java/com/cohere/api/requests/ClassifyRequest.java b/src/main/java/com/cohere/api/requests/ClassifyRequest.java
index 3dd6207..ff7aebc 100644
--- a/src/main/java/com/cohere/api/requests/ClassifyRequest.java
+++ b/src/main/java/com/cohere/api/requests/ClassifyRequest.java
@@ -163,7 +163,9 @@ public Builder from(ClassifyRequest other) {
@JsonSetter(value = "inputs", nulls = Nulls.SKIP)
public Builder inputs(List inputs) {
this.inputs.clear();
- this.inputs.addAll(inputs);
+ if (inputs != null) {
+ this.inputs.addAll(inputs);
+ }
return this;
}
@@ -173,7 +175,9 @@ public Builder addInputs(String inputs) {
}
public Builder addAllInputs(List inputs) {
- this.inputs.addAll(inputs);
+ if (inputs != null) {
+ this.inputs.addAll(inputs);
+ }
return this;
}
diff --git a/src/main/java/com/cohere/api/requests/DetokenizeRequest.java b/src/main/java/com/cohere/api/requests/DetokenizeRequest.java
index a5111d5..45b2d14 100644
--- a/src/main/java/com/cohere/api/requests/DetokenizeRequest.java
+++ b/src/main/java/com/cohere/api/requests/DetokenizeRequest.java
@@ -137,7 +137,9 @@ public _FinalStage model(@NotNull String model) {
*/
@java.lang.Override
public _FinalStage addAllTokens(List tokens) {
- this.tokens.addAll(tokens);
+ if (tokens != null) {
+ this.tokens.addAll(tokens);
+ }
return this;
}
@@ -158,7 +160,9 @@ public _FinalStage addTokens(Integer tokens) {
@JsonSetter(value = "tokens", nulls = Nulls.SKIP)
public _FinalStage tokens(List tokens) {
this.tokens.clear();
- this.tokens.addAll(tokens);
+ if (tokens != null) {
+ this.tokens.addAll(tokens);
+ }
return this;
}
diff --git a/src/main/java/com/cohere/api/requests/RerankRequest.java b/src/main/java/com/cohere/api/requests/RerankRequest.java
index 5e4c069..46c22ab 100644
--- a/src/main/java/com/cohere/api/requests/RerankRequest.java
+++ b/src/main/java/com/cohere/api/requests/RerankRequest.java
@@ -366,7 +366,9 @@ public _FinalStage topN(Optional topN) {
*/
@java.lang.Override
public _FinalStage addAllDocuments(List documents) {
- this.documents.addAll(documents);
+ if (documents != null) {
+ this.documents.addAll(documents);
+ }
return this;
}
@@ -393,7 +395,9 @@ public _FinalStage addDocuments(RerankRequestDocumentsItem documents) {
@JsonSetter(value = "documents", nulls = Nulls.SKIP)
public _FinalStage documents(List documents) {
this.documents.clear();
- this.documents.addAll(documents);
+ if (documents != null) {
+ this.documents.addAll(documents);
+ }
return this;
}
diff --git a/src/main/java/com/cohere/api/resources/batches/AsyncRawBatchesClient.java b/src/main/java/com/cohere/api/resources/batches/AsyncRawBatchesClient.java
index 8983224..7b1d79d 100644
--- a/src/main/java/com/cohere/api/resources/batches/AsyncRawBatchesClient.java
+++ b/src/main/java/com/cohere/api/resources/batches/AsyncRawBatchesClient.java
@@ -94,13 +94,13 @@ public CompletableFuture> list(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), ListBatchesResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, ListBatchesResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -137,11 +137,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -195,13 +193,13 @@ public CompletableFuture> create(
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), CreateBatchResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, CreateBatchResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -238,11 +236,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -288,13 +284,13 @@ public CompletableFuture> retrieve(String i
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
- ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), GetBatchResponse.class),
+ ObjectMappers.JSON_MAPPER.readValue(responseBodyString, GetBatchResponse.class),
response));
return;
}
- String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
@@ -331,11 +327,9 @@ public void onResponse(@NotNull Call call, @NotNull Response response) throws IO
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
+ Object errorBody = ObjectMappers.parseErrorBody(responseBodyString);
future.completeExceptionally(new CohereApiException(
- "Error with status code " + response.code(),
- response.code(),
- ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class),
- response));
+ "Error with status code " + response.code(), response.code(), errorBody, response));
return;
} catch (IOException e) {
future.completeExceptionally(new CohereException("Network error executing HTTP request", e));
@@ -382,14 +376,14 @@ public CompletableFuture>> cancel(String
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
try (ResponseBody responseBody = response.body()) {
+ String responseBodyString = responseBody != null ? responseBody.string() : "{}";
if (response.isSuccessful()) {
future.complete(new CohereHttpResponse<>(
ObjectMappers.JSON_MAPPER.readValue(
- responseBody.string(), new TypeReference