diff --git a/htlcswitch/decayedlog.go b/htlcswitch/decayedlog.go new file mode 100644 index 00000000..662a87b4 --- /dev/null +++ b/htlcswitch/decayedlog.go @@ -0,0 +1,405 @@ +package htlcswitch + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +const ( + // defaultDbDirectory is the default directory where our decayed log + // will store our (sharedHash, CLTV) key-value pairs. + defaultDbDirectory = "sharedhashes" + + // dbPermissions sets the database permissions to user write-and-readable. + dbPermissions = 0600 +) + +var ( + // sharedHashBucket is a bucket which houses the first HashPrefixSize + // bytes of a received HTLC's hashed shared secret as the key and the HTLC's + // CLTV expiry as the value. + sharedHashBucket = []byte("shared-hash") + + // batchReplayBucket is a bucket that maps batch identifiers to + // serialized ReplaySets. This is used to give idempotency in the event + // that a batch is processed more than once. + batchReplayBucket = []byte("batch-replay") +) + +var ( + // ErrDecayedLogInit is used to indicate a decayed log failed to create + // the proper bucketing structure on startup. + ErrDecayedLogInit = errors.New("unable to initialize decayed log") + + // ErrDecayedLogCorrupted signals that the anticipated bucketing + // structure has diverged since initialization. + ErrDecayedLogCorrupted = errors.New("decayed log structure corrupted") +) + +// DecayedLog implements the PersistLog interface. It stores the first +// HashPrefixSize bytes of a sha256-hashed shared secret along with a node's +// CLTV value. It is a decaying log meaning there will be a garbage collector +// to collect entries which are expired according to their stored CLTV value +// and the current block height. DecayedLog wraps boltdb for simplicity and +// batches writes to the database to decrease write contention. +type DecayedLog struct { + started int32 + stopped int32 + + dbPath string + + db *bolt.DB + + notifier chainntnfs.ChainNotifier + + wg sync.WaitGroup + quit chan struct{} +} + +// NewDecayedLog creates a new DecayedLog, which caches recently seen hash +// shared secrets. Entries are evicted as their cltv expires using block epochs +// from the given notifier. +func NewDecayedLog(dbPath string, + notifier chainntnfs.ChainNotifier) *DecayedLog { + + // Use default path for log database + if dbPath == "" { + dbPath = defaultDbDirectory + } + + return &DecayedLog{ + dbPath: dbPath, + notifier: notifier, + quit: make(chan struct{}), + } +} + +// Start opens the database we will be using to store hashed shared secrets. +// It also starts the garbage collector in a goroutine to remove stale +// database entries. +func (d *DecayedLog) Start() error { + if !atomic.CompareAndSwapInt32(&d.started, 0, 1) { + return nil + } + + // Open the boltdb for use. + var err error + if d.db, err = bolt.Open(d.dbPath, dbPermissions, nil); err != nil { + return fmt.Errorf("Could not open boltdb: %v", err) + } + + // Initialize the primary buckets used by the decayed log. + if err := d.initBuckets(); err != nil { + return err + } + + // Start garbage collector. + if d.notifier != nil { + epochClient, err := d.notifier.RegisterBlockEpochNtfn() + if err != nil { + return fmt.Errorf("Unable to register for epoch "+ + "notifications: %v", err) + } + + d.wg.Add(1) + go d.garbageCollector(epochClient) + } + + return nil +} + +// initBuckets initializes the primary buckets used by the decayed log, namely +// the shared hash bucket, and batch replay +func (d *DecayedLog) initBuckets() error { + return d.db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucketIfNotExists(sharedHashBucket) + if err != nil { + return ErrDecayedLogInit + } + + _, err = tx.CreateBucketIfNotExists(batchReplayBucket) + if err != nil { + return ErrDecayedLogInit + } + + return nil + }) +} + +// Stop halts the garbage collector and closes boltdb. +func (d *DecayedLog) Stop() error { + if !atomic.CompareAndSwapInt32(&d.stopped, 0, 1) { + return nil + } + + // Stop garbage collector. + close(d.quit) + + d.wg.Wait() + + // Close boltdb. + d.db.Close() + + return nil +} + +// garbageCollector deletes entries from sharedHashBucket whose expiry height +// has already past. This function MUST be run as a goroutine. +func (d *DecayedLog) garbageCollector(epochClient *chainntnfs.BlockEpochEvent) { + defer d.wg.Done() + defer epochClient.Cancel() + + for { + select { + case epoch, ok := <-epochClient.Epochs: + if !ok { + // Block epoch was canceled, shutting down. + log.Infof("Block epoch canceled, " + + "decaying hash log shutting down") + return + } + + // Perform a bout of garbage collection using the + // epoch's block height. + height := uint32(epoch.Height) + numExpired, err := d.gcExpiredHashes(height) + if err != nil { + log.Errorf("unable to expire hashes at "+ + "height=%d", height) + } + + if numExpired > 0 { + log.Infof("Garbage collected %v shared "+ + "secret hashes at height=%v", + numExpired, height) + } + + case <-d.quit: + // Received shutdown request. + log.Infof("Decaying hash log received " + + "shutdown request") + return + } + } +} + +// gcExpiredHashes purges the decaying log of all entries whose CLTV expires +// below the provided height. +func (d *DecayedLog) gcExpiredHashes(height uint32) (uint32, error) { + var numExpiredHashes uint32 + + err := d.db.Batch(func(tx *bolt.Tx) error { + numExpiredHashes = 0 + + // Grab the shared hash bucket + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return fmt.Errorf("sharedHashBucket " + + "is nil") + } + + var expiredCltv [][]byte + if err := sharedHashes.ForEach(func(k, v []byte) error { + // Deserialize the CLTV value for this entry. + cltv := uint32(binary.BigEndian.Uint32(v)) + + if cltv < height { + // This CLTV is expired. We must add it to an + // array which we'll loop over and delete every + // hash contained from the db. + expiredCltv = append(expiredCltv, k) + numExpiredHashes++ + } + + return nil + }); err != nil { + return err + } + + // Delete every item in the array. This must + // be done explicitly outside of the ForEach + // function for safety reasons. + for _, hash := range expiredCltv { + err := sharedHashes.Delete(hash) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return 0, nil + } + + return numExpiredHashes, nil +} + +// Delete removes a key-pair from the +// sharedHashBucket. +func (d *DecayedLog) Delete(hash *sphinx.HashPrefix) error { + return d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + return sharedHashes.Delete(hash[:]) + }) +} + +// Get retrieves the CLTV of a processed HTLC given the first 20 bytes of the +// Sha-256 hash of the shared secret. +func (d *DecayedLog) Get(hash *sphinx.HashPrefix) (uint32, error) { + var value uint32 + + err := d.db.View(func(tx *bolt.Tx) error { + // Grab the shared hash bucket which stores the mapping from + // truncated sha-256 hashes of shared secrets to CLTV's. + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return fmt.Errorf("sharedHashes is nil, could " + + "not retrieve CLTV value") + } + + // Retrieve the bytes which represents the CLTV + valueBytes := sharedHashes.Get(hash[:]) + if valueBytes == nil { + return sphinx.ErrLogEntryNotFound + } + + // The first 4 bytes represent the CLTV, store it in value. + value = uint32(binary.BigEndian.Uint32(valueBytes)) + + return nil + }) + if err != nil { + return value, err + } + + return value, nil +} + +// Put stores a shared secret hash as the key and the CLTV as the value. +func (d *DecayedLog) Put(hash *sphinx.HashPrefix, cltv uint32) error { + // Optimisitically serialize the cltv value into the scratch buffer. + var scratch [4]byte + binary.BigEndian.PutUint32(scratch[:], cltv) + + return d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + // Check to see if this hash prefix has been recorded before. If + // a value is found, this packet is being replayed. + valueBytes := sharedHashes.Get(hash[:]) + if valueBytes != nil { + return sphinx.ErrReplayedPacket + } + + return sharedHashes.Put(hash[:], scratch[:]) + }) +} + +// PutBatch accepts a pending batch of hashed secret entries to write to disk. +// Each hashed secret is inserted with a corresponding time value, dictating +// when the entry will be evicted from the log. +// NOTE: This method enforces idempotency by writing the replay set obtained +// from the first attempt for a particular batch ID, and decoding the return +// value to subsequent calls. For the indices of the replay set to be aligned +// properly, the batch MUST be constructed identically to the first attempt, +// pruning will cause the indices to become invalid. +func (d *DecayedLog) PutBatch(b *sphinx.Batch) (*sphinx.ReplaySet, error) { + // Since batched boltdb txns may be executed multiple times before + // succeeding, we will create a new replay set for each invocation to + // avoid any side-effects. If the txn is successful, this replay set + // will be merged with the replay set computed during batch construction + // to generate the complete replay set. If this batch was previously + // processed, the replay set will be deserialized from disk. + var replays *sphinx.ReplaySet + if err := d.db.Batch(func(tx *bolt.Tx) error { + sharedHashes := tx.Bucket(sharedHashBucket) + if sharedHashes == nil { + return ErrDecayedLogCorrupted + } + + // Load the batch replay bucket, which will be used to either + // retrieve the result of previously processing this batch, or + // to write the result of this operation. + batchReplayBkt := tx.Bucket(batchReplayBucket) + if batchReplayBkt == nil { + return ErrDecayedLogCorrupted + } + + // Check for the existence of this batch's id in the replay + // bucket. If a non-nil value is found, this indicates that we + // have already processed this batch before. We deserialize the + // resulting and return it to ensure calls to put batch are + // idempotent. + replayBytes := batchReplayBkt.Get(b.ID) + if replayBytes != nil { + replays = sphinx.NewReplaySet() + return replays.Decode(bytes.NewReader(replayBytes)) + } + + // The CLTV will be stored into scratch and then stored into the + // sharedHashBucket. + var scratch [4]byte + + replays = sphinx.NewReplaySet() + err := b.ForEach(func(seqNum uint16, hashPrefix *sphinx.HashPrefix, cltv uint32) error { + // Retrieve the bytes which represents the CLTV + valueBytes := sharedHashes.Get(hashPrefix[:]) + if valueBytes != nil { + replays.Add(seqNum) + return nil + } + + // Serialize the cltv value and write an entry keyed by + // the hash prefix. + binary.BigEndian.PutUint32(scratch[:], cltv) + return sharedHashes.Put(hashPrefix[:], scratch[:]) + }) + if err != nil { + return err + } + + // Merge the replay set computed from checking the on-disk + // entries with the in-batch replays computed during this + // batch's construction. + replays.Merge(b.ReplaySet) + + // Write the replay set under the batch identifier to the batch + // replays bucket. This can be used during recovery to test (1) + // that a particular batch was successfully processed and (2) + // recover the indexes of the adds that were rejected as + // replays. + var replayBuf bytes.Buffer + if err := replays.Encode(&replayBuf); err != nil { + return err + } + + return batchReplayBkt.Put(b.ID, replayBuf.Bytes()) + }); err != nil { + return nil, err + } + + b.ReplaySet = replays + b.IsCommitted = true + + return replays, nil +} + +// A compile time check to see if DecayedLog adheres to the PersistLog +// interface. +var _ sphinx.ReplayLog = (*DecayedLog)(nil) diff --git a/htlcswitch/decayedlog_test.go b/htlcswitch/decayedlog_test.go new file mode 100644 index 00000000..6ce217d1 --- /dev/null +++ b/htlcswitch/decayedlog_test.go @@ -0,0 +1,289 @@ +package htlcswitch + +import ( + "crypto/rand" + "os" + "testing" + "time" + + "github.com/lightningnetwork/lightning-onion" + "github.com/lightningnetwork/lnd/chainntnfs" +) + +const ( + cltv uint32 = 100000 +) + +// startup sets up the DecayedLog and possibly the garbage collector. +func startup(notifier bool) (sphinx.ReplayLog, *mockNotifier, + *sphinx.HashPrefix, error) { + + var log sphinx.ReplayLog + var chainNotifier *mockNotifier + if notifier { + + // Create the MockNotifier which triggers the garbage collector + chainNotifier = &mockNotifier{ + epochChan: make(chan *chainntnfs.BlockEpoch, 1), + } + + // Initialize the DecayedLog object + log = NewDecayedLog("tempdir", chainNotifier) + } else { + // Initialize the DecayedLog object + log = NewDecayedLog("tempdir", nil) + } + + // Open the channeldb (start the garbage collector) + err := log.Start() + if err != nil { + return nil, nil, nil, err + } + + // Create a HashPrefix identifier for a packet. Instead of actually + // generating an ECDH secret and hashing it, simulate with random bytes. + // This is used as a key to retrieve the cltv value. + var hashedSecret sphinx.HashPrefix + _, err = rand.Read(hashedSecret[:]) + if err != nil { + return nil, nil, nil, err + } + + return log, chainNotifier, &hashedSecret, nil +} + +// shutdown deletes the temporary directory that the test database uses +// and handles closing the database. +func shutdown(dir string, d sphinx.ReplayLog) { + d.Stop() + os.RemoveAll(dir) +} + +// TestDecayedLogGarbageCollector tests the ability of the garbage collector +// to delete expired cltv values every time a block is received. Expired cltv +// values are cltv values that are < current block height. +func TestDecayedLogGarbageCollector(t *testing.T) { + d, notifier, hashedSecret, err := startup(true) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Send block notifications to garbage collector. The garbage collector + // should remove the entry by block 100001. + + // Send block 100000 + notifier.epochChan <- &chainntnfs.BlockEpoch{ + Height: 100000, + } + + // Assert that hashedSecret is still in the sharedHashBucket + val, err := d.Get(hashedSecret) + if err != nil { + t.Fatalf("Get failed - received an error upon Get: %v", err) + } + + if val != cltv { + t.Fatalf("GC incorrectly deleted CLTV") + } + + // Send block 100001 (expiry block) + notifier.epochChan <- &chainntnfs.BlockEpoch{ + Height: 100001, + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Assert that hashedSecret is not in the sharedHashBucket + val, err = d.Get(hashedSecret) + if err == nil { + t.Fatalf("CLTV was not deleted") + } + if err != sphinx.ErrLogEntryNotFound { + t.Fatalf("Get failed - received unexpected error upon Get: %v", err) + } +} + +// TestDecayedLogPersistentGarbageCollector tests the persistence property of +// the garbage collector. The garbage collector will be restarted immediately and +// a block that expires the stored CLTV value will be sent to the ChainNotifier. +// We test that this causes the pair to be deleted even +// on GC restarts. +func TestDecayedLogPersistentGarbageCollector(t *testing.T) { + d, _, hashedSecret, err := startup(true) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket + if err = d.Put(hashedSecret, cltv); err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Shut down DecayedLog and the garbage collector along with it. + d.Stop() + + d2, notifier2, hashedSecret2, err := startup(true) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d2) + + // Send a block notification to the garbage collector that expires + // the stored CLTV. + notifier2.epochChan <- &chainntnfs.BlockEpoch{ + Height: int32(100001), + } + + // Wait for database write (GC is in a goroutine) + time.Sleep(500 * time.Millisecond) + + // Assert that hashedSecret is not in the sharedHashBucket + _, err = d2.Get(hashedSecret2) + if err == nil { + t.Fatalf("CLTV was not deleted") + } + if err != sphinx.ErrLogEntryNotFound { + t.Fatalf("Get failed - received unexpected error upon Get: %v", err) + } +} + +// TestDecayedLogInsertionAndRetrieval inserts a cltv value into the +// sharedHashBucket and then deletes it and finally asserts that we can no +// longer retrieve it. +func TestDecayedLogInsertionAndDeletion(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Delete hashedSecret from the sharedHashBucket. + err = d.Delete(hashedSecret) + if err != nil { + t.Fatalf("Unable to delete from channeldb: %v", err) + } + + // Assert that hashedSecret is not in the sharedHashBucket + _, err = d.Get(hashedSecret) + if err == nil { + t.Fatalf("CLTV was not deleted") + } + if err != sphinx.ErrLogEntryNotFound { + t.Fatalf("Get failed - received unexpected error upon Get: %v", err) + } +} + +// TestDecayedLogStartAndStop tests for persistence. The DecayedLog is started, +// a cltv value is stored in the sharedHashBucket, and then it the DecayedLog +// is stopped. The DecayedLog is then started up again and we test that the +// cltv value is indeed still stored in the sharedHashBucket. We then delete +// the cltv value and check that it persists upon startup. +func TestDecayedLogStartAndStop(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket. + err = d.Put(hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Shutdown the DecayedLog's channeldb + d.Stop() + + d2, _, hashedSecret2, err := startup(false) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d2) + + // Retrieve the stored cltv value given the hashedSecret key. + value, err := d2.Get(hashedSecret) + if err != nil { + t.Fatalf("Unable to retrieve from channeldb: %v", err) + } + + // Check that the original cltv value matches the retrieved cltv + // value. + if cltv != value { + t.Fatalf("Value retrieved doesn't match value stored") + } + + // Delete hashedSecret from sharedHashBucket + err = d2.Delete(hashedSecret2) + if err != nil { + t.Fatalf("Unable to delete from channeldb: %v", err) + } + + // Shutdown the DecayedLog's channeldb + d2.Stop() + + d3, _, hashedSecret3, err := startup(false) + if err != nil { + t.Fatalf("Unable to restart DecayedLog: %v", err) + } + defer shutdown("tempdir", d3) + + // Assert that hashedSecret is not in the sharedHashBucket + _, err = d3.Get(hashedSecret3) + if err == nil { + t.Fatalf("CLTV was not deleted") + } + if err != sphinx.ErrLogEntryNotFound { + t.Fatalf("Get failed - received unexpected error upon Get: %v", err) + } +} + +// TestDecayedLogStorageAndRetrieval stores a cltv value and then retrieves it +// via the nested sharedHashBucket and finally asserts that the original stored +// and retrieved cltv values are equal. +func TestDecayedLogStorageAndRetrieval(t *testing.T) { + d, _, hashedSecret, err := startup(false) + if err != nil { + t.Fatalf("Unable to start up DecayedLog: %v", err) + } + defer shutdown("tempdir", d) + + // Store in the sharedHashBucket + err = d.Put(hashedSecret, cltv) + if err != nil { + t.Fatalf("Unable to store in channeldb: %v", err) + } + + // Retrieve the stored cltv value given the hashedSecret key. + value, err := d.Get(hashedSecret) + if err != nil { + t.Fatalf("Unable to retrieve from channeldb: %v", err) + } + + // If the original cltv value does not match the value retrieved, + // then the test failed. + if cltv != value { + t.Fatalf("Value retrieved doesn't match value stored") + } + +} diff --git a/htlcswitch/mock.go b/htlcswitch/mock.go index 33898293..75b1ee3e 100644 --- a/htlcswitch/mock.go +++ b/htlcswitch/mock.go @@ -748,13 +748,18 @@ func (m *mockSigner) ComputeInputScript(tx *wire.MsgTx, signDesc *lnwallet.SignD } type mockNotifier struct { + epochChan chan *chainntnfs.BlockEpoch } -func (m *mockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash, numConfs uint32) (*chainntnfs.ConfirmationEvent, error) { +func (m *mockNotifier) RegisterConfirmationsNtfn(txid *chainhash.Hash, + numConfs uint32, heightHint uint32) (*chainntnfs.ConfirmationEvent, error) { return nil, nil } func (m *mockNotifier) RegisterBlockEpochNtfn() (*chainntnfs.BlockEpochEvent, error) { - return nil, nil + return &chainntnfs.BlockEpochEvent{ + Epochs: m.epochChan, + Cancel: func() {}, + }, nil } func (m *mockNotifier) Start() error { @@ -764,7 +769,10 @@ func (m *mockNotifier) Start() error { func (m *mockNotifier) Stop() error { return nil } -func (m *mockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint) (*chainntnfs.SpendEvent, error) { + +func (m *mockNotifier) RegisterSpendNtfn(outpoint *wire.OutPoint, + heightHint uint32, mempool bool) (*chainntnfs.SpendEvent, error) { + return &chainntnfs.SpendEvent{ Spend: make(chan *chainntnfs.SpendDetail), }, nil