Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using Google.Cloud.Spanner.Common.V1;
using Google.Cloud.Spanner.V1;
using System;
using System.Threading.Tasks;
using Xunit;

namespace Google.Cloud.Spanner.Data.Tests;

public class SessionManagerTests
{
private const string ConnectionString = "DataSource=projects/x/instances/y/databases/z";
private static readonly DatabaseName s_databaseName = DatabaseName.FromProjectInstanceDatabase("x", "y", "z");

[Fact]
public async Task EqualOptions_SameClient()
{
int factoryCalls = 0;
Func<SpannerClientCreationOptions, SpannerSettings, Task<SpannerClient>> factory = (options, settings) =>
{
factoryCalls++;
return Task.FromResult<SpannerClient>(new FailingSpannerClient());
};
var manager = new SessionManager(new SpannerSettings(), factory);

var clientOptions1 = new SpannerClientCreationOptions(new SpannerConnectionStringBuilder(ConnectionString));
var clientOptions2 = new SpannerClientCreationOptions(new SpannerConnectionStringBuilder(ConnectionString));

var sessionOptions1 = new SessionAcquisitionOptions(clientOptions1, s_databaseName, null, null);
var sessionOptions2 = new SessionAcquisitionOptions(clientOptions2, s_databaseName, null, null);

var session1 = await manager.AcquireSessionAsync(sessionOptions1);
var session2 = await manager.AcquireSessionAsync(sessionOptions2);

// Factory calls should be 1 because clientOptions1 and clientOptions2 are equal
Assert.Equal(1, factoryCalls);
Assert.Same(session1, session2); // Sessions should also be same (cached)
}

[Fact]
public async Task DifferentOptions_DifferentClients()
{
int factoryCalls = 0;
Func<SpannerClientCreationOptions, SpannerSettings, Task<SpannerClient>> factory = (options, settings) =>
{
factoryCalls++;
return Task.FromResult<SpannerClient>(new FailingSpannerClient());
};
var manager = new SessionManager(new SpannerSettings(), factory);

var clientOptions1 = new SpannerClientCreationOptions(new SpannerConnectionStringBuilder(ConnectionString));
var clientOptions2 = new SpannerClientCreationOptions(new SpannerConnectionStringBuilder(ConnectionString) { Port = 1234 });

var sessionOptions1 = new SessionAcquisitionOptions(clientOptions1, s_databaseName, null, null);
var sessionOptions2 = new SessionAcquisitionOptions(clientOptions2, s_databaseName, null, null);

var session1 = await manager.AcquireSessionAsync(sessionOptions1);
var session2 = await manager.AcquireSessionAsync(sessionOptions2);

Assert.Equal(2, factoryCalls);
Assert.NotSame(session1, session2);
}

[Fact]
public async Task SameClient_DifferentDatabase_DifferentSessions()
{
int factoryCalls = 0;
Func<SpannerClientCreationOptions, SpannerSettings, Task<SpannerClient>> factory = (options, settings) =>
{
factoryCalls++;
return Task.FromResult<SpannerClient>(new FailingSpannerClient());
};
var manager = new SessionManager(new SpannerSettings(), factory);
var clientOptions = new SpannerClientCreationOptions(new SpannerConnectionStringBuilder(ConnectionString));

var db1 = DatabaseName.FromProjectInstanceDatabase("x", "y", "db1");
var db2 = DatabaseName.FromProjectInstanceDatabase("x", "y", "db2");

var sessionOptions1 = new SessionAcquisitionOptions(clientOptions, db1, null, null);
var sessionOptions2 = new SessionAcquisitionOptions(clientOptions, db2, null, null);

var session1 = await manager.AcquireSessionAsync(sessionOptions1);
var session2 = await manager.AcquireSessionAsync(sessionOptions2);

Assert.Equal(1, factoryCalls);
Assert.NotSame(session1, session2);
}

private class FailingSpannerClient : SpannerClient
{
public FailingSpannerClient(SpannerSettings settings = null)
{
Settings = settings ?? SpannerSettings.GetDefault();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public async Task DifferentOptions_DifferentSessionPools()
#pragma warning restore CS0618 // Type or member is obsolete
}

[Fact]
[Fact(Skip = "After MUX we clone the SpannerSettings, which don't have an Equal override. Session pool types will be deprecated soon.")]
public async Task UsesSpannerSettings()
{
ClientFactory factory = (options, settings) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ public void TransactionConstructor()
{
var connection = new SpannerConnection();
var managedTransaction = ManagedTransaction.FromTransaction(SpannerClientHelpers.CreateMockClient(Logger.DefaultLogger), new Session(), ByteString.CopyFromUtf8("transactionId"), new V1.TransactionOptions { ReadOnly = new V1.TransactionOptions.Types.ReadOnly() }, null);
var session = new PooledSession(managedTransaction);

var transaction = new SpannerTransaction(connection, session, SpannerTransactionCreationOptions.ReadWrite, transactionOptions: null, isRetriable: false);
var transaction = new SpannerTransaction(connection, managedTransaction, SpannerTransactionCreationOptions.ReadWrite, transactionOptions: null, isRetriable: false);
var command = new SpannerBatchCommand(transaction);

Assert.Empty(command.Commands);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public void ReadWrite_Values()

Assert.Null(readWrite.TimestampBound);
Assert.Null(readWrite.TransactionId);
Assert.Null(readWrite.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadWrite, readWrite.TransactionMode);
Assert.False(readWrite.IsDetached);
Assert.False(readWrite.IsSingleUse);
Expand All @@ -73,6 +74,7 @@ public void PartitionedDml_Values()

Assert.Null(partitionedDml.TimestampBound);
Assert.Null(partitionedDml.TransactionId);
Assert.Null(partitionedDml.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadWrite, partitionedDml.TransactionMode);
Assert.False(partitionedDml.IsDetached);
Assert.False(partitionedDml.IsSingleUse);
Expand All @@ -90,6 +92,7 @@ public void ReadOnly_Values()

Assert.Equal(TimestampBound.Strong, readOnly.TimestampBound);
Assert.Null(readOnly.TransactionId);
Assert.Equal(TimestampBound.Strong, readOnly.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadOnly, readOnly.TransactionMode);
Assert.False(readOnly.IsDetached);
Assert.False(readOnly.IsSingleUse);
Expand All @@ -106,6 +109,7 @@ public void ForTimestampBoundReadOnly_Null()
var options = SpannerTransactionCreationOptions.ForTimestampBoundReadOnly(null);
Assert.Equal(TimestampBound.Strong, options.TimestampBound);
Assert.Null(options.TransactionId);
Assert.Equal(TimestampBound.Strong, options.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadOnly, options.TransactionMode);
Assert.False(options.IsDetached);
Assert.False(options.IsSingleUse);
Expand All @@ -123,6 +127,7 @@ public void ForTimestampBoundReadOnly_Custom()
var options = SpannerTransactionCreationOptions.ForTimestampBoundReadOnly(timestampBound);
Assert.Equal(timestampBound, options.TimestampBound);
Assert.Null(options.TransactionId);
Assert.Equal(timestampBound, options.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadOnly, options.TransactionMode);
Assert.False(options.IsDetached);
Assert.True(options.IsSingleUse);
Expand All @@ -145,13 +150,14 @@ public void FromReadOnlyTransactionId_NotNull()
var options = SpannerTransactionCreationOptions.FromReadOnlyTransactionId(transactionId);
Assert.Equal(transactionId, options.TransactionId);
Assert.Null(options.TimestampBound);
Assert.Equal(TimestampBound.Strong, options.EffectiveTimestampBound);
Assert.Equal(TransactionMode.ReadOnly, options.TransactionMode);
Assert.True(options.IsDetached);
Assert.False(options.IsSingleUse);
Assert.False(options.IsPartitionedDml);
Assert.False(options.ExcludeFromChangeStreams);
Assert.Equal(IsolationLevel.Unspecified, options.IsolationLevel);
Assert.Null(options.GetTransactionOptions());
Assert.Equal(TimestampBound.Strong.ToTransactionOptions(), options.GetTransactionOptions());
Assert.Equal(ReadLockMode.Unspecified, options.ReadLockMode);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ Task<ReliableStreamReader> ISpannerTransaction.ExecuteReadOrQueryAsync(ReadOrQue

async Task<ReliableStreamReader> Impl()
{
PooledSession session = await _connection.AcquireSessionAsync(_creationOptions, cancellationToken, out _).ConfigureAwait(false);
ManagedTransaction managedTransaction = await _connection.BeginManagedTransactionAsync(_creationOptions, cancellationToken, out _).ConfigureAwait(false);
var callSettings = _connection.CreateCallSettings(
request.GetCallSettings,
cancellationToken);
var reader = request.ExecuteReadOrQueryStreamReader(session, callSettings);
reader.StreamClosed += delegate { session.ReleaseToPool(forceDelete: false); };
var reader = request.ExecuteReadOrQueryStreamReader(managedTransaction, callSettings);
reader.StreamClosed += delegate { _ = Task.Run(() => managedTransaction.DisposeAsync()); };
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

ManagedTransaction.DisposeAsync() returns a ValueTask. To correctly run this in a background task using Task.Run, you should convert it to a Task using .AsTask(). This ensures the ValueTask is properly handled and avoids potential issues with how Task.Run handles delegates returning ValueTask.

                reader.StreamClosed += delegate { _ = Task.Run(() => managedTransaction.DisposeAsync().AsTask()); };

return reader;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,18 @@ internal async Task<TResult> RunAsync<TResult>(Func<SpannerTransaction, Task<TRe
{
GaxPreconditions.CheckNotNull(asyncWork, nameof(asyncWork));

// Session will be initialized and subsequently modified by CommitAttempt.
PooledSession session = null;
// Transaction will be initialized and subsequently modified by CommitAttempt.
ManagedTransaction managedTransaction = null;
try
{
return await ExecuteWithRetryAsync(CommitAttempt, cancellationToken).ConfigureAwait(false);
}
finally
{
session?.Dispose();
if (managedTransaction != null)
{
_ = Task.Run(() => managedTransaction.DisposeAsync().AsTask());
}
}

async Task<TResult> CommitAttempt()
Expand All @@ -82,9 +85,9 @@ async Task<TResult> CommitAttempt()
try
{
SpannerTransactionCreationOptions effectiveCreationOptions = _creationOptions;
session = await (session?.RefreshedOrNewAsync(cancellationToken) ?? _connection.AcquireSessionAsync(_creationOptions, cancellationToken, out effectiveCreationOptions)).ConfigureAwait(false);
managedTransaction = managedTransaction?.FreshAfterAbort() ?? await _connection.BeginManagedTransactionAsync(_creationOptions, cancellationToken, out effectiveCreationOptions).ConfigureAwait(false);

transaction = new SpannerTransaction(_connection, session, effectiveCreationOptions, _transactionOptions, isRetriable: true);
transaction = new SpannerTransaction(_connection, managedTransaction, effectiveCreationOptions, _transactionOptions, isRetriable: true);

TResult result = await asyncWork(transaction).ConfigureAwait(false);
await transaction.CommitAsync(cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -114,12 +117,6 @@ async Task<TResult> CommitAttempt()
{
if (transaction != null)
{
// Since the transaction was marked as retriable, disposing of it won't attempt to dispose of or
// return the underlying session to the pool. That's because we'll be attempting to get a
// fresh transaction for this same session first.
// If that fails will attempt a new session acquisition.
// This session will be disposed of by the pool if it can't be refreshed or by the RunAsync method
// if we are not retrying anymore.
transaction.Dispose();
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using Google.Api.Gax;
using Google.Cloud.Spanner.Common.V1;
using Google.Cloud.Spanner.V1;
using System;
using System.Collections.Concurrent;
using System.Threading.Tasks;

namespace Google.Cloud.Spanner.Data;

/// <summary>
/// Manages ManagedSessions used by SpannerConnection.
/// </summary>
public sealed class SessionManager
{
/// <summary>
/// The default session manager, used by <see cref="SpannerConnection"/> unless a different manager
/// is specified on construction.
/// </summary>
public static SessionManager Default { get; } = new SessionManager(new SpannerSettings(), null);

internal SpannerSettings SpannerSettings => _spannerSettings;

private readonly SpannerSettings _spannerSettings;
private readonly Func<SpannerClientCreationOptions, SpannerSettings, Task<SpannerClient>> _clientFactory;
private readonly ConcurrentDictionary<SpannerClientCreationOptions, Task<SpannerClient>> _clients = new ConcurrentDictionary<SpannerClientCreationOptions, Task<SpannerClient>>();
private readonly ConcurrentDictionary<SessionAcquisitionOptions, Task<ManagedSession>> _sessions = new ConcurrentDictionary<SessionAcquisitionOptions, Task<ManagedSession>>();

internal SessionManager(SpannerSettings spannerSettings, Func<SpannerClientCreationOptions, SpannerSettings, Task<SpannerClient>> clientFactory)
{
_spannerSettings = GaxPreconditions.CheckNotNull(spannerSettings, nameof(spannerSettings));
_spannerSettings.VersionHeaderBuilder.AppendAssemblyVersion("gccl", typeof(SessionManager));

_clientFactory = clientFactory ?? ((options, settings) => options.CreateSpannerClientAsync(settings));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not an issue with your changes, but I think CreatespannerClientAsync should really be taking a CancelationToken and passing it forward to the BuildAsync(CancellationToken) method within. Then all the async methods here could take a cancellation token and we won't have to ignore the cancellation token in the public SpannerConnection.OpenAsync(CancellationToken).

Since we already don't honor that cancellation token and all this is not part of the public surface, I'm fine with not fixing it here & now, just wanted to point it out.

}

/// <summary>
/// Creates a new <see cref="SessionManager"/> identical to this one but with the given
/// <see cref="SpannerSettings"/>.
/// </summary>
/// <param name="spannerSettings">
/// Spanner settings to apply to the new session manager.
/// May be null, in which case, defaults will be used.
/// </param>
public SessionManager WithSpannerSettings(SpannerSettings spannerSettings) =>
new SessionManager(spannerSettings?.Clone() ?? new SpannerSettings(), _clientFactory);

internal Task<ManagedSession> AcquireSessionAsync(SessionAcquisitionOptions sessionOptions) =>
_sessions.GetOrAdd(sessionOptions, CreateSessionAsync);

internal Task<SpannerClient> AcquireClientAsync(SpannerClientCreationOptions clientOptions) =>
_clients.GetOrAdd(clientOptions, options => _clientFactory(options, _spannerSettings));

private async Task<ManagedSession> CreateSessionAsync(SessionAcquisitionOptions sessionOptions)
{
var client = await AcquireClientAsync(sessionOptions.ClientOptions).ConfigureAwait(false);
var options = ManagedSessionOptions.Create(sessionOptions.DatabaseName, client)
.WithDatabaseRole(sessionOptions.DatabaseRole)
.WithTimeout(sessionOptions.Timeout);
return new ManagedSession(options);
}
}

internal readonly struct SessionAcquisitionOptions : IEquatable<SessionAcquisitionOptions>
{
public SpannerClientCreationOptions ClientOptions { get; }
public DatabaseName DatabaseName { get; }
public string DatabaseRole { get; }
public TimeSpan? Timeout { get; }

public SessionAcquisitionOptions(SpannerClientCreationOptions clientOptions, DatabaseName databaseName, string databaseRole, TimeSpan? timeout)
{
ClientOptions = GaxPreconditions.CheckNotNull(clientOptions, nameof(clientOptions));
DatabaseName = GaxPreconditions.CheckNotNull(databaseName, nameof(databaseName));
DatabaseRole = databaseRole;
Timeout = timeout;
}

public bool Equals(SessionAcquisitionOptions other) =>
ClientOptions.Equals(other.ClientOptions) &&
DatabaseName.Equals(other.DatabaseName) &&
string.Equals(DatabaseRole, other.DatabaseRole, StringComparison.Ordinal) &&
Equals(Timeout, other.Timeout);

public override bool Equals(object obj) => obj is SessionAcquisitionOptions other && Equals(other);

public override int GetHashCode() => GaxEqualityHelpers.CombineHashCodes(
ClientOptions.GetHashCode(),
DatabaseName.GetHashCode(),
DatabaseRole?.GetHashCode() ?? 0,
Timeout?.GetHashCode() ?? 0);
}
Loading
Loading