Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 37 additions & 104 deletions src/selectools/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@ def _make_key(session_id: str, namespace: Optional[str]) -> str:
return session_id


def _validate_session_id(session_id: str) -> None:
"""Validate session ID."""
if not session_id:
raise ValueError("session_id must not be empty")
if "\x00" in session_id:
raise ValueError(f"session_id must not contain null bytes: {session_id!r}")
if len(session_id) > 512:
raise ValueError(f"session_id too long ({len(session_id)} chars, max 512): {session_id!r}")


def _validate_namespace(namespace: Optional[str]) -> None:
"""Validate namespace."""
if namespace is None:
return

if not namespace:
raise ValueError("namespace must not be empty when provided")

if "\x00" in namespace:
raise ValueError(f"namespace must not contain null bytes: {namespace!r}")

if len(namespace) > 512:
raise ValueError(f"namespace too long ({len(namespace)} chars, max 512): {namespace!r}")


@stable
@dataclass
class SessionMetadata:
Expand Down Expand Up @@ -856,37 +881,14 @@ def __init__(
self._prefix = prefix
self._default_ttl = default_ttl

@staticmethod
def _validate_session_id(session_id: str) -> None:
"""Reject session IDs that could cause key collisions or other problems."""
if not session_id:
raise ValueError("session_id must not be empty")
if "\x00" in session_id:
raise ValueError(f"session_id must not contain null bytes: {session_id!r}")
if len(session_id) > 512:
raise ValueError(
f"session_id too long ({len(session_id)} chars, max 512): {session_id!r}"
)

@staticmethod
def _validate_namespace(namespace: Optional[str]) -> None:
if namespace is None:
return
if not namespace:
raise ValueError("namespace must not be empty when provided")
if "\x00" in namespace:
raise ValueError(f"namespace must not contain null bytes: {namespace!r}")
if len(namespace) > 512:
raise ValueError(f"namespace too long ({len(namespace)} chars, max 512): {namespace!r}")

def _key(self, session_id: str, namespace: Optional[str] = None) -> str:
self._validate_session_id(session_id)
self._validate_namespace(namespace)
_validate_session_id(session_id)
_validate_namespace(namespace)
return f"{self._prefix}{_make_key(session_id, namespace)}"

def _meta_key(self, session_id: str, namespace: Optional[str] = None) -> str:
self._validate_session_id(session_id)
self._validate_namespace(namespace)
_validate_session_id(session_id)
_validate_namespace(namespace)
return f"{self._prefix}__meta__{_make_key(session_id, namespace)}"

# -- public API --------------------------------------------------------
Expand Down Expand Up @@ -1129,32 +1131,9 @@ def __init__(

# -- validation helpers ------------------------------------------------

@staticmethod
def _validate_session_id(session_id: str) -> None:
"""Reject session IDs that could cause key collisions or other problems."""
if not session_id:
raise ValueError("session_id must not be empty")
if "\x00" in session_id:
raise ValueError(f"session_id must not contain null bytes: {session_id!r}")
if len(session_id) > 512:
raise ValueError(
f"session_id too long ({len(session_id)} chars, max 512): {session_id!r}"
)

@staticmethod
def _validate_namespace(namespace: Optional[str]) -> None:
if namespace is None:
return
if not namespace:
raise ValueError("namespace must not be empty when provided")
if "\x00" in namespace:
raise ValueError(f"namespace must not contain null bytes: {namespace!r}")
if len(namespace) > 512:
raise ValueError(f"namespace too long ({len(namespace)} chars, max 512): {namespace!r}")

def _key(self, session_id: str, namespace: Optional[str] = None) -> str:
self._validate_session_id(session_id)
self._validate_namespace(namespace)
_validate_session_id(session_id)
_validate_namespace(namespace)
return _make_key(session_id, namespace)

# -- public API --------------------------------------------------------
Expand Down Expand Up @@ -1358,33 +1337,9 @@ def __init__(
# expireAfterSeconds=0 deletes docs once `expires_at` is reached.
self._collection.create_index("expires_at", expireAfterSeconds=0)

# -- validation helpers ------------------------------------------------

@staticmethod
def _validate_session_id(session_id: str) -> None:
if not session_id:
raise ValueError("session_id must not be empty")
if "\x00" in session_id:
raise ValueError(f"session_id must not contain null bytes: {session_id!r}")
if len(session_id) > 512:
raise ValueError(
f"session_id too long ({len(session_id)} chars, max 512): {session_id!r}"
)

@staticmethod
def _validate_namespace(namespace: Optional[str]) -> None:
if namespace is None:
return
if not namespace:
raise ValueError("namespace must not be empty when provided")
if "\x00" in namespace:
raise ValueError(f"namespace must not contain null bytes: {namespace!r}")
if len(namespace) > 512:
raise ValueError(f"namespace too long ({len(namespace)} chars, max 512): {namespace!r}")

def _key(self, session_id: str, namespace: Optional[str] = None) -> str:
self._validate_session_id(session_id)
self._validate_namespace(namespace)
_validate_session_id(session_id)
_validate_namespace(namespace)
return _make_key(session_id, namespace)

# -- public API --------------------------------------------------------
Expand Down Expand Up @@ -1468,7 +1423,7 @@ def search(
return []
mongo_filter: Dict[str, Any] = {}
if namespace is not None:
self._validate_namespace(namespace)
_validate_namespace(namespace)
mongo_filter["namespace"] = namespace
results: List[SessionSearchResult] = []
for doc in self._collection.find(mongo_filter):
Expand Down Expand Up @@ -1545,31 +1500,9 @@ def __init__(

# -- validation helpers ------------------------------------------------

@staticmethod
def _validate_session_id(session_id: str) -> None:
if not session_id:
raise ValueError("session_id must not be empty")
if "\x00" in session_id:
raise ValueError(f"session_id must not contain null bytes: {session_id!r}")
if len(session_id) > 512:
raise ValueError(
f"session_id too long ({len(session_id)} chars, max 512): {session_id!r}"
)

@staticmethod
def _validate_namespace(namespace: Optional[str]) -> None:
if namespace is None:
return
if not namespace:
raise ValueError("namespace must not be empty when provided")
if "\x00" in namespace:
raise ValueError(f"namespace must not contain null bytes: {namespace!r}")
if len(namespace) > 512:
raise ValueError(f"namespace too long ({len(namespace)} chars, max 512): {namespace!r}")

def _key(self, session_id: str, namespace: Optional[str] = None) -> str:
self._validate_session_id(session_id)
self._validate_namespace(namespace)
_validate_session_id(session_id)
_validate_namespace(namespace)
return _make_key(session_id, namespace)

# -- public API --------------------------------------------------------
Expand Down Expand Up @@ -1670,7 +1603,7 @@ def search(
if not terms or limit <= 0:
return []
if namespace is not None:
self._validate_namespace(namespace)
_validate_namespace(namespace)
results: List[SessionSearchResult] = []
for item in self._scan_all():
if namespace is not None and item.get("namespace") != namespace:
Expand Down
10 changes: 5 additions & 5 deletions tests/test_sessions_supabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from selectools.memory import ConversationMemory
from selectools.sessions import SessionMetadata, SupabaseSessionStore
from selectools.types import Message, Role, ToolCall

from selectools.sessions import _validate_namespace
# ======================================================================
# Fake Supabase client
# ======================================================================
Expand Down Expand Up @@ -494,21 +494,21 @@ def test_session_id_at_limit_passes(self) -> None:
def test_empty_namespace_raises(self) -> None:
store = self._store_without_supabase()
with pytest.raises(ValueError, match="must not be empty"):
store._validate_namespace("")
_validate_namespace("")

def test_null_byte_in_namespace_raises(self) -> None:
store = self._store_without_supabase()
with pytest.raises(ValueError, match="null bytes"):
store._validate_namespace("bad\x00ns")
_validate_namespace("bad\x00ns")

def test_namespace_too_long_raises(self) -> None:
store = self._store_without_supabase()
with pytest.raises(ValueError, match="too long"):
store._validate_namespace("n" * 513)
_validate_namespace("n" * 513)

def test_none_namespace_passes(self) -> None:
store = self._store_without_supabase()
store._validate_namespace(None) # must not raise
_validate_namespace(None) # must not raise

def test_save_empty_session_id_raises(self) -> None:
client = FakeSupabaseClient()
Expand Down