|
| 1 | +"""SQL sanitization utilities for handling identifiers with special characters.""" |
| 2 | + |
| 3 | +import re |
| 4 | +from typing import Set, Tuple |
| 5 | + |
| 6 | + |
| 7 | +class SQLIdentifierQuoter: |
| 8 | + """ |
| 9 | + Utility class for automatically quoting SQL identifiers (table/column names) |
| 10 | + that contain special characters like dashes. |
| 11 | + """ |
| 12 | + |
| 13 | + # Characters that require quoting in identifiers |
| 14 | + SPECIAL_CHARS = {'-', ' ', '.', '@', '#', '$', '%', '^', '&', '*', '(', |
| 15 | + ')', '+', '=', '[', ']', '{', '}', '|', '\\', ':', |
| 16 | + ';', '"', "'", '<', '>', ',', '?', '/'} |
| 17 | + # SQL keywords that should not be quoted |
| 18 | + SQL_KEYWORDS = { |
| 19 | + 'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', |
| 20 | + 'AS', 'AND', 'OR', 'NOT', 'IN', 'BETWEEN', 'LIKE', 'IS', 'NULL', 'ORDER', |
| 21 | + 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'INSERT', 'UPDATE', 'DELETE', |
| 22 | + 'CREATE', 'DROP', 'ALTER', 'TABLE', 'INTO', 'VALUES', 'SET', 'COUNT', |
| 23 | + 'SUM', 'AVG', 'MAX', 'MIN', 'DISTINCT', 'ALL', 'UNION', 'INTERSECT', |
| 24 | + 'EXCEPT', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'CAST', 'ASC', 'DESC' |
| 25 | + } |
| 26 | + |
| 27 | + @classmethod |
| 28 | + def needs_quoting(cls, identifier: str) -> bool: |
| 29 | + """ |
| 30 | + Check if an identifier needs quoting based on special characters. |
| 31 | + |
| 32 | + Args: |
| 33 | + identifier: The table or column name to check |
| 34 | + |
| 35 | + Returns: |
| 36 | + True if the identifier needs quoting, False otherwise |
| 37 | + """ |
| 38 | + # Already quoted |
| 39 | + if (identifier.startswith('"') and identifier.endswith('"')) or \ |
| 40 | + (identifier.startswith('`') and identifier.endswith('`')): |
| 41 | + return False |
| 42 | + |
| 43 | + # Check if it's a SQL keyword |
| 44 | + if identifier.upper() in cls.SQL_KEYWORDS: |
| 45 | + return False |
| 46 | + |
| 47 | + # Check for special characters |
| 48 | + return any(char in cls.SPECIAL_CHARS for char in identifier) |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def quote_identifier(identifier: str, quote_char: str = '"') -> str: |
| 52 | + """ |
| 53 | + Quote an identifier if not already quoted. |
| 54 | + |
| 55 | + Args: |
| 56 | + identifier: The identifier to quote |
| 57 | + quote_char: The quote character to use (default: " for PostgreSQL/standard SQL) |
| 58 | + |
| 59 | + Returns: |
| 60 | + Quoted identifier |
| 61 | + """ |
| 62 | + identifier = identifier.strip() |
| 63 | + |
| 64 | + # Don't double-quote |
| 65 | + if (identifier.startswith('"') and identifier.endswith('"')) or \ |
| 66 | + (identifier.startswith('`') and identifier.endswith('`')): |
| 67 | + return identifier |
| 68 | + |
| 69 | + return f'{quote_char}{identifier}{quote_char}' |
| 70 | + |
| 71 | + @classmethod |
| 72 | + def extract_table_names_from_query(cls, sql_query: str) -> Set[str]: |
| 73 | + """ |
| 74 | + Extract potential table names from a SQL query. |
| 75 | + Looks for identifiers after FROM, JOIN, UPDATE, INSERT INTO, etc. |
| 76 | + |
| 77 | + Args: |
| 78 | + sql_query: The SQL query to parse |
| 79 | + |
| 80 | + Returns: |
| 81 | + Set of potential table names |
| 82 | + """ |
| 83 | + table_names = set() |
| 84 | + |
| 85 | + # Pattern to match table names after FROM, JOIN, UPDATE, INSERT INTO, etc. |
| 86 | + # This is a heuristic approach - not perfect but handles common cases |
| 87 | + patterns = [ |
| 88 | + r'\bFROM\s+([a-zA-Z0-9_\-]+)', |
| 89 | + r'\bJOIN\s+([a-zA-Z0-9_\-]+)', |
| 90 | + r'\bUPDATE\s+([a-zA-Z0-9_\-]+)', |
| 91 | + r'\bINSERT\s+INTO\s+([a-zA-Z0-9_\-]+)', |
| 92 | + r'\bTABLE\s+([a-zA-Z0-9_\-]+)', |
| 93 | + ] |
| 94 | + |
| 95 | + for pattern in patterns: |
| 96 | + matches = re.finditer(pattern, sql_query, re.IGNORECASE) |
| 97 | + for match in matches: |
| 98 | + table_name = match.group(1).strip() |
| 99 | + # Skip if it's already quoted or an alias |
| 100 | + if not ((table_name.startswith('"') and table_name.endswith('"')) or |
| 101 | + (table_name.startswith('`') and table_name.endswith('`'))): |
| 102 | + table_names.add(table_name) |
| 103 | + |
| 104 | + return table_names |
| 105 | + |
| 106 | + @classmethod |
| 107 | + def auto_quote_identifiers( |
| 108 | + cls, |
| 109 | + sql_query: str, |
| 110 | + known_tables: Set[str], |
| 111 | + quote_char: str = '"' |
| 112 | + ) -> Tuple[str, bool]: |
| 113 | + """ |
| 114 | + Automatically quote table names with special characters in a SQL query. |
| 115 | + |
| 116 | + Args: |
| 117 | + sql_query: The SQL query to process |
| 118 | + known_tables: Set of known table names from the database schema |
| 119 | + quote_char: Quote character to use (default: " for PostgreSQL, use ` for MySQL) |
| 120 | + |
| 121 | + Returns: |
| 122 | + Tuple of (modified_query, was_modified) |
| 123 | + """ |
| 124 | + modified = False |
| 125 | + result_query = sql_query |
| 126 | + |
| 127 | + # Extract potential table names from query |
| 128 | + query_tables = cls.extract_table_names_from_query(sql_query) |
| 129 | + |
| 130 | + # For each table that needs quoting |
| 131 | + for table in query_tables: |
| 132 | + # Check if this table exists in known schema and needs quoting |
| 133 | + if table in known_tables and cls.needs_quoting(table): |
| 134 | + # Quote the table name |
| 135 | + quoted = cls.quote_identifier(table, quote_char) |
| 136 | + |
| 137 | + # Replace unquoted occurrences with quoted version |
| 138 | + # Use word boundaries to avoid partial replacements |
| 139 | + # Handle cases: FROM table, JOIN table, table.column, etc. |
| 140 | + patterns_to_replace = [ |
| 141 | + (rf'\b{re.escape(table)}\b(?!\s*\.)', quoted), |
| 142 | + (rf'\b{re.escape(table)}\.', f'{quoted}.'), |
| 143 | + ] |
| 144 | + |
| 145 | + for pattern, replacement in patterns_to_replace: |
| 146 | + new_query = re.sub(pattern, replacement, result_query, flags=re.IGNORECASE) |
| 147 | + if new_query != result_query: |
| 148 | + modified = True |
| 149 | + result_query = new_query |
| 150 | + |
| 151 | + return result_query, modified |
| 152 | + |
| 153 | + |
| 154 | +class DatabaseSpecificQuoter: # pylint: disable=too-few-public-methods |
| 155 | + """Factory class to get the appropriate quote character for different databases.""" |
| 156 | + |
| 157 | + @staticmethod |
| 158 | + def get_quote_char(db_type: str) -> str: |
| 159 | + """ |
| 160 | + Get the appropriate quote character for a database type. |
| 161 | + |
| 162 | + Args: |
| 163 | + db_type: Database type ('postgresql', 'mysql', etc.) |
| 164 | + |
| 165 | + Returns: |
| 166 | + Quote character to use |
| 167 | + """ |
| 168 | + if db_type.lower() in ['mysql', 'mariadb']: |
| 169 | + return '`' |
| 170 | + # PostgreSQL, SQLite, SQL Server (standard SQL) use double quotes |
| 171 | + return '"' |
0 commit comments