Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions packages/django-google-spanner/django_spanner/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@
)


def _escape_tzname(tzname):
"""Escape a time zone name for embedding in a Spanner string literal.

The datetime helpers below inline the time zone name into a quoted
string literal. Cloud Spanner (GoogleSQL) uses backslash escaping inside
string literals, so a name containing a quote would otherwise close the
literal and inject SQL. Both quote characters are escaped so the result is
safe in single- and double-quoted contexts.
"""
return tzname.replace("\\", "\\\\").replace('"', '\\"').replace("'", "\\'")


class DatabaseOperations(BaseDatabaseOperations):
"""A Spanner-specific version of Django database operations."""

Expand Down Expand Up @@ -431,6 +443,7 @@ def datetime_extract_sql(self, lookup_type, field_name, params, tzname):
:returns: A SQL statement for extracting.
"""
tzname = tzname if settings.USE_TZ and tzname else "UTC"
tzname = _escape_tzname(tzname)
lookup_type = self.extract_names.get(lookup_type, lookup_type)
return (
'EXTRACT(%s FROM %s AT TIME ZONE "%s")'
Expand Down Expand Up @@ -518,6 +531,7 @@ def datetime_trunc_sql(self, lookup_type, field_name, params, tzname="UTC"):
"""
# https://cloud.google.com/spanner/docs/functions-and-operators#timestamp_trunc
tzname = tzname if settings.USE_TZ and tzname else "UTC"
tzname = _escape_tzname(tzname)
if lookup_type == "week":
# Spanner truncates to Sunday but Django expects Monday. First,
# subtract a day so that a Sunday will be truncated to the previous
Expand Down Expand Up @@ -553,6 +567,7 @@ def time_trunc_sql(self, lookup_type, field_name, params, tzname="UTC"):
"""
# https://cloud.google.com/spanner/docs/functions-and-operators#timestamp_trunc
tzname = tzname if settings.USE_TZ and tzname else "UTC"
tzname = _escape_tzname(tzname)
return (
'TIMESTAMP_TRUNC(%s, %s, "%s")'
% (
Expand Down Expand Up @@ -581,6 +596,7 @@ def datetime_cast_date_sql(self, field_name, params, tzname):
"""
# https://cloud.google.com/spanner/docs/functions-and-operators#date
tzname = tzname if settings.USE_TZ and tzname else "UTC"
tzname = _escape_tzname(tzname)
return 'DATE(%s, "%s")' % (field_name, tzname), params

def datetime_cast_time_sql(self, field_name, params, tzname):
Expand All @@ -600,6 +616,7 @@ def datetime_cast_time_sql(self, field_name, params, tzname):
:returns: A SQL statement for casting.
"""
tzname = tzname if settings.USE_TZ and tzname else "UTC"
tzname = _escape_tzname(tzname)
# Cloud Spanner doesn't have a function for converting
# TIMESTAMP to another time zone.
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,62 @@ def test_datetime_cast_time_sql_use_tz_false(self):
)
settings.USE_TZ = True # reset changes.

def test_datetime_extract_sql_escapes_tzname(self):
settings.USE_TZ = True
self.assertEqual(
self.db_operations.datetime_extract_sql(
"year", "dummy_field", None, 'X" OR "a"="a'
),
(
'EXTRACT(year FROM dummy_field AT TIME ZONE "X\\" OR \\"a\\"=\\"a")',
None,
),
)

def test_datetime_trunc_sql_escapes_tzname(self):
settings.USE_TZ = True
self.assertEqual(
self.db_operations.datetime_trunc_sql(
"day", "dummy_field", None, 'X" OR "a"="a'
),
(
'TIMESTAMP_TRUNC(dummy_field, day, "X\\" OR \\"a\\"=\\"a")',
None,
),
)

def test_time_trunc_sql_escapes_tzname(self):
settings.USE_TZ = True
self.assertEqual(
self.db_operations.time_trunc_sql(
"day", "dummy_field", None, 'X" OR "a"="a'
),
(
'TIMESTAMP_TRUNC(dummy_field, day, "X\\" OR \\"a\\"=\\"a")',
None,
),
)

def test_datetime_cast_date_sql_escapes_tzname(self):
settings.USE_TZ = True
self.assertEqual(
self.db_operations.datetime_cast_date_sql(
"dummy_field", None, 'X" OR "a"="a'
),
('DATE(dummy_field, "X\\" OR \\"a\\"=\\"a")', None),
)

def test_datetime_cast_time_sql_escapes_tzname(self):
settings.USE_TZ = True
self.assertEqual(
self.db_operations.datetime_cast_time_sql("dummy_field", None, "X' || 'a"),
(
"TIMESTAMP(FORMAT_TIMESTAMP('%Y-%m-%d %R:%E9S %Z', "
"dummy_field, 'X\\' || \\'a'))",
None,
),
)

def test_date_interval_sql(self):
self.assertEqual(
self.db_operations.date_interval_sql(timedelta(days=1)),
Expand Down
Loading