diff --git a/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.scss b/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.scss new file mode 100644 index 000000000..146149406 --- /dev/null +++ b/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.scss @@ -0,0 +1,94 @@ +@use '../../sass/colors'; +@use '../../sass/common'; +@use '../../sass/typescale'; + +:host { + @include common.overlay; + z-index: 20; // Above cohort-list header (z-index: 10) +} + +.dialog { + @include common.dialog; + height: 700px; + width: 800px; +} + +.header { + @include typescale.title-medium; + @include common.flex-row-align-center; + border-bottom: 1px solid var(--md-sys-color-outline-variant); + flex-shrink: 0; + gap: common.$spacing-medium; + height: common.$header-height; + justify-content: space-between; + padding: 0 common.$main-content-padding; +} + +.body { + @include common.flex-column; + flex-grow: 1; + gap: common.$spacing-xl; + overflow: auto; + padding: common.$main-content-padding; +} + +.section { + @include common.flex-column; + gap: common.$spacing-large; +} + +.section-title { + @include typescale.title-medium; +} + +.field { + @include common.flex-column; + gap: common.$spacing-small; + width: 100%; +} + +.field-title { + @include typescale.label-small; +} + +.description { + @include typescale.label-small; + color: var(--md-sys-color-outline); + font-style: italic; +} + +.small { + @include typescale.label-small; + color: var(--md-sys-color-outline); +} + +.action-buttons { + @include common.flex-row; + flex-wrap: wrap; + gap: common.$spacing-medium; +} + +.checkbox-wrapper { + @include common.flex-row-align-center; + gap: common.$spacing-small; + overflow-wrap: break-word; + + md-checkbox { + flex-shrink: 0; + } +} + +.number-input { + @include common.number-input; + width: max-content; +} + +.divider { + border-bottom: 1px solid var(--md-sys-color-outline-variant); + width: 100%; +} + +.buttons-wrapper { + @include common.flex-row-align-center; + justify-content: end; +} diff --git a/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.ts b/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.ts index 6ecd7f479..c7a909f5f 100644 --- a/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.ts +++ b/frontend/src/components/experiment_dashboard/agent_participant_configuration_dialog.ts @@ -1,6 +1,7 @@ import '../../pair-components/button'; import '../../pair-components/icon_button'; import '@material/web/textfield/filled-text-field.js'; +import '@material/web/checkbox/checkbox.js'; import {MobxLitElement} from '@adobe/lit-mobx'; import {CSSResultGroup, html, nothing} from 'lit'; @@ -15,11 +16,11 @@ import { AgentPersonaConfig, CohortConfig, ApiKeyType, - AgentPersonaType, + AgentChatSettings, createAgentModelSettings, } from '@deliberation-lab/utils'; -import {styles} from './cohort_settings_dialog.scss'; +import {styles} from './agent_participant_configuration_dialog.scss'; /** Agent participant configuration dialog */ @customElement('agent-participant-configuration-dialog') @@ -37,7 +38,15 @@ export class AgentParticipantDialog extends MobxLitElement { @property() agentId = ''; @property() promptContext = ''; @property() agent: AgentPersonaConfig | undefined = undefined; - @property() model: string = ''; + @property() apiType: ApiKeyType = ApiKeyType.GEMINI_API_KEY; + @property() model: string = 'gemini-3-pro-preview'; + @property() useWebSearch: boolean = false; + + // Chat settings + @property() wordsPerMinute: number | null = null; + @property() minMessagesBeforeResponding: number = 0; + @property() canSelfTriggerCalls: boolean = false; + @property() maxResponses: number | null = 100; private close() { this.dispatchEvent(new CustomEvent('close')); @@ -70,38 +79,41 @@ export class AgentParticipantDialog extends MobxLitElement { private resetFields() { this.agentId = ''; this.promptContext = ''; + this.apiType = ApiKeyType.GEMINI_API_KEY; + this.model = 'gemini-3-pro-preview'; + this.useWebSearch = false; + // Reset chat settings + this.wordsPerMinute = null; + this.minMessagesBeforeResponding = 0; + this.canSelfTriggerCalls = false; + this.maxResponses = 100; } private renderEdit() { return html` - ${this.renderAgentModel()} ${this.renderPromptContext()} +
+
Model settings
+ ${this.renderApiType()} ${this.renderModelId()} + ${this.renderWebSearchOption()} +
+
+
+
Prompt
+ ${this.renderPromptContext()} +
+
+
+
Chat settings
+
+ Configure how this agent participates in chat stages. +
+ ${this.renderChatSettings()} +
{ - this.isLoading = true; - this.analyticsService.trackButtonClick( - ButtonClick.AGENT_PARTICIPANT_ADD, - ); - if (this.cohort && this.model) { - this.experimentEditor.addAgentParticipant(); - this.agentId = ''; //Make agent ID blank for agents added from cohort panel that use default prompts - const modelSettings = createAgentModelSettings({ - apiType: ApiKeyType.GEMINI_API_KEY, - modelName: this.model, - }); - - this.experimentManager.createAgentParticipant(this.cohort.id, { - agentId: this.agentId, - promptContext: this.promptContext, - modelSettings, - }); - } - this.resetFields(); - this.isSuccess = true; - this.isLoading = false; - }} + @click=${this.handleAddAgent} > Add agent participant @@ -109,6 +121,38 @@ export class AgentParticipantDialog extends MobxLitElement { `; } + private handleAddAgent() { + this.isLoading = true; + this.analyticsService.trackButtonClick(ButtonClick.AGENT_PARTICIPANT_ADD); + if (this.cohort && this.model) { + this.experimentEditor.addAgentParticipant(); + this.agentId = ''; // Make agent ID blank for agents added from cohort panel that use default prompts + const modelSettings = createAgentModelSettings({ + apiType: this.apiType, + modelName: this.model, + useWebSearch: this.useWebSearch, + }); + + const chatSettings: AgentChatSettings = { + initialMessage: '', + wordsPerMinute: this.wordsPerMinute, + minMessagesBeforeResponding: this.minMessagesBeforeResponding, + canSelfTriggerCalls: this.canSelfTriggerCalls, + maxResponses: this.maxResponses, + }; + + this.experimentManager.createAgentParticipant(this.cohort.id, { + agentId: this.agentId, + promptContext: this.promptContext, + modelSettings, + chatSettings, + }); + } + this.resetFields(); + this.isSuccess = true; + this.isLoading = false; + } + private renderSuccess() { return html`
Agent participant added!
@@ -124,47 +168,101 @@ export class AgentParticipantDialog extends MobxLitElement { `; } - private renderAgentModel() { + private renderApiType() { return html` -
-
Model to use for this specific agent participant:
-
- ${this.renderModelButton( - 'gemini-2.5-flash', - 'Gemini 2.5 Flash', - ApiKeyType.GEMINI_API_KEY, - )} - ${this.renderModelButton( - 'gemini-2.5-pro', - 'Gemini 2.5 Pro', - ApiKeyType.GEMINI_API_KEY, - )} +
+
LLM API
+
+ ${this.renderApiTypeButton('Gemini', ApiKeyType.GEMINI_API_KEY)} + ${this.renderApiTypeButton('OpenAI', ApiKeyType.OPENAI_API_KEY)} + ${this.renderApiTypeButton('Anthropic', ApiKeyType.CLAUDE_API_KEY)} + ${this.renderApiTypeButton('Ollama', ApiKeyType.OLLAMA_CUSTOM_URL)}
`; } - private renderModelButton( - modelId: string, - modelName: string, - apiType: ApiKeyType, - ) { - const updateModel = () => { - this.model = modelId; + private renderApiTypeButton(apiName: string, apiType: ApiKeyType) { + const updateApiType = () => { + this.apiType = apiType; + // Set a default model name when switching API types + if (apiType === ApiKeyType.GEMINI_API_KEY) { + this.model = 'gemini-3-pro-preview'; + } else if (apiType === ApiKeyType.OPENAI_API_KEY) { + this.model = 'gpt-5.1-2025-11-13'; + } else if (apiType === ApiKeyType.CLAUDE_API_KEY) { + this.model = 'claude-opus-4-5-20251101'; + } else if (apiType === ApiKeyType.OLLAMA_CUSTOM_URL) { + this.model = 'llama3.2'; + } + // Reset web search when switching to OpenAI or Ollama (not supported) + if ( + apiType === ApiKeyType.OLLAMA_CUSTOM_URL || + apiType === ApiKeyType.OPENAI_API_KEY + ) { + this.useWebSearch = false; + } }; - const isActive = modelId == this.model; + const isActive = apiType === this.apiType; return html` - ${modelName} + ${apiName} `; } + private renderModelId() { + const updateModel = (e: InputEvent) => { + const content = (e.target as HTMLTextAreaElement).value; + this.model = content; + }; + + return html` +
+ + +
+ `; + } + + private renderWebSearchOption() { + // Only show for Gemini and Anthropic (OpenAI and Ollama don't support web search in chat completions) + if ( + this.apiType === ApiKeyType.OLLAMA_CUSTOM_URL || + this.apiType === ApiKeyType.OPENAI_API_KEY + ) { + return nothing; + } + + const toggleWebSearch = (event: Event) => { + const checked = (event.target as HTMLInputElement).checked; + this.useWebSearch = checked; + }; + + return html` +
+ + +
Enable web search
+
+ `; + } + private renderPromptContext() { const updatePromptContext = (e: InputEvent) => { const content = (e.target as HTMLTextAreaElement).value; @@ -172,14 +270,99 @@ export class AgentParticipantDialog extends MobxLitElement { }; return html` - - +
+ + +
+ Additional context to include in the agent's prompts. +
+
+ `; + } + + private renderChatSettings() { + return html` +
+ +
+ Agent's typing speed. Leave empty for instant messages. +
+
+ { + const value = (e.target as HTMLInputElement).value; + this.wordsPerMinute = value === '' ? null : Number(value); + }} + /> +
+
+
+ +
+ Number of chat messages that must exist before the agent can respond. +
+
+ { + const value = Number((e.target as HTMLInputElement).value); + if (!isNaN(value)) { + this.minMessagesBeforeResponding = value; + } + }} + /> +
+
+
+ { + this.canSelfTriggerCalls = (e.target as HTMLInputElement).checked; + }} + > + +
+ Can respond multiple times in a row + + (Agent's own messages can trigger new responses) + +
+
+
+ +
+ Maximum total responses during the chat. Leave empty for no limit. +
+
+ { + const value = (e.target as HTMLInputElement).value; + this.maxResponses = value === '' ? null : Number(value); + }} + /> +
+
`; } } diff --git a/frontend/src/components/experiment_dashboard/cohort_settings_dialog.scss b/frontend/src/components/experiment_dashboard/cohort_settings_dialog.scss index 415a3e3b2..b51865567 100644 --- a/frontend/src/components/experiment_dashboard/cohort_settings_dialog.scss +++ b/frontend/src/components/experiment_dashboard/cohort_settings_dialog.scss @@ -4,6 +4,7 @@ :host { @include common.overlay; + z-index: 20; // Above cohort-list header (z-index: 10) } .dialog { diff --git a/frontend/src/components/experiment_dashboard/participant_stats.scss b/frontend/src/components/experiment_dashboard/participant_stats.scss index 46c22b68e..cde400431 100644 --- a/frontend/src/components/experiment_dashboard/participant_stats.scss +++ b/frontend/src/components/experiment_dashboard/participant_stats.scss @@ -40,3 +40,80 @@ h4 { .stats-wrapper { padding: 0 common.$spacing-large; } + +// Agent log section +.agent-log-wrapper { + @include common.flex-column; + gap: common.$spacing-medium; + padding: 0 common.$spacing-large; +} + +.agent-config { + margin-bottom: common.$spacing-medium; +} + +.agent-log-summary { + @include common.flex-row; + flex-wrap: wrap; + gap: common.$spacing-small; + margin-bottom: common.$spacing-medium; +} + +.empty-message { + color: var(--md-sys-color-outline); + padding: common.$spacing-medium 0; +} + +.log-entry { + @include common.flex-column; + border: 1px solid var(--md-sys-color-outline-variant); + border-radius: common.$spacing-small; + gap: common.$spacing-small; + padding: common.$spacing-medium; + margin-bottom: common.$spacing-small; +} + +.log-entry-header { + @include common.flex-row-align-center; + flex-wrap: wrap; + gap: common.$spacing-small; +} + +.log-timestamp { + @include typescale.label-small; + color: var(--md-sys-color-outline); +} + +.log-duration { + @include typescale.label-small; + color: var(--md-sys-color-outline); +} + +.log-description { + @include typescale.body-small; + color: var(--md-sys-color-on-surface-variant); +} + +details { + margin-top: common.$spacing-small; +} + +details summary { + @include typescale.label-medium; + cursor: pointer; + color: var(--md-sys-color-primary); +} + +details pre { + background: var(--md-sys-color-surface-variant); + border-radius: common.$spacing-small; + overflow: auto; + padding: common.$spacing-medium; + margin-top: common.$spacing-small; + white-space: pre-wrap; + word-break: break-word; +} + +details code { + font-size: 90%; +} diff --git a/frontend/src/components/experiment_dashboard/participant_stats.ts b/frontend/src/components/experiment_dashboard/participant_stats.ts index d82a41140..b7641f99c 100644 --- a/frontend/src/components/experiment_dashboard/participant_stats.ts +++ b/frontend/src/components/experiment_dashboard/participant_stats.ts @@ -15,6 +15,12 @@ import { StageKind, UnifiedTimestamp, calculatePayoutTotal, + LogEntry, + LogEntryType, + ModelLogEntry, + ModelResponseStatus, + getUnifiedDurationSeconds, + convertUnifiedTimestampToDateTime, } from '@deliberation-lab/utils'; import {getCohortName} from '../../shared/cohort.utils'; import {getParticipantInlineDisplay} from '../../shared/participant.utils'; @@ -56,7 +62,7 @@ export class Preview extends MobxLitElement { return html` ${this.renderChips()} ${this.renderTable()} ${this.renderStats()}
- ${this.renderStageDatas()} + ${this.renderAgentLog()} ${this.renderStageDatas()} `; } @@ -329,6 +335,126 @@ export class Preview extends MobxLitElement {
${label}: ${convertUnifiedTimestampToDate(value)}
`; } + + /** Get logs filtered to this participant */ + private getParticipantLogs(): ModelLogEntry[] { + if (!this.profile) return []; + const privateId = this.profile.privateId; + return this.experimentManager.logs + .filter( + (log): log is ModelLogEntry => + log.type === LogEntryType.MODEL && log.privateId === privateId, + ) + .sort((a, b) => { + // Sort by query timestamp descending (newest first) + const aTime = Number(a.queryTimestamp ?? 0); + const bTime = Number(b.queryTimestamp ?? 0); + return bTime - aTime; + }); + } + + /** Render the Agent log section */ + private renderAgentLog() { + if (!this.profile?.agentConfig) { + return nothing; + } + + const agentConfig = this.profile.agentConfig; + const logs = this.getParticipantLogs(); + + return html` +

Agent log

+
+ ${this.renderAgentConfig(agentConfig)} +
+ + Total API calls: ${logs.length} + + + Successful: + ${logs.filter((l) => l.response.status === ModelResponseStatus.OK) + .length} + + + Failed: + ${logs.filter((l) => l.response.status !== ModelResponseStatus.OK) + .length} + +
+ ${logs.length === 0 + ? html`
No API calls yet
` + : logs.map((log) => this.renderLogEntry(log))} +
+
+ `; + } + + /** Render agent configuration info */ + private renderAgentConfig( + agentConfig: NonNullable, + ) { + return html` +
+
+
+
API Type
+
${agentConfig.modelSettings.apiType}
+
+
+
Model
+
${agentConfig.modelSettings.modelName}
+
+ ${agentConfig.promptContext + ? html` +
+
Prompt context
+
${agentConfig.promptContext}
+
+ ` + : nothing} +
+
+ `; + } + + /** Render a single log entry */ + private renderLogEntry(log: ModelLogEntry) { + const status = log.response.status; + const isSuccess = status === ModelResponseStatus.OK; + + return html` +
+
+ + ${status.toUpperCase()} + + + ${log.queryTimestamp + ? convertUnifiedTimestampToDateTime(log.queryTimestamp) + : 'N/A'} + + + (${getUnifiedDurationSeconds( + log.queryTimestamp, + log.responseTimestamp, + )}s) + + ${this.getStageName(log.stageId)} +
+ ${log.description + ? html`
${log.description}
` + : nothing} +
+ Prompt +
${log.prompt}
+
+
+ Response +
${JSON.stringify(log.response, null, 2)}
+
+
+ `; + } } declare global { diff --git a/functions/src/agent.utils.ts b/functions/src/agent.utils.ts index e4e00608b..d055650b0 100644 --- a/functions/src/agent.utils.ts +++ b/functions/src/agent.utils.ts @@ -144,6 +144,7 @@ export async function getAgentResponse( prompt, generationConfig, structuredOutputConfig, + modelSettings.useWebSearch, ); } else if (modelSettings.apiType === ApiKeyType.OPENAI_API_KEY) { response = await getOpenAIAPIResponse( @@ -152,6 +153,7 @@ export async function getAgentResponse( prompt, generationConfig, structuredOutputConfig, + modelSettings.useWebSearch, ); } else if (modelSettings.apiType === ApiKeyType.CLAUDE_API_KEY) { response = await getClaudeAPIResponse( @@ -160,6 +162,7 @@ export async function getAgentResponse( prompt, generationConfig, structuredOutputConfig, + modelSettings.useWebSearch, ); } else if (modelSettings.apiType === ApiKeyType.OLLAMA_CUSTOM_URL) { response = await getOllamaResponse( @@ -191,6 +194,7 @@ export async function getGeminiResponse( prompt: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { return await getGeminiAPIResponse( apiKeyConfig.geminiApiKey, @@ -198,6 +202,7 @@ export async function getGeminiResponse( prompt, generationConfig, structuredOutputConfig, + useWebSearch, ); } @@ -207,6 +212,7 @@ export async function getOpenAIAPIResponse( prompt: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { return await getOpenAIAPIChatCompletionResponse( apiKeyConfig.openAIApiKey?.apiKey || '', @@ -215,6 +221,7 @@ export async function getOpenAIAPIResponse( prompt, generationConfig, structuredOutputConfig, + useWebSearch, ); } @@ -224,6 +231,7 @@ export async function getClaudeAPIResponse( prompt: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { return await getClaudeAPIChatCompletionResponse( apiKeyConfig.claudeApiKey?.apiKey || '', @@ -232,6 +240,7 @@ export async function getClaudeAPIResponse( prompt, generationConfig, structuredOutputConfig, + useWebSearch, ); } diff --git a/functions/src/agent_participant.utils.ts b/functions/src/agent_participant.utils.ts index 424ba372b..8409aef13 100644 --- a/functions/src/agent_participant.utils.ts +++ b/functions/src/agent_participant.utils.ts @@ -29,6 +29,10 @@ export async function completeStageAsAgentParticipant( participant: ParticipantProfileExtended, ) { const experimentId = experiment.id; + console.log( + `[AgentParticipant] completeStageAsAgentParticipant called for ${participant.publicId}, stage: ${participant.currentStageId}, status: ${participant.currentStatus}`, + ); + const participantDoc = getFirestoreParticipantRef( experimentId, participant.privateId, @@ -40,11 +44,13 @@ export async function completeStageAsAgentParticipant( if (!stage) { console.error( - `Could not find stage ${participant.currentStageId} for experiment ${experimentId}`, + `[AgentParticipant] Could not find stage ${participant.currentStageId} for experiment ${experimentId}`, ); return; } + console.log(`[AgentParticipant] Stage kind: ${stage.kind}`); + const status = participant.currentStatus; let updatedStatus = false; @@ -99,7 +105,14 @@ export async function completeStageAsAgentParticipant( stage, ); + console.log( + `[AgentParticipant] stageActions for ${participant.publicId}: callApi=${stageActions.callApi}, moveToNextStage=${stageActions.moveToNextStage}`, + ); + if (stageActions.callApi) { + console.log( + `[AgentParticipant] Calling API for ${participant.publicId}...`, + ); const response = await getParsedAgentParticipantPromptResponse( experimenterData, experiment.id, @@ -109,12 +122,20 @@ export async function completeStageAsAgentParticipant( // TODO: Try fetching custom participant prompt first stageManager.getDefaultParticipantStructuredPrompt(stage), ); + console.log( + `[AgentParticipant] API response for ${participant.publicId}:`, + response ? JSON.stringify(response) : 'null', + ); if (response) { const answer = stageManager.extractAgentParticipantAnswerFromResponse( participant, stage, response, ); + console.log( + `[AgentParticipant] Extracted answer for ${participant.publicId}:`, + answer ? JSON.stringify(answer) : 'undefined', + ); // If profile stage, no action needed as there is no "answer" // TODO: Consider making "set profile" not part of a stage // Otherwise, write answer to storage @@ -128,22 +149,43 @@ export async function completeStageAsAgentParticipant( .doc(participant.privateId) .collection('stageData') .doc(stage.id); + console.log( + `[AgentParticipant] Writing answer to stageData for ${participant.publicId}`, + ); answerDoc.set(answer); } + // For profile stage, log that the participant object was modified + if (stage.kind === StageKind.PROFILE) { + console.log( + `[AgentParticipant] Profile stage - participant updated: name=${participant.name}, avatar=${participant.avatar}, pronouns=${participant.pronouns}`, + ); + } } } if (stageActions.moveToNextStage) { + console.log( + `[AgentParticipant] Moving ${participant.publicId} to next stage...`, + ); await updateParticipantNextStage( experimentId, participant, experiment.stageIds, ); + console.log( + `[AgentParticipant] ${participant.publicId} now on stage ${participant.currentStageId}`, + ); } // Write ParticipantAnswer doc if profile has been updated if (stageActions.moveToNextStage || updatedStatus) { - participantDoc.set(participant); + console.log( + `[AgentParticipant] Writing participant doc for ${participant.publicId} (moveToNextStage=${stageActions.moveToNextStage}, updatedStatus=${updatedStatus})`, + ); + await participantDoc.set(participant); + console.log( + `[AgentParticipant] Participant doc written for ${participant.publicId}`, + ); } } diff --git a/functions/src/api/claude.api.ts b/functions/src/api/claude.api.ts index 4180e5211..38143bd67 100644 --- a/functions/src/api/claude.api.ts +++ b/functions/src/api/claude.api.ts @@ -58,30 +58,6 @@ function makeStructuredOutputSchema(schema: StructuredOutputSchema): object { }; } -function makeStructuredOutputGenerationConfig( - structuredOutputConfig?: StructuredOutputConfig, -): Partial { - if ( - !structuredOutputConfig || - structuredOutputConfig.type === StructuredOutputType.NONE - ) { - return {responseMimeType: 'text/plain'}; - } - if (structuredOutputConfig.type === StructuredOutputType.JSON_FORMAT) { - return {responseMimeType: 'application/json'}; - } - if (!structuredOutputConfig.schema) { - throw new Error( - `Expected schema for structured output type ${structuredOutputConfig.type}`, - ); - } - const schema = makeStructuredOutputSchema(structuredOutputConfig.schema); - return { - responseMimeType: 'application/json', - responseSchema: schema, - }; -} - function convertToClaudeFormat( prompt: string | Array<{role: string; content: string; name?: string}>, ): Array<{role: string; content: string; name?: string}> { @@ -98,6 +74,7 @@ export async function callClaudeChatCompletion( prompt: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { const client = new Anthropic({apiKey, baseURL: baseUrl}); const allMessages = convertToClaudeFormat(prompt); @@ -112,6 +89,12 @@ export async function callClaudeChatCompletion( content: string; }[]; + // Configure web search tool if enabled + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const tools: any[] | undefined = useWebSearch + ? [{type: 'web_search_20250305', name: 'web_search'}] + : undefined; + let response; try { response = await client.messages.create({ @@ -119,6 +102,7 @@ export async function callClaudeChatCompletion( system: systemPrompt, // The system prompt as a top-level string messages: filteredMessages, // The array containing only user/assistant turns max_tokens: generationConfig.maxTokens, + ...(tools && {tools}), ...(generationConfig.temperature !== 1.0 ? {temperature: generationConfig.temperature} : generationConfig.topP !== 1.0 @@ -161,13 +145,19 @@ export async function callClaudeChatCompletion( }; } + // Find the text content block (when web search is used, response contains multiple blocks) + const textBlock = response.content.find( + (block: {type: string}) => block.type === 'text', + ) as {type: string; text?: string} | undefined; + const responseText = textBlock?.text; + const finishReason = response.stop_reason; if (finishReason === MAX_TOKENS_FINISH_REASON) { return { status: ModelResponseStatus.LENGTH_ERROR, generationConfig, rawResponse: JSON.stringify(response), - text: response.content[0].text, + text: responseText, errorMessage: `Token limit (${generationConfig.maxTokens}) exceeded`, }; } else if (finishReason !== SUCCESS_FINISH_REASON) { @@ -175,7 +165,7 @@ export async function callClaudeChatCompletion( status: ModelResponseStatus.UNKNOWN_ERROR, generationConfig, rawResponse: JSON.stringify(response), - text: response.content[0].text, + text: responseText, errorMessage: `Provider sent unrecognized finish_reason: ${finishReason}`, }; } @@ -184,7 +174,7 @@ export async function callClaudeChatCompletion( status: ModelResponseStatus.OK, generationConfig, rawResponse: JSON.stringify(response), - text: response.content[0].text, + text: responseText, }; if (structuredOutputConfig?.enabled) { return addParsedModelResponse(modelResponse); @@ -199,6 +189,7 @@ export async function getClaudeAPIChatCompletionResponse( promptText: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { try { const response = await callClaudeChatCompletion( @@ -208,6 +199,7 @@ export async function getClaudeAPIChatCompletionResponse( promptText, generationConfig, structuredOutputConfig, + useWebSearch, ); if (!response) { return { diff --git a/functions/src/api/gemini.api.ts b/functions/src/api/gemini.api.ts index eaaddac9d..b351d472d 100644 --- a/functions/src/api/gemini.api.ts +++ b/functions/src/api/gemini.api.ts @@ -15,6 +15,7 @@ import { StructuredOutputSchema, ModelResponseStatus, ModelResponse, + addParsedModelResponse, } from '@deliberation-lab/utils'; const GEMINI_DEFAULT_MODEL = 'gemini-2.5-flash'; @@ -165,6 +166,7 @@ export async function callGemini( generationConfig: GenerationConfig, modelName = GEMINI_DEFAULT_MODEL, safetySettings?: SafetySetting[], + useGoogleSearch?: boolean, ): Promise { const genAI = new GoogleGenAI({apiKey}); @@ -181,6 +183,11 @@ export async function callGemini( config.systemInstruction = systemInstruction; } + // Add Google Search grounding tool if enabled + if (useGoogleSearch) { + config.tools = [{googleSearch: {}}]; + } + const response = await genAI.models.generateContent({ model: modelName, contents: contents, @@ -253,7 +260,8 @@ export async function callGemini( imageDataList: imageDataList.length > 0 ? imageDataList : undefined, }; - return modelResponse; + // Parse JSON from response text if present + return addParsedModelResponse(modelResponse) ?? modelResponse; } /** Constructs Gemini API query and returns response. */ @@ -263,6 +271,7 @@ export async function getGeminiAPIResponse( promptText: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useGoogleSearch?: boolean, ): Promise { // Extract disableSafetyFilters setting from generationConfig const disableSafetyFilters = generationConfig.disableSafetyFilters ?? false; @@ -319,6 +328,7 @@ export async function getGeminiAPIResponse( geminiConfig, modelName, safetySettings, + useGoogleSearch, ); // eslint-disable-next-line @typescript-eslint/no-explicit-any } catch (error: any) { diff --git a/functions/src/api/openai.api.ts b/functions/src/api/openai.api.ts index 5be019dab..f056c0acb 100644 --- a/functions/src/api/openai.api.ts +++ b/functions/src/api/openai.api.ts @@ -15,14 +15,15 @@ const MAX_TOKENS_FINISH_REASON = 'length'; const REFUSAL_FINISH_REASON = 'content_filter'; function makeStructuredOutputSchema(schema: StructuredOutputSchema): object { + // OpenAI JSON Schema requires lowercase type names const typeMap: {[key in StructuredOutputDataType]?: string} = { - [StructuredOutputDataType.STRING]: 'STRING', - [StructuredOutputDataType.NUMBER]: 'NUMBER', - [StructuredOutputDataType.INTEGER]: 'INTEGER', - [StructuredOutputDataType.BOOLEAN]: 'BOOLEAN', - [StructuredOutputDataType.ARRAY]: 'ARRAY', - [StructuredOutputDataType.OBJECT]: 'OBJECT', - [StructuredOutputDataType.ENUM]: 'STRING', + [StructuredOutputDataType.STRING]: 'string', + [StructuredOutputDataType.NUMBER]: 'number', + [StructuredOutputDataType.INTEGER]: 'integer', + [StructuredOutputDataType.BOOLEAN]: 'boolean', + [StructuredOutputDataType.ARRAY]: 'array', + [StructuredOutputDataType.OBJECT]: 'object', + [StructuredOutputDataType.ENUM]: 'string', }; const type = typeMap[schema.type]; if (!type) { @@ -83,8 +84,11 @@ function makeStructuredOutputParameters( const schema = makeStructuredOutputSchema(structuredOutputConfig.schema); return { type: 'json_schema', - strict: true, - json_schema: schema, + json_schema: { + name: 'response_schema', + strict: true, + schema: schema, + }, }; } @@ -116,6 +120,7 @@ export async function callOpenAIChatCompletion( prompt: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + _useWebSearch?: boolean, // Accepted but not implemented - OpenAI chat completions API doesn't support web search ): Promise { const client = new OpenAI({ apiKey: apiKey, @@ -247,6 +252,7 @@ export async function getOpenAIAPIChatCompletionResponse( promptText: string | Array<{role: string; content: string; name?: string}>, generationConfig: ModelGenerationConfig, structuredOutputConfig?: StructuredOutputConfig, + useWebSearch?: boolean, ): Promise { if (!modelName) { console.warn('OpenAI API model name not set.'); @@ -275,6 +281,7 @@ export async function getOpenAIAPIChatCompletionResponse( promptText, generationConfig, structuredOutputConfig, + useWebSearch, ); if (!response) { return { diff --git a/functions/src/chat/chat.agent.ts b/functions/src/chat/chat.agent.ts index 7066fca31..79791e8a1 100644 --- a/functions/src/chat/chat.agent.ts +++ b/functions/src/chat/chat.agent.ts @@ -57,11 +57,23 @@ export async function createAgentChatMessageFromPrompt( // Profile of agent who will be sending the chat message user: ParticipantProfileExtended | MediatorProfileExtended, ) { - if (!user.agentConfig) return false; + console.log( + `[AgentChat] createAgentChatMessageFromPrompt called for user ${user.publicId} (type: ${user.type}), stage: ${stageId}, triggerChatId: ${triggerChatId || 'initial'}`, + ); + + if (!user.agentConfig) { + console.log( + `[AgentChat] No agentConfig for user ${user.publicId}, returning false`, + ); + return false; + } // Stage (in order to determine stage kind) const stage = await getFirestoreStage(experimentId, stageId); - if (!stage) return false; + if (!stage) { + console.log(`[AgentChat] Stage ${stageId} not found, returning false`); + return false; + } // Fetches stored (else default) prompt config for given stage const promptConfig = (await getStructuredPromptConfig( @@ -71,9 +83,20 @@ export async function createAgentChatMessageFromPrompt( )) as ChatPromptConfig | undefined; if (!promptConfig) { + console.log( + `[AgentChat] No promptConfig for user ${user.publicId} on stage ${stageId}, returning false`, + ); return false; } + console.log( + `[AgentChat] Got promptConfig for user ${user.publicId}:`, + JSON.stringify({ + hasPrompt: !!promptConfig.prompt, + chatSettings: promptConfig.chatSettings, + }), + ); + // Check if this is an initial message request (empty triggerChatId) if (triggerChatId === '') { // Check if we've already sent an initial message for this user @@ -203,9 +226,19 @@ export async function getAgentChatMessage( // Confirm that agent can send chat messages based on prompt config const chatSettings = promptConfig.chatSettings; + console.log( + `[AgentChat] getAgentChatMessage for ${user.publicId}: checking canSendAgentChatMessage with ${chatMessages.length} existing messages, chatSettings:`, + JSON.stringify(chatSettings), + ); if (!canSendAgentChatMessage(user.publicId, chatSettings, chatMessages)) { + console.log( + `[AgentChat] canSendAgentChatMessage returned false for ${user.publicId}`, + ); return {message: null, success: true}; } + console.log( + `[AgentChat] canSendAgentChatMessage returned true for ${user.publicId}`, + ); // Ensure user has agent config if (!user.agentConfig) { @@ -306,39 +339,51 @@ export async function getAgentChatMessage( let shouldRespond = true; let readyToEndChat = false; - if (structured?.enabled && response.text) { - const jsonMatch = response.text.match(/```json\n(\{[\s\S]*\})\n```/); - if (jsonMatch && jsonMatch[1]) { - try { - const parsed = JSON.parse(jsonMatch[1]) as Record; - - const shouldRespondValue = structured.shouldRespondField - ? parsed[structured.shouldRespondField] - : undefined; - shouldRespond = shouldRespondValue === false ? false : true; - - const messageField = structured.messageField || 'response'; - if (typeof parsed[messageField] === 'string') { - message = parsed[messageField] as string; + if (structured?.enabled) { + // Try to get parsed response - either from parsedResponse (set by addParsedModelResponse) + // or by parsing markdown code blocks in the text + let parsed: Record | null = null; + + if (response.parsedResponse) { + // Use already-parsed response (from OpenAI structured output, etc.) + parsed = response.parsedResponse as Record; + } else if (response.text) { + // Try to parse from markdown code block (legacy format) + const jsonMatch = response.text.match(/```json\n(\{[\s\S]*\})\n```/); + if (jsonMatch && jsonMatch[1]) { + try { + parsed = JSON.parse(jsonMatch[1]) as Record; + } catch (error) { + console.error('getAgentChatMessage JSON parse error in text:', error); } + } + } - const explanationField = structured.explanationField || 'explanation'; - if (typeof parsed[explanationField] === 'string') { - explanation = parsed[explanationField] as string; - } + if (parsed) { + const shouldRespondValue = structured.shouldRespondField + ? parsed[structured.shouldRespondField] + : undefined; + shouldRespond = shouldRespondValue === false ? false : true; + + const messageField = structured.messageField || 'response'; + if (typeof parsed[messageField] === 'string') { + message = parsed[messageField] as string; + } - readyToEndChat = structured.readyToEndField - ? Boolean(parsed[structured.readyToEndField]) - : false; - } catch (error) { - console.error('getAgentChatMessage JSON parse error in text:', error); - // message remains response.text + const explanationField = structured.explanationField || 'explanation'; + if (typeof parsed[explanationField] === 'string') { + explanation = parsed[explanationField] as string; } - } else { - // JSON block not found, message remains response.text + + readyToEndChat = structured.readyToEndField + ? Boolean(parsed[structured.readyToEndField]) + : false; } - } else if ( + } + + if ( !response.text && + !response.parsedResponse && (!response.imageDataList || response.imageDataList.length === 0) ) { return {message: null, success: false}; @@ -505,7 +550,13 @@ export async function sendAgentGroupChatMessage( .doc(chatMessage.id); chatMessage.timestamp = Timestamp.now(); - agentDocument.set(chatMessage); + console.log( + `[AgentChat] sendAgentGroupChatMessage: Writing message ${chatMessage.id} from ${chatMessage.senderId} to experiments/${experimentId}/cohorts/${cohortId}/publicStageData/${stageId}/chats`, + ); + await agentDocument.set(chatMessage); + console.log( + `[AgentChat] sendAgentGroupChatMessage: Message ${chatMessage.id} written successfully`, + ); return true; } diff --git a/functions/src/structured_prompt.utils.ts b/functions/src/structured_prompt.utils.ts index 2a7c2bf9e..2dcd3e805 100644 --- a/functions/src/structured_prompt.utils.ts +++ b/functions/src/structured_prompt.utils.ts @@ -44,6 +44,8 @@ import {stageManager} from './app'; /** Attempts to fetch corresponding prompt config from storage, * else returns the stage's default config. + * If the user has chat settings in their agentConfig, those are merged into + * the prompt config (for quick-add agents that use default prompts). */ export async function getStructuredPromptConfig( experimentId: string, @@ -53,6 +55,8 @@ export async function getStructuredPromptConfig( if (!user.agentConfig) { return undefined; } + let promptConfig: BasePromptConfig | undefined; + switch (user.type) { case UserType.PARTICIPANT: const participantPrompt = await getAgentParticipantPrompt( @@ -61,10 +65,10 @@ export async function getStructuredPromptConfig( user.agentConfig?.agentId, ); // Return stored prompt or fallback default prompt - return ( + promptConfig = participantPrompt ?? - stageManager.getDefaultParticipantStructuredPrompt(stage) - ); + stageManager.getDefaultParticipantStructuredPrompt(stage); + break; case UserType.MEDIATOR: const mediatorPrompt = await getAgentMediatorPrompt( experimentId, @@ -72,10 +76,32 @@ export async function getStructuredPromptConfig( user.agentConfig?.agentId, ); // If prompt not stored under experiment, then return undefined - return mediatorPrompt ?? undefined; + promptConfig = mediatorPrompt ?? undefined; + break; default: return undefined; } + + // If we have a prompt config and the user has chat settings override, + // merge them into the prompt config (for chat stages) + if ( + promptConfig && + user.agentConfig.chatSettings && + (stage.kind === StageKind.CHAT || stage.kind === StageKind.PRIVATE_CHAT) + ) { + // Cast to ChatPromptConfig to access chatSettings + const chatPromptConfig = promptConfig as BasePromptConfig & { + chatSettings?: typeof user.agentConfig.chatSettings; + }; + if (chatPromptConfig.chatSettings) { + chatPromptConfig.chatSettings = { + ...chatPromptConfig.chatSettings, + ...user.agentConfig.chatSettings, + }; + } + } + + return promptConfig; } /** Populates data object with Firestore documents needed for given diff --git a/functions/src/triggers/chat.triggers.ts b/functions/src/triggers/chat.triggers.ts index b54ca33f6..f93581ee9 100644 --- a/functions/src/triggers/chat.triggers.ts +++ b/functions/src/triggers/chat.triggers.ts @@ -91,6 +91,14 @@ export const onPublicChatMessageCreated = onDocumentCreated( const agentParticipants = allParticipants.filter( (p) => p.agentConfig && p.currentStatus === ParticipantStatus.IN_PROGRESS, ); + console.log( + `[ChatTrigger] onPublicChatMessageCreated: Found ${allParticipants.length} total participants, ${agentParticipants.length} agent participants with IN_PROGRESS status`, + ); + agentParticipants.forEach((p) => { + console.log( + `[ChatTrigger] Agent participant: publicId=${p.publicId}, privateId=${p.privateId}, status=${p.currentStatus}, stageId=${p.currentStageId}`, + ); + }); await Promise.all( agentParticipants.map((participant) => createAgentChatMessageFromPrompt( diff --git a/functions/src/utils/firestore.ts b/functions/src/utils/firestore.ts index 595887825..29e7cdefd 100644 --- a/functions/src/utils/firestore.ts +++ b/functions/src/utils/firestore.ts @@ -386,8 +386,13 @@ export async function getAgentMediatorPrompt( export async function getAgentParticipantPrompt( experimentId: string, stageId: string, - agentId: string, + agentId: string | undefined, ): Promise { + // If no agentId, return null to use default prompt + if (!agentId) { + return null; + } + const prompt = await app .firestore() .collection('experiments') diff --git a/utils/src/agent.ts b/utils/src/agent.ts index bd4fb1bc6..40abe495a 100644 --- a/utils/src/agent.ts +++ b/utils/src/agent.ts @@ -4,10 +4,7 @@ import { } from './participant'; import {generateId} from './shared'; import {StageKind} from './stages/stage'; -import { - ChatMediatorStructuredOutputConfig, - createStructuredOutputConfig, -} from './structured_output'; +import {ChatMediatorStructuredOutputConfig} from './structured_output'; import { MediatorPromptConfig, ParticipantPromptConfig, @@ -37,6 +34,8 @@ export interface ProfileAgentConfig { agentId: string; // ID of agent persona used promptContext: string; // Additional text to concatenate to agent prompts modelSettings: AgentModelSettings; + // Optional chat settings override for quick-add agents (when no persona prompt is defined) + chatSettings?: AgentChatSettings; } /** Generation config for a specific stage's model call. */ @@ -65,6 +64,8 @@ export interface ModelGenerationConfig { export interface AgentModelSettings { apiType: ApiKeyType; modelName: string; + // Enable web search grounding (Google Search for Gemini, web_search for OpenAI/Claude) + useWebSearch?: boolean; } // TODO: Move to structured_prompt.ts @@ -191,6 +192,7 @@ export function createAgentModelSettings( apiType: config.apiType ?? DEFAULT_AGENT_API_TYPE, // TODO: pick model name that matches API? modelName: config.modelName ?? DEFAULT_AGENT_API_MODEL, + useWebSearch: config.useWebSearch ?? false, }; } diff --git a/utils/src/model_response.ts b/utils/src/model_response.ts index 6c19cf2e7..cb20348ae 100644 --- a/utils/src/model_response.ts +++ b/utils/src/model_response.ts @@ -114,8 +114,13 @@ export function addParsedModelResponse(response: ModelResponse) { return; } + // If there's no text to parse, return the response as-is + if (!response.text) { + return response; + } + try { - const cleanedText = response.text!.replace(/```json\s*|\s*```/g, '').trim(); + const cleanedText = response.text.replace(/```json\s*|\s*```/g, '').trim(); response.parsedResponse = JSON.parse(cleanedText); return response; } catch (error) {