Skip to content

Commit 05c93f2

Browse files
authored
Merge pull request #306 from FalkorDB/fix/sql-identifier-quoting
Fix: Auto-quote SQL identifiers with special characters (dashes, spaces, etc.)
2 parents b93cc04 + 9075242 commit 05c93f2

File tree

7 files changed

+677
-2
lines changed

7 files changed

+677
-2
lines changed

.github/wordlist.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,8 @@ socio
9090
sexualized
9191
www
9292
faq
93+
sanitization
94+
Sanitization
95+
JOINs
96+
subqueries
97+
subquery

api/agents/analysis_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def _build_prompt( # pylint: disable=too-many-arguments, too-many-positional-a
202202
- Never skip explaining missing information, ambiguities, or instruction issues.
203203
- Respond ONLY in strict JSON format, without extra text.
204204
- If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far.
205+
- CRITICAL: When table or column names contain special characters (especially dashes/hyphens like '-'), you MUST wrap them in double quotes for PostgreSQL (e.g., "table-name") or backticks for MySQL (e.g., `table-name`). This is NON-NEGOTIABLE.
205206
206207
If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question.
207208

api/core/text2sql.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from api.loaders.postgres_loader import PostgresLoader
1919
from api.loaders.mysql_loader import MySQLLoader
2020
from api.memory.graphiti_tool import MemoryTool
21+
from api.sql_utils import SQLIdentifierQuoter, DatabaseSpecificQuoter
2122

2223
# Use the same delimiter as in the JavaScript
2324
MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||"
@@ -316,6 +317,33 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
316317
follow_up_result = ""
317318
execution_error = False
318319

320+
# Auto-quote table names with special characters (like dashes)
321+
original_sql = answer_an['sql_query']
322+
if original_sql:
323+
# Extract known table names from the result schema
324+
known_tables = {table[0] for table in result} if result else set()
325+
326+
# Determine database type and get appropriate quote character
327+
db_type, _ = get_database_type_and_loader(db_url)
328+
quote_char = DatabaseSpecificQuoter.get_quote_char(
329+
db_type or 'postgresql'
330+
)
331+
332+
# Auto-quote identifiers with special characters
333+
sanitized_sql, was_modified = (
334+
SQLIdentifierQuoter.auto_quote_identifiers(
335+
original_sql, known_tables, quote_char
336+
)
337+
)
338+
339+
if was_modified:
340+
msg = (
341+
"SQL query auto-sanitized: quoted table names with "
342+
"special characters"
343+
)
344+
logging.info(msg)
345+
answer_an['sql_query'] = sanitized_sql
346+
319347
logging.info("Generated SQL query: %s", answer_an['sql_query']) # nosemgrep
320348
yield json.dumps(
321349
{
@@ -590,7 +618,7 @@ async def generate(): # pylint: disable=too-many-locals,too-many-branches,too-m
590618
return generate()
591619

592620

593-
async def execute_destructive_operation(
621+
async def execute_destructive_operation( # pylint: disable=too-many-statements
594622
user_id: str,
595623
graph_id: str,
596624
confirm_data: ConfirmRequest,
@@ -613,7 +641,7 @@ async def execute_destructive_operation(
613641
raise InvalidArgumentError("No SQL query provided")
614642

615643
# Create a generator function for streaming the confirmation response
616-
async def generate_confirmation():
644+
async def generate_confirmation(): # pylint: disable=too-many-locals,too-many-statements
617645
# Create memory tool for saving query results
618646
memory_tool = await MemoryTool.create(user_id, graph_id)
619647

@@ -635,6 +663,39 @@ async def generate_confirmation():
635663
"message": "Step 2: Executing confirmed SQL query"}
636664
yield json.dumps(step) + MESSAGE_DELIMITER
637665

666+
# Auto-quote table names for confirmed destructive operations
667+
sql_query = confirm_data.sql_query if hasattr(
668+
confirm_data, 'sql_query'
669+
) else ""
670+
if sql_query:
671+
# Get schema to extract known tables
672+
graph = db.select_graph(graph_id)
673+
tables_query = "MATCH (t:Table) RETURN t.name"
674+
try:
675+
tables_res = (await graph.query(tables_query)).result_set
676+
known_tables = (
677+
{row[0] for row in tables_res}
678+
if tables_res else set()
679+
)
680+
except Exception: # pylint: disable=broad-exception-caught
681+
known_tables = set()
682+
683+
# Determine database type and get appropriate quote character
684+
db_type, _ = get_database_type_and_loader(db_url)
685+
quote_char = DatabaseSpecificQuoter.get_quote_char(
686+
db_type or 'postgresql'
687+
)
688+
689+
# Auto-quote identifiers
690+
sanitized_sql, was_modified = (
691+
SQLIdentifierQuoter.auto_quote_identifiers(
692+
sql_query, known_tables, quote_char
693+
)
694+
)
695+
if was_modified:
696+
logging.info("Confirmed SQL query auto-sanitized")
697+
sql_query = sanitized_sql
698+
638699
# Check if this query modifies the database schema using appropriate loader
639700
is_schema_modifying, operation_type = (
640701
loader_class.is_schema_modifying_query(sql_query)

api/sql_utils/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Utility modules for QueryWeaver API."""
2+
3+
from .sql_sanitizer import SQLIdentifierQuoter, DatabaseSpecificQuoter
4+
5+
__all__ = ['SQLIdentifierQuoter', 'DatabaseSpecificQuoter']

api/sql_utils/sql_sanitizer.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""SQL sanitization utilities for handling identifiers with special characters."""
2+
3+
import re
4+
from typing import Set, Tuple
5+
6+
7+
class SQLIdentifierQuoter:
8+
"""
9+
Utility class for automatically quoting SQL identifiers (table/column names)
10+
that contain special characters like dashes.
11+
"""
12+
13+
# Characters that require quoting in identifiers
14+
SPECIAL_CHARS = {'-', ' ', '.', '@', '#', '$', '%', '^', '&', '*', '(',
15+
')', '+', '=', '[', ']', '{', '}', '|', '\\', ':',
16+
';', '"', "'", '<', '>', ',', '?', '/'}
17+
# SQL keywords that should not be quoted
18+
SQL_KEYWORDS = {
19+
'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON',
20+
'AS', 'AND', 'OR', 'NOT', 'IN', 'BETWEEN', 'LIKE', 'IS', 'NULL', 'ORDER',
21+
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'INSERT', 'UPDATE', 'DELETE',
22+
'CREATE', 'DROP', 'ALTER', 'TABLE', 'INTO', 'VALUES', 'SET', 'COUNT',
23+
'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT', 'ALL', 'UNION', 'INTERSECT',
24+
'EXCEPT', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'CAST', 'ASC', 'DESC'
25+
}
26+
27+
@classmethod
28+
def needs_quoting(cls, identifier: str) -> bool:
29+
"""
30+
Check if an identifier needs quoting based on special characters.
31+
32+
Args:
33+
identifier: The table or column name to check
34+
35+
Returns:
36+
True if the identifier needs quoting, False otherwise
37+
"""
38+
# Already quoted
39+
if (identifier.startswith('"') and identifier.endswith('"')) or \
40+
(identifier.startswith('`') and identifier.endswith('`')):
41+
return False
42+
43+
# Check if it's a SQL keyword
44+
if identifier.upper() in cls.SQL_KEYWORDS:
45+
return False
46+
47+
# Check for special characters
48+
return any(char in cls.SPECIAL_CHARS for char in identifier)
49+
50+
@staticmethod
51+
def quote_identifier(identifier: str, quote_char: str = '"') -> str:
52+
"""
53+
Quote an identifier if not already quoted.
54+
55+
Args:
56+
identifier: The identifier to quote
57+
quote_char: The quote character to use (default: " for PostgreSQL/standard SQL)
58+
59+
Returns:
60+
Quoted identifier
61+
"""
62+
identifier = identifier.strip()
63+
64+
# Don't double-quote
65+
if (identifier.startswith('"') and identifier.endswith('"')) or \
66+
(identifier.startswith('`') and identifier.endswith('`')):
67+
return identifier
68+
69+
return f'{quote_char}{identifier}{quote_char}'
70+
71+
@classmethod
72+
def extract_table_names_from_query(cls, sql_query: str) -> Set[str]:
73+
"""
74+
Extract potential table names from a SQL query.
75+
Looks for identifiers after FROM, JOIN, UPDATE, INSERT INTO, etc.
76+
77+
Args:
78+
sql_query: The SQL query to parse
79+
80+
Returns:
81+
Set of potential table names
82+
"""
83+
table_names = set()
84+
85+
# Pattern to match table names after FROM, JOIN, UPDATE, INSERT INTO, etc.
86+
# This is a heuristic approach - not perfect but handles common cases
87+
patterns = [
88+
r'\bFROM\s+([a-zA-Z0-9_\-]+)',
89+
r'\bJOIN\s+([a-zA-Z0-9_\-]+)',
90+
r'\bUPDATE\s+([a-zA-Z0-9_\-]+)',
91+
r'\bINSERT\s+INTO\s+([a-zA-Z0-9_\-]+)',
92+
r'\bTABLE\s+([a-zA-Z0-9_\-]+)',
93+
]
94+
95+
for pattern in patterns:
96+
matches = re.finditer(pattern, sql_query, re.IGNORECASE)
97+
for match in matches:
98+
table_name = match.group(1).strip()
99+
# Skip if it's already quoted or an alias
100+
if not ((table_name.startswith('"') and table_name.endswith('"')) or
101+
(table_name.startswith('`') and table_name.endswith('`'))):
102+
table_names.add(table_name)
103+
104+
return table_names
105+
106+
@classmethod
107+
def auto_quote_identifiers(
108+
cls,
109+
sql_query: str,
110+
known_tables: Set[str],
111+
quote_char: str = '"'
112+
) -> Tuple[str, bool]:
113+
"""
114+
Automatically quote table names with special characters in a SQL query.
115+
116+
Args:
117+
sql_query: The SQL query to process
118+
known_tables: Set of known table names from the database schema
119+
quote_char: Quote character to use (default: " for PostgreSQL, use ` for MySQL)
120+
121+
Returns:
122+
Tuple of (modified_query, was_modified)
123+
"""
124+
modified = False
125+
result_query = sql_query
126+
127+
# Extract potential table names from query
128+
query_tables = cls.extract_table_names_from_query(sql_query)
129+
130+
# For each table that needs quoting
131+
for table in query_tables:
132+
# Check if this table exists in known schema and needs quoting
133+
if table in known_tables and cls.needs_quoting(table):
134+
# Quote the table name
135+
quoted = cls.quote_identifier(table, quote_char)
136+
137+
# Replace unquoted occurrences with quoted version
138+
# Use word boundaries to avoid partial replacements
139+
# Handle cases: FROM table, JOIN table, table.column, etc.
140+
patterns_to_replace = [
141+
(rf'\b{re.escape(table)}\b(?!\s*\.)', quoted),
142+
(rf'\b{re.escape(table)}\.', f'{quoted}.'),
143+
]
144+
145+
for pattern, replacement in patterns_to_replace:
146+
new_query = re.sub(pattern, replacement, result_query, flags=re.IGNORECASE)
147+
if new_query != result_query:
148+
modified = True
149+
result_query = new_query
150+
151+
return result_query, modified
152+
153+
154+
class DatabaseSpecificQuoter: # pylint: disable=too-few-public-methods
155+
"""Factory class to get the appropriate quote character for different databases."""
156+
157+
@staticmethod
158+
def get_quote_char(db_type: str) -> str:
159+
"""
160+
Get the appropriate quote character for a database type.
161+
162+
Args:
163+
db_type: Database type ('postgresql', 'mysql', etc.)
164+
165+
Returns:
166+
Quote character to use
167+
"""
168+
if db_type.lower() in ['mysql', 'mariadb']:
169+
return '`'
170+
# PostgreSQL, SQLite, SQL Server (standard SQL) use double quotes
171+
return '"'

0 commit comments

Comments
 (0)