Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/confluent_kafka/schema_registry/_async/protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,13 +788,23 @@ def _get_message_desc_proto(
) -> Tuple[str, descriptor_pb2.DescriptorProto]:
index = msg_index[0]
if isinstance(desc, descriptor_pb2.FileDescriptorProto):
msg = desc.message_type[index]
messages = desc.message_type
if index < 0 or index >= len(messages):
raise SerializationError(
"message index {} out of range, schema has {} top-level message(s)".format(index, len(messages))
)
msg = messages[index]
path = path + "." + msg.name if path else msg.name
if len(msg_index) == 1:
return path, msg
return self._get_message_desc_proto(path, msg, msg_index[1:])
else:
msg = desc.nested_type[index]
messages = desc.nested_type
if index < 0 or index >= len(messages):
raise SerializationError(
"message index {} out of range, message has {} nested message(s)".format(index, len(messages))
)
msg = messages[index]
path = path + "." + msg.name if path else msg.name
if len(msg_index) == 1:
return path, msg
Expand Down
14 changes: 12 additions & 2 deletions src/confluent_kafka/schema_registry/_sync/protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,13 +777,23 @@ def _get_message_desc_proto(
) -> Tuple[str, descriptor_pb2.DescriptorProto]:
index = msg_index[0]
if isinstance(desc, descriptor_pb2.FileDescriptorProto):
msg = desc.message_type[index]
messages = desc.message_type
if index < 0 or index >= len(messages):
raise SerializationError(
"message index {} out of range, schema has {} top-level message(s)".format(index, len(messages))
)
msg = messages[index]
path = path + "." + msg.name if path else msg.name
if len(msg_index) == 1:
return path, msg
return self._get_message_desc_proto(path, msg, msg_index[1:])
else:
msg = desc.nested_type[index]
messages = desc.nested_type
if index < 0 or index >= len(messages):
raise SerializationError(
"message index {} out of range, message has {} nested message(s)".format(index, len(messages))
)
msg = messages[index]
path = path + "." + msg.name if path else msg.name
if len(msg_index) == 1:
return path, msg
Expand Down
37 changes: 37 additions & 0 deletions tests/schema_registry/_async/test_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
from io import BytesIO

import pytest
from google.protobuf import descriptor_pb2

from confluent_kafka.schema_registry.protobuf import (
AsyncProtobufDeserializer,
AsyncProtobufSerializer,
_create_index_array,
decimal_to_protobuf,
protobuf_to_decimal,
)
from confluent_kafka.schema_registry.serde import SchemaId
from confluent_kafka.serialization import SerializationError
from tests.integration.schema_registry.data.proto import DependencyTestProto_pb2, metadata_proto_pb2


Expand All @@ -48,6 +51,40 @@ def test_create_index(pb2, coordinates):
assert msg_idx == coordinates


def _two_message_file_proto():
fdp = descriptor_pb2.FileDescriptorProto()
fdp.name = "test.proto"
fdp.package = "pkg"
first = fdp.message_type.add()
first.name = "First"
nested = first.nested_type.add()
nested.name = "Inner"
second = fdp.message_type.add()
second.name = "Second"
return fdp


def test_message_index_in_range():
deserializer = object.__new__(AsyncProtobufDeserializer)
fdp = _two_message_file_proto()

assert deserializer._get_message_desc_proto("", fdp, [0])[0] == "First"
assert deserializer._get_message_desc_proto("", fdp, [1])[0] == "Second"
assert deserializer._get_message_desc_proto("", fdp, [0, 0])[0] == "First.Inner"


@pytest.mark.parametrize("msg_index", [[-1], [2], [0, -1], [0, 5]])
def test_message_index_out_of_range(msg_index):
# The message index array is attacker-controlled wire framing; a zigzag
# varint can decode to a negative or out-of-range value. A negative index
# would otherwise wrap around and resolve to a different message type.
deserializer = object.__new__(AsyncProtobufDeserializer)
fdp = _two_message_file_proto()

with pytest.raises(SerializationError, match="out of range"):
deserializer._get_message_desc_proto("", fdp, msg_index)


@pytest.mark.parametrize(
"pb2",
[
Expand Down
37 changes: 37 additions & 0 deletions tests/schema_registry/_sync/test_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
from io import BytesIO

import pytest
from google.protobuf import descriptor_pb2

from confluent_kafka.schema_registry.protobuf import (
ProtobufDeserializer,
ProtobufSerializer,
_create_index_array,
decimal_to_protobuf,
protobuf_to_decimal,
)
from confluent_kafka.schema_registry.serde import SchemaId
from confluent_kafka.serialization import SerializationError
from tests.integration.schema_registry.data.proto import DependencyTestProto_pb2, metadata_proto_pb2


Expand All @@ -48,6 +51,40 @@ def test_create_index(pb2, coordinates):
assert msg_idx == coordinates


def _two_message_file_proto():
fdp = descriptor_pb2.FileDescriptorProto()
fdp.name = "test.proto"
fdp.package = "pkg"
first = fdp.message_type.add()
first.name = "First"
nested = first.nested_type.add()
nested.name = "Inner"
second = fdp.message_type.add()
second.name = "Second"
return fdp


def test_message_index_in_range():
deserializer = object.__new__(ProtobufDeserializer)
fdp = _two_message_file_proto()

assert deserializer._get_message_desc_proto("", fdp, [0])[0] == "First"
assert deserializer._get_message_desc_proto("", fdp, [1])[0] == "Second"
assert deserializer._get_message_desc_proto("", fdp, [0, 0])[0] == "First.Inner"


@pytest.mark.parametrize("msg_index", [[-1], [2], [0, -1], [0, 5]])
def test_message_index_out_of_range(msg_index):
# The message index array is attacker-controlled wire framing; a zigzag
# varint can decode to a negative or out-of-range value. A negative index
# would otherwise wrap around and resolve to a different message type.
deserializer = object.__new__(ProtobufDeserializer)
fdp = _two_message_file_proto()

with pytest.raises(SerializationError, match="out of range"):
deserializer._get_message_desc_proto("", fdp, msg_index)


@pytest.mark.parametrize(
"pb2",
[
Expand Down