Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { IConfigurationService } from '../../../../../platform/configuration/com
import { createDecorator } from '../../../../../platform/instantiation/common/instantiation.js';
import { ChatConfiguration } from '../../common/constants.js';
import { ChatMessageRole, ILanguageModelsService } from '../../common/languageModels.js';
import { TerminalToolId } from '../../common/tools/terminalToolIds.js';
import { IToolData } from '../../common/tools/languageModelToolsService.js';

export const enum ToolRiskLevel {
Expand Down Expand Up @@ -64,16 +65,15 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer
}

getCached(tool: IToolData, parameters: unknown): IToolRiskAssessment | undefined {
const key = tool.id + '::' + stableStringify(parameters);
return this._cache.get(key)?.assessment;
return this._cache.get(this._cacheKey(tool, parameters))?.assessment;
}

async assess(tool: IToolData, parameters: unknown, token: CancellationToken): Promise<IToolRiskAssessment | undefined> {
if (!this.isEnabled()) {
return undefined;
}

const key = tool.id + '::' + stableStringify(parameters);
const key = this._cacheKey(tool, parameters);

const cached = this._cache.get(key);
if (cached) {
Expand Down Expand Up @@ -104,6 +104,10 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer
return promise;
}

private _cacheKey(tool: IToolData, parameters: unknown): string {
return tool.id + '::' + stableStringify(normalizeRiskCacheParameters(tool, parameters));
}

private async _invokeModel(tool: IToolData, parameters: unknown, token: CancellationToken): Promise<IToolRiskAssessment | undefined> {
const modelId = this._configurationService.getValue<string>(ChatConfiguration.ToolRiskAssessmentModel) || 'copilot-fast';

Expand Down Expand Up @@ -145,6 +149,19 @@ export class ChatToolRiskAssessmentService implements IChatToolRiskAssessmentSer
}
}

/**
* Compute the subset of tool parameters that are relevant to the risk
* assessment, used as the cache key so re-invocations of the same tool call
* hit the cache even when model-generated descriptive fields differ.
*/
function normalizeRiskCacheParameters(tool: IToolData, parameters: unknown): unknown {
if (tool.id === TerminalToolId.RunInTerminal && parameters && typeof parameters === 'object') {
const p = parameters as Record<string, unknown>;
return { command: p.command };
}
return parameters;
}

function buildPrompt(tool: IToolData, parameters: unknown): string {
let argsJson: string;
try {
Expand All @@ -155,7 +172,6 @@ function buildPrompt(tool: IToolData, parameters: unknown): string {
if (argsJson.length > MAX_PARAM_BYTES) {
argsJson = argsJson.slice(0, MAX_PARAM_BYTES) + '...[truncated]';
}

return [
`You assess what one terminal command does for a code-editing AI agent, and how risky it is.`,
`Reply with STRICT JSON only (no prose, no markdown fences):`,
Expand Down
Loading