Skip to content
72 changes: 28 additions & 44 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version):
"http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint)
self.protocol_version = new_version

def _add_resolved_hosts(self):
for endpoint in self.endpoints_resolved:
host, new = self.add_host(endpoint, signal=False)
if new:
host.set_up()
for listener in self.listeners:
listener.on_add(host)

def _populate_hosts(self):
self.profile_manager.populate(
weakref.proxy(self), self.metadata.all_hosts())
self.load_balancing_policy.populate(
Expand All @@ -1717,17 +1710,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
_register_cluster_shutdown(self)

self._add_resolved_hosts()

try:
self.control_connection.connect()

# we set all contact points up for connecting, but we won't infer state after this
for endpoint in self.endpoints_resolved:
h = self.metadata.get_host(endpoint)
if h and self.profile_manager.distance(h) == HostDistance.IGNORED:
h.is_up = None
self._populate_hosts()

log.debug("Control connection created")
except Exception:
Expand Down Expand Up @@ -3534,28 +3520,22 @@ def _set_new_connection(self, conn):
if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
old.close()
def _connect_host_in_lbp(self):

def _try_connect_to_hosts(self):
errors = {}
lbp = (
self._cluster.load_balancing_policy
if self._cluster._config_mode == _ConfigMode.LEGACY else
self._cluster._default_load_balancing_policy
)

for host in lbp.make_query_plan():
lbp = self._cluster.load_balancing_policy \
if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy

for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved):
try:
return (self._try_connect(host), None)
except ConnectionException as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
self._cluster.signal_connection_failure(host, exc, is_host_addition=False)
return (self._try_connect(endpoint), None)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove the except ConnectionException as exc?
I'm not saying it was wrong, I just don't understand. Please explain more, preferably in the commit message.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do that in a separate PR

