From a0d3b7d9e310a520a01d31a758d6a23391b66d72 Mon Sep 17 00:00:00 2001
From: Wilmer Paulino <wilmer.paulino@gmail.com>
Date: Thu, 6 Dec 2018 21:13:35 -0800
Subject: [PATCH] chainntnfs: support caching confirm/spend hints for scripts

In this commit, we refactor the HeightHintCache and its underlying
interfaces to be able to manipulate hints for ConfRequests and
SpendRequests. By doing so, we'll be able to manipulate hints for
scripts if the request includes either a zero hash or a zero outpoint.
---
 chainntnfs/height_hint_cache.go      | 96 +++++++++++++---------------
 chainntnfs/height_hint_cache_test.go | 53 ++++++++-------
 chainntnfs/txnotifier.go             | 33 ++++++++++
 chainntnfs/txnotifier_test.go        | 44 +++++++------
 4 files changed, 130 insertions(+), 96 deletions(-)

diff --git a/chainntnfs/height_hint_cache.go b/chainntnfs/height_hint_cache.go
index b05f09df..bfd0b574 100644
--- a/chainntnfs/height_hint_cache.go
+++ b/chainntnfs/height_hint_cache.go
@@ -4,8 +4,6 @@ import (
 	"bytes"
 	"errors"
 
-	"github.com/btcsuite/btcd/chaincfg/chainhash"
-	"github.com/btcsuite/btcd/wire"
 	bolt "github.com/coreos/bbolt"
 	"github.com/lightningnetwork/lnd/channeldb"
 )
