diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/.travis.yml b/.travis.yml old mode 100644 new mode 100755 diff --git a/DESIGN.md b/DESIGN.md new file mode 100755 index 0000000..2edff8d --- /dev/null +++ b/DESIGN.md @@ -0,0 +1,45 @@ +## CURP + +### Completed for CURP +* Record and sync RPCs. +* Keys sent with client requests to track commutativity in client operations. +* Accept records only if operations stored in witnesses don't commute. +* Master tries to apply command only locally if commutative. If not commutative, replicates synchronously and responds that it synced. +* Master synchronously replicates commands sent in Sync RPCs. +* GC records at witnesses when done applying. +* Send to witnesses and master in parallel, check for success or sync. If failure, send sync to master. + +### CURP Code Base +* `raft.go`: Witness state defined. Garbage collect at witnesses when operation completed. Support for handling record requests: accept and record if keys commutative and not leader, reject otherwise. Master syncs if operation not commutative, support for sync operation at master. +* `commands.go`: Sync and Record RPCs, add Synced field to ClientResponse to know if master synced. Add keys to ClientRequests. +* `session.go`: Sending to all witnesses and master in parallel. If all succeeded or synced at master, succeed. Otherwise, send Sync RPC to master. Keep repeating until success. +* `log.go`: Update log entry to contain keys for commutativity checks. +* `api.go`: Add witness state to raft nodes. +* `net_transport.go`: Add new RPC types. + +## RIFL + +### Completed for RIFL +* Added client IDs and sequence numbers to client RPCs +* Assign client ID at master using global nextClientId +* Replicate nextClientId counter to other servers with LogNextClientId operation +* Store responses to client RPCs in cache that is periodically garbage-collected based on configurable timeout +* Check for duplicate before applying to state machine +* Make nextClientId and cache of client responses persistent. + +### RIFL Code Base +* `raft.go`: Support for ClientId RPC handling, incrementing nextClientId at all replicas +* `fsm.go`: Before applying a command locally, check for cached response. +* `client_response_cache.go`: Stores state about the response to a client RPC along with a timestamp. Cache is periodically garbage collected. +* `session.go`: Starting a client session requires getting a new client ID, use that client ID and assign monotonically increasing sequence numbers for client RPCs. +* `commands.go`: RPC format for ClientRequest and ClientResponse updated to contain Client ID and sequence number, new RPC format ClientIdRequest and ClientIdResponse. GenericClientRequest for sending a request to a Raft leader. +* `log.go`: Update log entry to contain client IDs and sequence numbers. +* `config.go`: Set interval at which to garbage collect cache and how long responses to client RPCs should remain cached +* `api.go`: client response cache and next client ID state added to each raft node and snapshot restoring operations. +* `snapshot.go`: Support for snapshotting the client response cache and the next client ID (must be stored persistently). +* `file_snapshot.go`: Support for snapshotting the client response cache and the next client ID. +* `inmem_snapshot.go`: Support for snapshotting the client response cache and the next client ID. +* `net_transport.go`: Add new RPC types. + +Run tests for RIFL: `src/test/runTests.sh` +Currently has a race condition diff --git a/src/raft/.gitignore b/src/raft/.gitignore old mode 100644 new mode 100755 diff --git a/src/raft/.travis.yml b/src/raft/.travis.yml old mode 100644 new mode 100755 diff --git a/src/raft/LICENSE b/src/raft/LICENSE old mode 100644 new mode 100755 diff --git a/src/raft/Makefile b/src/raft/Makefile old mode 100644 new mode 100755 diff --git a/src/raft/README.md b/src/raft/README.md old mode 100644 new mode 100755 diff --git a/src/raft/api.go b/src/raft/api.go old mode 100644 new mode 100755 index 814d7e7..ad93ce1 --- a/src/raft/api.go +++ b/src/raft/api.go @@ -1,6 +1,7 @@ package raft import ( + "encoding/json" "errors" "fmt" "io" @@ -48,6 +49,36 @@ var ( // ErrCantBootstrap is returned when attempt is made to bootstrap a // cluster that already has state present. ErrCantBootstrap = errors.New("bootstrap only works on new clusters") + + // ErrBadClientId is returned when a client issues a RPC with a client + // ID the cluster doesn't recognize. + ErrBadClientId = errors.New("bad client ID used") + + // ErrNotCommutative is returned when a client tries to push an operation + // to a witness that is not commutative with other operations stored at + // the witness. + ErrNotCommutative = errors.New("operation not commutative with operations in witness") + + // ErrNotWitness is returned when a client contacts a leader instead of + // a witness. + ErrNotWitness = errors.New("contacted leader instead of witness") + + // ErrNoActiveServers is returned when a client tries to contact a cluster + // and cannot reach any servers. + ErrNoActiveServers = errors.New("no active raft servers found") + + // ErrNoActiveLeader is returned when a client tries to contact a leader + // and cannot reach an active leader. + ErrNoActiveLeader = errors.New("no active leader found") + + // ErrWitnessFrozen is returned when a client tries to record a command + // in a witness that cannot accept client record requests. + ErrWitnessFrozen = errors.New("witness cannot accept record request, frozen") + + // ErrStaleTerm is returned when a client tries to record a command in + // a witness using a stale term number, meaning that it is sending the + // command to a potentially stale set of witnesses. + ErrStaleTerm = errors.New("witness cannot accept record request with stale term") ) // Raft implements a Raft node. @@ -81,6 +112,10 @@ type Raft struct { // fsmSnapshotCh is used to trigger a new snapshot being taken fsmSnapshotCh chan *reqSnapshotFuture + // True if witness can't accept client record requests, false otherwise. + frozen bool + frozenLock sync.RWMutex + // lastContact is the last time we had contact from the // leader node. This can be used to gauge staleness. lastContact time.Time @@ -108,6 +143,12 @@ type Raft struct { // LogStore provides durable storage for logs logs LogStore + // Cache of client responses. Used for RIFL. Map of ClientIDs to + // map of client RPC sequence numbers to response data. Periodically + // garbage collected. + clientResponseCache map[uint64]map[uint64]clientResponseEntry + clientResponseLock sync.RWMutex + // Used to request the leader to make configuration changes. configurationChangeCh chan *configurationChangeFuture @@ -115,6 +156,9 @@ type Raft struct { // the log/snapshot. configurations configurations + // Next Client ID to assign to new client. Used for RIFL. + nextClientId uint64 + // RPC chan comes from the transport layer rpcCh <-chan RPC @@ -193,6 +237,9 @@ func BootstrapCluster(conf *Config, logs LogStore, stable StableStore, return fmt.Errorf("failed to save current term: %v", err) } + // Set empty maps for witness state + stableSetWitnessState(stable, make(map[ClientSeqNo]Log), make(map[uint32]Key)) + // Append configuration entry to log. entry := &Log{ Index: 1, @@ -268,6 +315,8 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, // Attempt to restore any snapshots we find, newest to oldest. var snapshotIndex uint64 var snapshotTerm uint64 + var snapshotClientId uint64 + var snapshotClientResponseCache map[uint64]map[uint64]clientResponseEntry snapshots, err := snaps.List() if err != nil { return fmt.Errorf("failed to list snapshots: %v", err) @@ -288,6 +337,8 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, snapshotIndex = snapshot.Index snapshotTerm = snapshot.Term + snapshotClientId = snapshot.NextClientId + snapshotClientResponseCache = snapshot.ClientResponseCache break } if len(snapshots) > 0 && (snapshotIndex == 0 || snapshotTerm == 0) { @@ -298,6 +349,8 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, // until we play back the Raft log entries. lastIndex := snapshotIndex lastTerm := snapshotTerm + lastClientId := snapshotClientId + lastClientResponseCache := snapshotClientResponseCache // Apply any Raft log entries past the snapshot. lastLogIndex, err := logs.LastIndex() @@ -310,7 +363,25 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, return fmt.Errorf("failed to get log at index %d: %v", index, err) } if entry.Type == LogCommand { - _,_ = fsm.Apply(&entry) + resp := fsm.Apply(&entry) + data, err := json.Marshal(resp) + if err != nil { + return fmt.Errorf("failed to marshal response to command at index %d: %v", index, err) + } + clientCache, ok := lastClientResponseCache[entry.ClientID] + if !ok { + clientCache = make(map[uint64]clientResponseEntry) + } + clientCache[entry.SeqNo] = clientResponseEntry{ + response: data, + timestamp: time.Now(), // will be garbage collected later + } + lastClientResponseCache[entry.ClientID] = clientCache + } + if entry.Type == LogNextClientId { + if err := decodeMsgPack(entry.Data, &lastClientId); err != nil { + panic(fmt.Errorf("failed to decode next cliend id: %v", err)) + } } lastIndex = entry.Index lastTerm = entry.Term @@ -323,7 +394,7 @@ func RecoverCluster(conf *Config, fsm FSM, logs LogStore, stable StableStore, return fmt.Errorf("failed to snapshot FSM: %v", err) } version := getSnapshotVersion(conf.ProtocolVersion) - sink, err := snaps.Create(version, lastIndex, lastTerm, configuration, 1, trans) + sink, err := snaps.Create(version, lastIndex, lastTerm, configuration, 1, lastClientId, lastClientResponseCache, trans) if err != nil { return fmt.Errorf("failed to create snapshot: %v", err) } @@ -399,6 +470,7 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna logger = conf.Logger } else { if conf.LogOutput == nil { + //devNull = open(os.devnull, 'w') conf.LogOutput = os.Stderr } logger = log.New(conf.LogOutput, "", log.LstdFlags) @@ -437,19 +509,22 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna // Create Raft struct. r := &Raft{ - protocolVersion: protocolVersion, - applyCh: make(chan *logFuture), - conf: *conf, - fsm: fsm, - fsmMutateCh: make(chan interface{}, 128), - fsmSnapshotCh: make(chan *reqSnapshotFuture), - leaderCh: make(chan bool), - localID: localID, - localAddr: localAddr, - logger: logger, - logs: logs, + protocolVersion: protocolVersion, + applyCh: make(chan *logFuture), + conf: *conf, + clientResponseCache: make(map[uint64]map[uint64]clientResponseEntry), + frozen: false, + fsm: fsm, + fsmMutateCh: make(chan interface{}, 128), + fsmSnapshotCh: make(chan *reqSnapshotFuture), + leaderCh: make(chan bool), + localID: localID, + localAddr: localAddr, + logger: logger, + logs: logs, configurationChangeCh: make(chan *configurationChangeFuture), configurations: configurations{}, + nextClientId: 0, rpcCh: trans.Consumer(), snapshots: snaps, userSnapshotCh: make(chan *userSnapshotFuture), @@ -461,7 +536,7 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna configurationsCh: make(chan *configurationsFuture, 8), bootstrapCh: make(chan *bootstrapFuture), observers: make(map[uint64]*Observer), - } + } // Initialize as a follower. r.setState(Follower) @@ -505,6 +580,7 @@ func NewRaft(conf *Config, fsm FSM, logs LogStore, stable StableStore, snaps Sna r.goFunc(r.run) r.goFunc(r.runFSM) r.goFunc(r.runSnapshots) + r.goFunc(r.runGcClientResponseCache) return r, nil } @@ -599,7 +675,7 @@ func (r *Raft) Leader() ServerAddress { // An optional timeout can be provided to limit the amount of time we wait // for the command to be started. This must be run on the leader or it // will fail. -func (r *Raft) Apply(cmd []byte, timeout time.Duration) ApplyFuture { +func (r *Raft) Apply(log *Log, timeout time.Duration) ApplyFuture { metrics.IncrCounter([]string{"raft", "apply"}, 1) var timer <-chan time.Time if timeout > 0 { @@ -607,10 +683,38 @@ func (r *Raft) Apply(cmd []byte, timeout time.Duration) ApplyFuture { } // Create a log future, no index or term yet + logFuture := &logFuture{ + log: *log, + } + logFuture.init() + + select { + case <-timer: + return errorFuture{ErrEnqueueTimeout} + case <-r.shutdownCh: + return errorFuture{ErrRaftShutdown} + case r.applyCh <- logFuture: + return logFuture + } +} + +// Updates all Raft nodes with the value of NextClientId at the leader. +// This must be run at the leader. +func (r *Raft) SendNextClientId(timeout time.Duration) Future { + var timer <-chan time.Time + if timeout > 0 { + timer = time.After(timeout) + } + + buf, err := encodeMsgPack(r.nextClientId) + if err != nil { + panic(fmt.Errorf("failed to encode next client id: %v", err)) + } + logFuture := &logFuture{ log: Log{ - Type: LogCommand, - Data: cmd, + Type: LogNextClientId, + Data: buf.Bytes(), }, } logFuture.init() @@ -1006,3 +1110,8 @@ func (r *Raft) LastIndex() uint64 { func (r *Raft) AppliedIndex() uint64 { return r.getLastApplied() } + +// Checks if raft node is the current leader. +func (r *Raft) IsLeader() bool { + return r.getState() == Leader +} diff --git a/src/raft/bench/bench.go b/src/raft/bench/bench.go old mode 100644 new mode 100755 diff --git a/src/raft/client_response_cache.go b/src/raft/client_response_cache.go new file mode 100755 index 0000000..6ff2189 --- /dev/null +++ b/src/raft/client_response_cache.go @@ -0,0 +1,47 @@ +package raft + +import ( + "time" +) + +// Manages the cache of client responses for use in RIFL, including +// garbage collecting the cache. + +// clientResponseEntry holds state about the response to a client RPC. +// For use in RIFL. +type clientResponseEntry struct { + response interface{} + timestamp time.Time +} + +// Continuously check to garbage collect the cache. +func (r *Raft) runGcClientResponseCache() { + for { + select { + case <-randomTimeout(r.conf.ClientResponseGcInterval): + r.gcClientResponseCache() + + case <-r.shutdownCh: + return + } + } +} + +// Garbage collect entries in the cache that have expired. +func (r *Raft) gcClientResponseCache() { + r.clientResponseLock.RLock() + currTime := time.Now() + for clientID, clientCache := range r.clientResponseCache { + for seqNo, entry := range clientCache { + if currTime.Sub(entry.timestamp) >= r.conf.ClientResponseGcRemoveTime { + r.clientResponseLock.RUnlock() + r.clientResponseLock.Lock() + delete(clientCache, seqNo) // does nothing if key does not exist, no race condition + r.clientResponseLock.Unlock() + r.clientResponseLock.RLock() + } + } + r.clientResponseCache[clientID] = clientCache + } + r.clientResponseLock.RUnlock() +} diff --git a/src/raft/commands.go b/src/raft/commands.go old mode 100644 new mode 100755 index 1f0e447..ed3fc69 --- a/src/raft/commands.go +++ b/src/raft/commands.go @@ -150,17 +150,131 @@ func (r *InstallSnapshotResponse) GetRPCHeader() RPCHeader { return r.RPCHeader } -type ClientRequest struct { +// Record RPCs are used to store commutative operations at witnesses. +// Accepted if commutative with other operations at witness, rejected +// otherwise. +type RecordRequest struct { + RPCHeader + + // Entry to commit + Entry *Log + // Use term to make sure witness is valid. + Term uint64 +} + +// See WithRPCHeader. +func (r *RecordRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Record RPCs are used to store commutative operations at witnesses. +// Accepted if commutative with other operations at witness, rejected +// otherwise. +type RecordResponse struct { + RPCHeader + + // True if operation recorded at witness, false otherwise. + Success bool + // Discover term if term not correct. + Term uint64 +} + +// See WithRPCHeader. +func (r *RecordResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Issued by a client to the master when a client cannot record an +// operation in all witnesses. +type SyncRequest struct { + RPCHeader + + Entry *Log +} + +// See WithRPCHeader. +func (r *SyncRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Interface used for all generic client requests so that client library +// can find active leader. +type GenericClientResponse interface { + GetLeaderAddress() ServerAddress +} + +// Sent when the master has completed the sync in response to SyncRequest. +type SyncResponse struct { + RPCHeader + + // True if successfully synced at master. + Success bool + LeaderAddress ServerAddress + ResponseData []byte +} + +// See WithRPCHeader. +func (r *SyncResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// See GenericClientResponse interface. +func (r *SyncResponse) GetLeaderAddress() ServerAddress { + return r.LeaderAddress +} + +// Sent from new leader to witness to set witness into recovery +// mode (don't receive requests) and get all client requests stored +// at witness. +type RecoveryDataRequest struct { + RPCHeader +} + +// See WithRPCHeader. +func (r *RecoveryDataRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Contains all client requests stored at witness, sent from witness +// to new leader. +type RecoveryDataResponse struct { + RPCHeader + + // All client requests stored at witness. + Entries []Log +} + +// See WithRPCHeader. +func (r *RecoveryDataResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Unfreeze witness to allow it to process record requests again. +type UnfreezeRequest struct { + RPCHeader +} + +// See WithRPCHeader. +func (r *UnfreezeRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Response to UnfreezeRequest (see UnfreezeRequest). +type UnfreezeResponse struct { RPCHeader +} - // New entries to commit. - Entries[] *Log - // True if should initiate or maintain session, false otherwise. - KeepSession bool - // ID of client to contact raft server. - ClientAddr ServerAddress - // Command to be executed when client session terminates. - EndSessionCommand []byte +// Ssee WithRPCHeader. +func (r *UnfreezeResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Sent by the client to apply a command at a raft cluster. +type ClientRequest struct { + RPCHeader + + // New entry to commit. + Entry *Log } // See WithRPCHeader. @@ -168,15 +282,58 @@ func (r *ClientRequest) GetRPCHeader() RPCHeader { return r.RPCHeader } +// Contains the result of applying a command, sent in response to ClientRequest. type ClientResponse struct { - RPCHeader + RPCHeader - Success bool - LeaderAddress ServerAddress - ResponseData []byte + // True if command successfully executed. + Success bool + // Address of current leader. Used to redirect from follower to leader. + LeaderAddress ServerAddress + // Response from applying command. + ResponseData []byte + // True if leader synced (not commutative), false otherwise. + Synced bool } // See WithRPCHeader. func (r *ClientResponse) GetRPCHeader() RPCHeader { return r.RPCHeader } + +// See GenericClientResponse interface. +func (r *ClientResponse) GetLeaderAddress() ServerAddress { + return r.LeaderAddress +} + +// Requests an ID for a client. Clients must have an ID allocated by +// the leader to make requests. +type ClientIdRequest struct { + RPCHeader +} + +// See WithRPCHeader. +func (r *ClientIdRequest) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// Returns an ID allocated by the leader, sent in response to ClientIdRequest. +type ClientIdResponse struct { + RPCHeader + + // ID of client assigned by cluster. + ClientID uint64 + + // Address of active leader. Used as a hint to find active leader. + LeaderAddress ServerAddress +} + +// See WithRPCHeader. +func (r *ClientIdResponse) GetRPCHeader() RPCHeader { + return r.RPCHeader +} + +// See GenericClientResponse. +func (r *ClientIdResponse) GetLeaderAddress() ServerAddress { + return r.LeaderAddress +} diff --git a/src/raft/commitment.go b/src/raft/commitment.go old mode 100644 new mode 100755 diff --git a/src/raft/commitment_test.go b/src/raft/commitment_test.go old mode 100644 new mode 100755 diff --git a/src/raft/config.go b/src/raft/config.go old mode 100644 new mode 100755 index c1ce03a..e0fd258 --- a/src/raft/config.go +++ b/src/raft/config.go @@ -193,21 +193,31 @@ type Config struct { // Logger is a user-provided logger. If nil, a logger writing to LogOutput // is used. Logger *log.Logger + + // Interval at which to garbage collect the client response cache used with + // RIFL. + ClientResponseGcInterval time.Duration + + // How long a client response should be kept in the cache to prevent duplicate + // execution. Used with RIFL. + ClientResponseGcRemoveTime time.Duration } // DefaultConfig returns a Config with usable defaults. func DefaultConfig() *Config { return &Config{ - ProtocolVersion: ProtocolVersionMax, - HeartbeatTimeout: 1000 * time.Millisecond, - ElectionTimeout: 1000 * time.Millisecond, - CommitTimeout: 50 * time.Millisecond, - MaxAppendEntries: 64, - ShutdownOnRemove: true, - TrailingLogs: 10240, - SnapshotInterval: 120 * time.Second, - SnapshotThreshold: 8192, - LeaderLeaseTimeout: 500 * time.Millisecond, + ProtocolVersion: ProtocolVersionMax, + HeartbeatTimeout: 1000 * time.Millisecond, + ElectionTimeout: 1000 * time.Millisecond, + CommitTimeout: 50 * time.Millisecond, + MaxAppendEntries: 64, + ShutdownOnRemove: true, + TrailingLogs: 10240, + SnapshotInterval: 120 * time.Second, + SnapshotThreshold: 8192, + LeaderLeaseTimeout: 500 * time.Millisecond, + ClientResponseGcInterval: time.Minute, + ClientResponseGcRemoveTime: 4 * time.Hour, } } diff --git a/src/raft/configuration.go b/src/raft/configuration.go old mode 100644 new mode 100755 diff --git a/src/raft/configuration_test.go b/src/raft/configuration_test.go old mode 100644 new mode 100755 diff --git a/src/raft/discard_snapshot.go b/src/raft/discard_snapshot.go old mode 100644 new mode 100755 diff --git a/src/raft/discard_snapshot_test.go b/src/raft/discard_snapshot_test.go old mode 100644 new mode 100755 diff --git a/src/raft/file_snapshot.go b/src/raft/file_snapshot.go old mode 100644 new mode 100755 index ffc9414..3db614d --- a/src/raft/file_snapshot.go +++ b/src/raft/file_snapshot.go @@ -141,7 +141,7 @@ func snapshotName(term, index uint64) string { // Create is used to start a new snapshot func (f *FileSnapshotStore) Create(version SnapshotVersion, index, term uint64, - configuration Configuration, configurationIndex uint64, trans Transport) (SnapshotSink, error) { + configuration Configuration, configurationIndex uint64, nextClientId uint64, clientResponseCache map[uint64]map[uint64]clientResponseEntry, trans Transport) (SnapshotSink, error) { // We only support version 1 snapshots at this time. if version != 1 { return nil, fmt.Errorf("unsupported snapshot version %d", version) @@ -166,13 +166,15 @@ func (f *FileSnapshotStore) Create(version SnapshotVersion, index, term uint64, parentDir: f.path, meta: fileSnapshotMeta{ SnapshotMeta: SnapshotMeta{ - Version: version, - ID: name, - Index: index, - Term: term, - Peers: encodePeers(configuration, trans), - Configuration: configuration, - ConfigurationIndex: configurationIndex, + Version: version, + ID: name, + Index: index, + Term: term, + NextClientId: nextClientId, + ClientResponseCache: clientResponseCache, + Peers: encodePeers(configuration, trans), + Configuration: configuration, + ConfigurationIndex: configurationIndex, }, CRC: nil, }, diff --git a/src/raft/file_snapshot_test.go b/src/raft/file_snapshot_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fsm.go b/src/raft/fsm.go old mode 100644 new mode 100755 index 3164966..3a74715 --- a/src/raft/fsm.go +++ b/src/raft/fsm.go @@ -15,7 +15,7 @@ type FSM interface { // It returns a value which will be made available in the // ApplyFuture returned by Raft.Apply method if that // method was called on the same Raft node as the FSM. - Apply(*Log) (interface{}, []func() [][]byte) + Apply(*Log) interface{} // Snapshot is used to support log compaction. This call should // return an FSMSnapshot which can be used to save a point-in-time @@ -51,22 +51,21 @@ func (r *Raft) runFSM() { commit := func(req *commitTuple) { // Apply the log if a command - var resp interface{} - var callback []func() [][]byte + var resp interface{} if req.log.Type == LogCommand { - start := time.Now() - resp, callback = r.fsm.Apply(req.log) - metrics.MeasureSince([]string{"raft", "fsm", "apply"}, start) + r.applyCommandLocally(req.log, &resp) } // Update the indexes - lastIndex = req.log.Index - lastTerm = req.log.Term + // Need to take max because could have gotten stale client request that is replayed. + if req.log.Index > lastIndex || req.log.Term > lastTerm { + lastIndex = req.log.Index + lastTerm = req.log.Term + } // Invoke the future if given if req.future != nil { req.future.response = resp - req.future.callback = callback req.future.respond(nil) } } @@ -136,3 +135,32 @@ func (r *Raft) runFSM() { } } } + +// Apply a command to the local FSM. Ensures exactly-once semantics with RIFL. +// Params: +// - log: Log entry to apply locally. Should be of type LogCommand. +// - resp: Response object to populate after executing command. +func (r *Raft) applyCommandLocally(log *Log, resp *interface{}) { + r.clientResponseLock.Lock() + clientCache, clientIdKnown := r.clientResponseCache[log.ClientID] + if !clientIdKnown { + r.clientResponseCache[log.ClientID] = make(map[uint64]clientResponseEntry) + clientCache = r.clientResponseCache[log.ClientID] + } + cachedResp, duplicateReq := clientCache[log.SeqNo] + if duplicateReq { + r.logger.Printf("found cached response for client %v with seqno %v with resp %v", log.ClientID, log.SeqNo, cachedResp.response) + *resp = cachedResp.response + } else { + start := time.Now() + *resp = r.fsm.Apply(log) + metrics.MeasureSince([]string{"raft", "fsm", "apply"}, start) + // Add response to clientResponseCache. + clientCache[log.SeqNo] = clientResponseEntry{ + response: *resp, + timestamp: time.Now(), + } + r.clientResponseCache[log.ClientID] = clientCache + } + r.clientResponseLock.Unlock() +} diff --git a/src/raft/future.go b/src/raft/future.go old mode 100644 new mode 100755 index 9d4b228..fac59a5 --- a/src/raft/future.go +++ b/src/raft/future.go @@ -36,7 +36,6 @@ type ApplyFuture interface { // by the FSM.Apply method. This must not be called // until after the Error method has returned. Response() interface{} - Callback() []func() [][]byte } // ConfigurationFuture is used for GetConfiguration and can return the @@ -76,10 +75,6 @@ func (e errorFuture) Index() uint64 { return 0 } -func (e errorFuture) Callback() []func() [][]byte { - return nil -} - // deferError can be embedded to allow a future // to provide an error in the future. type deferError struct { @@ -142,7 +137,6 @@ type logFuture struct { log Log response interface{} dispatch time.Time - callback []func() [][]byte } func (l *logFuture) Response() interface{} { @@ -153,10 +147,6 @@ func (l *logFuture) Index() uint64 { return l.log.Index } -func (l *logFuture) Callback() []func() [][]byte { - return l.callback -} - type shutdownFuture struct { raft *Raft } @@ -284,7 +274,6 @@ type appendFuture struct { start time.Time args *AppendEntriesRequest resp *AppendEntriesResponse - callback []func() [][]byte } func (a *appendFuture) Start() time.Time { @@ -298,7 +287,3 @@ func (a *appendFuture) Request() *AppendEntriesRequest { func (a *appendFuture) Response() *AppendEntriesResponse { return a.resp } - -func (a *appendFuture) Callback() []func() [][]byte { - return a.callback -} diff --git a/src/raft/future_test.go b/src/raft/future_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/apply_src.go b/src/raft/fuzzy/apply_src.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/cluster.go b/src/raft/fuzzy/cluster.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/fsm.go b/src/raft/fuzzy/fsm.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/membership_test.go b/src/raft/fuzzy/membership_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/node.go b/src/raft/fuzzy/node.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/partition_test.go b/src/raft/fuzzy/partition_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/readme.md b/src/raft/fuzzy/readme.md old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/resolve.go b/src/raft/fuzzy/resolve.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/simple_test.go b/src/raft/fuzzy/simple_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/slowvoter_test.go b/src/raft/fuzzy/slowvoter_test.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/transport.go b/src/raft/fuzzy/transport.go old mode 100644 new mode 100755 diff --git a/src/raft/fuzzy/verifier.go b/src/raft/fuzzy/verifier.go old mode 100644 new mode 100755 diff --git a/src/raft/inmem_snapshot.go b/src/raft/inmem_snapshot.go old mode 100644 new mode 100755 index 3aa92b3..63e0aa6 --- a/src/raft/inmem_snapshot.go +++ b/src/raft/inmem_snapshot.go @@ -33,7 +33,7 @@ func NewInmemSnapshotStore() *InmemSnapshotStore { // Create replaces the stored snapshot with a new one using the given args func (m *InmemSnapshotStore) Create(version SnapshotVersion, index, term uint64, - configuration Configuration, configurationIndex uint64, trans Transport) (SnapshotSink, error) { + configuration Configuration, configurationIndex uint64, nextClientId uint64, clientResponseCache map[uint64]map[uint64]clientResponseEntry, trans Transport) (SnapshotSink, error) { // We only support version 1 snapshots at this time. if version != 1 { return nil, fmt.Errorf("unsupported snapshot version %d", version) @@ -46,13 +46,15 @@ func (m *InmemSnapshotStore) Create(version SnapshotVersion, index, term uint64, sink := &InmemSnapshotSink{ meta: SnapshotMeta{ - Version: version, - ID: name, - Index: index, - Term: term, - Peers: encodePeers(configuration, trans), - Configuration: configuration, - ConfigurationIndex: configurationIndex, + Version: version, + ID: name, + Index: index, + Term: term, + NextClientId: nextClientId, + ClientResponseCache: clientResponseCache, + Peers: encodePeers(configuration, trans), + Configuration: configuration, + ConfigurationIndex: configurationIndex, }, contents: &bytes.Buffer{}, } diff --git a/src/raft/inmem_snapshot_test.go b/src/raft/inmem_snapshot_test.go old mode 100644 new mode 100755 diff --git a/src/raft/inmem_store.go b/src/raft/inmem_store.go old mode 100644 new mode 100755 diff --git a/src/raft/inmem_transport.go b/src/raft/inmem_transport.go old mode 100644 new mode 100755 index ce37f63..a559e72 --- a/src/raft/inmem_transport.go +++ b/src/raft/inmem_transport.go @@ -115,6 +115,34 @@ func (i *InmemTransport) RequestVote(id ServerID, target ServerAddress, args *Re return nil } +// RecoverData implements the Transport interface. +func (i *InmemTransport) RecoverData(id ServerID, target ServerAddress, args *RecoveryDataRequest, resp *RecoveryDataResponse) error { + rpcResp, err := i.makeRPC(target, args, nil, i.timeout) + if err != nil { + return err + } + + // Copy the result back + out := rpcResp.Response.(*RecoveryDataResponse) + *resp = *out + return nil +} + +// UnfreezeWitness implements the Transport interface. +func (i *InmemTransport) UnfreezeWitness(id ServerID, target ServerAddress, args *UnfreezeRequest, resp *UnfreezeResponse) error { + rpcResp, err := i.makeRPC(target, args, nil, i.timeout) + if err != nil { + return err + } + + // Copy the result back + out := rpcResp.Response.(*UnfreezeResponse) + *resp = *out + return nil +} + + + // InstallSnapshot implements the Transport interface. func (i *InmemTransport) InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error { rpcResp, err := i.makeRPC(target, args, data, 10*i.timeout) diff --git a/src/raft/inmem_transport_test.go b/src/raft/inmem_transport_test.go old mode 100644 new mode 100755 diff --git a/src/raft/integ_test.go b/src/raft/integ_test.go old mode 100644 new mode 100755 diff --git a/src/raft/log.go b/src/raft/log.go old mode 100644 new mode 100755 index 4ade38e..de8cab3 --- a/src/raft/log.go +++ b/src/raft/log.go @@ -31,6 +31,9 @@ const ( // created when a server is added, removed, promoted, etc. Only used // when protocol version 1 or greater is in use. LogConfiguration + + // LogNextClientId is used to set the next client ID across the cluster. + LogNextClientId ) // Log entries are replicated to all members of the Raft cluster @@ -47,8 +50,20 @@ type Log struct { // Data holds the log entry's type-specific data. Data []byte + + // Client ID. Only used for LogCommand. + ClientID uint64 + + // Sequence number of command. Only used for LogCommand. + SeqNo uint64 + + // Keys associated with RPC, used to check for commutativity. + Keys []Key } +// Used to check for operations that conflict in commutativity checks. +type Key []byte + // LogStore is used to provide an interface for storing // and retrieving logs in a durable fashion. type LogStore interface { diff --git a/src/raft/log_cache.go b/src/raft/log_cache.go old mode 100644 new mode 100755 diff --git a/src/raft/log_cache_test.go b/src/raft/log_cache_test.go old mode 100644 new mode 100755 diff --git a/src/raft/membership.md b/src/raft/membership.md old mode 100644 new mode 100755 diff --git a/src/raft/net_transport.go b/src/raft/net_transport.go old mode 100644 new mode 100755 index a918438..2e99910 --- a/src/raft/net_transport.go +++ b/src/raft/net_transport.go @@ -18,8 +18,18 @@ const ( rpcAppendEntries uint8 = iota rpcRequestVote rpcInstallSnapshot + rpcRecoverDataRequest + rpcRecoverDataResponse + rpcUnfreezeRequest + rpcUnfreezeResponse rpcClientRequest - rpcClientResponse + rpcClientResponse + rpcClientIdRequest + rpcClientIdResponse + rpcRecordRequest + rpcRecordResponse + rpcSyncRequest + rpcSyncResponse // DefaultTimeoutScale is the default TimeoutScale in a NetworkTransport. DefaultTimeoutScale = 256 * 1024 // 256KB @@ -254,7 +264,7 @@ func (n *NetworkTransport) getPooledConn(target ServerAddress) *netConn { // getConnFromAddressProvider returns a connection from the server address provider if available, or defaults to a connection using the target server address func (n *NetworkTransport) getConnFromAddressProvider(id ServerID, target ServerAddress) (*netConn, error) { address := n.getProviderAddressOrFallback(id, target) - return n.getConn(address) + return n.getConn(address) } func (n *NetworkTransport) getProviderAddressOrFallback(id ServerID, target ServerAddress) ServerAddress { @@ -336,6 +346,16 @@ func (n *NetworkTransport) RequestVote(id ServerID, target ServerAddress, args * return n.genericRPC(id, target, rpcRequestVote, args, resp) } +// RecoverData implements the Transport interface. +func (n *NetworkTransport) RecoverData(id ServerID, target ServerAddress, args *RecoveryDataRequest, resp *RecoveryDataResponse) error { + return n.genericRPC(id, target, rpcRecoverDataRequest, args, resp) +} + +// UnfreezeWitness implements the Transport interface. +func (n *NetworkTransport) UnfreezeWitness(id ServerID, target ServerAddress, args *UnfreezeRequest, resp *UnfreezeResponse) error { + return n.genericRPC(id, target, rpcUnfreezeRequest, args, resp) +} + // genericRPC handles a simple request/response RPC. func (n *NetworkTransport) genericRPC(id ServerID, target ServerAddress, rpcType uint8, args interface{}, resp interface{}) error { // Get a conn @@ -414,11 +434,11 @@ func (n *NetworkTransport) DecodePeer(buf []byte) ServerAddress { // listen is used to handling incoming connections. func (n *NetworkTransport) listen() { for { - // Accept incoming connections + // Accept incoming connections conn, err := n.stream.Accept() - if err != nil { + if err != nil { if n.IsShutdown() { - n.logger.Printf("Shutting down") + n.logger.Printf("Shutting down") return } n.logger.Printf("[ERR] raft-net: Failed to accept connection: %v", err) @@ -455,7 +475,7 @@ func (n *NetworkTransport) handleConn(conn net.Conn) { // handleCommand is used to decode and dispatch a single command. func (n *NetworkTransport) handleCommand(r *bufio.Reader, dec *codec.Decoder, enc *codec.Encoder) error { - // Get the rpc type + // Get the rpc type rpcType, err := r.ReadByte() if err != nil { return err @@ -499,20 +519,48 @@ func (n *NetworkTransport) handleCommand(r *bufio.Reader, dec *codec.Decoder, en rpc.Command = &req rpc.Reader = io.LimitReader(r, req.Size) - case rpcClientRequest: - var req ClientRequest + case rpcSyncRequest: + var req SyncRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + + case rpcRecordRequest: + var req RecordRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + + case rpcRecoverDataRequest: + var req RecoveryDataRequest if err := dec.Decode(&req); err != nil { return err } rpc.Command = &req + case rpcClientRequest: + var req ClientRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + + case rpcClientIdRequest: + var req ClientIdRequest + if err := dec.Decode(&req); err != nil { + return err + } + rpc.Command = &req + default: return fmt.Errorf("unknown rpc type %d", rpcType) } // Check for heartbeat fast-path if isHeartbeat { - n.heartbeatFnLock.Lock() + n.heartbeatFnLock.Lock() fn := n.heartbeatFn n.heartbeatFnLock.Unlock() if fn != nil { @@ -578,7 +626,7 @@ func decodeResponse(conn *netConn, resp interface{}) (bool, error) { func sendRPC(conn *netConn, rpcType uint8, args interface{}) error { // Write the request type if err := conn.w.WriteByte(rpcType); err != nil { - conn.Release() + conn.Release() return err } @@ -589,8 +637,8 @@ func sendRPC(conn *netConn, rpcType uint8, args interface{}) error { } // Flush - if err := conn.w.Flush(); err != nil { - conn.Release() + if err := conn.w.Flush(); err != nil { + conn.Release() return err } return nil diff --git a/src/raft/net_transport_test.go b/src/raft/net_transport_test.go old mode 100644 new mode 100755 diff --git a/src/raft/observer.go b/src/raft/observer.go old mode 100644 new mode 100755 diff --git a/src/raft/peersjson.go b/src/raft/peersjson.go old mode 100644 new mode 100755 diff --git a/src/raft/peersjson_test.go b/src/raft/peersjson_test.go old mode 100644 new mode 100755 diff --git a/src/raft/raft.go b/src/raft/raft.go old mode 100644 new mode 100755 index 230e224..4e05025 --- a/src/raft/raft.go +++ b/src/raft/raft.go @@ -3,13 +3,14 @@ package raft import ( "bytes" "container/list" + "crypto/sha256" + "encoding/binary" + "encoding/json" "fmt" + "github.com/armon/go-metrics" "io" "io/ioutil" "time" - "encoding/json" - "github.com/armon/go-metrics" - "sync" ) const ( @@ -17,9 +18,11 @@ const ( ) var ( - keyCurrentTerm = []byte("CurrentTerm") - keyLastVoteTerm = []byte("LastVoteTerm") - keyLastVoteCand = []byte("LastVoteCand") + keyCurrentTerm = []byte("CurrentTerm") + keyLastVoteTerm = []byte("LastVoteTerm") + keyLastVoteCand = []byte("LastVoteCand") + keyWitnessStateKeys = []byte("WitnessStateKeys") + keyWitnessStateRecords = []byte("WitnessStateRecords") ) // getRPCHeader returns an initialized RPCHeader struct for the given @@ -76,12 +79,6 @@ type commitTuple struct { future *logFuture } -type clientSession struct { - lastContact time.Time - heartbeatCh chan bool - endSessionCommand []byte -} - // leaderState is state that is used while we are a leader. type leaderState struct { commitCh chan struct{} @@ -90,8 +87,14 @@ type leaderState struct { replState map[ServerID]*followerReplication notify map[*verifyFuture]struct{} stepDown chan struct{} - clientSessions map[ServerAddress]*clientSession - clientSessionsLock sync.RWMutex +} + +// Tuple used to uniquely identify RPC using RIFL. +type ClientSeqNo struct { + // Identifies unique client. + ClientID uint64 + // Identifies unique RPC from a client. + SeqNo uint64 } // setLeader is used to modify the current leader of the cluster @@ -182,7 +185,7 @@ func (r *Raft) runFollower() { b.respond(r.liveBootstrap(b.configuration)) case <-heartbeatTimer: - // Restart the heartbeat timer + // Restart the heartbeat timer heartbeatTimer = randomTimeout(r.conf.HeartbeatTimeout) // Check if we have had a successful contact @@ -346,7 +349,6 @@ func (r *Raft) runLeader() { r.leaderState.replState = make(map[ServerID]*followerReplication) r.leaderState.notify = make(map[*verifyFuture]struct{}) r.leaderState.stepDown = make(chan struct{}, 1) - r.leaderState.clientSessions = make(map[ServerAddress]*clientSession) // Cleanup state on step down defer func() { @@ -422,6 +424,9 @@ func (r *Raft) runLeader() { } r.dispatchLogs([]*logFuture{noop}) + //TODO: make sure it's safe to replay from witnesses here (can't start having client requests) + r.recoverWithWitness() + // Sit in the leader loop until we step down r.leaderLoop() } @@ -472,6 +477,74 @@ func (r *Raft) startStopReplication() { } } +// Replay requests from a witness. Should only be called when a new leader +// is elected. Witness is set to recovery mode and sends all saved client +// requests, which are replayed by the new master. +func (r *Raft) recoverWithWitness() { + // Construct request. + req := &RecoveryDataRequest{ + RPCHeader: r.getRPCHeader(), + } + resp := &RecoveryDataResponse{} + // when get it count, at teh end iterate over and if have >= ceil(f/2) + 1 counts then can commit + entryCounts := make(map[uint64]map[uint64]uint64, 0) + uniqueEntries := make(map[uint64]map[uint64]Log) + // Choose f+1 witnesses and send RecoveryDataRequest. + quorumSz := len(r.configurations.latest.Servers) / 2 + 1 + chosenWitnesses := make([]Server, 0) + for _, witness := range r.configurations.latest.Servers { + if witness.ID == r.localID { + // Don't choose self to recover from. + continue + } + err := r.trans.RecoverData(witness.ID, witness.Address, req, resp) + if err == nil { + chosenWitnesses = append(chosenWitnesses, witness) + for _,entry := range resp.Entries { + if entryCounts[entry.ClientID] == nil { + entryCounts[entry.ClientID] = make(map[uint64]uint64) + uniqueEntries[entry.ClientID] = make(map[uint64]Log) + } + entryCounts[entry.ClientID][entry.SeqNo] += 1 // initialized to 0 + uniqueEntries[entry.ClientID][entry.SeqNo] = entry + } + } else { + // Cannot recover from this witness. + r.logger.Printf("[ERR] Failed to recover from witness %v: %v", witness, err) + } + // Contacted a quorum. Can assume will always reach this state (otherwise couldn't elect leader). + if len(chosenWitnesses) == quorumSz { + break + } + } + // Execute operations that are stored at ceil(f/2) + 1 witnesses synchronously. + minMajoritySz := (quorumSz / 2) + 1 + for clientID := range entryCounts { + for seqNo := range entryCounts[clientID] { + if entryCounts[clientID][seqNo] < uint64(minMajoritySz) { + continue + } + entry := uniqueEntries[clientID][seqNo] + var err error + // Can disregard return value. + r.applySynchronousCommand(&entry, &err) + if err != nil { + r.logger.Printf("[ERR] Error executing operation retrieved from witness") + } + } + } + for _,chosenWitness := range chosenWitnesses { + // Unfreeze all chosen f+1 witnesses. + unfreezeReq := &UnfreezeRequest{ + RPCHeader: r.getRPCHeader(), + } + err := r.trans.UnfreezeWitness(chosenWitness.ID, chosenWitness.Address, unfreezeReq, &UnfreezeResponse{}) + if err != nil { + r.logger.Printf("[ERR] Failed to unfreeze witness %v: %v", chosenWitness, err) + } + } +} + // configurationChangeChIfStable returns r.configurationChangeCh if it's safe // to process requests from it, or nil otherwise. This must only be called // from the main thread. @@ -512,7 +585,7 @@ func (r *Raft) leaderLoop() { case <-r.leaderState.commitCh: // Process the newly committed entries - oldCommitIndex := r.getCommitIndex() + oldCommitIndex := r.getCommitIndex() commitIndex := r.leaderState.commitment.getCommitIndex() r.setCommitIndex(commitIndex) @@ -755,7 +828,7 @@ func (r *Raft) restoreUserSnapshot(meta *SnapshotMeta, reader io.Reader) error { // Dump the snapshot. Note that we use the latest configuration, // not the one that came with the snapshot. sink, err := r.snapshots.Create(version, lastIndex, term, - r.configurations.latest, r.configurations.latestIndex, r.trans) + r.configurations.latest, r.configurations.latestIndex, r.nextClientId, r.clientResponseCache, r.trans) if err != nil { return fmt.Errorf("failed to create snapshot: %v", err) } @@ -842,7 +915,7 @@ func (r *Raft) appendConfigurationEntry(future *configurationChangeFuture) { // dispatchLog is called on the leader to push a log to disk, mark it // as inflight and begin replication of it. func (r *Raft) dispatchLogs(applyLogs []*logFuture) { - now := time.Now() + now := time.Now() defer metrics.MeasureSince([]string{"raft", "leader", "dispatchLog"}, now) term := r.getCurrentTerm() @@ -915,7 +988,7 @@ func (r *Raft) processLogs(index uint64, future *logFuture) { // processLog is invoked to process the application of a single committed log entry. func (r *Raft) processLog(l *Log, future *logFuture) { - switch l.Type { + switch l.Type { case LogBarrier: // Barrier is handled by the FSM fallthrough @@ -930,10 +1003,33 @@ func (r *Raft) processLog(l *Log, future *logFuture) { } } + // Garbage collect at witnesses. + clientSeqNo := ClientSeqNo{ + ClientID: l.ClientID, + SeqNo: l.SeqNo, + } + // TODO: also need to delete key + records, keys := stableGetWitnessState(r.stable) + delete(records, clientSeqNo) + for _, key := range l.Keys { + hash := getKeyHash(key) + if bytes.Compare(key, keys[hash]) == 0 { + delete(keys, hash) + } + } + stableSetWitnessState(r.stable, records, keys) + // Return so that the future is only responded to // by the FSM handler when the application is done return + case LogNextClientId: + var nextClientId uint64 + if err := decodeMsgPack(l.Data, &nextClientId); err != nil { + panic(fmt.Errorf("failed to decode next cliend id: %v", err)) + } + r.nextClientId = nextClientId + case LogConfiguration: case LogAddPeerDeprecated: case LogRemovePeerDeprecated: @@ -950,6 +1046,60 @@ func (r *Raft) processLog(l *Log, future *logFuture) { } } +// Hashes a key to be used in a direct-mapped cache. +// Params: +// - key: Key to get hash value of +// Returns: hash of key +func getKeyHash(key Key) uint32 { + hash := sha256.Sum256(key) + hashSlice := hash[:] + return binary.LittleEndian.Uint32(hashSlice) +} + +// stableSetWitnessStorage writes the witnessState to stable storage. +// Should only be called if r.witnessState.Lock is held. Panics if +// failure. +func stableSetWitnessState(stable StableStore, records map[ClientSeqNo]Log, keys map[uint32]Key) { + recordsBuf, err1 := encodeMsgPack(records) + if err1 != nil { + panic(fmt.Errorf("failed to encode witness state records: %v", err1)) + } + err2 := stable.Set(keyWitnessStateRecords, recordsBuf.Bytes()) + if err2 != nil { + panic(fmt.Errorf("failed to write witness state records to stable storage: %v", err2)) + } + keysBuf, err3 := encodeMsgPack(keys) + if err3 != nil { + panic(fmt.Errorf("failed to encode witness state keys: %v", err3)) + } + err4 := stable.Set(keyWitnessStateKeys, keysBuf.Bytes()) + if err4 != nil { + panic(fmt.Errorf("failed to write witness state keys to stable storage: %v", err4)) + } +} + +func stableGetWitnessState(stable StableStore) (map[ClientSeqNo]Log, map[uint32]Key) { + recordsBuf, err1 := stable.Get(keyWitnessStateRecords) + if err1 != nil { + panic(fmt.Errorf("failed to read witness state records from stable storage: %v", err1)) + } + records := make(map[ClientSeqNo]Log) + err2 := decodeMsgPack(recordsBuf, &records) + if err2 != nil { + panic(fmt.Errorf("failed to decode witness state records: %v", err2)) + } + keysBuf, err3 := stable.Get(keyWitnessStateKeys) + if err3 != nil { + panic(fmt.Errorf("failed to read witness state keys from stable storage: %v", err3)) + } + keys := make(map[uint32]Key) + err4 := decodeMsgPack(keysBuf, &keys) + if err4 != nil { + panic(fmt.Errorf("failed to decode witness state keys: %v", err4)) + } + return records, keys +} + // processRPC is called to handle an incoming RPC request. This must only be // called from the main thread. func (r *Raft) processRPC(rpc RPC) { @@ -958,16 +1108,26 @@ func (r *Raft) processRPC(rpc RPC) { return } - switch cmd := rpc.Command.(type) { + switch cmd := rpc.Command.(type) { case *AppendEntriesRequest: r.appendEntries(rpc, cmd) case *RequestVoteRequest: r.requestVote(rpc, cmd) case *InstallSnapshotRequest: r.installSnapshot(rpc, cmd) - case *ClientRequest: - r.clientRequest(rpc, cmd) - default: + case *RecordRequest: + r.recordRequest(rpc, cmd) + case *SyncRequest: + r.syncRequest(rpc, cmd) + case *RecoveryDataRequest: + r.recoveryDataRequest(rpc, cmd) + case *UnfreezeRequest: + r.unfreezeRequest(rpc, cmd) + case *ClientRequest: + r.clientRequest(rpc, cmd) + case *ClientIdRequest: + r.clientIdRequest(rpc, cmd) + default: r.logger.Printf("[ERR] raft: Got unexpected command: %#v", rpc.Command) rpc.Respond(nil, fmt.Errorf("unexpected command")) } @@ -1291,7 +1451,7 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) { } version := getSnapshotVersion(r.protocolVersion) sink, err := r.snapshots.Create(version, req.LastLogIndex, req.LastLogTerm, - reqConfiguration, reqConfigurationIndex, r.trans) + reqConfiguration, reqConfigurationIndex, r.nextClientId, r.clientResponseCache, r.trans) if err != nil { r.logger.Printf("[ERR] raft: Failed to create snapshot to install: %v", err) rpcErr = fmt.Errorf("failed to create snapshot: %v", err) @@ -1363,105 +1523,283 @@ func (r *Raft) installSnapshot(rpc RPC, req *InstallSnapshotRequest) { return } -// Handle a clientRequest RPC from client. -func (r *Raft) clientRequest(rpc RPC, c *ClientRequest) { - leader := r.Leader() - resp := &ClientResponse{ - Success : false, - LeaderAddress : leader, +// Handle a recoveryDataRequest from new leader to witness. Returns +// all entries stored at witness and freezes witness until unfreeze +// request is sent. +// Params: +// - rpc: RPC object used to send a response. +// - req: Recovery DAta Request being handled. +func (r *Raft) recoveryDataRequest(rpc RPC, req *RecoveryDataRequest) { + logMap,_ := stableGetWitnessState(r.stable) + logs := make([]Log, 0) + for _,log := range logMap { + logs = append(logs, log) } - // Have we contacted the leader? - var rpcErr error - if (r.getState() == Leader) { - // Maintain sessions - if (c.KeepSession) { - r.leaderState.clientSessionsLock.RLock() - _, ok := r.leaderState.clientSessions[c.ClientAddr] - r.leaderState.clientSessionsLock.RUnlock() - // If first session, start heartbeat loop. - if c.EndSessionCommand != nil { - if !ok { - r.leaderState.clientSessionsLock.Lock() - r.leaderState.clientSessions[c.ClientAddr] = &clientSession{} - r.leaderState.clientSessions[c.ClientAddr].heartbeatCh = make (chan bool, 1) - r.leaderState.clientSessions[c.ClientAddr].endSessionCommand = c.EndSessionCommand - r.leaderState.clientSessionsLock.Unlock() - go r.clientSessionHeartbeatLoop(c.ClientAddr) - } - r.leaderState.clientSessionsLock.RLock() - ch := r.leaderState.clientSessions[c.ClientAddr].heartbeatCh - r.leaderState.clientSessionsLock.RUnlock() - ch <- true - } - } - // Apply all commands in client request. - go func(r *Raft, resp *ClientResponse, rpc RPC, c *ClientRequest) { - var rpcErr error - for _,entry := range(c.Entries) { - if (entry != nil) { - r.applyCommand(entry.Data, resp, &rpcErr) - } - } - rpc.Respond(resp, rpcErr) - }(r, resp, rpc, c) - } else { - rpcErr = ErrNotLeader - resp.Success = false - rpc.Respond(resp, rpcErr) + resp := &RecoveryDataResponse { + RPCHeader: r.getRPCHeader(), + Entries: logs, } + r.frozenLock.Lock() + r.frozen = true + r.frozenLock.Unlock() + rpc.Respond(resp, nil) } -// Apply a command from leader to all raft FSMs. */ -func (r *Raft) applyCommand(command []byte, resp *ClientResponse, rpcErr *error) { - f := r.Apply(command, 0) - if f.Error() != nil { - r.logger.Printf("err: %v",f.Error()) - *rpcErr = f.Error() - resp.Success = false +// Handle a unfreezeRequest from new leader to witness. Sent after +// recoveryDataRequest to allow witness to start receiving client +// record requests again. +func (r *Raft) unfreezeRequest(rpc RPC, req *UnfreezeRequest) { + r.frozenLock.Lock() + r.frozen = false + r.frozenLock.Unlock() + resp := &UnfreezeResponse { + RPCHeader: r.getRPCHeader(), + } - /* If callback, make leader execute callback */ - var nextCommands [][]byte - callbacks := f.Callback() - for _,callback := range callbacks { - commands := callback() - for _, command := range commands { - nextCommands = append(nextCommands, command) + rpc.Respond(resp, nil) +} + +// Handle a clientIdRequest from client. Can only be handled at +// the leader. Assigns a new client ID and replicates the client +// ID to followers. +// Params: +// - rpc: RPC object used to send a response. +// - c: Client Id Request being handled. +func (r *Raft) clientIdRequest(rpc RPC, c *ClientIdRequest) { + leader := r.Leader() + resp := &ClientIdResponse{ + LeaderAddress: leader, + } + // Can only assign client IDs at the leader. + if r.getState() == Leader { + resp.ClientID = r.nextClientId + r.nextClientId += 1 + r.clientResponseLock.Lock() + r.clientResponseCache[resp.ClientID] = make(map[uint64]clientResponseEntry) + r.clientResponseLock.Unlock() + r.logger.Printf("Client ID to send is %v", r.nextClientId) + go func(r *Raft, resp *ClientIdResponse, rpc RPC) { + f := r.SendNextClientId(0) + if f.Error() != nil { + r.logger.Printf("err :%v", f.Error()) + } + rpc.Respond(resp, f.Error()) + }(r, resp, rpc) + } else { + rpc.Respond(resp, ErrNotLeader) + } +} + +// Handle a syncRequest from client. Can only be handled at the +// leader, and required a valid client ID. Synchronously +// executes the client command. +// Params: +// - rpc: RPC object used to send a response +// - sync: Sync Request being handled. +func (r *Raft) syncRequest(rpc RPC, sync *SyncRequest) { + leader := r.Leader() + r.logger.Printf("leader: ", leader) + resp := &SyncResponse{ + Success: false, + LeaderAddress: leader, + } + // Check if client ID is valid. + r.clientResponseLock.RLock() + _, ok := r.clientResponseCache[sync.Entry.ClientID] + r.clientResponseLock.RUnlock() + if !ok { + rpc.Respond(resp, ErrBadClientId) + return + } + // Check if request has already been made. + // Have we contacted the leader? + if r.getState() == Leader { + // Apply all commands in client request. + r.goFunc(func() { + var rpcErr error + resp.ResponseData = r.applySynchronousCommand(sync.Entry, &rpcErr) + resp.Success = true + rpc.Respond(resp, rpcErr) + }) + } else { + resp.Success = false + rpc.Respond(resp, ErrNotLeader) + } +} + +// Handle a recordRequest from client. Can only be handled +// at a witness, not the leader. Records an operation successfully +// if it is commutative with other stored operations. +func (r *Raft) recordRequest(rpc RPC, record *RecordRequest) { + // Master can't act as a witness. + if r.getState() == Leader { + resp := &RecordResponse{ + Success: false, + } + rpc.Respond(resp, ErrNotWitness) + return + } + + // Can't accept record request if frozen. + r.frozenLock.RLock() + isFrozen := r.frozen + r.frozenLock.RUnlock() + if isFrozen { + resp := &RecordResponse { + Success: false, } + rpc.Respond(resp, ErrWitnessFrozen) + return } - data, _:= json.Marshal(f.Response()) - resp.ResponseData = data - resp.Success = true - for _,nextCommand := range nextCommands { - r.applyCommand(nextCommand, resp, rpcErr) - } -} -/* Manage a client session. */ -func (r *Raft) clientSessionHeartbeatLoop(clientAddr ServerAddress) { - r.leaderState.clientSessionsLock.RLock() - ch := r.leaderState.clientSessions[clientAddr].heartbeatCh - r.leaderState.clientSessionsLock.RUnlock() - for { - select { - case <- ch: - r.leaderState.clientSessionsLock.Lock() - r.leaderState.clientSessions[clientAddr].lastContact = time.Now() - r.leaderState.clientSessionsLock.Unlock() - case <- time.After(30*time.Second): - r.logger.Printf("ending client session") - var err error - r.leaderState.clientSessionsLock.RLock() - command := r.leaderState.clientSessions[clientAddr].endSessionCommand - r.leaderState.clientSessionsLock.RUnlock() - if command != nil { - r.applyCommand(command, &ClientResponse{}, &err) - } - r.leaderState.clientSessionsLock.Lock() - delete(r.leaderState.clientSessions, clientAddr) - r.leaderState.clientSessionsLock.Unlock() - return + // Can't accept record request if sending to stale set of witnesses. + if record.Term < r.getCurrentTerm() { + resp := &RecordResponse { + Success: false, } + rpc.Respond(resp, ErrStaleTerm) + return } + + success := r.storeIfCommutative(record.Entry) + r.logger.Printf("witness says client req is commutative: %b", success) + // Respond to client. + resp := &RecordResponse{ + Success: success, + } + + if success { + rpc.Respond(resp, nil) + } else { + rpc.Respond(resp, ErrNotCommutative) + } +} + +// Check if an operation is commutative with other operations +// stored at the witness and if this is the case, store it and +// return true, otherwise return false. +// Params: +// - log: Log entry of type LogCommand to store. +// Return true if successfully stored (must be commutative with +// other operations, false otherwise. +func (r *Raft) storeIfCommutative(log *Log) bool { + records, keys := stableGetWitnessState(r.stable) + + // Check if operation involving key already stored at witness or no + // space to store in direct-associative cache. + for _, key := range log.Keys { + hash := getKeyHash(key) + if _, ok := keys[hash]; ok { + return false + } + } + + // Add keys separately in case keys included multiple times by client. + for _, key := range log.Keys { + hash := getKeyHash(key) + keys[hash] = key + } + // Record RPC in witness. + clientSeqNo := ClientSeqNo{ + ClientID: log.ClientID, + SeqNo: log.SeqNo, + } + records[clientSeqNo] = *log + + // Write updates to stable storage. + stableSetWitnessState(r.stable, records, keys) + + return true +} + +// Handle a clientRequest RPC from client. Can only be handled at +// the leader. Requires a valid client ID. Only executes locally +// and reports not synced if commutative, otherwise replicates +// synchronously to followers and reports synced. +// Params: +// - rpc: RPC object used to send a response. +// - c: Client Request object being handled. +func (r *Raft) clientRequest(rpc RPC, c *ClientRequest) { + leader := r.Leader() + resp := &ClientResponse{ + Success: false, + LeaderAddress: leader, + } + // Check if client ID is valid. + r.clientResponseLock.RLock() + _, ok := r.clientResponseCache[c.Entry.ClientID] + r.clientResponseLock.RUnlock() + if !ok { + rpc.Respond(resp, ErrBadClientId) + return + } + // Check if request has already been made. + // Have we contacted the leader? + if r.getState() == Leader { + // Apply all commands in client request. + r.goFunc(func() { + var rpcErr error + r.applyCommand(c.Entry, resp, &rpcErr) + rpc.Respond(resp, rpcErr) + }) + } else { + resp.Success = false + rpc.Respond(resp, ErrNotLeader) + } +} + +// Apply a command locally if it is commutative (not synced) or +// replicate to followers (synced). Sets fields in resp based on +// execution of request and if synced. +// Params: +// - log: Log entry to apply, type LogCommand. +// - resp: Response to populate after completing command. +// - rpcErr: Pointer to error to set if necessary. +func (r *Raft) applyCommand(log *Log, resp *ClientResponse, rpcErr *error) { + commutative := r.storeIfCommutative(log) + if commutative { + // Apply locally, store in witness cache, and respond + resp.ResponseData = r.applyCommutativeCommand(log, rpcErr) + resp.Synced = false + } else { + // Sync all previous requests and execute this request synchronously. + resp.ResponseData = r.applySynchronousCommand(log, rpcErr) + resp.Synced = true + } + resp.LeaderAddress = r.Leader() +} + +// Apply a command locally. Should only be called by the leader if +// the leader has confirmed that the operation is commutative and +// is stored in its set of current operations. +// Params: +// - log: Log entry to apply commutatively, type LogCommand. +// - rpcErr: Pointer to error to set if necessary. +// Returns: byte array containing response to applying command. +func (r *Raft) applyCommutativeCommand(log *Log, rpcErr *error) []byte { + // Apply locally, store in witness cache, and respond + var response interface{} + r.applyCommandLocally(log, &response) + data, _ := json.Marshal(response) + // Replicate to client asynchronously + r.goFunc(func() { r.Apply(log, 0) }) + return data +} + +// Replicate a command to followers. Should be called if leader has +// confirmed that an operation is not commutative. +// Params: +// - log: Log entry to apply commutatively, type LogCommand. +// - rpcErr: Pointer to error to set if necessary. +// Returns: byte array containing reponse to applying command. +func (r *Raft) applySynchronousCommand(log *Log, rpcErr *error) []byte { + f := r.Apply(log, 0) + if f.Error() != nil { + r.logger.Printf("err: %v", f.Error()) + *rpcErr = f.Error() + } + data, _ := json.Marshal(f.Response()) + return data } // setLastContact is used to set the last contact time to now diff --git a/src/raft/raft_test.go b/src/raft/raft_test.go old mode 100644 new mode 100755 diff --git a/src/raft/replication.go b/src/raft/replication.go old mode 100644 new mode 100755 index 2b5ec47..e631b5a --- a/src/raft/replication.go +++ b/src/raft/replication.go @@ -182,7 +182,7 @@ START: } // Make the RPC call - start = time.Now() + start = time.Now() if err := r.trans.AppendEntries(s.peer.ID, s.peer.Address, &req, &resp); err != nil { r.logger.Printf("[ERR] raft: Failed to AppendEntries to %v: %v", s.peer, err) s.failures++ @@ -337,7 +337,7 @@ func (r *Raft) heartbeat(s *followerReplication, stopCh chan struct{}) { var resp AppendEntriesResponse for { // Wait for the next heartbeat interval or forced notify - select { + select { case <-s.notifyCh: case <-randomTimeout(r.conf.HeartbeatTimeout / 10): case <-stopCh: diff --git a/src/raft/session.go b/src/raft/session.go old mode 100644 new mode 100755 index 2355dce..1fcf09a --- a/src/raft/session.go +++ b/src/raft/session.go @@ -1,271 +1,337 @@ package raft import ( - "net" - "time" - "fmt" - "errors" - "bufio" - - "github.com/hashicorp/go-msgpack/codec" + "errors" + "sync" + "time" + "math" ) -type Session struct { - trans *NetworkTransport - currConn *netConn - raftServers []ServerAddress - stopCh chan bool - active bool - endSessionCommand []byte +// Client library for Raft. Provides session abstraction that handles starting +// a session, making requests, and closing a session. + +// Connection and associated lock for synchronization. +type syncedConn struct { + // Connection to Raft server. + conn *netConn + // Lock protecting conn. + lock sync.Mutex } -// Send request to cluster without using session. -func SendSingletonRequestToCluster(addrs []ServerAddress, data []byte, resp *ClientResponse) error { - if resp == nil { - return errors.New("Response is nil") - } - // Send RPC - clientRequest := ClientRequest{ - RPCHeader: RPCHeader{ - ProtocolVersion: ProtocolVersionMax, - }, - Entries: []*Log{ - &Log{ - Type: LogCommand, - Data: data, - }, - }, - } - return sendSingletonRpcToActiveLeader(addrs, &clientRequest, resp) +// Session abstraction used to make requests to Raft cluster. +type Session struct { + // Client network layer. + trans *NetworkTransport + // Connections to all Raft nodes. + conns []syncedConn + // Leader is index into conns or addrs arrays. + leader int + leaderLock sync.RWMutex + // Term tracks the current Raft term to avoid stale witnesses. + term uint64 + termLock sync.RWMutex + // Addresses of all Raft servers. + addrs []ServerAddress + // Client ID assigned by cluster for use in RIFL. + clientID uint64 + // Sequence number of next RPC for use in RIFL. + rpcSeqNo uint64 + // Size of superquorum (number of witnesses need to record commutative operation in). + superquorumSz int } +// Open client session to cluster. +// Params: +// - trans: Client transport layer for networking opertaions +// - addrs: Addresses of all Raft servers +// Return: created session +func CreateClientSession(trans *NetworkTransport, addrs []ServerAddress) (*Session, error) { + session := &Session{ + trans: trans, + conns: make([]syncedConn, len(addrs)), + leader: -1, + addrs: addrs, + rpcSeqNo: 0, + } + f := len(addrs) / 2 // Raft needs 2f+1 replicas + session.superquorumSz = f + int(math.Ceil(float64(f)/2.0)) + 1 -/* Open client session to cluster. Takes clientID, server addresses for all servers in cluster, and returns success or failure. - Start go routine to periodically send heartbeat messages and switch to new leader when necessary. */ -func CreateClientSession(trans *NetworkTransport, addrs []ServerAddress, endSessionCommand []byte) (*Session, error) { - session := &Session{ - trans: trans, - raftServers: addrs, - active: true, - stopCh : make(chan bool, 1), - endSessionCommand: endSessionCommand, - } - var err error - session.currConn, err = findActiveServerWithTrans(addrs, trans) - if err != nil { - return nil ,err - } - if endSessionCommand != nil { - go session.sessionKeepAliveLoop() - } - return session, nil -} + // Initialize syncedConn array. + for i := range session.conns { + session.conns[i] = syncedConn{} + } + + // Open connections to all raft servers. + var err error + for i, addr := range addrs { + session.conns[i].conn, err = trans.getConn(addr) + if err == nil { + session.leader = i + } + } + // Report error if can't connect to any server. + if session.leader == -1 { + return nil, ErrNoActiveServers + } -/* Make request to open session. */ -func (s *Session) SendRequest(data []byte, resp *ClientResponse) error { - if !s.active { - return errors.New("Inactive client session.") - } - if resp == nil { - return errors.New("Response is nil") - } - req := ClientRequest { - RPCHeader: RPCHeader { - ProtocolVersion: ProtocolVersionMax, - }, - Entries: []*Log{ - &Log { - Type: LogCommand, - Data: data, - }, - }, - ClientAddr: s.trans.LocalAddr(), - EndSessionCommand: s.endSessionCommand, - KeepSession: true, - } - return s.sendToActiveLeader(&req, resp) + // Get a client ID from the leader. + req := ClientIdRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + }, + } + resp := ClientIdResponse{} + err = session.sendToActiveLeader(&req, &resp, rpcClientIdRequest) + if err != nil { + return nil, err + } + session.clientID = resp.ClientID + return session, nil } - -/* Close client session. Kill heartbeat go routine. */ -func (s *Session) CloseClientSession() error { - if !s.active { - return errors.New("Inactive client session") - } - s.stopCh <- true - fmt.Println("closed client session") - return nil +// Make request to Raft cluster using open session. +// Params: +// - data: client request to send to cluster +// - keys: array of keys that request updates, used in commutativity checks +// - resp: pointer to response that will be populated +func (s *Session) SendRequest(data []byte, keys []Key, resp *ClientResponse) error { + seqNo := s.rpcSeqNo + s.rpcSeqNo++ + return s.SendRequestWithSeqNo(data, keys, resp, seqNo) } -/* Loop to send and receive heartbeat messages. */ -func (s *Session) sessionKeepAliveLoop() { - for s.active { - select { - case <-time.After(10*time.Second): - case <- s.stopCh: - s.active = false - } - if !s.active { - fmt.Println("client session no longer active") - return - } - // Send RPC - heartbeat := ClientRequest{ - RPCHeader: RPCHeader{ - ProtocolVersion: ProtocolVersionMax, - }, - Entries: nil, - ClientAddr: s.trans.LocalAddr(), - KeepSession: true, - EndSessionCommand: s.endSessionCommand, - } - s.sendToActiveLeader(&heartbeat, &ClientResponse{}) - } - fmt.Println("client session no longer active") +// Make request to Raft cluster using open session and specifying a sequence +// number. Only use for testing! (Use SendRequest in production). +// Params: +// - data: client request to send to cluster +// - keys: array of keys that request updates, used in commutativity checks +// - resp: pointer to response that will be populated +// - seqno: sequence number to use for request (for testing purposes) +func (s *Session) SendRequestWithSeqNo(data []byte, keys []Key, resp *ClientResponse, seqno uint64) error { + if resp == nil { + return errors.New("Response is nil") + } + req := ClientRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + }, + Entry: &Log{ + Type: LogCommand, + Data: data, + Keys: keys, + ClientID: s.clientID, + SeqNo: seqno, + }, + } + return s.sendToActiveLeader(&req, resp, rpcClientRequest) } -func (s *Session) sendToActiveLeader(request *ClientRequest, response *ClientResponse) error { - var err error = errors.New("") - retries := 5 - /* Send heartbeat to active leader. Connect to active leader if connection no longer to active leader. */ - for err != nil { - if retries <= 0 { - s.active = false - return errors.New("Failed to find active leader.") - } - if s.currConn == nil { - s.active = false - return errors.New("No current connection.") - } - err = sendRPC(s.currConn, rpcClientRequest, request) - /* Try another server if server went down. */ - for err != nil { - if retries <= 0 { - s.active = false - return errors.New("Failed to find active leader.") - } - s.currConn, err = findActiveServerWithTrans(s.raftServers, s.trans) - if err != nil || s.currConn == nil { - s.active = false - return errors.New("No active server found.") - } - retries-- - err = sendRPC(s.currConn, rpcClientRequest, request) - } - /* Decode response if necesary. Try new server to find leader if necessary. */ - if (s.currConn == nil) { - return errors.New("Failed to find active leader.") - } - _, err = decodeResponse(s.currConn, &response) - if err != nil { - if response != nil && response.LeaderAddress != "" { - s.currConn, _ = s.trans.getConn(response.LeaderAddress) - } else { - /* Wait for leader to be elected. */ - time.Sleep(1000*time.Millisecond) - } - } - retries-- - } - return nil +// Close client session. +// TODO: GC client request tables. +func (s *Session) CloseClientSession() error { + return nil } -func sendSingletonRpcToActiveLeader(addrs []ServerAddress, request *ClientRequest, response *ClientResponse) error { - retries := 5 - conn, err := findActiveServerWithoutTrans(addrs) - if err != nil { - return errors.New("No active server found.") - } - err = errors.New("") - /* Send heartbeat to active leader. Connect to active leader if connection no longer to active leader. */ - for err != nil { - if conn == nil { - return errors.New("No current connection.") - } - if retries <= 0 { - conn.conn.Close() - return errors.New("Failed to find active leader.") - } - err = sendRPC(conn, rpcClientRequest, request) - /* Try another server if server went down. */ - for err != nil { - fmt.Println("error sending: ", err) - if retries <= 0 { - if conn != nil { - conn.conn.Close() - } - return errors.New("Failed to find active leader.") - } - conn, err = findActiveServerWithoutTrans(addrs) - if err != nil || conn == nil { - if conn != nil { - conn.conn.Close() - } - return errors.New("No active server found.") - } - retries-- - err = sendRPC(conn, rpcClientRequest, request) - } - /* Decode response if necesary. Try new server to find leader if necessary. */ - _, err = decodeResponse(conn, &response) - if err != nil { - if response.LeaderAddress != "" { - conn, _ = buildNetConn(response.LeaderAddress) - } else { - /* Wait for leader to be elcted. */ - time.Sleep(1000*time.Millisecond) - } - } - retries-- - } - conn.conn.Close() - return nil +// Make request to Raft cluster following CURP protocol. Send to witnesses and +// master simultaneously to complete in 1 RTT. +// Params: +// - data: client request to send to cluster +// - keys: array of keys that request updates, used in commutativity checks +// - resp: pointer to response that will be populated +// - seqno: sequence number to use for request (for testing purposes) +func (s *Session) SendFastRequest(data []byte, keys []Key, resp *ClientResponse) { + seqNo := s.rpcSeqNo + s.rpcSeqNo++ + s.SendFastRequestWithSeqNo(data, keys, resp, seqNo) } -func findActiveServerWithTrans(addrs []ServerAddress, trans *NetworkTransport) (*netConn, error) { - for _, addr := range(addrs) { - conn, err := trans.getConn(addr) - if err == nil { - return conn, nil - } - } - return nil, errors.New("No active raft servers.") +// Make request to Raft cluster following CURP protocol. Send to witnesses and +// master simultaneously to complete in 1 RTT. Specify sequence number for testing +// purposes. Only use SendFastRequest in production! +// Params: +// - data: client request to send to cluster +// - keys: array of keys that request updates, used in commutativity checks +// - resp: pointer to response that will be populated +// - seqno: sequence number to use for request (for testing purposes) +func (s *Session) SendFastRequestWithSeqNo(data []byte, keys []Key, resp *ClientResponse, seqNo uint64) { + req := ClientRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + }, + Entry: &Log{ + Type: LogCommand, + Data: data, + Keys: keys, + ClientID: s.clientID, + SeqNo: seqNo, + }, + } + + // Repeat until success. + // TODO: only retry limited number of times + for true { + resultCh := make(chan bool, len(s.addrs)) + go func(s *Session, req *ClientRequest, resp *ClientResponse, resultCh *chan bool) { + err := s.sendToActiveLeader(req, resp, rpcClientRequest) + if err != nil { + *resultCh <- false + } else { + *resultCh <- true + } + }(s, &req, resp, &resultCh) + s.sendToAllWitnesses(req.Entry, &resultCh) + + success := true + + // Wait for superquorum to respond. + for i := 0; i <= s.superquorumSz; i += 1 { // TODO: should this be len + 1? + result := <-resultCh + success = success && result + // TODO: if synced, automatically succeed, otherwise if not success need to retry + } + if success || resp.Synced { + return + } + // If fail to record at witnesses and not synced, issue sync request. + sync := &SyncRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + }, + Entry: req.Entry, + } + var syncResp SyncResponse + err := s.sendToActiveLeader(sync, &syncResp, rpcSyncRequest) + if err == nil && syncResp.Success { + return + } + // Failed to sync. Try everything again + } + } -func findActiveServerWithoutTrans(addrs []ServerAddress) (*netConn, error) { - for _, addr := range(addrs) { - conn, err := buildNetConn(addr) - if err == nil { - return conn, nil - } - if conn != nil { - conn.conn.Close() - } +// Send log entry to all witnesses in parallel and put results (success +// or failure) into channel. Get all values from channel to ensure that +// RPCs to witnesses have completed. +// Params: +// - entry: Log entry to send to all witnesses. +// - resultCh: channel to put completion status into. +func (s *Session) sendToAllWitnesses(entry *Log, resultCh *chan bool) { + s.termLock.RLock() + term := s.term + s.termLock.RUnlock() + + req := &RecordRequest{ + RPCHeader: RPCHeader{ + ProtocolVersion: ProtocolVersionMax, + }, + Entry: entry, + Term: term, } - return nil, errors.New("No active raft servers.") + + // Send to all witnesses. + for i := range s.conns { + go func(req *RecordRequest, resultCh *chan bool) { + *resultCh <- s.sendToWitness(i, req) + }(req, resultCh) + } } -func buildNetConn(target ServerAddress) (*netConn, error) { - // Dial a new connection - conn, err := net.Dial("tcp", string(target)) +// Send request to a witness specified by id. Synchronous. +// Params: +// - id: ID of witness sending request to +// - req: RecordRequest to send to witness +// Returns: success or failure of RPC. +func (s *Session) sendToWitness(id int, req *RecordRequest) bool { + var err error + s.conns[id].lock.Lock() + if s.conns[id].conn == nil { + s.conns[id].conn, err = s.trans.getConn(s.addrs[id]) + if err != nil { + s.conns[id].lock.Unlock() + return false + } + } + err = sendRPC(s.conns[id].conn, rpcRecordRequest, req) if err != nil { - fmt.Println("error dialing: ", err) - return nil, err + s.conns[id].lock.Unlock() + return false } + resp := &RecordResponse{} + _, err = decodeResponse(s.conns[id].conn, resp) + s.conns[id].lock.Unlock() - // Wrap the conn - netConn := &netConn{ - target: target, - conn: conn, - r: bufio.NewReader(conn), - w: bufio.NewWriter(conn), + // Update term if found new term. + s.termLock.Lock() + if resp.Term > s.term { + s.term = resp.Term + } + s.termLock.Unlock() + + if err != nil || !resp.Success { + return false } + return true +} - // Setup encoder/decoders - netConn.dec = codec.NewDecoder(netConn.r, &codec.MsgpackHandle{}) - netConn.enc = codec.NewEncoder(netConn.w, &codec.MsgpackHandle{}) +// Send a RPC to the active leader. Try to use the currently cached active leader, and +// if there is no cached leader or it is unreachable, try other Raft servers until a +// leader is found. If no active Raft server is found, return an error. +// Params: +// - request: JSON representation of request +// - response: client response that contains a leader address to help find an active leader +// - rpcType: type of RPC being sent. +func (s *Session) sendToActiveLeader(request interface{}, response GenericClientResponse, rpcType uint8) error { + sendFailures := 0 + var err error + + s.leaderLock.Lock() + defer s.leaderLock.Unlock() + + // Continue trying to send until have tried contacting all servers. + for sendFailures < len(s.addrs) { + // If no open connection to guessed leader, try to open one. + s.conns[s.leader].lock.Lock() + if s.conns[s.leader].conn == nil { + s.conns[s.leader].conn, err = s.trans.getConn(s.addrs[s.leader]) + if err != nil { + s.conns[s.leader].lock.Unlock() + sendFailures += 1 + s.leader = (s.leader + 1) % len(s.conns) + continue + } + } + err = sendRPC(s.conns[s.leader].conn, rpcType, request) + + // Failed to send RPC - try next server. + if err != nil { + s.conns[s.leader].lock.Unlock() + sendFailures += 1 + s.leader = (s.leader + 1) % len(s.conns) + continue + } + + // Try to decode response. + _, err = decodeResponse(s.conns[s.leader].conn, &response) + s.conns[s.leader].lock.Unlock() + + // If failure, use leader hint or wait for election to complete. + if err != nil { + if response != nil && response.GetLeaderAddress() != "" { + s.leader = (s.leader + 1) % len(s.conns) + for i, addr := range s.addrs { + if addr == response.GetLeaderAddress() { + s.leader = i + break + } + } + } else { + time.Sleep(100 * time.Millisecond) + } + } else { + return nil + } + } - // Done - return netConn, nil + return ErrNoActiveLeader } diff --git a/src/raft/snapshot.go b/src/raft/snapshot.go old mode 100644 new mode 100755 index 5287ebc..3bb0c4e --- a/src/raft/snapshot.go +++ b/src/raft/snapshot.go @@ -22,6 +22,12 @@ type SnapshotMeta struct { Index uint64 Term uint64 + // Next Client ID to use. Used with RIFL. + NextClientId uint64 + + // Responses to client RPCs. Used with RIFL. + ClientResponseCache map[uint64]map[uint64]clientResponseEntry + // Peers is deprecated and used to support version 0 snapshots, but will // be populated in version 1 snapshots as well to help with upgrades. Peers []byte @@ -44,7 +50,7 @@ type SnapshotStore interface { // the given committed configuration. The version parameter controls // which snapshot version to create. Create(version SnapshotVersion, index, term uint64, configuration Configuration, - configurationIndex uint64, trans Transport) (SnapshotSink, error) + configurationIndex uint64, nextClientId uint64, clientRequestCache map[uint64]map[uint64]clientResponseEntry, trans Transport) (SnapshotSink, error) // List is used to list the available snapshots in the store. // It should return then in descending order, with the highest index first. @@ -175,7 +181,7 @@ func (r *Raft) takeSnapshot() (string, error) { r.logger.Printf("[INFO] raft: Starting snapshot up to %d", snapReq.index) start := time.Now() version := getSnapshotVersion(r.protocolVersion) - sink, err := r.snapshots.Create(version, snapReq.index, snapReq.term, committed, committedIndex, r.trans) + sink, err := r.snapshots.Create(version, snapReq.index, snapReq.term, committed, committedIndex, r.nextClientId, r.clientResponseCache, r.trans) if err != nil { return "", fmt.Errorf("failed to create snapshot: %v", err) } diff --git a/src/raft/stable.go b/src/raft/stable.go old mode 100644 new mode 100755 diff --git a/src/raft/state.go b/src/raft/state.go old mode 100644 new mode 100755 diff --git a/src/raft/tcp_transport.go b/src/raft/tcp_transport.go old mode 100644 new mode 100755 diff --git a/src/raft/tcp_transport_test.go b/src/raft/tcp_transport_test.go old mode 100644 new mode 100755 diff --git a/src/raft/transport.go b/src/raft/transport.go old mode 100644 new mode 100755 index 85459b2..556d7d5 --- a/src/raft/transport.go +++ b/src/raft/transport.go @@ -43,6 +43,13 @@ type Transport interface { // RequestVote sends the appropriate RPC to the target node. RequestVote(id ServerID, target ServerAddress, args *RequestVoteRequest, resp *RequestVoteResponse) error + // RecoverData sends the appropriate RPC to the target node. + RecoverData(id ServerID, target ServerAddress, args *RecoveryDataRequest, resp *RecoveryDataResponse) error + + // RequestVote sends the appropriate RPC to the target node. + UnfreezeWitness(id ServerID, target ServerAddress, args *UnfreezeRequest, resp *UnfreezeResponse) error + + // InstallSnapshot is used to push a snapshot down to a follower. The data is read from // the ReadCloser and streamed to the client. InstallSnapshot(id ServerID, target ServerAddress, args *InstallSnapshotRequest, resp *InstallSnapshotResponse, data io.Reader) error diff --git a/src/raft/transport_test.go b/src/raft/transport_test.go old mode 100644 new mode 100755 diff --git a/src/raft/util.go b/src/raft/util.go old mode 100644 new mode 100755 index 69dcfba..90428d7 --- a/src/raft/util.go +++ b/src/raft/util.go @@ -33,7 +33,7 @@ func randomTimeout(minVal time.Duration) <-chan time.Time { return nil } extra := (time.Duration(rand.Int63()) % minVal) - return time.After(minVal + extra) + return time.After(minVal + extra) } // min returns the minimum. diff --git a/src/raft/util_test.go b/src/raft/util_test.go old mode 100644 new mode 100755 diff --git a/src/test/bench/TESTING.md b/src/test/bench/TESTING.md new file mode 100644 index 0000000..95f111a --- /dev/null +++ b/src/test/bench/TESTING.md @@ -0,0 +1,8 @@ +Run scripts from rcmaster. + +Run generateFigures.sh to create CDFs of latencies with 100%, 95%, and 90% +commutative requests and throughput vs latency for 100%, 95%, and 90% comutative +requests. Graphs created in current working directory. + +Uses rc20-25 (to change this, change IP addresses in generateCdf.py and +generateThroughputVsLatency.py). diff --git a/src/test/bench/ThroughputVsLatency-100.png b/src/test/bench/ThroughputVsLatency-100.png new file mode 100644 index 0000000..4ef19ad Binary files /dev/null and b/src/test/bench/ThroughputVsLatency-100.png differ diff --git a/src/test/bench/ThroughputVsLatency-90.png b/src/test/bench/ThroughputVsLatency-90.png new file mode 100644 index 0000000..c3877eb Binary files /dev/null and b/src/test/bench/ThroughputVsLatency-90.png differ diff --git a/src/test/bench/ThroughputVsLatency-95.png b/src/test/bench/ThroughputVsLatency-95.png new file mode 100644 index 0000000..3ef7e82 Binary files /dev/null and b/src/test/bench/ThroughputVsLatency-95.png differ diff --git a/src/test/bench/cdf-100.png b/src/test/bench/cdf-100.png new file mode 100644 index 0000000..8835a83 Binary files /dev/null and b/src/test/bench/cdf-100.png differ diff --git a/src/test/bench/cdf-90.png b/src/test/bench/cdf-90.png new file mode 100644 index 0000000..bbb2e4c Binary files /dev/null and b/src/test/bench/cdf-90.png differ diff --git a/src/test/bench/cdf-95.png b/src/test/bench/cdf-95.png new file mode 100644 index 0000000..97bb991 Binary files /dev/null and b/src/test/bench/cdf-95.png differ diff --git a/src/test/bench/client.go b/src/test/bench/client.go new file mode 100755 index 0000000..d46dfde --- /dev/null +++ b/src/test/bench/client.go @@ -0,0 +1,101 @@ +package main + +import ( + "test/keyValStore" + "raft" + "fmt" + "os" + "flag" + "test/utils" + "time" +) + +// Arguments: +// - config: path name of config +// - addr: IP address +// - comm: x/100 requests are commutative +// - n: number of total requests +// - t: number of threads +// - parallel: are requests parallelized? +// Prints all latencies to stdout in microseconds, 1 per line +func main() { + addrPtr := flag.String("addr", "127.0.0.1", "IP address and port number of client") + configPathPtr := flag.String("config", "config", "Path to config file") + commPercentPtr := flag.Int("comm", 100, "x/100 requests are commutative") + nPtr := flag.Int("n", 100, "total number of requests") + parallelPtr := flag.Bool("parallel", false, "true if requests are parallelized, false if serial") + + flag.Parse() + + config, err := utils.ReadConfig(*configPathPtr) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading config at %s: %s", *configPathPtr, err) + return + } + servers := config.Servers + n := *nPtr + results := make(chan int64, n) + + start := time.Now() + go runClient(n, *commPercentPtr, *addrPtr, servers, *parallelPtr, &results) + + resultList := make([]int64, n) + for i := 0; i < n; i++ { + elem := <-results + resultList[i] = elem + } + elapsed := time.Since(start) + + // Print results + for _,result := range resultList { + fmt.Println(result) + } + seconds := float64(elapsed.Seconds()) + (float64(elapsed.Nanoseconds()) / 1000000000.0) + throughput := float64(n) / seconds + fmt.Println("THROUGHPUT: ", throughput) + + +} + +func runClient(n int, commPercent int, addr string, servers []raft.ServerAddress, parallel bool, results *chan int64) { + trans, err := raft.NewTCPTransport(addr, nil, 2, time.Second, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating TCP transport: %s", err) + return + } + + client, err := keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session: %s", err) + return + } + + for i := 0; i < n; i++ { + if parallel { + go makeRequest(client, i, commPercent, addr, results) + } else { + makeRequest(client, i, commPercent, addr, results) + } + } +} + +func makeRequest(client *keyValStore.Client, i int, commPercent int, addr string, results *chan int64) { + nonComm := 100 - commPercent + // interval is -1 if no non-commutative operations + interval := -1 + if nonComm != 0 { + interval = 100 / nonComm + } + start := time.Now() + if interval != -1 && i % interval == 0 { + // Non-commutative operation. + client.Set("foo", "bar") + } else { + // Commutative operation. + uniqueKey := addr + string(i) + client.Set(uniqueKey, "bar") + } + elapsed := time.Since(start) + usElapsed := elapsed.Nanoseconds() / 1000 + *results <- usElapsed +} diff --git a/src/test/bench/config b/src/test/bench/config new file mode 100755 index 0000000..45332d3 --- /dev/null +++ b/src/test/bench/config @@ -0,0 +1,3 @@ +192.168.1.120:8000 +192.168.1.121:8000 +192.168.1.122:8000 diff --git a/src/test/bench/generateCdf.py b/src/test/bench/generateCdf.py new file mode 100644 index 0000000..82e9f26 --- /dev/null +++ b/src/test/bench/generateCdf.py @@ -0,0 +1,24 @@ +from runExperiment import runExper +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import numpy as np +import sys + +# First command line argument is % commutative requests +commPercent = sys.argv[1] + +clients = ["192.168.1.123:5000"] + +latencies, throughput = runExper("/home/evd/RaftFlyer/src/test/bench/config", + clients, 100, False, int(commPercent)) + + +latencies = np.array(latencies, dtype=np.int32) +sorted_data = np.sort(latencies) +yvals=np.arange(len(sorted_data))/float(len(sorted_data)-1) +plt.plot(sorted_data, yvals) + +plt.xlabel("Latency (microseconds)") +plt.title("CDF of Latency with %s%% commutative requests" % commPercent) +plt.savefig("cdf-%s.png" % commPercent) diff --git a/src/test/bench/generateFigures.sh b/src/test/bench/generateFigures.sh new file mode 100755 index 0000000..b27b6ce --- /dev/null +++ b/src/test/bench/generateFigures.sh @@ -0,0 +1,6 @@ +python generateCdf.py 100 +python generateCdf.py 95 +python generateCdf.py 90 +python generateThroughputVsLatency.py 100 +python generateThroughputVsLatency.py 95 +python generateThroughputVsLatency.py 90 diff --git a/src/test/bench/generateThroughputVsLatency.py b/src/test/bench/generateThroughputVsLatency.py new file mode 100644 index 0000000..daf8871 --- /dev/null +++ b/src/test/bench/generateThroughputVsLatency.py @@ -0,0 +1,30 @@ +from runExperiment import runExper +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +import numpy as np +import sys + +# Firist command line argument is % commutative requests +commPercent = sys.argv[1] +maxThreads = 10 +# Add more clients with ports, iterate over by adding clients each time +clients = ["192.168.1.123", "192.168.1.124", "192.168.1.125"] +throughputList = [] +latencyList = [] +for i in range(1, maxThreads): + tempClients = [] + for client in clients: + for j in range(0, i): + tempClients.append(client + ":" + str(5000 + j)) + avgThroughput = 0 + avgLatency = 0 + latencies, throughput = runExper("/home/evd/RaftFlyer/src/test/bench/config", tempClients, 100, True, int(commPercent)) + latencies = np.array(latencies, dtype=np.int32) + throughputList.append(throughput) + latencyList.append(np.mean(latencies)) +plt.plot(throughputList, latencyList, linestyle='-', marker='o') +plt.xlabel("Throughput (ops/sec)") +plt.ylabel("Latency (microsec)") +plt.title("Throughput vs Latency with %s%% commutative requests" % commPercent) +plt.savefig("ThroughputVsLatency-%s.png" % commPercent) diff --git a/src/test/bench/runExperiment.py b/src/test/bench/runExperiment.py new file mode 100644 index 0000000..6c02e2f --- /dev/null +++ b/src/test/bench/runExperiment.py @@ -0,0 +1,56 @@ +import time +import sys, string +import subprocess +import os + +# IP addresses of the form 192.168.1.120:8000 where machine number is 20 +def ipAddrToMachineNum(ipAddr): + ip = ipAddr.split(":")[0] + fourthElem = ip.split(".")[3] + return int(fourthElem) - 100 + +def getPortNum(ipAddr): + port = ipAddr.split(":")[1].strip() + print "port: %s" % port + return port + +def runExper(config, clients, numReqs, parallel, percentCommutative): + # Read config + f = open(config, 'r') + servers = f.readlines() + devNull = open(os.devnull, 'w') + # Start Raft servers + serverProcesses = [] + for i in range(len(servers)): + serverCmd = "ssh rc%s \"fuser -k %s/tcp; ./RaftFlyer/src/test/bench/server -config=%s -i=%s\"" % (ipAddrToMachineNum(servers[i]), getPortNum(servers[i]), config, i) + process = subprocess.Popen(serverCmd, shell=True) + serverProcesses.append(process) + time.sleep(0.5) # Allow time to reach stability + + # Start Raft clients + clientProcesses = [] + for client in clients: + clientCmd = "ssh rc%s \"fuser -k %s/tcp; ./RaftFlyer/src/test/bench/client -config=%s -addr=%s -comm=%d -n=%d -parallel=%s\"" % (ipAddrToMachineNum(client), getPortNum(client), config, client, percentCommutative, numReqs, str(parallel)) + process = subprocess.Popen(clientCmd, shell=True, stdout=subprocess.PIPE) + clientProcesses.append(process) + + # Collect client measurements + latencies = [] + totThroughput = 0.0 + for client in clientProcesses: + output = client.stdout.read() + outputLines = output.splitlines() + latencies = latencies + outputLines[0:len(outputLines)-2] # All lines except last line + throughputArr = outputLines[len(outputLines)-1].split(":") + if len(throughputArr) < 2: + print "ERROR: cannot parse throughput %s" % outputLines[len(outputLines) - 1] + return + throughput = float(throughputArr[1]) + totThroughput += throughput + avgThroughput = totThroughput / float(len(clients)) + + # Kill all raft servers + for process in serverProcesses: + process.terminate() + + return latencies, avgThroughput diff --git a/src/test/bench/server.go b/src/test/bench/server.go new file mode 100755 index 0000000..45f1edd --- /dev/null +++ b/src/test/bench/server.go @@ -0,0 +1,37 @@ +package main + +import( + "test/keyValStore" + "flag" + "os" + "os/signal" + "test/utils" + "fmt" +) + +// Arguments: +// - i: replica number +// - config: path name of config +func main() { + configPathPtr := flag.String("config", "config", "Path to config file") + iPtr := flag.Int("i", 0, "Replica number, indexed by line in config file") + + flag.Parse() + + configPath := *configPathPtr + i := *iPtr + + config, err := utils.ReadConfig(configPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading config at %s: %s", config, err) + return + } + + servers := config.Servers + keyValStore.StartNode(keyValStore.CreateWorkers(1)[0], servers, i) + + // Wait for CTRL-C + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c +} diff --git a/src/test/correctness/gc_client.go b/src/test/correctness/gc_client.go new file mode 100755 index 0000000..6d53332 --- /dev/null +++ b/src/test/correctness/gc_client.go @@ -0,0 +1,49 @@ +package main + +import ( + "raft" + "fmt" + "test/keyValStore" + "time" + "test/utils" + "os" +) + +// Tests correct garbage collection of client responses. Assumes that server is configured +// to garbage collect responses that have been stored for less than 1 second. + +var c *keyValStore.Client + +func main() { + trans, transErr := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if transErr != nil { + fmt.Fprintf(os.Stderr, "Error with creating TCP transport, could not run tests: ", transErr) + return + } + var err error + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session, could not run tests: ", err) + return + } + + testsFailed := utils.RunTestSuite(testGc) + fmt.Println(testsFailed) +} + +func testGc() (error) { + val1, err1 := c.IncWithSeqno(1234) + if err1 != nil { + return fmt.Errorf("Error sending RPC first time: %v", err1) + } + time.Sleep(time.Second) + val2, err2 := c.IncWithSeqno(1234) + if err2 != nil { + return fmt.Errorf("Error retransmitting RPC: %v", err2) + } + if val1 == val2 { + return fmt.Errorf("Cached responses not correctly garbage collected.") + } + return nil +} diff --git a/src/test/correctness/recovery_client.go b/src/test/correctness/recovery_client.go new file mode 100755 index 0000000..92a03e0 --- /dev/null +++ b/src/test/correctness/recovery_client.go @@ -0,0 +1,57 @@ +package main + +import ( + "test/keyValStore" + "raft" + "fmt" + "time" + "test/utils" +) + +// Sanity check to verify that client can send request and receive response. + +var c *keyValStore.Client + +func main() { + trans, err := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if err != nil { + fmt.Println("Error with creating TCP transport: ", err) + return + } + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Println("Can't create client session", err) + return + } + + testsFailed := utils.RunTestSuite(testLeaderRecovery) + fmt.Println(testsFailed) +} + +func testLeaderRecovery() (error) { + timeout := time.After(2 * time.Second) + tick := time.Tick(1*time.Millisecond) + count := uint64(0) + for { + select { + case <-timeout: + //expected := uint64((2 * 1000) + 1) // 2 seconds * # microseconds in second + 1 for last increment + expected := count + 1 + received, err := c.Inc() + if err != nil { + return fmt.Errorf("Error sending increment RPC to test for leader recovery: %s", err) + } + if received != expected { + return fmt.Errorf("Expected %d and received %d, error in leader recovery.", expected, received) + } + return nil + case <-tick: + _,err := c.Inc() + if err == nil { + count += 1 + } + } + } + return nil +} diff --git a/src/test/correctness/recovery_cluster.go b/src/test/correctness/recovery_cluster.go new file mode 100755 index 0000000..e40d9b8 --- /dev/null +++ b/src/test/correctness/recovery_cluster.go @@ -0,0 +1,50 @@ +package main + +import( + "raft" + "test/keyValStore" + "os" + "os/signal" + "time" + "strconv" + "fmt" +) + +// Run cluster for 2 seconds and then kill leader. +// Optional first argument is interval at which to garbage collect entries from client response cache +// in milliseconds. Optional second argument is length of time that entries should be left in the +// client response cache before being garbage collected (in milliseconds). +func main() { + args := os.Args[1:] + var gcInterval, gcRemoveTime time.Duration + gcInterval = 0 + gcRemoveTime = 0 + if len(args) > 0 { + interval, err := strconv.Atoi(args[0]) + if err != nil { + fmt.Println("GC Interval must be an integer.") + return + } + gcInterval = time.Duration(interval) * time.Millisecond + } + if len(args) > 1 { + removeTime, err := strconv.Atoi(args[1]) + if err != nil { + fmt.Println("GC remove time must be an integer.") + return + } + gcRemoveTime = time.Duration(removeTime) * time.Millisecond + } + addrs := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002","127.0.0.1:8003","127.0.0.1:8004"} + cluster := keyValStore.MakeNewCluster(5, keyValStore.CreateWorkers(5), addrs, gcInterval, gcRemoveTime) + time.Sleep(5*time.Second) + for _,node := range cluster.Rafts { + if node.IsLeader() { + node.Shutdown() + break + } + } + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c +} diff --git a/src/test/correctness/restart_cluster.go b/src/test/correctness/restart_cluster.go new file mode 100755 index 0000000..ad3b6b0 --- /dev/null +++ b/src/test/correctness/restart_cluster.go @@ -0,0 +1,47 @@ +package main + +import( + "raft" + "test/keyValStore" + "os" + "os/signal" + "time" + "strconv" + "fmt" +) + +// Run cluster for 10 seconds, and then restart. Used to test for correct snapshotting. +// Optional first argument is interval at which to garbage collect entries from client response cache +// in milliseconds. Optional second argument is length of time that entries should be left in the +// client response cache before being garbage collected (in milliseconds). +func main() { + args := os.Args[1:] + var gcInterval, gcRemoveTime time.Duration + gcInterval = 0 + gcRemoveTime = 0 + if len(args) > 0 { + interval, err := strconv.Atoi(args[0]) + if err != nil { + fmt.Println("GC Interval must be an integer.") + return + } + gcInterval = time.Duration(interval) * time.Millisecond + } + if len(args) > 1 { + removeTime, err := strconv.Atoi(args[1]) + if err != nil { + fmt.Println("GC remove time must be an integer.") + return + } + gcRemoveTime = time.Duration(removeTime) * time.Millisecond + } + addrs := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + cluster := keyValStore.MakeNewCluster(3, keyValStore.CreateWorkers(3), addrs, gcInterval, gcRemoveTime) + time.Sleep(10*time.Second) + keyValStore.ShutdownCluster(cluster.Rafts) + fmt.Println("Restarting cluster") + keyValStore.RestartCluster(cluster) + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c +} diff --git a/src/test/correctness/restart_curp_client.go b/src/test/correctness/restart_curp_client.go new file mode 100755 index 0000000..51f15af --- /dev/null +++ b/src/test/correctness/restart_curp_client.go @@ -0,0 +1,49 @@ +package main + +import ( + "raft" + "fmt" + "test/keyValStore" + "time" + "test/utils" + "os" +) + +// Tests that cached client responses are correctly stored in a snapshot and restored +// when the cluster is restarted. + +var c *keyValStore.Client + +func main() { + trans, transErr := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if transErr != nil { + fmt.Fprintf(os.Stderr, "Error with creating TCP transport, could not run tests: ", transErr) + return + } + var err error + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session, could not run tests: ", err) + return + } + + testsFailed := utils.RunTestSuite(testRestartWithClientCaches) + fmt.Println(testsFailed) +} + +func testRestartWithClientCaches() (error) { + err1 := c.Set("foo", "bar") + if err1 != nil { + return fmt.Errorf("Error sending RPC first time: %v", err1) + } + time.Sleep(2*time.Second) + val, err2 := c.Get("foo") + if err2 != nil { + return fmt.Errorf("Error retransmitting RPC: %v", err2) + } + if val != "bar" { + return fmt.Errorf("Didn't correctly restore value of \"bar\" after restart, isntead %s", val) + } + return nil +} diff --git a/src/test/correctness/restart_rifl_client.go b/src/test/correctness/restart_rifl_client.go new file mode 100755 index 0000000..7ea3b7d --- /dev/null +++ b/src/test/correctness/restart_rifl_client.go @@ -0,0 +1,49 @@ +package main + +import ( + "raft" + "fmt" + "test/keyValStore" + "time" + "test/utils" + "os" +) + +// Tests that cached client responses are correctly stored in a snapshot and restored +// when the cluster is restarted. + +var c *keyValStore.Client + +func main() { + trans, transErr := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if transErr != nil { + fmt.Fprintf(os.Stderr, "Error with creating TCP transport, could not run tests: ", transErr) + return + } + var err error + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session, could not run tests: ", err) + return + } + + testsFailed := utils.RunTestSuite(testRestartWithClientCaches) + fmt.Println(testsFailed) +} + +func testRestartWithClientCaches() (error) { + val1, err1 := c.IncWithSeqno(1234) + if err1 != nil { + return fmt.Errorf("Error sending RPC first time: %v", err1) + } + time.Sleep(2*time.Second) + val2, err2 := c.IncWithSeqno(1234) + if err2 != nil { + return fmt.Errorf("Error retransmitting RPC: %v", err2) + } + if val1 != val2 { + return fmt.Errorf("Cached responses not correctly restored from snapshot after restart: %v, %v.", val1, val2) + } + return nil +} diff --git a/src/test/correctness/rifl_client.go b/src/test/correctness/rifl_client.go new file mode 100755 index 0000000..274a99b --- /dev/null +++ b/src/test/correctness/rifl_client.go @@ -0,0 +1,86 @@ +package main + +import ( + "raft" + "fmt" + "test/keyValStore" + "time" + "test/utils" + "os" +) + +// Checks that RPCs issued from the same client with the same sequence number are only +// executed once, while RPCs from different clients with the same sequence numbers and +// RPCs from the same client with different sequence numbers are reexecuted. + +var c1 *keyValStore.Client +var c2 *keyValStore.Client + +func main() { + trans, transErr := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if transErr != nil { + fmt.Fprintf(os.Stderr, "Error with creating TCP transport, could not run tests: ", transErr) + return + } + var err error + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c1, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session, could not run tests: ", err) + return + } + c2, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating second client session, could not run tests: ", err) + return + } + + testsFailed := utils.RunTestSuite(testSameClientSameSeqno, testSameClientDiffSeqno, testDiffClientSameSeqno) + fmt.Println(testsFailed) +} + +func testSameClientSameSeqno() (error) { + val1, val2, err := sendIncRpcs(c1, c1, 1234, 1234) + if err != nil { + return fmt.Errorf("Error sending same request from same client with same sequence number: %v", err) + } + if val1 != val2 { + return fmt.Errorf("Requests from same client with same sequence number produced different results: %v, %v", val1, val2) + } + return nil +} + +func testSameClientDiffSeqno() (error) { + // Test same client, different sequence number + val1, val2, err := sendIncRpcs(c1, c1, 12, 34) + if err != nil { + return fmt.Errorf("Error sending same request from same client with different sequence number: %v", err) + } + if val1 == val2 { + return fmt.Errorf("Requests from same client with different sequence numbers produced same results: %v, %v", val1, val2) + } + return nil +} + +func testDiffClientSameSeqno() (error) { + val1, val2, err := sendIncRpcs(c1, c2, 123, 123) + if err != nil { + return fmt.Errorf("Error sending same request from different clients with same sequence number: %v", err) + } + if val1 == val2 { + return fmt.Errorf("Requests from different clients with same sequence numbers produced same results: %v, %v", val1, val2) + } + return nil +} + +func sendIncRpcs(c1 *keyValStore.Client, c2 *keyValStore.Client, seqno1 uint64, seqno2 uint64) (uint64, uint64, error) { + val1, err1 := c1.IncWithSeqno(seqno1) + if err1 != nil { + return 0, 0, fmt.Errorf("Error sending RPC first time: %v", err1) + } + val2, err2 := c2.IncWithSeqno(seqno2) + if err2 != nil { + return 0, 0, fmt.Errorf("Error retransmitting RPC: %v", err2) + } + return val1, val2, nil +} diff --git a/src/test/correctness/runTests.sh b/src/test/correctness/runTests.sh new file mode 100755 index 0000000..37b39cc --- /dev/null +++ b/src/test/correctness/runTests.sh @@ -0,0 +1,49 @@ +# Run all CURP tests. + +if ! go build run_cluster.go +then + echo "Cluster build failing. Cannot run tests." + return +fi +echo "Starting tests..." +echo "" +./run_cluster > /dev/null &> /dev/null & +sleep .1 +FAILED=$(go run sanity_check.go) +FAILED=$(expr $(go run rifl_client.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "run_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +./run_cluster 10 100 > /dev/null &> /dev/null & +sleep .1 +FAILED=$(expr $(go run gc_client.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "run_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +go run restart_cluster.go > /dev/null &> /dev/null & +sleep .1 +FAILED=$(expr $(go run restart_rifl_client.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "restart_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +./run_cluster > /dev/null &> /dev/null & +sleep .1 +FAILED=$(expr $(go run simul_commutative.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "run_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +go run restart_cluster.go > /dev/null &> /dev/null & +sleep .1 +FAILED=$(expr $(go run restart_curp_client.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "restart_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +go run recovery_cluster.go > /dev/null &> /dev/null & +sleep .1 +FAILED=$(expr $(go run recovery_client.go) + $FAILED) +CLUSTER_JOB=$(ps aux | grep "recovery_cluster" | grep -v grep | awk '{print $2}') &> /dev/null +kill $CLUSTER_JOB &> /dev/null +wait $CLUSTER_JOB &> /dev/null +echo "" +echo "***** TESTS FAILED: "$FAILED" *****" + diff --git a/src/test/correctness/run_cluster.go b/src/test/correctness/run_cluster.go new file mode 100755 index 0000000..825a830 --- /dev/null +++ b/src/test/correctness/run_cluster.go @@ -0,0 +1,43 @@ +package main + +import( + "raft" + "test/keyValStore" + "os" + "os/signal" + "time" + "strconv" + "fmt" +) + +// Start a Raft cluster locally. +// Optional first argument is interval at which to garbage collect entries from client response cache +// in milliseconds. Optional second argument is length of time that entries should be left in the +// client response cache before being garbage collected (in milliseconds). +func main() { + args := os.Args[1:] + var gcInterval, gcRemoveTime time.Duration + gcInterval = 0 + gcRemoveTime = 0 + if len(args) > 0 { + interval, err := strconv.Atoi(args[0]) + if err != nil { + fmt.Println("GC Interval must be an integer.") + return + } + gcInterval = time.Duration(interval) * time.Millisecond + } + if len(args) > 1 { + removeTime, err := strconv.Atoi(args[1]) + if err != nil { + fmt.Println("GC remove time must be an integer.") + return + } + gcRemoveTime = time.Duration(removeTime) * time.Millisecond + } + addrs := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + keyValStore.MakeNewCluster(3, keyValStore.CreateWorkers(3), addrs, gcInterval, gcRemoveTime) + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + <-c +} diff --git a/src/test/correctness/sanity_check.go b/src/test/correctness/sanity_check.go new file mode 100755 index 0000000..bbbe3ae --- /dev/null +++ b/src/test/correctness/sanity_check.go @@ -0,0 +1,43 @@ +package main + +import ( + "test/keyValStore" + "raft" + "fmt" + "strings" + "time" + "test/utils" +) + +// Sanity check to verify that client can send request and receive response. + +var c *keyValStore.Client + +func main() { + trans, err := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if err != nil { + fmt.Println("Error with creating TCP transport: ", err) + return + } + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Println("Can't create client session", err) + return + } + + testsFailed := utils.RunTestSuite(testSanityCheck) + fmt.Println(testsFailed) +} + +func testSanityCheck() (error) { + c.Set("foo","bar") + str, getErr := c.Get("foo") + if getErr != nil { + return fmt.Errorf("Error sending Get RPC: %v", getErr) + } + if strings.Compare(str,"bar") != 0 { + return fmt.Errorf("Should have received 'bar' but instead received '%v'", str) + } + return nil +} diff --git a/src/test/correctness/simul_commutative.go b/src/test/correctness/simul_commutative.go new file mode 100755 index 0000000..8034c41 --- /dev/null +++ b/src/test/correctness/simul_commutative.go @@ -0,0 +1,98 @@ +package main + +import ( + "raft" + "fmt" + "test/keyValStore" + "time" + "test/utils" + "os" +) + +// Test that simultaneous commutative and non-commutative operations execute without +// error. + +var c1 *keyValStore.Client +var c2 *keyValStore.Client + +func main() { + trans, transErr := raft.NewTCPTransport("127.0.0.1:5000", nil, 2, time.Second, nil) + if transErr != nil { + fmt.Fprintf(os.Stderr, "Error with creating TCP transport, could not run tests: ", transErr) + return + } + var err error + servers := []raft.ServerAddress{"127.0.0.1:8000","127.0.0.1:8001","127.0.0.1:8002"} + c1, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating client session, could not run tests: ", err) + return + } + c2, err = keyValStore.CreateClient(trans, servers) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating second client session, could not run tests: ", err) + return + } + + testsFailed := utils.RunTestSuite(testSimultaneousCommutative, testSimultaneousNotCommutative) + fmt.Println(testsFailed) +} + +func testSimultaneousCommutative() (error) { + resultCh := make(chan error, 2) + go func() { + resultCh <- c1.Set("foo","1") + }() + go func() { + resultCh <- c2.Set("bar","1") + }() + err1 := <-resultCh + err2 := <-resultCh + if err1 != nil { + return fmt.Errorf("Error sending simultaneous commutative requests: %v", err1) + } + if err2 != nil { + return fmt.Errorf("Error sending simultaneous commutative requests: %v", err2) + } + result1, err3 := c1.Get("foo") + result2, err4 := c2.Get("bar") + if err3 != nil { + return fmt.Errorf("Error checking result of simultaneous commutative request: %s", err3) + } + if err4 != nil { + return fmt.Errorf("Error checking result of simultaneous commutative request: %s", err4) + } + if result1 != "1" { + return fmt.Errorf("Error with simultaneous commutative request for foo: expected 1 but received %d", result1) + } + if result2 != "1" { + return fmt.Errorf("Error with simultaneous commutative request for bar: expected 1 but received %d", result2) + } + return nil +} + +func testSimultaneousNotCommutative() (error) { + resultCh := make(chan error, 2) + go func() { + resultCh <- c1.Set("foo","1") + }() + go func() { + resultCh <- c2.Set("foo","1") + }() + err1 := <-resultCh + err2 := <-resultCh + if err1 != nil { + return fmt.Errorf("Error sending simultaneous non-commutative requests: %v", err1) + } + if err2 != nil { + return fmt.Errorf("Error sending simultaneous non-commutative requests: %v", err2) + } + result, err3 := c1.Get("foo") + if err3 != nil { + return fmt.Errorf("Error checking result of simultaneous non-commutative request: %s", err3) + } + if result != "1" { + return fmt.Errorf("Error in non-commutative operation: expected 1 and got %s", result) + } + return nil +} diff --git a/src/test/keyValStore/client.go b/src/test/keyValStore/client.go new file mode 100755 index 0000000..787d09a --- /dev/null +++ b/src/test/keyValStore/client.go @@ -0,0 +1,126 @@ +package keyValStore + +import ( + "raft" + "encoding/json" +) + +// Client library for sending RPCs to keyValStore Raft servers. Allows you +// to set a key, get a value, and increment an integer. + +// Handle given to client to make reqeuests. +type Client struct { + // Client transport layer. + trans *raft.NetworkTransport + // Addresses in cluster. + servers []raft.ServerAddress + // Open session with cluster leader. + session *raft.Session +} + +// Create new client for sending RPCs. +// Params: +// - trans: transport layer for client. +// - servers: list of Raft server addresses. +// Returns: client handle, error if any. +func CreateClient(trans *raft.NetworkTransport, servers []raft.ServerAddress) (*Client, error) { + newSession, err := raft.CreateClientSession(trans, servers) + if err != nil { + return nil, err + } + return &Client { + trans: trans, + servers: servers, + session: newSession, + }, nil +} + +// Cleanup associated with client. +func (c *Client) DestroyClient() { + c.session.CloseClientSession() +} + +// Increment integer and return value. Not idempotent! +// Returns: updated value of integer. +func (c *Client) Inc() (uint64, error) { + args := make(map[string]string) + args[FunctionArg] = IncCommand + data, marshal_err := json.Marshal(args) + if marshal_err != nil { + return 0, marshal_err + } + resp := raft.ClientResponse{} + keys := []raft.Key{raft.Key([]byte{1})} + c.session.SendFastRequest(data, keys, &resp) + var response IncResponse + recvErr := json.Unmarshal(resp.ResponseData, &response) + if recvErr != nil { + return 0, recvErr + } + return response.Value, nil +} + +// Same as Inc() but specifies sequence number. Use for testing +// purposes only (in production only use Inc). +// Params: +// - seqno: Sequence number to use when sending RPC. +// Returns: updated value of integer. +func (c *Client) IncWithSeqno(seqno uint64) (uint64, error) { + args := make(map[string]string) + args[FunctionArg] = IncCommand + data, marshal_err := json.Marshal(args) + if marshal_err != nil { + return 0, marshal_err + } + resp := raft.ClientResponse{} + keys := []raft.Key{raft.Key([]byte{1})} + c.session.SendFastRequestWithSeqNo(data, keys, &resp, seqno) + var response IncResponse + recvErr := json.Unmarshal(resp.ResponseData, &response) + if recvErr != nil { + return 0, recvErr + } + return response.Value, nil +} + + +// Send RPC to set the value of a key. No expected response. +// Params: +// - key: Key to access. +// - value: Value to set with key. +func (c *Client) Set(key string, value string) error { + args := make(map[string]string) + args[FunctionArg] = SetCommand + args[KeyArg] = key + args[ValueArg] = value + data, marshal_err := json.Marshal(args) + if marshal_err != nil { + return marshal_err + } + keys := []raft.Key{raft.Key([]byte(key))} + c.session.SendFastRequest(data, keys, &raft.ClientResponse{}) + return nil +} + +// Send RPC to get the value of a key. +// Params: +// - key: Key to get value of. +// Returns: value of key, empty string if error not nil. +func (c *Client) Get(key string) (string, error) { + args := make(map[string]string) + args[FunctionArg] = GetCommand + args[KeyArg] = key + data, marshal_err := json.Marshal(args) + if marshal_err != nil { + return "", marshal_err + } + resp := raft.ClientResponse{} + keys := []raft.Key{raft.Key([]byte(key))} + c.session.SendFastRequest(data, keys, &resp) + var response GetResponse + recvErr := json.Unmarshal(resp.ResponseData, &response) + if recvErr != nil { + return "", recvErr + } + return response.Value, nil +} diff --git a/src/test/keyValStore/cluster.go b/src/test/keyValStore/cluster.go new file mode 100755 index 0000000..bf339f1 --- /dev/null +++ b/src/test/keyValStore/cluster.go @@ -0,0 +1,220 @@ +package keyValStore + +import( + "fmt" + "raft" + "io/ioutil" + "time" + "log" + "os" +) + +// Manage keyValStore cluster locally, including making a new cluster, +// restarting a cluster, and shutting down a cluster. + +// Start a new Raft cluster locally. +// Params: +// - n: number of servers in cluster. +// - fsms: fsms to run on servers. +// - addrs: addresses of servers in cluster. +// - gcInterval: interval at which to garbage collect client responses +// - gcRemoveTime: length of time to cache client responses before garbage collection. +// Returns: running cluster. +func MakeNewCluster(n int, fsms []raft.FSM, addrs []raft.ServerAddress, gcInterval time.Duration, gcRemoveTime time.Duration) *cluster { + return MakeCluster(n, fsms, addrs, gcInterval, gcRemoveTime, nil) +} + +// Given a cluster that has been stopped, restart it. +// Params: +// - c : cluster to restart. +func RestartCluster(c *cluster) { + for i := range c.fsms { + trans, err := raft.NewTCPTransport(string(c.trans[i].LocalAddr()), nil, 2, time.Second, nil) + if err != nil { + fmt.Println("[ERR] err creating transport: ", err) + } + c.trans[i] = trans + } + + for i := range c.fsms { + peerConf := c.conf + peerConf.LocalID = c.configuration.Servers[i].ID + peerConf.Logger = log.New(os.Stdout, string(peerConf.LocalID) + " : ", log.Lmicroseconds) + + err := raft.RecoverCluster(peerConf, c.fsms[i], c.stores[i], c.stores[i], c.snaps[i], c.trans[i], c.configuration) + if err != nil { + fmt.Println("[ERR] err: %v", err) + } + raft, err := raft.NewRaft(peerConf, c.fsms[i], c.stores[i], c.stores[i], c.snaps[i], c.trans[i]) + if err != nil { + fmt.Println("[ERR] NewRaft failed: %v", err) + } + + raft.AddVoter(peerConf.LocalID, c.trans[i].LocalAddr(), 0, 0) + } +} + +// Shutdown a set of running Raft servers. +// Params: +// - nodes: array of running Raft nodes. +func ShutdownCluster(nodes []*raft.Raft) { + for _,node := range nodes { + f := node.Shutdown() + if f.Error() != nil { + fmt.Println("Error shutting down cluster: ", f.Error()) + } + } +} + +// Starts up a new cluster. +// Params: +// - n: number of servers in cluster. +// - fsms: array of FSMs to run at Raft servers. +// - addrs: addresses of Raft servers. +// - gcInterval: interval at which to check for expired client responses. +// - gcRemoveTime: interval at which client responses expire. +func MakeCluster(n int, fsms []raft.FSM, addrs []raft.ServerAddress, gcInterval time.Duration, gcRemoveTime time.Duration, startingCluster *cluster) (*cluster) { + conf := raft.DefaultConfig() + if gcInterval != 0 { + conf.ClientResponseGcInterval = gcInterval + } + if gcRemoveTime != 0 { + conf.ClientResponseGcRemoveTime = gcRemoveTime + } + bootstrap := true + + c := &cluster{ + conf: conf, + // Propagation takes a maximum of 2 heartbeat timeouts (time to + // get a new heartbeat that would cause a commit) plus a bit. + propagateTimeout: conf.HeartbeatTimeout*2 + conf.CommitTimeout, + longstopTimeout: 5 * time.Second, + } + + // Setup the stores and transports + for i := 0; i < n; i++ { + dir, err := ioutil.TempDir("", "raft") + if err != nil { + fmt.Println("[ERR] err: %v ", err) + } + + store := raft.NewInmemStore() + c.dirs = append(c.dirs, dir) + c.stores = append(c.stores, store) + c.fsms = append(c.fsms, fsms[i]) + + + snap, err := raft.NewFileSnapshotStore(dir, 3, nil) + c.snaps = append(c.snaps, snap) + + trans, err := raft.NewTCPTransport(string(addrs[i]), nil, 2, time.Second, nil) + if err != nil { + fmt.Println("[ERR] err creating transport: ", err) + } + c.trans = append(c.trans, trans) + c.configuration.Servers = append(c.configuration.Servers, raft.Server{ + Suffrage: raft.Voter, + ID: raft.ServerID(fmt.Sprintf("server-%s", trans.LocalAddr())), + Address: addrs[i], + }) + } + + // Create all the rafts + c.startTime = time.Now() + for i := 0; i < n; i++ { + logs := c.stores[i] + store := c.stores[i] + snap := c.snaps[i] + trans := c.trans[i] + + peerConf := conf + peerConf.LocalID = c.configuration.Servers[i].ID + peerConf.Logger = log.New(os.Stdout, string(peerConf.LocalID) + " : ", log.Lmicroseconds) + + if bootstrap { + err := raft.BootstrapCluster(peerConf, logs, store, snap, trans, c.configuration) + if err != nil { + fmt.Println("[ERR] BootstrapCluster failed: %v", err) + } + } + + raft, err := raft.NewRaft(peerConf, c.fsms[i], logs, store, snap, trans) + if err != nil { + fmt.Println("[ERR] NewRaft failed: %v", err) + } + + raft.AddVoter(peerConf.LocalID, trans.LocalAddr(), 0, 0) + c.Rafts = append(c.Rafts, raft) + } + + return c +} + +// Start single node +// Used for perf metrics, so don't write logging messages. +func StartNode(fsm raft.FSM, addrs []raft.ServerAddress, i int) { + conf := raft.DefaultConfig() + bootstrap := true + + // Setup the stores and transports + dir, err := ioutil.TempDir("", "raft") + if err != nil { + fmt.Println("[ERR] err: %v ", err) + } + + store := raft.NewInmemStore() + + snap, err := raft.NewFileSnapshotStore(dir, 3, nil) + + trans, err := raft.NewTCPTransport(string(addrs[i]), nil, 2, time.Second, nil) + if err != nil { + fmt.Println("[ERR] err creating transport: ", err) + } + configuration := raft.Configuration{} + + for _,addr := range addrs { + configuration.Servers = append(configuration.Servers, raft.Server{ + Suffrage: raft.Voter, + ID: raft.ServerID(fmt.Sprintf("server-%s", addr)), + Address: addr, + }) + } + + // Create all the rafts + logs := store + + conf.LocalID = configuration.Servers[i].ID + //conf.Logger = log.SetOutput(ioutil.Discard) + conf.Logger = log.New(os.Stdout, string(conf.LocalID) + " : ", log.Lmicroseconds) + conf.Logger.SetOutput(ioutil.Discard) + + if bootstrap { + err := raft.BootstrapCluster(conf, logs, store, snap, trans, configuration) + if err != nil { + fmt.Println("[ERR] BootstrapCluster failed: %v", err) + } + } + + raft, err := raft.NewRaft(conf, fsm, logs, store, snap, trans) + if err != nil { + fmt.Println("[ERR] NewRaft failed: %v", err) + } + + raft.AddVoter(conf.LocalID, trans.LocalAddr(), 0, 0) +} + + +// Representation of cluster. +type cluster struct { + dirs []string + stores []*raft.InmemStore + fsms []raft.FSM + snaps []*raft.FileSnapshotStore + trans []raft.Transport + Rafts []*raft.Raft + conf *raft.Config + propagateTimeout time.Duration + longstopTimeout time.Duration + startTime time.Time + configuration raft.Configuration +} diff --git a/src/test/keyValStore/defs.go b/src/test/keyValStore/defs.go new file mode 100755 index 0000000..914c90c --- /dev/null +++ b/src/test/keyValStore/defs.go @@ -0,0 +1,19 @@ +package keyValStore + +// Client RPCs. +const GetCommand string = "Get" +const SetCommand string = "Set" +const IncCommand string = "Inc" +const FunctionArg string = "function" +const KeyArg string = "key" +const ValueArg string = "value" + +// Response to Get RPC. +type GetResponse struct { + Value string +} + +// Response to Inc RPC. +type IncResponse struct { + Value uint64 +} diff --git a/src/test/keyValStore/worker.go b/src/test/keyValStore/worker.go new file mode 100755 index 0000000..5d17da7 --- /dev/null +++ b/src/test/keyValStore/worker.go @@ -0,0 +1,79 @@ +package keyValStore + +import( + "raft" + "encoding/json" + "fmt" + "io" +) + +// FSM running on Raft servers to implement key-val store. +// *WorkerFSM implements raft.FSM by implementing Apply, +// Snapshot, Restore. +type WorkerFSM struct { + // Map representing key-value store. + KeyValMap map[string]string + counter uint64 +} + +type WorkerSnapshot struct{} + +// Create array of worker FSMs for starting a cluster. +// Params: +// - n: number of workers to create. +// Returns: array of raft FSMs of length n. +func CreateWorkers(n int) ([]raft.FSM) { + workers := make([]*WorkerFSM, n) + for i := range workers { + workers[i] = &WorkerFSM{ + KeyValMap: make(map[string]string), + counter: 0, + } + } + fsms := make([]raft.FSM, n) + for i, w := range workers { + fsms[i] = w + } + return fsms +} + +// Apply command to FSM and return response. +// Params: +// - log: log entry to apply to FSM. +// Returns: response JSON object. +func (w *WorkerFSM) Apply(log *raft.Log)(interface{}) { + args := make(map[string]string) + err := json.Unmarshal(log.Data, &args) + if err != nil { + fmt.Println("Poorly formatted request: ", err) + return nil + } + function := args[FunctionArg] + switch function { + case GetCommand: + return GetResponse{Value: w.KeyValMap[args[KeyArg]]} + case SetCommand: + w.KeyValMap[args[KeyArg]] = args[ValueArg] + return nil + case IncCommand: + w.counter += 1 + return IncResponse{Value: w.counter} + } + return nil +} + +// Don't need full implementation for testing. +func (w *WorkerFSM) Snapshot() (raft.FSMSnapshot, error) { + return WorkerSnapshot{}, nil +} + +// Don't need full implementation for testing. +func (w *WorkerFSM) Restore(i io.ReadCloser) error { + return nil +} + +// Don't need full implementation for testing. +func (s WorkerSnapshot) Persist(sink raft.SnapshotSink) error {return nil} + +// Don't need full implementation for testing. +func (s WorkerSnapshot) Release() {} diff --git a/src/test/utils/config.go b/src/test/utils/config.go new file mode 100755 index 0000000..39c588d --- /dev/null +++ b/src/test/utils/config.go @@ -0,0 +1,31 @@ +package utils + +import ( + "os" + "bufio" + "fmt" + "raft" +) + +type Config struct { + Servers []raft.ServerAddress +} + +func ReadConfig(path string) (*Config, error) { + file, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("Cannot open path: %s", path) + } + defer file.Close() + + var lines []raft.ServerAddress + scanner := bufio.NewScanner(file) + for scanner.Scan() { + lines = append(lines, raft.ServerAddress(scanner.Text())) + } + + c := &Config { + Servers: lines, + } + return c, nil +} diff --git a/src/test/utils/test.go b/src/test/utils/test.go new file mode 100755 index 0000000..1fadaf9 --- /dev/null +++ b/src/test/utils/test.go @@ -0,0 +1,30 @@ +package utils + +import ( + "fmt" + "os" + "runtime" + "reflect" +) + +// Run a series of tests and print if they pass or fail. +// Params: +// - tests: variable-length list of test functions to run. Expected +// to have no arguments and return an error (nil if success +// description of error if failure). +// Returns: number of tests failed +func RunTestSuite(tests ...func()(error)) int { + testsFailed := 0 + + for _,test := range tests { + err := test() + testName := runtime.FuncForPC(reflect.ValueOf(test).Pointer()).Name() + if err != nil { + fmt.Fprintf(os.Stderr, "%v FAILING: %v\n", testName, err) + testsFailed += 1 + } + fmt.Fprintf(os.Stderr, "%v passing\n", testName) + } + + return testsFailed +}