Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions src/control/system/raft/database.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// (C) Copyright 2020-2024 Intel Corporation.
// (C) Copyright 2025 Hewlett Packard Enterprise Development LP
// (C) Copyright 2025-2026 Hewlett Packard Enterprise Development LP
//
// SPDX-License-Identifier: BSD-2-Clause-Patent
//
Expand Down Expand Up @@ -95,6 +95,7 @@ type (
raftTransport raft.Transport
raft syncRaft
raftLeaderNotifyCh chan bool
cbMutex sync.Mutex
onLeadershipGained []onLeadershipGainedFn
onLeadershipLost []onLeadershipLostFn
onRaftShutdown []onRaftShutdownFn
Expand Down Expand Up @@ -246,11 +247,14 @@ func NewDatabase(log logging.Logger, cfg *DatabaseConfig) (*Database, error) {
repAddr, _ := cfg.LocalReplicaAddr()

db := &Database{
log: log,
cfg: cfg,
replicaAddr: repAddr,
shutdownErrCh: make(chan error),
raftLeaderNotifyCh: make(chan bool),
log: log,
cfg: cfg,
replicaAddr: repAddr,
shutdownErrCh: make(chan error),
// Buffered so hashicorp/raft can post a leadership
// transition without blocking if the monitor goroutine
// is momentarily busy running callbacks.
raftLeaderNotifyCh: make(chan bool, 1),

data: &dbData{
log: log,
Expand Down Expand Up @@ -416,21 +420,47 @@ func (db *Database) IsLeader() bool {
// OnLeadershipGained registers callbacks to be run when this instance
// gains the leadership role.
func (db *Database) OnLeadershipGained(fns ...onLeadershipGainedFn) {
db.cbMutex.Lock()
defer db.cbMutex.Unlock()
db.onLeadershipGained = append(db.onLeadershipGained, fns...)
}

// OnLeadershipLost registers callbacks to be run when this instance
// loses the leadership role.
func (db *Database) OnLeadershipLost(fns ...onLeadershipLostFn) {
db.cbMutex.Lock()
defer db.cbMutex.Unlock()
db.onLeadershipLost = append(db.onLeadershipLost, fns...)
}

// OnRaftShutdown registers callbacks to be run when this instance
// shuts down.
func (db *Database) OnRaftShutdown(fns ...onRaftShutdownFn) {
db.cbMutex.Lock()
defer db.cbMutex.Unlock()
db.onRaftShutdown = append(db.onRaftShutdown, fns...)
}

// Return copies of the registered callbacks under a lock, so that they
// can be safely retrieved and executed without holding the lock.
func (db *Database) onLeadershipGainedCbs() []onLeadershipGainedFn {
db.cbMutex.Lock()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason these couldn't be readlocks? Since these cb functions are just reading back the contents of the arrays. Could prevent unnecessary blocking as well.

defer db.cbMutex.Unlock()
return append([]onLeadershipGainedFn(nil), db.onLeadershipGained...)
}

func (db *Database) onLeadershipLostCbs() []onLeadershipLostFn {
db.cbMutex.Lock()
defer db.cbMutex.Unlock()
return append([]onLeadershipLostFn(nil), db.onLeadershipLost...)
}

func (db *Database) onRaftShutdownCbs() []onRaftShutdownFn {
db.cbMutex.Lock()
defer db.cbMutex.Unlock()
return append([]onRaftShutdownFn(nil), db.onRaftShutdown...)
}

// Start checks to see if the system is configured as a MS replica. If
// not, it returns early without an error. If it is, the persistent storage
// is initialized if necessary, and the replica is started to begin the
Expand Down Expand Up @@ -490,7 +520,7 @@ func (db *Database) monitorLeadershipState(parent context.Context) {
var cancelGainedCtx context.CancelFunc

runOnLeadershipLost := func() {
for _, fn := range db.onLeadershipLost {
for _, fn := range db.onLeadershipLostCbs() {
if err := fn(); err != nil {
db.log.Errorf("failure in onLeadershipLost callback: %s", err)
}
Expand Down Expand Up @@ -544,7 +574,7 @@ func (db *Database) stepUp(ctx context.Context, cancel context.CancelFunc) {
return // restart the monitoring loop
}

for i, fn := range db.onLeadershipGained {
for i, fn := range db.onLeadershipGainedCbs() {
db.log.Tracef("executing onLeadershipGained[%d]", i)

if err := fn(ctx); err != nil {
Expand Down
29 changes: 19 additions & 10 deletions src/control/system/raft/database_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//
// (C) Copyright 2020-2024 Intel Corporation.
// (C) Copyright 2025 Hewlett Packard Enterprise Development LP
// (C) Copyright 2025-2026 Hewlett Packard Enterprise Development LP
//
// SPDX-License-Identifier: BSD-2-Clause-Patent
//
Expand Down Expand Up @@ -40,16 +40,22 @@ import (

func waitForLeadership(ctx context.Context, t *testing.T, db *Database, gained bool) {
t.Helper()

ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()

for {
if db.IsLeader() == gained {
return
}
select {
case <-ctx.Done():
t.Fatal(ctx.Err())
t.Fatalf("timed out waiting for leadership gained=%t: %s", gained, ctx.Err())
return
default:
if db.IsLeader() == gained {
return
}
time.Sleep(1 * time.Second)
case <-ticker.C:
}
}
}
Expand Down Expand Up @@ -131,10 +137,9 @@ func TestSystem_Database_LeadershipCallbacks(t *testing.T) {

db, cleanup := TestDatabase(t, log, localhost)
defer cleanup()
if err := db.Start(dbCtx); err != nil {
t.Fatal(err)
}

// Register callbacks before Start() so the monitor goroutine
// cannot race ahead and iterate an empty slice under load.
var onGainedCalled, onLostCalled uint32
db.OnLeadershipGained(func(_ context.Context) error {
atomic.StoreUint32(&onGainedCalled, 1)
Expand All @@ -145,6 +150,10 @@ func TestSystem_Database_LeadershipCallbacks(t *testing.T) {
return nil
})

if err := db.Start(dbCtx); err != nil {
t.Fatal(err)
}

waitForLeadership(ctx, t, db, true)
dbCancel()
waitForLeadership(ctx, t, db, false)
Expand Down
2 changes: 1 addition & 1 deletion src/control/system/raft/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (db *Database) ShutdownRaft() error {
// run as many of them as possible in order to clean things
// up.
if shutdownErr == nil {
for _, cb := range db.onRaftShutdown {
for _, cb := range db.onRaftShutdownCbs() {
if cbErr := cb(); cbErr != nil {
db.log.Errorf("onRaftShutdown callback failed: %s", cbErr)
}
Expand Down
Loading