Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ Zen for Node.js 16+ is compatible with:
- βœ… [`@clickhouse/client`](https://www.npmjs.com/package/@clickhouse/client) 1.x
- βœ… [`@prisma/client`](https://www.npmjs.com/package/@prisma/client) 6.x and 5.x
- βœ… [`@prisma/adapter-pg`](https://www.npmjs.com/package/@prisma/adapter-pg) 7.x
- βœ… [`@prisma/adapter-better-sqlite3`](https://www.npmjs.com/package/@prisma/adapter-better-sqlite3) 7.x

### Cloud providers

Expand Down
151 changes: 151 additions & 0 deletions end2end/tests-new/prisma-better-sqlite3.test.mjs
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import { spawn } from "child_process";
import { resolve } from "path";
import { test, before } from "node:test";
import { equal, fail, match, doesNotMatch } from "node:assert";
import { getRandomPort } from "./utils/get-port.mjs";
import { timeout } from "./utils/timeout.mjs";
import { promisify } from "util";
import { exec as execCb } from "child_process";

const execAsync = promisify(execCb);

const appDir = resolve(
import.meta.dirname,
"../../sample-apps/prisma-better-sqlite3"
);

process.env.DATABASE_URL = "file:./data/dev.db";

const port = await getRandomPort();
const port2 = await getRandomPort();

before(async () => {
await execAsync(
"npx prisma generate", // Generate prisma client
{
cwd: appDir,
}
);

await execAsync(
"npx prisma migrate reset --force", // Rset db and apply migrations
{
cwd: appDir,
}
);
});

test("it blocks request in blocking mode", async () => {
const server = spawn(
`node`,
["-r", "@aikidosec/firewall/instrument", "--import", "tsx", "app.ts", port],
{
cwd: appDir,
env: {
...process.env,
AIKIDO_DEBUG: "true",
AIKIDO_BLOCK: "true",
},
}
);

try {
server.on("error", (err) => {
fail(err.message);
});

let stdout = "";
server.stdout.on("data", (data) => {
stdout += data.toString();
});

let stderr = "";
server.stderr.on("data", (data) => {
stderr += data.toString();
});

// Wait for the server to start
await timeout(2000);

const [sqlInjection, normalAdd] = await Promise.all([
fetch(`http://127.0.0.1:${port}/posts/Test' OR 1=1 -- C`, {
method: "GET",
signal: AbortSignal.timeout(5000),
}),
fetch(`http://127.0.0.1:${port}/posts/Happy`, {
method: "GET",
signal: AbortSignal.timeout(5000),
}),
]);

equal(sqlInjection.status, 500);
equal(normalAdd.status, 200);
match(stdout, /Starting agent/);
match(stderr, /Zen has blocked an SQL injection/);
} catch (err) {
fail(err);
} finally {
server.kill();
}
});

test("it does not block request in non-blocking mode", async () => {
const server = spawn(
`node`,
[
"-r",
"@aikidosec/firewall/instrument",
"--import",
"tsx",
"app.ts",
port2,
],
{
cwd: appDir,
env: {
...process.env,
AIKIDO_DEBUG: "true",
AIKIDO_BLOCK: "false",
},
}
);

try {
server.on("error", (err) => {
fail(err.message);
});

let stdout = "";
server.stdout.on("data", (data) => {
stdout += data.toString();
});

let stderr = "";
server.stderr.on("data", (data) => {
stderr += data.toString();
});

// Wait for the server to start
await timeout(2000);

const [sqlInjection, normalAdd] = await Promise.all([
fetch(`http://127.0.0.1:${port2}/posts/Test' OR 1=1 -- C`, {
method: "GET",
signal: AbortSignal.timeout(5000),
}),
fetch(`http://127.0.0.1:${port2}/posts/Happy`, {
method: "GET",
signal: AbortSignal.timeout(5000),
}),
]);

equal(sqlInjection.status, 200);
equal(normalAdd.status, 200);
match(stdout, /Starting agent/);
doesNotMatch(stderr, /Zen has blocked an SQL injection/);
} catch (err) {
fail(err);
} finally {
server.kill();
}
});
4 changes: 3 additions & 1 deletion library/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

117 changes: 78 additions & 39 deletions library/sinks/BetterSQLite3.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { Hooks } from "../agent/hooks/Hooks";
import type { PackageFunctionInstrumentationInstruction } from "../agent/hooks/instrumentation/types";
import { InterceptorResult } from "../agent/hooks/InterceptorResult";
import { wrapExport } from "../agent/hooks/wrapExport";
import type { PartialWrapPackageInfo } from "../agent/hooks/WrapPackageInfo";
import { Wrapper } from "../agent/Wrapper";
import { isPlainObject } from "../helpers/isPlainObject";
import { checkContextForIdor } from "../vulnerabilities/idor/checkContextForIdor";
Expand All @@ -11,6 +12,10 @@ import { checkContextForSqlInjection } from "../vulnerabilities/sql-injection/ch
import { SQLDialect } from "../vulnerabilities/sql-injection/dialects/SQLDialect";
import { SQLDialectSQLite } from "../vulnerabilities/sql-injection/dialects/SQLDialectSQLite";

const sqlFunctions = ["exec", "pragma"];
const fsPathFunctions = ["backup", "loadExtension"];
const statementSqlFunctions = ["run", "get", "all", "iterate", "bind"];

export class BetterSQLite3 implements Wrapper {
private readonly dialect: SQLDialect = new SQLDialectSQLite();

Expand Down Expand Up @@ -143,51 +148,51 @@ export class BetterSQLite3 implements Wrapper {
return undefined;
}

wrap(hooks: Hooks) {
const sqlFunctions = ["exec", "pragma"];
const fsPathFunctions = ["backup", "loadExtension"];
const statementSqlFunctions = ["run", "get", "all", "iterate", "bind"];
private wrapExports(exports: any, pkgInfo: PartialWrapPackageInfo) {
for (const func of sqlFunctions) {
wrapExport(exports, func, pkgInfo, {
kind: "sql_op",
inspectArgs: (args) => {
return this.inspectQuery(`better-sqlite3.${func}`, args);
},
});
}
for (const func of fsPathFunctions) {
wrapExport(exports, func, pkgInfo, {
kind: "fs_op",
inspectArgs: (args) => {
return this.inspectPath(`better-sqlite3.${func}`, args);
},
});
}

wrapExport(exports, "prepare", pkgInfo, {
kind: "sql_op",
modifyReturnValue: (args, statement) => {
Comment thread
timokoessler marked this conversation as resolved.
for (const func of statementSqlFunctions) {
wrapExport(statement, func, pkgInfo, {
kind: "sql_op",
inspectArgs: (args, _, statement) => {
return this.inspectStatementOperation(
`better-sqlite3.prepare(...).${func}`,
args,
statement
);
},
});
}
return statement;
},
});
}

wrap(hooks: Hooks) {
const pkg = hooks
.addPackage("better-sqlite3")
.withVersion("^12.0.0 || ^11.0.0 || ^10.0.0 || ^9.0.0 || ^8.0.0");

pkg.onRequire((exports, pkgInfo) => {
for (const func of sqlFunctions) {
wrapExport(exports.prototype, func, pkgInfo, {
kind: "sql_op",
inspectArgs: (args) => {
return this.inspectQuery(`better-sqlite3.${func}`, args);
},
});
}
for (const func of fsPathFunctions) {
wrapExport(exports.prototype, func, pkgInfo, {
kind: "fs_op",
inspectArgs: (args) => {
return this.inspectPath(`better-sqlite3.${func}`, args);
},
});
}

wrapExport(exports.prototype, "prepare", pkgInfo, {
kind: "sql_op",
modifyReturnValue: (args, statement) => {
for (const func of statementSqlFunctions) {
wrapExport(statement, func, pkgInfo, {
kind: "sql_op",
inspectArgs: (args, _, statement) => {
return this.inspectStatementOperation(
`better-sqlite3.prepare(...).${func}`,
args,
statement
);
},
});
}
return statement;
},
});
this.wrapExports(exports.prototype, pkgInfo);
});

const wrapperFunctionsInstructions: PackageFunctionInstrumentationInstruction[] =
Expand Down Expand Up @@ -257,5 +262,39 @@ export class BetterSQLite3 implements Wrapper {
},
],
});

hooks
.addPackage("@prisma/adapter-better-sqlite3")
.withVersion("^7.0.0")
.addFileInstrumentation({
// This is not needed for CJS, as we can see sub-imports of CJS packages
path: "dist/index.mjs",
functions: [
{
nodeType: "MethodDefinition",
className: "PrismaBetterSqlite3Adapter",
name: "constructor",
operationKind: undefined,
modifyArgs: (args) => {
const pkgInfo = {
type: "external",
name: "@prisma/adapter-better-sqlite3",
} satisfies PartialWrapPackageInfo;

if (
!args[0] ||
typeof args[0] !== "object" ||
!("prepare" in args[0])
) {
return args;
}

this.wrapExports(args[0], pkgInfo);

return args;
},
},
],
});
}
}
7 changes: 7 additions & 0 deletions sample-apps/prisma-better-sqlite3/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
node_modules
# Keep environment variables out of version control
.env

/generated/prisma

data/
51 changes: 51 additions & 0 deletions sample-apps/prisma-better-sqlite3/app.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import Zen from "@aikidosec/firewall";

import { Hono } from "hono";
import { serve } from "@hono/node-server";
import { prisma } from "./lib/prisma";

process.env.DATABASE_URL = "file:./data/dev.db";

function getPort() {
const port = parseInt(process.argv[2], 10) || 4000;

if (isNaN(port)) {
console.error("Invalid port");
process.exit(1);
}

return port;
}

async function main() {
const port = getPort();

const app = new Hono();

Zen.addHonoMiddleware(app);

app.get("/", async (c) => {
return c.text("Hello, world!");
});

app.get("/posts/:title", async (c) => {
// Insecure, do not use in production
const posts = await prisma.$queryRawUnsafe(
`SELECT * FROM "Post" WHERE title = '${c.req.param().title}'`
);
return c.json(posts);
});

serve({
fetch: app.fetch,
port: port,
}).on("listening", () => {
console.log(`Server is running on port ${port}`);
});
}

main().catch(async (e) => {
console.error(e);
await prisma.$disconnect();
process.exit(1);
});
10 changes: 10 additions & 0 deletions sample-apps/prisma-better-sqlite3/lib/prisma.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import "dotenv/config";
import { PrismaBetterSqlite3 } from "@prisma/adapter-better-sqlite3";
import { PrismaClient } from "../generated/prisma/client";

const connectionString = `${process.env.DATABASE_URL}`;

const adapter = new PrismaBetterSqlite3({ url: connectionString });
const prisma = new PrismaClient({ adapter });

export { prisma };
Loading