diff --git a/channeldb/waitingproof.go b/channeldb/waitingproof.go index 327ea628..28749d8f 100644 --- a/channeldb/waitingproof.go +++ b/channeldb/waitingproof.go @@ -2,6 +2,7 @@ package channeldb import ( "encoding/binary" + "sync" "io" @@ -35,6 +36,7 @@ type WaitingProofStore struct { // calls, when object isn't stored in it. cache map[WaitingProofKey]struct{} db *DB + mu sync.RWMutex } // NewWaitingProofStore creates new instance of proofs storage. @@ -56,7 +58,10 @@ func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { // Add adds new waiting proof in the storage. func (s *WaitingProofStore) Add(proof *WaitingProof) error { - return s.db.Batch(func(tx *bolt.Tx) error { + s.mu.Lock() + defer s.mu.Unlock() + + err := s.db.Update(func(tx *bolt.Tx) error { var err error var b bytes.Buffer @@ -72,36 +77,47 @@ func (s *WaitingProofStore) Add(proof *WaitingProof) error { } key := proof.Key() - if err := bucket.Put(key[:], b.Bytes()); err != nil { - return err - } - s.cache[proof.Key()] = struct{}{} - - return nil + return bucket.Put(key[:], b.Bytes()) }) + if err != nil { + return err + } + + // Knowing that the write succeeded, we can now update the in-memory + // cache with the proof's key. + s.cache[proof.Key()] = struct{}{} + + return nil } // Remove removes the proof from storage by its key. func (s *WaitingProofStore) Remove(key WaitingProofKey) error { + s.mu.Lock() + defer s.mu.Unlock() + if _, ok := s.cache[key]; !ok { return ErrWaitingProofNotFound } - return s.db.Batch(func(tx *bolt.Tx) error { + err := s.db.Update(func(tx *bolt.Tx) error { // Get or create the top bucket. bucket := tx.Bucket(waitingProofsBucketKey) if bucket == nil { return ErrWaitingProofNotFound } - if err := bucket.Delete(key[:]); err != nil { - return err - } - - delete(s.cache, key) - return nil + return bucket.Delete(key[:]) }) + if err != nil { + return err + } + + // Since the proof was successfully deleted from the store, we can now + // remove it from the in-memory cache. + delete(s.cache, key) + + return nil } // ForAll iterates thought all waiting proofs and passing the waiting proof @@ -135,6 +151,9 @@ func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { proof := &WaitingProof{} + s.mu.RLock() + defer s.mu.RUnlock() + if _, ok := s.cache[key]; !ok { return nil, ErrWaitingProofNotFound }