except Exception as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
errors[str(endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")

return (None, errors)

def _reconnect_internal(self):
Expand All @@ -3567,43 +3547,43 @@ def _reconnect_internal(self):
to the exception that was raised when an attempt was made to open
a connection to that host.
"""
(conn, _) = self._connect_host_in_lbp()
(conn, _) = self._try_connect_to_hosts()
if conn is not None:
return conn

# Try to re-resolve hostnames as a fallback when all hosts are unreachable
self._cluster._resolve_hostnames()

self._cluster._add_resolved_hosts()
self._cluster._populate_hosts()

(conn, errors) = self._connect_host_in_lbp()
(conn, errors) = self._try_connect_to_hosts()
if conn is not None:
return conn

raise NoHostAvailable("Unable to connect to any servers", errors)

def _try_connect(self, host):
def _try_connect(self, endpoint):
"""
Creates a new Connection, registers for pushed events, and refreshes
node/token and schema metadata.
"""
log.debug("[control connection] Opening new connection to %s", host)
log.debug("[control connection] Opening new connection to %s", endpoint)

while True:
try:
connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True)
connection = self._cluster.connection_factory(endpoint, is_control_connection=True)
if self._is_shutdown:
connection.close()
raise DriverException("Reconnecting during shutdown")
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
self._cluster.protocol_downgrade(endpoint, e.startup_version)
except ProtocolException as e:
# protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
# protocol version. If the protocol version was not explicitly specified,
# and that the server raises a beta protocol error, we should downgrade.
if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version)
else:
raise

Expand Down Expand Up @@ -3821,7 +3801,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
tokens = local_row.get("tokens")

host = self._cluster.metadata.get_host(connection.original_endpoint)
if host:
if not host:
log.info("[control connection] Local host %s not found in metadata, adding it", connection.original_endpoint)
peers_result.append(local_row)
else:
datacenter = local_row.get("data_center")
rack = local_row.get("rack")
self._update_location_info(host, datacenter, rack)
Expand Down Expand Up @@ -4177,8 +4160,9 @@ def _get_peers_query(self, peers_query_type, connection=None):
query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version
host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version
original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint)
host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version
host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version
uses_native_address_query = (
host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)

Expand Down
4 changes: 4 additions & 0 deletions cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def export_schema_as_string(self):
def refresh(self, connection, timeout, target_type=None, change_type=None, fetch_size=None,
metadata_request_timeout=None, **kwargs):

# If the host is not in metadata, we can't proceed, hosts should be added after succesfully establishing control connection
if not self.get_host(connection.original_endpoint):
return

Comment on lines +143 to +145

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment explaining what is going on here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong fix, we need to address it in a separate PR, correct fix would be to pull version from the system.local when this information is absent.
Here is the issue for it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment, let me know if this can stay for now, or should I change it

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't return imply a sucessfull refresh? Shouldn't we throw an exception here?

Copy link
Collaborator

@dkropachev dkropachev Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lorak-mmk , this issue is triggered only on first metadata loading, at the time there is no known hosts, so we can't pull version info from them, but we need to pull metadata, so the best case would be to pull versions from the system.local

server_version = self.get_host(connection.original_endpoint).release_version
dse_version = self.get_host(connection.original_endpoint).dse_version
parser = get_schema_parser(connection, server_version, dse_version, timeout, metadata_request_timeout, fetch_size)
Expand Down
3 changes: 3 additions & 0 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def populate(self, cluster, hosts):

def distance(self, host):
dc = self._dc(host)
if not self.local_dc:
self.local_dc = dc
return HostDistance.LOCAL
if dc == self.local_dc:
return HostDistance.LOCAL

Expand Down
2 changes: 1 addition & 1 deletion cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No
self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint)
self.conviction_policy = conviction_policy_factory(self)
if not host_id:
host_id = uuid.uuid4()
raise ValueError("host_id may not be None")
self.host_id = host_id
Comment on lines 177 to 180

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit: "Don't create Host instances with random host_id"

The change here is the one that the commit message explains. Perhaps the chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved) line is also explained. Other changes are not explained, and are not at all obvious to me.

When writing commits, please assume that a reader won't be as familiar with the relevant code as you are. It is almost always true - even if reviewer is an active maintainer, there is high chance they did not work with this specific area recently.

self.set_location_info(datacenter, rack)
self.lock = RLock()
Expand Down
36 changes: 24 additions & 12 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ def test_profile_lb_swap(self):
"""
Tests that profile load balancing policies are not shared

Creates two LBP, runs a few queries, and validates that each LBP is execised
seperately between EP's
Creates two LBP, runs a few queries, and validates that each LBP is exercised
separately between EP's. Each RoundRobinPolicy starts from its own random
position and maintains independent round-robin ordering.

@since 3.5
@jira_ticket PYTHON-569
Expand All @@ -916,17 +917,28 @@ def test_profile_lb_swap(self):
with TestCluster(execution_profiles=exec_profiles) as cluster:
session = cluster.connect(wait_for_all_pools=True)

# default is DCA RR for all hosts
expected_hosts = set(cluster.metadata.all_hosts())
rr1_queried_hosts = set()
rr2_queried_hosts = set()

rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.add(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.add(rs.response_future._current_host)

assert rr2_queried_hosts == rr1_queried_hosts
num_hosts = len(expected_hosts)
assert num_hosts > 1, "Need at least 2 hosts for this test"

rr1_queried_hosts = []
rr2_queried_hosts = []

for _ in range(num_hosts * 2):
rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.append(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.append(rs.response_future._current_host)

# Both policies should have queried all hosts
assert set(rr1_queried_hosts) == expected_hosts
assert set(rr2_queried_hosts) == expected_hosts

# The order of hosts should demonstrate round-robin behavior
# After num_hosts queries, the pattern should repeat
for i in range(num_hosts):
assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts]
assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts]

def test_ta_lbp(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/standard/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def test_get_control_connection_host(self):

# reconnect and make sure that the new host is reflected correctly
self.cluster.control_connection._reconnect()
new_host = self.cluster.get_control_connection_host()
assert host != new_host
new_host1 = self.cluster.get_control_connection_host()

self.cluster.control_connection._reconnect()
new_host2 = self.cluster.get_control_connection_host()

assert new_host1 != new_host2

# TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed
@unittest.skip('Fails on scylla due to the broadcast_rpc_port is None')
Expand Down
9 changes: 4 additions & 5 deletions tests/integration/standard/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def test_predicate_changes(self):
external_event = True
contact_point = DefaultEndPoint("127.0.0.1")

single_host = {Host(contact_point, SimpleConvictionPolicy)}
all_hosts = {Host(DefaultEndPoint("127.0.0.{}".format(i)), SimpleConvictionPolicy) for i in (1, 2, 3)}

predicate = lambda host: host.endpoint == contact_point if external_event else True
hfp = ExecutionProfile(
load_balancing_policy=HostFilterPolicy(RoundRobinPolicy(), predicate=predicate)
Expand All @@ -62,7 +59,8 @@ def test_predicate_changes(self):
response = session.execute("SELECT * from system.local WHERE key='local'")
queried_hosts.update(response.response_future.attempted_hosts)

assert queried_hosts == single_host
assert len(queried_hosts) == 1
assert queried_hosts.pop().endpoint == contact_point

external_event = False
futures = session.update_created_pools()
Expand All @@ -72,7 +70,8 @@ def test_predicate_changes(self):
for _ in range(10):
response = session.execute("SELECT * from system.local WHERE key='local'")
queried_hosts.update(response.response_future.attempted_hosts)
assert queried_hosts == all_hosts
assert len(queried_hosts) == 3
assert {host.endpoint for host in queried_hosts} == {DefaultEndPoint(f"127.0.0.{i}") for i in range(1, 4)}


class WhiteListRoundRobinPolicyTests(unittest.TestCase):
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/standard/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,8 @@ def make_query_plan(self, working_keyspace=None, query=None):
live_hosts = sorted(list(self._live_hosts))
host = []
try:
host = [live_hosts[self.host_index_to_use]]
if len(live_hosts) > 0:
host = [live_hosts[self.host_index_to_use]]
except IndexError as e:
raise IndexError(
'You specified an index larger than the number of hosts. Total hosts: {}. Index specified: {}'.format(
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/advanced/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import unittest
from unittest.mock import Mock
import uuid

from cassandra.pool import Host
from cassandra.policies import RoundRobinPolicy
Expand Down Expand Up @@ -72,7 +73,7 @@ def test_target_no_host(self):

def test_target_host_down(self):
node_count = 4
hosts = [Host(i, Mock()) for i in range(node_count)]
hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)]
target_host = hosts[1]

policy = DSELoadBalancingPolicy(RoundRobinPolicy())
Expand All @@ -87,7 +88,7 @@ def test_target_host_down(self):

def test_target_host_nominal(self):
node_count = 4
hosts = [Host(i, Mock()) for i in range(node_count)]
hosts = [Host(i, Mock(), host_id=uuid.uuid4()) for i in range(node_count)]
target_host = hosts[1]
target_host.is_up = True

Expand Down
Loading
Loading