diff --git a/src/vs/workbench/contrib/chat/browser/tools/chatToolRiskAssessmentService.ts b/src/vs/workbench/contrib/chat/browser/tools/chatToolRiskAssessmentService.ts index d5c544079143c..6347b50a12fe7 100644 --- a/src/vs/workbench/contrib/chat/browser/tools/chatToolRiskAssessmentService.ts +++ b/src/vs/workbench/contrib/chat/browser/tools/chatToolRiskAssessmentService.ts @@ -11,6 +11,7 @@ import { IConfigurationService } from '../../../../../platform/configuration/com import { createDecorator } from '../../../../../platform/instantiation/common/instantiation.js'; import { ChatConfiguration } from '../../common/constants.js'; import { ChatMessageRole, ILanguageModelsService } from '../../common/languageModels.js'; +import { TerminalToolId } from '../../common/tools/terminalToolIds.js'; import { IToolData } from '../../common/tools/languageModelToolsService.js'; export const enum ToolRiskLevel { @@ -64,8 +65,7 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer } getCached(tool: IToolData, parameters: unknown): IToolRiskAssessment | undefined { - const key = tool.id + '::' + stableStringify(parameters); - return this._cache.get(key)?.assessment; + return this._cache.get(this._cacheKey(tool, parameters))?.assessment; } async assess(tool: IToolData, parameters: unknown, token: CancellationToken): Promise { @@ -73,7 +73,7 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer return undefined; } - const key = tool.id + '::' + stableStringify(parameters); + const key = this._cacheKey(tool, parameters); const cached = this._cache.get(key); if (cached) { @@ -104,6 +104,10 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer return promise; } + private _cacheKey(tool: IToolData, parameters: unknown): string { + return tool.id + '::' + stableStringify(normalizeRiskCacheParameters(tool, parameters)); + } + private async _invokeModel(tool: IToolData, parameters: unknown, token: CancellationToken): Promise { const modelId = this._configurationService.getValue(ChatConfiguration.ToolRiskAssessmentModel) || 'copilot-fast'; @@ -145,6 +149,19 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer } } +/** + * Compute the subset of tool parameters that are relevant to the risk + * assessment, used as the cache key so re-invocations of the same tool call + * hit the cache even when model-generated descriptive fields differ. + */ +function normalizeRiskCacheParameters(tool: IToolData, parameters: unknown): unknown { + if (tool.id === TerminalToolId.RunInTerminal && parameters && typeof parameters === 'object') { + const p = parameters as Record; + return { command: p.command }; + } + return parameters; +} + function buildPrompt(tool: IToolData, parameters: unknown): string { let argsJson: string; try { @@ -155,7 +172,6 @@ function buildPrompt(tool: IToolData, parameters: unknown): string { if (argsJson.length > MAX_PARAM_BYTES) { argsJson = argsJson.slice(0, MAX_PARAM_BYTES) + '...[truncated]'; } - return [ `You assess what one terminal command does for a code-editing AI agent, and how risky it is.`, `Reply with STRICT JSON only (no prose, no markdown fences):`,