diff --git a/channeldb/witness_cache.go b/channeldb/witness_cache.go index 5a7b7db7..d21d9373 100644 --- a/channeldb/witness_cache.go +++ b/channeldb/witness_cache.go @@ -70,12 +70,42 @@ func (d *DB) NewWitnessCache() *WitnessCache { } } -// AddWitness adds a new witness of wType to the witness cache. The type of the -// witness will be used to map the witness to the key that will be used to look -// it up. +// witnessEntry is a key-value struct that holds each key -> witness pair, used +// when inserting records into the cache. +type witnessEntry struct { + key []byte + witness []byte +} + +// AddWitnesses adds a batch of new witnesses of wType to the witness cache. The +// type of the witness will be used to map each witness to the key that will be +// used to look it up. All witnesses should be of the same WitnessType. // // TODO(roasbeef): fake closure to map instead a constructor? -func (w *WitnessCache) AddWitness(wType WitnessType, witness []byte) error { +func (w *WitnessCache) AddWitnesses(wType WitnessType, witnesses ...[]byte) error { + // Optimistically compute the witness keys before attempting to start + // the db transaction. + entries := make([]witnessEntry, 0, len(witnesses)) + for _, witness := range witnesses { + // Map each witness to its key by applying the appropriate + // transformation for the given witness type. + switch wType { + case Sha256HashWitness: + key := sha256.Sum256(witness) + entries = append(entries, witnessEntry{ + key: key[:], + witness: witness, + }) + default: + return ErrUnknownWitnessType + } + } + + // Exit early if there are no witnesses to add. + if len(entries) == 0 { + return nil + } + return w.db.Batch(func(tx *bbolt.Tx) error { witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) if err != nil { @@ -93,16 +123,14 @@ func (w *WitnessCache) AddWitness(wType WitnessType, witness []byte) error { return err } - // Now that we have the proper bucket for this witness, we'll map the - // witness type to the proper key. - var witnessKey []byte - switch wType { - case Sha256HashWitness: - key := sha256.Sum256(witness) - witnessKey = key[:] + for _, entry := range entries { + err = witnessTypeBucket.Put(entry.key, entry.witness) + if err != nil { + return err + } } - return witnessTypeBucket.Put(witnessKey, witness) + return nil }) } diff --git a/channeldb/witness_cache_test.go b/channeldb/witness_cache_test.go index e4dac0b7..b68f2518 100644 --- a/channeldb/witness_cache_test.go +++ b/channeldb/witness_cache_test.go @@ -25,7 +25,7 @@ func TestWitnessCacheRetrieval(t *testing.T) { witnessKey := sha256.Sum256(witness) // First, we'll attempt to add the witness to the database. - err = wCache.AddWitness(Sha256HashWitness, witness) + err = wCache.AddWitnesses(Sha256HashWitness, witness) if err != nil { t.Fatalf("unable to add witness: %v", err) } @@ -59,13 +59,13 @@ func TestWitnessCacheDeletion(t *testing.T) { // We'll start by adding two witnesses to the cache. witness1 := rev[:] witness1Key := sha256.Sum256(witness1) - if err := wCache.AddWitness(Sha256HashWitness, witness1); err != nil { + if err := wCache.AddWitnesses(Sha256HashWitness, witness1); err != nil { t.Fatalf("unable to add witness: %v", err) } witness2 := key[:] witness2Key := sha256.Sum256(witness2) - if err := wCache.AddWitness(Sha256HashWitness, witness2); err != nil { + if err := wCache.AddWitnesses(Sha256HashWitness, witness2); err != nil { t.Fatalf("unable to add witness: %v", err) } @@ -107,7 +107,7 @@ func TestWitnessCacheUnknownWitness(t *testing.T) { // We'll attempt to add a new, undefined witness type to the database. // We should get an error. - err = wCache.AddWitness(234, key[:]) + err = wCache.AddWitnesses(234, key[:]) if err != ErrUnknownWitnessType { t.Fatalf("expected ErrUnknownWitnessType, got %v", err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index bc47ebe5..2f49121f 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -64,8 +64,10 @@ type WitnessBeacon interface { // True is returned for the second argument if the preimage is found. LookupPreimage(payhash []byte) ([]byte, bool) - // AddPreImage adds a newly discovered preimage to the global cache. - AddPreimage(pre []byte) error + // AddPreimages adds a batch of newly discovered preimages to the global + // cache, and also signals any subscribers of the newly discovered + // witness. + AddPreimages(preimages ...[]byte) error } // ChannelArbitratorConfig contains all the functionality that the diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 6075b379..e343a3db 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -79,7 +79,7 @@ func (h *htlcOutgoingContestResolver) Resolve() (ContractResolver, error) { // With the preimage obtained, we can now add it to the global // cache. - if err := h.PreimageDB.AddPreimage(preimage[:]); err != nil { + if err := h.PreimageDB.AddPreimages(preimage[:]); err != nil { log.Errorf("%T(%v): unable to add witness to cache", h, h.htlcResolution.ClaimOutpoint) } diff --git a/htlcswitch/link.go b/htlcswitch/link.go index c6d3dbcf..b1faddb2 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -1417,7 +1417,7 @@ func (l *channelLink) handleUpstreamMsg(msg lnwire.Message) { // any contested contracts watched by any on-chain arbitrators // can now sweep this HTLC on-chain. go func() { - err := l.cfg.PreimageCache.AddPreimage(pre[:]) + err := l.cfg.PreimageCache.AddPreimages(pre[:]) if err != nil { l.errorf("unable to add preimage=%x to "+ "cache", pre[:]) diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 89913c17..1bd90bcf 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -46,11 +46,13 @@ func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { return p, ok } -func (m *mockPreimageCache) AddPreimage(preimage []byte) error { +func (m *mockPreimageCache) AddPreimages(preimages ...[]byte) error { m.Lock() defer m.Unlock() - m.preimageMap[sha256.Sum256(preimage[:])] = preimage + for _, preimage := range preimages { + m.preimageMap[sha256.Sum256(preimage)] = preimage + } return nil } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index ddf2ee1d..02012162 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -584,7 +584,7 @@ func TestForceClose(t *testing.T) { // Before we force close Alice's channel, we'll add the pre-image of // Bob's HTLC to her preimage cache. - aliceChannel.pCache.AddPreimage(preimageBob[:]) + aliceChannel.pCache.AddPreimages(preimageBob[:]) // With the cache populated, we'll now attempt the force close // initiated by Alice. @@ -4953,7 +4953,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { // Now that Bob has force closed, we'll modify Alice's pre image cache // such that she now gains the ability to also settle the incoming HTLC // from Bob. - aliceChannel.pCache.AddPreimage(preimageBob[:]) + aliceChannel.pCache.AddPreimages(preimageBob[:]) // We'll then use Bob's transaction to trigger a spend notification for // Alice. diff --git a/lnwallet/interface.go b/lnwallet/interface.go index 1fe245d6..68a26308 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -274,9 +274,13 @@ type PreimageCache interface { // argument. Otherwise, it'll return false. LookupPreimage(hash []byte) ([]byte, bool) - // AddPreimage attempts to add a new preimage to the global cache. If - // successful a nil error will be returned. - AddPreimage(preimage []byte) error + // AddPreimages adds a batch of newly discovered preimages to the global + // cache, and also signals any subscribers of the newly discovered + // witness. + // + // NOTE: The backing slice of MUST NOT be modified, otherwise the + // subscribers may be notified of the incorrect preimages. + AddPreimages(preimages ...[]byte) error } // WalletDriver represents a "driver" for a particular concrete diff --git a/lnwallet/test_utils.go b/lnwallet/test_utils.go index e273a6a5..9f5afc2c 100644 --- a/lnwallet/test_utils.go +++ b/lnwallet/test_utils.go @@ -403,11 +403,13 @@ func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { return p, ok } -func (m *mockPreimageCache) AddPreimage(preimage []byte) error { +func (m *mockPreimageCache) AddPreimages(preimages ...[]byte) error { m.Lock() defer m.Unlock() - m.preimageMap[sha256.Sum256(preimage[:])] = preimage + for _, preimage := range preimages { + m.preimageMap[sha256.Sum256(preimage)] = preimage + } return nil } diff --git a/mock.go b/mock.go index 2cbac0c5..2aac5f7d 100644 --- a/mock.go +++ b/mock.go @@ -317,11 +317,13 @@ func (m *mockPreimageCache) LookupPreimage(hash []byte) ([]byte, bool) { return p, ok } -func (m *mockPreimageCache) AddPreimage(preimage []byte) error { +func (m *mockPreimageCache) AddPreimages(preimages ...[]byte) error { m.Lock() defer m.Unlock() - m.preimageMap[sha256.Sum256(preimage[:])] = preimage + for _, preimage := range preimages { + m.preimageMap[sha256.Sum256(preimage)] = preimage + } return nil } diff --git a/witness_beacon.go b/witness_beacon.go index d7b92b01..98584318 100644 --- a/witness_beacon.go +++ b/witness_beacon.go @@ -101,28 +101,41 @@ func (p *preimageBeacon) LookupPreimage(payHash []byte) ([]byte, bool) { return preimage, true } -// AddPreImage adds a newly discovered preimage to the global cache, and also -// signals any subscribers of the newly discovered witness. -func (p *preimageBeacon) AddPreimage(pre []byte) error { - p.Lock() - defer p.Unlock() +// AddPreimages adds a batch of newly discovered preimages to the global cache, +// and also signals any subscribers of the newly discovered witness. +func (p *preimageBeacon) AddPreimages(preimages ...[]byte) error { + // Exit early if no preimages are presented. + if len(preimages) == 0 { + return nil + } - srvrLog.Infof("Adding preimage=%x to witness cache", pre[:]) + // Copy the preimages to ensure the backing area can't be modified by + // the caller when delivering notifications. + preimageCopies := make([][]byte, 0, len(preimages)) + for _, preimage := range preimages { + srvrLog.Infof("Adding preimage=%x to witness cache", preimage) + preimageCopies = append(preimageCopies, preimage) + } // First, we'll add the witness to the decaying witness cache. - err := p.wCache.AddWitness(channeldb.Sha256HashWitness, pre) + err := p.wCache.AddWitnesses(channeldb.Sha256HashWitness, preimages...) if err != nil { return err } + p.Lock() + defer p.Unlock() + // With the preimage added to our state, we'll now send a new // notification to all subscribers. for _, client := range p.subscribers { go func(c *preimageSubscriber) { - select { - case c.updateChan <- pre: - case <-c.quit: - return + for _, preimage := range preimageCopies { + select { + case c.updateChan <- preimage: + case <-c.quit: + return + } } }(client) }