@@ -51,16 +49,16 @@ var (
 // which an outpoint could have been spent within.
 type SpendHintCache interface {
 	// CommitSpendHint commits a spend hint for the outpoints to the cache.
-	CommitSpendHint(height uint32, ops ...wire.OutPoint) error
+	CommitSpendHint(height uint32, spendRequests ...SpendRequest) error
 
 	// QuerySpendHint returns the latest spend hint for an outpoint.
 	// ErrSpendHintNotFound is returned if a spend hint does not exist
 	// within the cache for the outpoint.
-	QuerySpendHint(op wire.OutPoint) (uint32, error)
+	QuerySpendHint(spendRequest SpendRequest) (uint32, error)
 
 	// PurgeSpendHint removes the spend hint for the outpoints from the
 	// cache.
-	PurgeSpendHint(ops ...wire.OutPoint) error
+	PurgeSpendHint(spendRequests ...SpendRequest) error
 }
 
 // ConfirmHintCache is an interface whose duty is to cache confirm hints for
@@ -69,16 +67,16 @@ type SpendHintCache interface {
 type ConfirmHintCache interface {
 	// CommitConfirmHint commits a confirm hint for the transactions to the
 	// cache.
-	CommitConfirmHint(height uint32, txids ...chainhash.Hash) error
+	CommitConfirmHint(height uint32, confRequests ...ConfRequest) error
 
 	// QueryConfirmHint returns the latest confirm hint for a transaction
 	// hash. ErrConfirmHintNotFound is returned if a confirm hint does not
 	// exist within the cache for the transaction hash.
-	QueryConfirmHint(txid chainhash.Hash) (uint32, error)
+	QueryConfirmHint(confRequest ConfRequest) (uint32, error)
 
 	// PurgeConfirmHint removes the confirm hint for the transactions from
 	// the cache.
-	PurgeConfirmHint(txids ...chainhash.Hash) error
+	PurgeConfirmHint(confRequests ...ConfRequest) error
 }
 
 // HeightHintCache is an implementation of the SpendHintCache and
@@ -118,12 +116,15 @@ func (c *HeightHintCache) initBuckets() error {
 }
 
 // CommitSpendHint commits a spend hint for the outpoints to the cache.
-func (c *HeightHintCache) CommitSpendHint(height uint32, ops ...wire.OutPoint) error {
-	if len(ops) == 0 {
+func (c *HeightHintCache) CommitSpendHint(height uint32,
+	spendRequests ...SpendRequest) error {
+
+	if len(spendRequests) == 0 {
 		return nil
 	}
 
-	Log.Tracef("Updating spend hint to height %d for %v", height, ops)
+	Log.Tracef("Updating spend hint to height %d for %v", height,
+		spendRequests)
 
 	return c.db.Batch(func(tx *bolt.Tx) error {
 		spendHints := tx.Bucket(spendHintBucket)
@@ -136,14 +137,12 @@ func (c *HeightHintCache) CommitSpendHint(height uint32, ops ...wire.OutPoint) e
 			return err
 		}
 
-		for _, op := range ops {
-			var outpoint bytes.Buffer
-			err := channeldb.WriteElement(&outpoint, op)
+		for _, spendRequest := range spendRequests {
+			spendHintKey, err := spendRequest.SpendHintKey()
 			if err != nil {
 				return err
 			}
-
-			err = spendHints.Put(outpoint.Bytes(), hint.Bytes())
+			err = spendHints.Put(spendHintKey, hint.Bytes())
 			if err != nil {
 				return err
 			}
@@ -156,7 +155,7 @@ func (c *HeightHintCache) CommitSpendHint(height uint32, ops ...wire.OutPoint) e
 // QuerySpendHint returns the latest spend hint for an outpoint.
 // ErrSpendHintNotFound is returned if a spend hint does not exist within the
 // cache for the outpoint.
-func (c *HeightHintCache) QuerySpendHint(op wire.OutPoint) (uint32, error) {
+func (c *HeightHintCache) QuerySpendHint(spendRequest SpendRequest) (uint32, error) {
 	var hint uint32
 	err := c.db.View(func(tx *bolt.Tx) error {
 		spendHints := tx.Bucket(spendHintBucket)
@@ -164,12 +163,11 @@ func (c *HeightHintCache) QuerySpendHint(op wire.OutPoint) (uint32, error) {
 			return ErrCorruptedHeightHintCache
 		}
 
-		var outpoint bytes.Buffer
-		if err := channeldb.WriteElement(&outpoint, op); err != nil {
+		spendHintKey, err := spendRequest.SpendHintKey()
+		if err != nil {
 			return err
 		}
-
-		spendHint := spendHints.Get(outpoint.Bytes())
+		spendHint := spendHints.Get(spendHintKey)
 		if spendHint == nil {
 			return ErrSpendHintNotFound
 		}
@@ -184,12 +182,12 @@ func (c *HeightHintCache) QuerySpendHint(op wire.OutPoint) (uint32, error) {
 }
 
 // PurgeSpendHint removes the spend hint for the outpoints from the cache.
-func (c *HeightHintCache) PurgeSpendHint(ops ...wire.OutPoint) error {
-	if len(ops) == 0 {
+func (c *HeightHintCache) PurgeSpendHint(spendRequests ...SpendRequest) error {
+	if len(spendRequests) == 0 {
 		return nil
 	}
 
-	Log.Tracef("Removing spend hints for %v", ops)
+	Log.Tracef("Removing spend hints for %v", spendRequests)
 
 	return c.db.Batch(func(tx *bolt.Tx) error {
 		spendHints := tx.Bucket(spendHintBucket)
@@ -197,15 +195,12 @@ func (c *HeightHintCache) PurgeSpendHint(ops ...wire.OutPoint) error {
 			return ErrCorruptedHeightHintCache
 		}
 
-		for _, op := range ops {
-			var outpoint bytes.Buffer
-			err := channeldb.WriteElement(&outpoint, op)
+		for _, spendRequest := range spendRequests {
+			spendHintKey, err := spendRequest.SpendHintKey()
 			if err != nil {
 				return err
 			}
-
-			err = spendHints.Delete(outpoint.Bytes())
-			if err != nil {
+			if err := spendHints.Delete(spendHintKey); err != nil {
 				return err
 			}
 		}
@@ -215,12 +210,15 @@ func (c *HeightHintCache) PurgeSpendHint(ops ...wire.OutPoint) error {
 }
 
 // CommitConfirmHint commits a confirm hint for the transactions to the cache.
-func (c *HeightHintCache) CommitConfirmHint(height uint32, txids ...chainhash.Hash) error {
-	if len(txids) == 0 {
+func (c *HeightHintCache) CommitConfirmHint(height uint32,
+	confRequests ...ConfRequest) error {
+
+	if len(confRequests) == 0 {
 		return nil
 	}
 
-	Log.Tracef("Updating confirm hints to height %d for %v", height, txids)
+	Log.Tracef("Updating confirm hints to height %d for %v", height,
+		confRequests)
 
 	return c.db.Batch(func(tx *bolt.Tx) error {
 		confirmHints := tx.Bucket(confirmHintBucket)
@@ -233,14 +231,12 @@ func (c *HeightHintCache) CommitConfirmHint(height uint32, txids ...chainhash.Ha
 			return err
 		}
 
-		for _, txid := range txids {
-			var txHash bytes.Buffer
-			err := channeldb.WriteElement(&txHash, txid)
+		for _, confRequest := range confRequests {
+			confHintKey, err := confRequest.ConfHintKey()
 			if err != nil {
 				return err
 			}
-
-			err = confirmHints.Put(txHash.Bytes(), hint.Bytes())
+			err = confirmHints.Put(confHintKey, hint.Bytes())
 			if err != nil {
 				return err
 			}
@@ -253,7 +249,7 @@ func (c *HeightHintCache) CommitConfirmHint(height uint32, txids ...chainhash.Ha
 // QueryConfirmHint returns the latest confirm hint for a transaction hash.
 // ErrConfirmHintNotFound is returned if a confirm hint does not exist within
 // the cache for the transaction hash.
-func (c *HeightHintCache) QueryConfirmHint(txid chainhash.Hash) (uint32, error) {
+func (c *HeightHintCache) QueryConfirmHint(confRequest ConfRequest) (uint32, error) {
 	var hint uint32
 	err := c.db.View(func(tx *bolt.Tx) error {
 		confirmHints := tx.Bucket(confirmHintBucket)
@@ -261,12 +257,11 @@ func (c *HeightHintCache) QueryConfirmHint(txid chainhash.Hash) (uint32, error)
 			return ErrCorruptedHeightHintCache
 		}
 
-		var txHash bytes.Buffer
-		if err := channeldb.WriteElement(&txHash, txid); err != nil {
+		confHintKey, err := confRequest.ConfHintKey()
+		if err != nil {
 			return err
 		}
-
-		confirmHint := confirmHints.Get(txHash.Bytes())
+		confirmHint := confirmHints.Get(confHintKey)
 		if confirmHint == nil {
 			return ErrConfirmHintNotFound
 		}
@@ -282,12 +277,12 @@ func (c *HeightHintCache) QueryConfirmHint(txid chainhash.Hash) (uint32, error)
 
 // PurgeConfirmHint removes the confirm hint for the transactions from the
 // cache.
-func (c *HeightHintCache) PurgeConfirmHint(txids ...chainhash.Hash) error {
-	if len(txids) == 0 {
+func (c *HeightHintCache) PurgeConfirmHint(confRequests ...ConfRequest) error {
+	if len(confRequests) == 0 {
 		return nil
 	}
 
-	Log.Tracef("Removing confirm hints for %v", txids)
+	Log.Tracef("Removing confirm hints for %v", confRequests)
 
 	return c.db.Batch(func(tx *bolt.Tx) error {
 		confirmHints := tx.Bucket(confirmHintBucket)
@@ -295,15 +290,12 @@ func (c *HeightHintCache) PurgeConfirmHint(txids ...chainhash.Hash) error {
 			return ErrCorruptedHeightHintCache
 		}
 
-		for _, txid := range txids {
-			var txHash bytes.Buffer
-			err := channeldb.WriteElement(&txHash, txid)
+		for _, confRequest := range confRequests {
+			confHintKey, err := confRequest.ConfHintKey()
 			if err != nil {
 				return err
 			}
-
-			err = confirmHints.Delete(txHash.Bytes())
-			if err != nil {
+			if err := confirmHints.Delete(confHintKey); err != nil {
 				return err
 			}
 		}
diff --git a/chainntnfs/height_hint_cache_test.go b/chainntnfs/height_hint_cache_test.go
index f444b18d..d2fe81c2 100644
--- a/chainntnfs/height_hint_cache_test.go
+++ b/chainntnfs/height_hint_cache_test.go
@@ -39,7 +39,9 @@ func TestHeightHintCacheConfirms(t *testing.T) {
 	// Querying for a transaction hash not found within the cache should
 	// return an error indication so.
 	var unknownHash chainhash.Hash
-	_, err := hintCache.QueryConfirmHint(unknownHash)
+	copy(unknownHash[:], bytes.Repeat([]byte{0x01}, 32))
+	unknownConfRequest := ConfRequest{TxID: unknownHash}
+	_, err := hintCache.QueryConfirmHint(unknownConfRequest)
 	if err != ErrConfirmHintNotFound {
 		t.Fatalf("expected ErrConfirmHintNotFound, got: %v", err)
 	}
@@ -48,23 +50,24 @@ func TestHeightHintCacheConfirms(t *testing.T) {
 	// cache with the same confirm hint.
 	const height = 100
 	const numHashes = 5
-	txHashes := make([]chainhash.Hash, numHashes)
+	confRequests := make([]ConfRequest, numHashes)
 	for i := 0; i < numHashes; i++ {
 		var txHash chainhash.Hash
-		copy(txHash[:], bytes.Repeat([]byte{byte(i)}, 32))
-		txHashes[i] = txHash
+		copy(txHash[:], bytes.Repeat([]byte{byte(i + 1)}, 32))
+		confRequests[i] = ConfRequest{TxID: txHash}
 	}
 
-	if err := hintCache.CommitConfirmHint(height, txHashes...); err != nil {
+	err = hintCache.CommitConfirmHint(height, confRequests...)
+	if err != nil {
 		t.Fatalf("unable to add entries to cache: %v", err)
 	}
 
 	// With the hashes committed, we'll now query the cache to ensure that
 	// we're able to properly retrieve the confirm hints.
-	for _, txHash := range txHashes {
-		confirmHint, err := hintCache.QueryConfirmHint(txHash)
+	for _, confRequest := range confRequests {
+		confirmHint, err := hintCache.QueryConfirmHint(confRequest)
 		if err != nil {
-			t.Fatalf("unable to query for hint: %v", err)
+			t.Fatalf("unable to query for hint of %v: %v", confRequest, err)
 		}
 		if confirmHint != height {
 			t.Fatalf("expected confirm hint %d, got %d", height,
@@ -74,14 +77,14 @@ func TestHeightHintCacheConfirms(t *testing.T) {
 
 	// We'll also attempt to purge all of them in a single database
 	// transaction.
-	if err := hintCache.PurgeConfirmHint(txHashes...); err != nil {
+	if err := hintCache.PurgeConfirmHint(confRequests...); err != nil {
 		t.Fatalf("unable to remove confirm hints: %v", err)
 	}
 
 	// Finally, we'll attempt to query for each hash. We should expect not
 	// to find a hint for any of them.
-	for _, txHash := range txHashes {
-		_, err := hintCache.QueryConfirmHint(txHash)
+	for _, confRequest := range confRequests {
+		_, err := hintCache.QueryConfirmHint(confRequest)
 		if err != ErrConfirmHintNotFound {
 			t.Fatalf("expected ErrConfirmHintNotFound, got :%v", err)
 		}
@@ -97,8 +100,9 @@ func TestHeightHintCacheSpends(t *testing.T) {
 
 	// Querying for an outpoint not found within the cache should return an
 	// error indication so.
-	var unknownOutPoint wire.OutPoint
-	_, err := hintCache.QuerySpendHint(unknownOutPoint)
+	unknownOutPoint := wire.OutPoint{Index: 1}
+	unknownSpendRequest := SpendRequest{OutPoint: unknownOutPoint}
+	_, err := hintCache.QuerySpendHint(unknownSpendRequest)
 	if err != ErrSpendHintNotFound {
 		t.Fatalf("expected ErrSpendHintNotFound, got: %v", err)
 	}
@@ -107,21 +111,22 @@ func TestHeightHintCacheSpends(t *testing.T) {
 	// the same spend hint.
 	const height = 100
 	const numOutpoints = 5
-	var txHash chainhash.Hash
-	copy(txHash[:], bytes.Repeat([]byte{0xFF}, 32))
-	outpoints := make([]wire.OutPoint, numOutpoints)
+	spendRequests := make([]SpendRequest, numOutpoints)
 	for i := uint32(0); i < numOutpoints; i++ {
-		outpoints[i] = wire.OutPoint{Hash: txHash, Index: i}
+		spendRequests[i] = SpendRequest{
+			OutPoint: wire.OutPoint{Index: i + 1},
+		}
 	}
 
-	if err := hintCache.CommitSpendHint(height, outpoints...); err != nil {
-		t.Fatalf("unable to add entry to cache: %v", err)
+	err = hintCache.CommitSpendHint(height, spendRequests...)
+	if err != nil {
+		t.Fatalf("unable to add entries to cache: %v", err)
 	}
 
 	// With the outpoints committed, we'll now query the cache to ensure
 	// that we're able to properly retrieve the confirm hints.
-	for _, op := range outpoints {
-		spendHint, err := hintCache.QuerySpendHint(op)
+	for _, spendRequest := range spendRequests {
+		spendHint, err := hintCache.QuerySpendHint(spendRequest)
 		if err != nil {
 			t.Fatalf("unable to query for hint: %v", err)
 		}
@@ -133,14 +138,14 @@ func TestHeightHintCacheSpends(t *testing.T) {
 
 	// We'll also attempt to purge all of them in a single database
 	// transaction.
-	if err := hintCache.PurgeSpendHint(outpoints...); err != nil {
+	if err := hintCache.PurgeSpendHint(spendRequests...); err != nil {
 		t.Fatalf("unable to remove spend hint: %v", err)
 	}
 
 	// Finally, we'll attempt to query for each outpoint. We should expect
 	// not to find a hint for any of them.
-	for _, op := range outpoints {
-		_, err = hintCache.QuerySpendHint(op)
+	for _, spendRequest := range spendRequests {
+		_, err = hintCache.QuerySpendHint(spendRequest)
 		if err != ErrSpendHintNotFound {
 			t.Fatalf("expected ErrSpendHintNotFound, got: %v", err)
 		}
diff --git a/chainntnfs/txnotifier.go b/chainntnfs/txnotifier.go
index e0cfe1c5..d64d3aaa 100644
--- a/chainntnfs/txnotifier.go
+++ b/chainntnfs/txnotifier.go
@@ -1,6 +1,7 @@
 package chainntnfs
 
 import (
+	"bytes"
 	"errors"
 	"fmt"
 	"sync"
@@ -8,6 +9,7 @@ import (
 	"github.com/btcsuite/btcd/chaincfg/chainhash"
 	"github.com/btcsuite/btcd/wire"
 	"github.com/btcsuite/btcutil"
+	"github.com/lightningnetwork/lnd/channeldb"
 )
 
 const (
@@ -165,6 +167,21 @@ func (r ConfRequest) String() string {
 	return fmt.Sprintf("script=%v", r.PkScript)
 }
 
+// ConfHintKey returns the key that will be used to index the confirmation
+// request's hint within the height hint cache.
+func (r ConfRequest) ConfHintKey() ([]byte, error) {
+	if r.TxID == ZeroHash {
+		return r.PkScript.Script(), nil
+	}
+
+	var txid bytes.Buffer
+	if err := channeldb.WriteElement(&txid, r.TxID); err != nil {
+		return nil, err
+	}
+
+	return txid.Bytes(), nil
+}
+
 // ConfNtfn represents a notifier client's request to receive a notification
 // once the target transaction gets sufficient confirmations. The client is
 // asynchronously notified via the ConfirmationEvent channels.
@@ -265,6 +282,22 @@ func (r SpendRequest) String() string {
 	return fmt.Sprintf("script=%v", r.PkScript)
 }
 
+// SpendHintKey returns the key that will be used to index the spend request's
+// hint within the height hint cache.
+func (r SpendRequest) SpendHintKey() ([]byte, error) {
+	if r.OutPoint == ZeroOutPoint {
+		return r.PkScript.Script(), nil
+	}
+
+	var outpoint bytes.Buffer
+	err := channeldb.WriteElement(&outpoint, r.OutPoint)
+	if err != nil {
+		return nil, err
+	}
+
+	return outpoint.Bytes(), nil
+}
+
 // SpendNtfn represents a client's request to receive a notification once an
 // outpoint has been spent on-chain. The client is asynchronously notified via
 // the SpendEvent channels.
diff --git a/chainntnfs/txnotifier_test.go b/chainntnfs/txnotifier_test.go
index 6a9a533f..d426ebdf 100644
--- a/chainntnfs/txnotifier_test.go
+++ b/chainntnfs/txnotifier_test.go
@@ -17,29 +17,31 @@ var (
 
 type mockHintCache struct {
 	mu         sync.Mutex
-	confHints  map[chainhash.Hash]uint32
-	spendHints map[wire.OutPoint]uint32
+	confHints  map[chainntnfs.ConfRequest]uint32
+	spendHints map[chainntnfs.SpendRequest]uint32
 }
 
 var _ chainntnfs.SpendHintCache = (*mockHintCache)(nil)
 var _ chainntnfs.ConfirmHintCache = (*mockHintCache)(nil)
 
-func (c *mockHintCache) CommitSpendHint(heightHint uint32, ops ...wire.OutPoint) error {
+func (c *mockHintCache) CommitSpendHint(heightHint uint32,
+	spendRequests ...chainntnfs.SpendRequest) error {
+
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	for _, op := range ops {
-		c.spendHints[op] = heightHint
+	for _, spendRequest := range spendRequests {
+		c.spendHints[spendRequest] = heightHint
 	}
 
 	return nil
 }
 
-func (c *mockHintCache) QuerySpendHint(op wire.OutPoint) (uint32, error) {
+func (c *mockHintCache) QuerySpendHint(spendRequest chainntnfs.SpendRequest) (uint32, error) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	hint, ok := c.spendHints[op]
+	hint, ok := c.spendHints[spendRequest]
 	if !ok {
 		return 0, chainntnfs.ErrSpendHintNotFound
 	}
@@ -47,33 +49,35 @@ func (c *mockHintCache) QuerySpendHint(op wire.OutPoint) (uint32, error) {
 	return hint, nil
 }
 
-func (c *mockHintCache) PurgeSpendHint(ops ...wire.OutPoint) error {
+func (c *mockHintCache) PurgeSpendHint(spendRequests ...chainntnfs.SpendRequest) error {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	for _, op := range ops {
-		delete(c.spendHints, op)
+	for _, spendRequest := range spendRequests {
+		delete(c.spendHints, spendRequest)
 	}
 
 	return nil
 }
 
-func (c *mockHintCache) CommitConfirmHint(heightHint uint32, txids ...chainhash.Hash) error {
+func (c *mockHintCache) CommitConfirmHint(heightHint uint32,
+	confRequests ...chainntnfs.ConfRequest) error {
+
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	for _, txid := range txids {
-		c.confHints[txid] = heightHint
+	for _, confRequest := range confRequests {
+		c.confHints[confRequest] = heightHint
 	}
 
 	return nil
 }
 
-func (c *mockHintCache) QueryConfirmHint(txid chainhash.Hash) (uint32, error) {
+func (c *mockHintCache) QueryConfirmHint(confRequest chainntnfs.ConfRequest) (uint32, error) {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	hint, ok := c.confHints[txid]
+	hint, ok := c.confHints[confRequest]
 	if !ok {
 		return 0, chainntnfs.ErrConfirmHintNotFound
 	}
@@ -81,12 +85,12 @@ func (c *mockHintCache) QueryConfirmHint(txid chainhash.Hash) (uint32, error) {
 	return hint, nil
 }
 
-func (c *mockHintCache) PurgeConfirmHint(txids ...chainhash.Hash) error {
+func (c *mockHintCache) PurgeConfirmHint(confRequests ...chainntnfs.ConfRequest) error {
 	c.mu.Lock()
 	defer c.mu.Unlock()
 
-	for _, txid := range txids {
-		delete(c.confHints, txid)
+	for _, confRequest := range confRequests {
+		delete(c.confHints, confRequest)
 	}
 
 	return nil
@@ -94,8 +98,8 @@ func (c *mockHintCache) PurgeConfirmHint(txids ...chainhash.Hash) error {
 
 func newMockHintCache() *mockHintCache {
 	return &mockHintCache{
-		confHints:  make(map[chainhash.Hash]uint32),
-		spendHints: make(map[wire.OutPoint]uint32),
+		confHints:  make(map[chainntnfs.ConfRequest]uint32),
+		spendHints: make(map[chainntnfs.SpendRequest]uint32),
 	}
 }