diff --git a/channeldb/migration_01_to_11/README.md b/channeldb/migration_01_to_11/README.md deleted file mode 100644 index 7e3a81ef..00000000 --- a/channeldb/migration_01_to_11/README.md +++ /dev/null @@ -1,24 +0,0 @@ -channeldb -========== - -[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE) -[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb) - -The channeldb implements the persistent storage engine for `lnd` and -generically a data storage layer for the required state within the Lightning -Network. The backing storage engine is -[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store -based off of LMDB. - -The package implements an object-oriented storage model with queries and -mutations flowing through a particular object instance rather than the database -itself. The storage implemented by the objects includes: open channels, past -commitment revocation states, the channel graph which includes authenticated -node and channel announcements, outgoing payments, and invoices - -## Installation and Updating - -```bash -$ go get -u github.com/lightningnetwork/lnd/channeldb -``` diff --git a/channeldb/migration_01_to_11/addr_test.go b/channeldb/migration_01_to_11/addr_test.go deleted file mode 100644 index 8cdf99c3..00000000 --- a/channeldb/migration_01_to_11/addr_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "net" - "strings" - "testing" - - "github.com/lightningnetwork/lnd/tor" -) - -type unknownAddrType struct{} - -func (t unknownAddrType) Network() string { return "unknown" } -func (t unknownAddrType) String() string { return "unknown" } - -var testIP4 = net.ParseIP("192.168.1.1") -var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") - -var addrTests = []struct { - expAddr net.Addr - serErr string -}{ - // Valid addresses. - { - expAddr: &net.TCPAddr{ - IP: testIP4, - Port: 12345, - }, - }, - { - expAddr: &net.TCPAddr{ - IP: testIP6, - Port: 65535, - }, - }, - { - expAddr: &tor.OnionAddr{ - OnionService: "3g2upl4pq6kufc4m.onion", - Port: 9735, - }, - }, - { - expAddr: &tor.OnionAddr{ - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", - Port: 80, - }, - }, - - // Invalid addresses. - { - expAddr: unknownAddrType{}, - serErr: ErrUnknownAddressType.Error(), - }, - { - expAddr: &net.TCPAddr{ - // Remove last byte of IPv4 address. - IP: testIP4[:len(testIP4)-1], - Port: 12345, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Add an extra byte of IPv4 address. - IP: append(testIP4, 0xff), - Port: 12345, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Remove last byte of IPv6 address. - IP: testIP6[:len(testIP6)-1], - Port: 65535, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Add an extra byte to the IPv6 address. - IP: append(testIP6, 0xff), - Port: 65535, - }, - serErr: "unable to encode", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid suffix. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion", - Port: 80, - }, - serErr: "invalid suffix", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid length. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion", - Port: 80, - }, - serErr: "unknown onion service length", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid encoding. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion", - Port: 80, - }, - serErr: "illegal base32", - }, -} - -// TestAddrSerialization tests that the serialization method used by channeldb -// for net.Addr's works as intended. -func TestAddrSerialization(t *testing.T) { - t.Parallel() - - var b bytes.Buffer - for _, test := range addrTests { - err := serializeAddr(&b, test.expAddr) - switch { - case err == nil && test.serErr != "": - t.Fatalf("expected serialization err for addr %v", - test.expAddr) - - case err != nil && test.serErr == "": - t.Fatalf("unexpected serialization err for addr %v: %v", - test.expAddr, err) - - case err != nil && !strings.Contains(err.Error(), test.serErr): - t.Fatalf("unexpected serialization err for addr %v, "+ - "want: %v, got %v", test.expAddr, test.serErr, - err) - - case err != nil: - continue - } - - addr, err := deserializeAddr(&b) - if err != nil { - t.Fatalf("unable to deserialize address: %v", err) - } - - if addr.String() != test.expAddr.String() { - t.Fatalf("expected address %v after serialization, "+ - "got %v", addr, test.expAddr) - } - } -} diff --git a/channeldb/migration_01_to_11/channel.go b/channeldb/migration_01_to_11/channel.go index 23d66852..e67c0c69 100644 --- a/channeldb/migration_01_to_11/channel.go +++ b/channeldb/migration_01_to_11/channel.go @@ -1,12 +1,9 @@ package migration_01_to_11 import ( - "bytes" - "encoding/binary" "errors" "fmt" "io" - "net" "strconv" "strings" "sync" @@ -15,8 +12,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" @@ -36,90 +31,6 @@ var ( // // TODO(roasbeef): flesh out comment openChannelBucket = []byte("open-chan-bucket") - - // chanInfoKey can be accessed within the bucket for a channel - // (identified by its chanPoint). This key stores all the static - // information for a channel which is decided at the end of the - // funding flow. - chanInfoKey = []byte("chan-info-key") - - // chanCommitmentKey can be accessed within the sub-bucket for a - // particular channel. This key stores the up to date commitment state - // for a particular channel party. Appending a 0 to the end of this key - // indicates it's the commitment for the local party, and appending a 1 - // to the end of this key indicates it's the commitment for the remote - // party. - chanCommitmentKey = []byte("chan-commitment-key") - - // revocationStateKey stores their current revocation hash, our - // preimage producer and their preimage store. - revocationStateKey = []byte("revocation-state-key") - - // dataLossCommitPointKey stores the commitment point received from the - // remote peer during a channel sync in case we have lost channel state. - dataLossCommitPointKey = []byte("data-loss-commit-point-key") - - // closingTxKey points to a the closing tx that we broadcasted when - // moving the channel to state CommitBroadcasted. - closingTxKey = []byte("closing-tx-key") - - // commitDiffKey stores the current pending commitment state we've - // extended to the remote party (if any). Each time we propose a new - // state, we store the information necessary to reconstruct this state - // from the prior commitment. This allows us to resync the remote party - // to their expected state in the case of message loss. - // - // TODO(roasbeef): rename to commit chain? - commitDiffKey = []byte("commit-diff-key") - - // revocationLogBucket is dedicated for storing the necessary delta - // state between channel updates required to re-construct a past state - // in order to punish a counterparty attempting a non-cooperative - // channel closure. This key should be accessed from within the - // sub-bucket of a target channel, identified by its channel point. - revocationLogBucket = []byte("revocation-log-key") -) - -var ( - // ErrNoCommitmentsFound is returned when a channel has not set - // commitment states. - ErrNoCommitmentsFound = fmt.Errorf("no commitments found") - - // ErrNoChanInfoFound is returned when a particular channel does not - // have any channels state. - ErrNoChanInfoFound = fmt.Errorf("no chan info found") - - // ErrNoRevocationsFound is returned when revocation state for a - // particular channel cannot be found. - ErrNoRevocationsFound = fmt.Errorf("no revocations found") - - // ErrNoPendingCommit is returned when there is not a pending - // commitment for a remote party. A new commitment is written to disk - // each time we write a new state in order to be properly fault - // tolerant. - ErrNoPendingCommit = fmt.Errorf("no pending commits found") - - // ErrInvalidCircuitKeyLen signals that a circuit key could not be - // decoded because the byte slice is of an invalid length. - ErrInvalidCircuitKeyLen = fmt.Errorf( - "length of serialized circuit key must be 16 bytes") - - // ErrNoCommitPoint is returned when no data loss commit point is found - // in the database. - ErrNoCommitPoint = fmt.Errorf("no commit point found") - - // ErrNoCloseTx is returned when no closing tx is found for a channel - // in the state CommitBroadcasted. - ErrNoCloseTx = fmt.Errorf("no closing tx found") - - // ErrNoRestoredChannelMutation is returned when a caller attempts to - // mutate a channel that's been recovered. - ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + - "channel state") - - // ErrChanBorked is returned when a caller attempts to mutate a borked - // channel. - ErrChanBorked = fmt.Errorf("cannot mutate borked channel") ) // ChannelType is an enum-like type that describes one of several possible @@ -136,30 +47,8 @@ const ( // SingleFunder represents a channel wherein one party solely funds the // entire capacity of the channel. SingleFunder ChannelType = 0 - - // DualFunder represents a channel wherein both parties contribute - // funds towards the total capacity of the channel. The channel may be - // funded symmetrically or asymmetrically. - DualFunder ChannelType = 1 - - // SingleFunderTweakless is similar to the basic SingleFunder channel - // type, but it omits the tweak for one's key in the commitment - // transaction of the remote party. - SingleFunderTweakless ChannelType = 2 ) -// IsSingleFunder returns true if the channel type if one of the known single -// funder variants. -func (c ChannelType) IsSingleFunder() bool { - return c == SingleFunder || c == SingleFunderTweakless -} - -// IsTweakless returns true if the target channel uses a commitment that -// doesn't tweak the key for the remote party. -func (c ChannelType) IsTweakless() bool { - return c == SingleFunderTweakless -} - // ChannelConstraints represents a set of constraints meant to allow a node to // limit their exposure, enact flow control and ensure that all HTLCs are // economically relevant. This struct will be mirrored for both sides of the @@ -444,10 +333,6 @@ type OpenChannel struct { // negotiate fees, or close the channel. IsInitiator bool - // chanStatus is the current status of this channel. If it is not in - // the state Default, it should not be used for forwarding payments. - chanStatus ChannelStatus - // FundingBroadcastHeight is the height in which the funding // transaction was broadcast. This value can be used by higher level // sub-systems to determine if a channel is stale and/or should have @@ -519,11 +404,6 @@ type OpenChannel struct { // implementation of secret store is shachain store. RevocationStore shachain.Store - // Packager is used to create and update forwarding packages for this - // channel, which encodes all necessary information to recover from - // failures and reforward HTLCs that were not fully processed. - Packager FwdPackager - // FundingTxn is the transaction containing this channel's funding // outpoint. Upon restarts, this txn will be rebroadcast if the channel // is found to be pending. @@ -548,657 +428,6 @@ func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { return c.ShortChannelID } -// ChanStatus returns the current ChannelStatus of this channel. -func (c *OpenChannel) ChanStatus() ChannelStatus { - c.RLock() - defer c.RUnlock() - - return c.chanStatus -} - -// ApplyChanStatus allows the caller to modify the internal channel state in a -// thead-safe manner. -func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() - - return c.putChanStatus(status) -} - -// ClearChanStatus allows the caller to clear a particular channel status from -// the primary channel status bit field. After this method returns, a call to -// HasChanStatus(status) should return false. -func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() - - return c.clearChanStatus(status) -} - -// HasChanStatus returns true if the internal bitfield channel status of the -// target channel has the specified status bit set. -func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { - c.RLock() - defer c.RUnlock() - - return c.hasChanStatus(status) -} - -func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { - return c.chanStatus&status == status -} - -// RefreshShortChanID updates the in-memory short channel ID using the latest -// value observed on disk. -func (c *OpenChannel) RefreshShortChanID() error { - c.Lock() - defer c.Unlock() - - var sid lnwire.ShortChannelID - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - sid = channel.ShortChannelID - - return nil - }) - if err != nil { - return err - } - - c.ShortChannelID = sid - c.Packager = NewChannelPackager(sid) - - return nil -} - -// fetchChanBucket is a helper function that returns the bucket where a -// channel's data resides in given: the public key for the node, the outpoint, -// and the chainhash that the channel resides on. -func fetchChanBucket(tx *bbolt.Tx, nodeKey *btcec.PublicKey, - outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bbolt.Bucket, error) { - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return nil, ErrNoChanDBExists - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := nodeKey.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return nil, ErrNoActiveChannels - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket := nodeChanBucket.Bucket(chainHash[:]) - if chainBucket == nil { - return nil, ErrNoActiveChannels - } - - // With the bucket for the node and chain fetched, we can now go down - // another level, for this channel itself. - var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, err - } - chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) - if chanBucket == nil { - return nil, ErrChannelNotFound - } - - return chanBucket, nil -} - -// fullSync syncs the contents of an OpenChannel while re-using an existing -// database transaction. -func (c *OpenChannel) fullSync(tx *bbolt.Tx) error { - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket, err := tx.CreateBucketIfNotExists(openChannelBucket) - if err != nil { - return err - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := c.IdentityPub.SerializeCompressed() - nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) - if err != nil { - return err - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:]) - if err != nil { - return err - } - - // With the bucket for the node fetched, we can now go down another - // level, creating the bucket for this channel itself. - var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { - return err - } - chanBucket, err := chainBucket.CreateBucket( - chanPointBuf.Bytes(), - ) - switch { - case err == bbolt.ErrBucketExists: - // If this channel already exists, then in order to avoid - // overriding it, we'll return an error back up to the caller. - return ErrChanAlreadyExists - case err != nil: - return err - } - - return putOpenChannel(chanBucket, c) -} - -// MarkAsOpen marks a channel as fully open given a locator that uniquely -// describes its location within the chain. -func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { - c.Lock() - defer c.Unlock() - - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - channel.IsPending = false - channel.ShortChannelID = openLoc - - return putOpenChannel(chanBucket, channel) - }); err != nil { - return err - } - - c.IsPending = false - c.ShortChannelID = openLoc - c.Packager = NewChannelPackager(openLoc) - - return nil -} - -// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the -// passed commitPoint for use to retrieve funds in case the remote force closes -// the channel. -func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - var b bytes.Buffer - if err := WriteElement(&b, commitPoint); err != nil { - return err - } - - putCommitPoint := func(chanBucket *bbolt.Bucket) error { - return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) - } - - return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) -} - -// DataLossCommitPoint retrieves the stored commit point set during -// MarkDataLoss. If not found ErrNoCommitPoint is returned. -func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { - var commitPoint *btcec.PublicKey - - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCommitPoint - default: - return err - } - - bs := chanBucket.Get(dataLossCommitPointKey) - if bs == nil { - return ErrNoCommitPoint - } - r := bytes.NewReader(bs) - if err := ReadElements(r, &commitPoint); err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, err - } - - return commitPoint, nil -} - -// MarkBorked marks the event when the channel as reached an irreconcilable -// state, such as a channel breach or state desynchronization. Borked channels -// should never be added to the switch. -func (c *OpenChannel) MarkBorked() error { - c.Lock() - defer c.Unlock() - - return c.putChanStatus(ChanStatusBorked) -} - -// ChanSyncMsg returns the ChannelReestablish message that should be sent upon -// reconnection with the remote peer that we're maintaining this channel with. -// The information contained within this message is necessary to re-sync our -// commitment chains in the case of a last or only partially processed message. -// When the remote party receiver this message one of three things may happen: -// -// 1. We're fully synced and no messages need to be sent. -// 2. We didn't get the last CommitSig message they sent, to they'll re-send -// it. -// 3. We didn't get the last RevokeAndAck message they sent, so they'll -// re-send it. -// -// If this is a restored channel, having status ChanStatusRestored, then we'll -// modify our typical chan sync message to ensure they force close even if -// we're on the very first state. -func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { - c.Lock() - defer c.Unlock() - - // The remote commitment height that we'll send in the - // ChannelReestablish message is our current commitment height plus - // one. If the receiver thinks that our commitment height is actually - // *equal* to this value, then they'll re-send the last commitment that - // they sent but we never fully processed. - localHeight := c.LocalCommitment.CommitHeight - nextLocalCommitHeight := localHeight + 1 - - // The second value we'll send is the height of the remote commitment - // from our PoV. If the receiver thinks that their height is actually - // *one plus* this value, then they'll re-send their last revocation. - remoteChainTipHeight := c.RemoteCommitment.CommitHeight - - // If this channel has undergone a commitment update, then in order to - // prove to the remote party our knowledge of their prior commitment - // state, we'll also send over the last commitment secret that the - // remote party sent. - var lastCommitSecret [32]byte - if remoteChainTipHeight != 0 { - remoteSecret, err := c.RevocationStore.LookUp( - remoteChainTipHeight - 1, - ) - if err != nil { - return nil, err - } - lastCommitSecret = [32]byte(*remoteSecret) - } - - // Additionally, we'll send over the current unrevoked commitment on - // our local commitment transaction. - currentCommitSecret, err := c.RevocationProducer.AtIndex( - localHeight, - ) - if err != nil { - return nil, err - } - - // If we've restored this channel, then we'll purposefully give them an - // invalid LocalUnrevokedCommitPoint so they'll force close the channel - // allowing us to sweep our funds. - if c.hasChanStatus(ChanStatusRestored) { - currentCommitSecret[0] ^= 1 - - // If this is a tweakless channel, then we'll purposefully send - // a next local height taht's invalid to trigger a force close - // on their end. We do this as tweakless channels don't require - // that the commitment point is valid, only that it's present. - if c.ChanType.IsTweakless() { - nextLocalCommitHeight = 0 - } - } - - return &lnwire.ChannelReestablish{ - ChanID: lnwire.NewChanIDFromOutPoint( - &c.FundingOutpoint, - ), - NextLocalCommitHeight: nextLocalCommitHeight, - RemoteCommitTailHeight: remoteChainTipHeight, - LastRemoteCommitSecret: lastCommitSecret, - LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( - currentCommitSecret[:], - ), - }, nil -} - -// isBorked returns true if the channel has been marked as borked in the -// database. This requires an existing database transaction to already be -// active. -// -// NOTE: The primary mutex should already be held before this method is called. -func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) { - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return false, err - } - - return channel.chanStatus != ChanStatusDefault, nil -} - -// MarkCommitmentBroadcasted marks the channel as a commitment transaction has -// been broadcast, either our own or the remote, and we should watch the chain -// for it to confirm before taking any further action. It takes as argument the -// closing tx _we believe_ will appear in the chain. This is only used to -// republish this tx at startup to ensure propagation, and we should still -// handle the case where a different tx actually hits the chain. -func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error { - c.Lock() - defer c.Unlock() - - var b bytes.Buffer - if err := WriteElement(&b, closeTx); err != nil { - return err - } - - putClosingTx := func(chanBucket *bbolt.Bucket) error { - return chanBucket.Put(closingTxKey, b.Bytes()) - } - - return c.putChanStatus(ChanStatusCommitBroadcasted, putClosingTx) -} - -// BroadcastedCommitment retrieves the stored closing tx set during -// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. -func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { - var closeTx *wire.MsgTx - - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCloseTx - default: - return err - } - - bs := chanBucket.Get(closingTxKey) - if bs == nil { - return ErrNoCloseTx - } - r := bytes.NewReader(bs) - return ReadElement(r, &closeTx) - }) - if err != nil { - return nil, err - } - - return closeTx, nil -} - -// putChanStatus appends the given status to the channel. fs is an optional -// list of closures that are given the chanBucket in order to atomically add -// extra information together with the new status. -func (c *OpenChannel) putChanStatus(status ChannelStatus, - fs ...func(*bbolt.Bucket) error) error { - - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Add this status to the existing bitvector found in the DB. - status = channel.chanStatus | status - channel.chanStatus = status - - if err := putOpenChannel(chanBucket, channel); err != nil { - return err - } - - for _, f := range fs { - if err := f(chanBucket); err != nil { - return err - } - } - - return nil - }); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Unset this bit in the bitvector on disk. - status = channel.chanStatus & ^status - channel.chanStatus = status - - return putOpenChannel(chanBucket, channel) - }); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -// putChannel serializes, and stores the current state of the channel in its -// entirety. -func putOpenChannel(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - // First, we'll write out all the relatively static fields, that are - // decided upon initial channel creation. - if err := putChanInfo(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan info: %v", err) - } - - // With the static channel info written out, we'll now write out the - // current commitment state for both parties. - if err := putChanCommitments(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan commitments: %v", err) - } - - // Finally, we'll write out the revocation state for both parties - // within a distinct key space. - if err := putChanRevocationState(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan revocations: %v", err) - } - - return nil -} - -// fetchOpenChannel retrieves, and deserializes (including decrypting -// sensitive) the complete channel currently active with the passed nodeID. -func fetchOpenChannel(chanBucket *bbolt.Bucket, - chanPoint *wire.OutPoint) (*OpenChannel, error) { - - channel := &OpenChannel{ - FundingOutpoint: *chanPoint, - } - - // First, we'll read all the static information that changes less - // frequently from disk. - if err := fetchChanInfo(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan info: %v", err) - } - - // With the static information read, we'll now read the current - // commitment state for both sides of the channel. - if err := fetchChanCommitments(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan commitments: %v", err) - } - - // Finally, we'll retrieve the current revocation state so we can - // properly - if err := fetchChanRevocationState(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan revocations: %v", err) - } - - channel.Packager = NewChannelPackager(channel.ShortChannelID) - - return channel, nil -} - -// SyncPending writes the contents of the channel to the database while it's in -// the pending (waiting for funding confirmation) state. The IsPending flag -// will be set to true. When the channel's funding transaction is confirmed, -// the channel should be marked as "open" and the IsPending flag set to false. -// Note that this function also creates a LinkNode relationship between this -// newly created channel and a new LinkNode instance. This allows listing all -// channels in the database globally, or according to the LinkNode they were -// created with. -// -// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type -// that includes service bits. -func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { - c.Lock() - defer c.Unlock() - - c.FundingBroadcastHeight = pendingHeight - - return c.Db.Update(func(tx *bbolt.Tx) error { - return syncNewChannel(tx, c, []net.Addr{addr}) - }) -} - -// syncNewChannel will write the passed channel to disk, and also create a -// LinkNode (if needed) for the channel peer. -func syncNewChannel(tx *bbolt.Tx, c *OpenChannel, addrs []net.Addr) error { - // First, sync all the persistent channel state to disk. - if err := c.fullSync(tx); err != nil { - return err - } - - nodeInfoBucket, err := tx.CreateBucketIfNotExists(nodeInfoBucket) - if err != nil { - return err - } - - // If a LinkNode for this identity public key already exists, - // then we can exit early. - nodePub := c.IdentityPub.SerializeCompressed() - if nodeInfoBucket.Get(nodePub) != nil { - return nil - } - - // Next, we need to establish a (possibly) new LinkNode relationship - // for this channel. The LinkNode metadata contains reachability, - // up-time, and service bits related information. - linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) - - // TODO(roasbeef): do away with link node all together? - - return putLinkNode(nodeInfoBucket, linkNode) -} - -// UpdateCommitment updates the commitment state for the specified party -// (remote or local). The commitment stat completely describes the balance -// state at this point in the commitment chain. This method its to be called on -// two occasions: when we revoke our prior commitment state, and when the -// remote party revokes their prior commitment state. -func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state as all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - if err = putChanInfo(chanBucket, c); err != nil { - return fmt.Errorf("unable to store chan info: %v", err) - } - - // With the proper bucket fetched, we'll now write the latest - // commitment state to disk for the target party. - err = putChanCommitment( - chanBucket, newCommitment, true, - ) - if err != nil { - return fmt.Errorf("unable to store chan "+ - "revocations: %v", err) - } - - return nil - }) - if err != nil { - return err - } - - c.LocalCommitment = *newCommitment - - return nil -} - // HTLC is the on-disk representation of a hash time-locked contract. HTLCs are // contained within ChannelDeltas which encode the current state of the // commitment between state updates. @@ -1247,101 +476,6 @@ type HTLC struct { LogIndex uint64 } -// SerializeHtlcs writes out the passed set of HTLC's into the passed writer -// using the current default on-disk serialization format. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { - numHtlcs := uint16(len(htlcs)) - if err := WriteElement(b, numHtlcs); err != nil { - return err - } - - for _, htlc := range htlcs { - if err := WriteElements(b, - htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, - htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], - htlc.HtlcIndex, htlc.LogIndex, - ); err != nil { - return err - } - } - - return nil -} - -// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed -// io.Reader. The bytes within the passed reader MUST have been previously -// written to using the SerializeHtlcs function. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { - var numHtlcs uint16 - if err := ReadElement(r, &numHtlcs); err != nil { - return nil, err - } - - var htlcs []HTLC - if numHtlcs == 0 { - return htlcs, nil - } - - htlcs = make([]HTLC, numHtlcs) - for i := uint16(0); i < numHtlcs; i++ { - if err := ReadElements(r, - &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, - &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, - &htlcs[i].Incoming, &htlcs[i].OnionBlob, - &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, - ); err != nil { - return htlcs, err - } - } - - return htlcs, nil -} - -// Copy returns a full copy of the target HTLC. -func (h *HTLC) Copy() HTLC { - clone := HTLC{ - Incoming: h.Incoming, - Amt: h.Amt, - RefundTimeout: h.RefundTimeout, - OutputIndex: h.OutputIndex, - } - copy(clone.Signature[:], h.Signature) - copy(clone.RHash[:], h.RHash[:]) - - return clone -} - -// LogUpdate represents a pending update to the remote commitment chain. The -// log update may be an add, fail, or settle entry. We maintain this data in -// order to be able to properly retransmit our proposed -// state if necessary. -type LogUpdate struct { - // LogIndex is the log index of this proposed commitment update entry. - LogIndex uint64 - - // UpdateMsg is the update message that was included within the our - // local update log. The LogIndex value denotes the log index of this - // update which will be used when restoring our local update log if - // we're left with a dangling update on restart. - UpdateMsg lnwire.Message -} - -// Encode writes a log update to the provided io.Writer. -func (l *LogUpdate) Encode(w io.Writer) error { - return WriteElements(w, l.LogIndex, l.UpdateMsg) -} - -// Decode reads a log update from the provided io.Reader. -func (l *LogUpdate) Decode(r io.Reader) error { - return ReadElements(r, &l.LogIndex, &l.UpdateMsg) -} - // CircuitKey is used by a channel to uniquely identify the HTLCs it receives // from the switch, and is used to purge our in-memory state of HTLCs that have // already been processed by a link. Two list of CircuitKeys are included in @@ -1360,723 +494,20 @@ type CircuitKey struct { HtlcID uint64 } -// SetBytes deserializes the given bytes into this CircuitKey. -func (k *CircuitKey) SetBytes(bs []byte) error { - if len(bs) != 16 { - return ErrInvalidCircuitKeyLen - } - - k.ChanID = lnwire.NewShortChanIDFromInt( - binary.BigEndian.Uint64(bs[:8])) - k.HtlcID = binary.BigEndian.Uint64(bs[8:]) - - return nil -} - -// Bytes returns the serialized bytes for this circuit key. -func (k CircuitKey) Bytes() []byte { - var bs = make([]byte, 16) - binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64()) - binary.BigEndian.PutUint64(bs[8:], k.HtlcID) - return bs -} - -// Encode writes a CircuitKey to the provided io.Writer. -func (k *CircuitKey) Encode(w io.Writer) error { - var scratch [16]byte - binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64()) - binary.BigEndian.PutUint64(scratch[8:], k.HtlcID) - - _, err := w.Write(scratch[:]) - return err -} - -// Decode reads a CircuitKey from the provided io.Reader. -func (k *CircuitKey) Decode(r io.Reader) error { - var scratch [16]byte - - if _, err := io.ReadFull(r, scratch[:]); err != nil { - return err - } - k.ChanID = lnwire.NewShortChanIDFromInt( - binary.BigEndian.Uint64(scratch[:8])) - k.HtlcID = binary.BigEndian.Uint64(scratch[8:]) - - return nil -} - // String returns a string representation of the CircuitKey. func (k CircuitKey) String() string { return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID) } -// CommitDiff represents the delta needed to apply the state transition between -// two subsequent commitment states. Given state N and state N+1, one is able -// to apply the set of messages contained within the CommitDiff to N to arrive -// at state N+1. Each time a new commitment is extended, we'll write a new -// commitment (along with the full commitment state) to disk so we can -// re-transmit the state in the case of a connection loss or message drop. -type CommitDiff struct { - // ChannelCommitment is the full commitment state that one would arrive - // at by applying the set of messages contained in the UpdateDiff to - // the prior accepted commitment. - Commitment ChannelCommitment - - // LogUpdates is the set of messages sent prior to the commitment state - // transition in question. Upon reconnection, if we detect that they - // don't have the commitment, then we re-send this along with the - // proper signature. - LogUpdates []LogUpdate - - // CommitSig is the exact CommitSig message that should be sent after - // the set of LogUpdates above has been retransmitted. The signatures - // within this message should properly cover the new commitment state - // and also the HTLC's within the new commitment state. - CommitSig *lnwire.CommitSig - - // OpenedCircuitKeys is a set of unique identifiers for any downstream - // Add packets included in this commitment txn. After a restart, this - // set of htlcs is acked from the link's incoming mailbox to ensure - // there isn't an attempt to re-add them to this commitment txn. - OpenedCircuitKeys []CircuitKey - - // ClosedCircuitKeys records the unique identifiers for any settle/fail - // packets that were resolved by this commitment txn. After a restart, - // this is used to ensure those circuits are removed from the circuit - // map, and the downstream packets in the link's mailbox are removed. - ClosedCircuitKeys []CircuitKey - - // AddAcks specifies the locations (commit height, pkg index) of any - // Adds that were failed/settled in this commit diff. This will ack - // entries in *this* channel's forwarding packages. - // - // NOTE: This value is not serialized, it is used to atomically mark the - // resolution of adds, such that they will not be reprocessed after a - // restart. - AddAcks []AddRef - - // SettleFailAcks specifies the locations (chan id, commit height, pkg - // index) of any Settles or Fails that were locked into this commit - // diff, and originate from *another* channel, i.e. the outgoing link. - // - // NOTE: This value is not serialized, it is used to atomically acks - // settles and fails from the forwarding packages of other channels, - // such that they will not be reforwarded internally after a restart. - SettleFailAcks []SettleFailRef -} - -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { - if err := serializeChanCommit(w, &diff.Commitment); err != nil { - return err - } - - if err := diff.CommitSig.Encode(w, 0); err != nil { - return err - } - - numUpdates := uint16(len(diff.LogUpdates)) - if err := binary.Write(w, byteOrder, numUpdates); err != nil { - return err - } - - for _, diff := range diff.LogUpdates { - err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) - if err != nil { - return err - } - } - - numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) - if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { - return err - } - - for _, openRef := range diff.OpenedCircuitKeys { - err := WriteElements(w, openRef.ChanID, openRef.HtlcID) - if err != nil { - return err - } - } - - numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) - if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { - return err - } - - for _, closedRef := range diff.ClosedCircuitKeys { - err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) - if err != nil { - return err - } - } - - return nil -} - -func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { - var ( - d CommitDiff - err error - ) - - d.Commitment, err = deserializeChanCommit(r) - if err != nil { - return nil, err - } - - d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { - return nil, err - } - - var numUpdates uint16 - if err := binary.Read(r, byteOrder, &numUpdates); err != nil { - return nil, err - } - - d.LogUpdates = make([]LogUpdate, numUpdates) - for i := 0; i < int(numUpdates); i++ { - err := ReadElements(r, - &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, - ) - if err != nil { - return nil, err - } - } - - var numOpenRefs uint16 - if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { - return nil, err - } - - d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) - for i := 0; i < int(numOpenRefs); i++ { - err := ReadElements(r, - &d.OpenedCircuitKeys[i].ChanID, - &d.OpenedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - var numClosedRefs uint16 - if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { - return nil, err - } - - d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) - for i := 0; i < int(numClosedRefs); i++ { - err := ReadElements(r, - &d.ClosedCircuitKeys[i].ChanID, - &d.ClosedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - return &d, nil -} - -// AppendRemoteCommitChain appends a new CommitDiff to the end of the -// commitment chain for the remote party. This method is to be used once we -// have prepared a new commitment state for the remote party, but before we -// transmit it to the remote party. The contents of the argument should be -// sufficient to retransmit the updates and signature needed to reconstruct the -// state in full, in the case that we need to retransmit. -func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - return c.Db.Update(func(tx *bbolt.Tx) error { - // First, we'll grab the writable bucket where this channel's - // data resides. - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Any outgoing settles and fails necessarily have a - // corresponding adds in this channel's forwarding packages. - // Mark all of these as being fully processed in our forwarding - // package, which prevents us from reprocessing them after - // startup. - err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) - if err != nil { - return err - } - - // Additionally, we ack from any fails or settles that are - // persisted in another channel's forwarding package. This - // prevents the same fails and settles from being retransmitted - // after restarts. The actual fail or settle we need to - // propagate to the remote party is now in the commit diff. - err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) - if err != nil { - return err - } - - // TODO(roasbeef): use seqno to derive key for later LCP - - // With the bucket retrieved, we'll now serialize the commit - // diff itself, and write it to disk. - var b bytes.Buffer - if err := serializeCommitDiff(&b, diff); err != nil { - return err - } - return chanBucket.Put(commitDiffKey, b.Bytes()) - }) -} - -// RemoteCommitChainTip returns the "tip" of the current remote commitment -// chain. This value will be non-nil iff, we've created a new commitment for -// the remote party that they haven't yet ACK'd. In this case, their commitment -// chain will have a length of two: their current unrevoked commitment, and -// this new pending commitment. Once they revoked their prior state, we'll swap -// these pointers, causing the tip and the tail to point to the same entry. -func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { - var cd *CommitDiff - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoPendingCommit - default: - return err - } - - tipBytes := chanBucket.Get(commitDiffKey) - if tipBytes == nil { - return ErrNoPendingCommit - } - - tipReader := bytes.NewReader(tipBytes) - dcd, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - - cd = dcd - return nil - }) - if err != nil { - return nil, err - } - - return cd, err -} - -// InsertNextRevocation inserts the _next_ commitment point (revocation) into -// the database, and also modifies the internal RemoteNextRevocation attribute -// to point to the passed key. This method is to be using during final channel -// set up, _after_ the channel has been fully confirmed. -// -// NOTE: If this method isn't called, then the target channel won't be able to -// propose new states for the commitment state of the remote party. -func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - c.RemoteNextRevocation = revKey - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return putChanRevocationState(chanBucket, c) - }) - if err != nil { - return err - } - - return nil -} - -// AdvanceCommitChainTail records the new state transition within an on-disk -// append-only log which records all state transitions by the remote peer. In -// the case of an uncooperative broadcast of a prior state by the remote peer, -// this log can be consulted in order to reconstruct the state needed to -// rectify the situation. This method will add the current commitment for the -// remote party to the revocation log, and promote the current pending -// commitment to the current remote commitment. -func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - var newRemoteCommit *ChannelCommitment - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Persist the latest preimage state to disk as the remote peer - // has just added to our local preimage store, and given us a - // new pending revocation key. - if err := putChanRevocationState(chanBucket, c); err != nil { - return err - } - - // With the current preimage producer/store state updated, - // append a new log entry recording this the delta of this - // state transition. - // - // TODO(roasbeef): could make the deltas relative, would save - // space, but then tradeoff for more disk-seeks to recover the - // full state. - logKey := revocationLogBucket - logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) - if err != nil { - return err - } - - // Before we append this revoked state to the revocation log, - // we'll swap out what's currently the tail of the commit tip, - // with the current locked-in commitment for the remote party. - tipBytes := chanBucket.Get(commitDiffKey) - tipReader := bytes.NewReader(tipBytes) - newCommit, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - err = putChanCommitment( - chanBucket, &newCommit.Commitment, false, - ) - if err != nil { - return err - } - if err := chanBucket.Delete(commitDiffKey); err != nil { - return err - } - - // With the commitment pointer swapped, we can now add the - // revoked (prior) state to the revocation log. - // - // TODO(roasbeef): store less - err = appendChannelLogEntry(logBucket, &c.RemoteCommitment) - if err != nil { - return err - } - - // Lastly, we write the forwarding package to disk so that we - // can properly recover from failures and reforward HTLCs that - // have not received a corresponding settle/fail. - if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { - return err - } - - newRemoteCommit = &newCommit.Commitment - - return nil - }) - if err != nil { - return err - } - - // With the db transaction complete, we'll swap over the in-memory - // pointer of the new remote commitment, which was previously the tip - // of the commit chain. - c.RemoteCommitment = *newRemoteCommit - - return nil -} - -// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure -// this always returns the next index that has been not been allocated, this -// will first try to examine any pending commitments, before falling back to the -// last locked-in local commitment. -func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { - // First, load the most recent commit diff that we initiated for the - // remote party. If no pending commit is found, this is not treated as - // a critical error, since we can always fall back. - pendingRemoteCommit, err := c.RemoteCommitChainTip() - if err != nil && err != ErrNoPendingCommit { - return 0, err - } - - // If a pending commit was found, its local htlc index will be at least - // as large as the one on our local commitment. - if pendingRemoteCommit != nil { - return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil - } - - // Otherwise, fallback to using the local htlc index of our commitment. - return c.LocalCommitment.LocalHtlcIndex, nil -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in map indexed by the -// remote commitment height at which the updates were locked in. -func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { - c.RLock() - defer c.RUnlock() - - var fwdPkgs []*FwdPkg - if err := c.Db.View(func(tx *bbolt.Tx) error { - var err error - fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) - return err - }); err != nil { - return nil, err - } - - return fwdPkgs, nil -} - -// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs -// indicating that a response to this Add has been committed to the remote party. -// Doing so will prevent these Add HTLCs from being reforwarded internally. -func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.AckAddHtlcs(tx, addRefs...) - }) -} - -// AckSettleFails updates the SettleFailFilter containing any of the provided -// SettleFailRefs, indicating that the response has been delivered to the -// incoming link, corresponding to a particular AddRef. Doing so will prevent -// the responses from being retransmitted internally. -func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.AckSettleFails(tx, settleFailRefs...) - }) -} - -// SetFwdFilter atomically sets the forwarding filter for the forwarding package -// identified by `height`. -func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.SetFwdFilter(tx, height, fwdFilter) - }) -} - -// RemoveFwdPkg atomically removes a forwarding package specified by the remote -// commitment height. -// -// NOTE: This method should only be called on packages marked FwdStateCompleted. -func (c *OpenChannel) RemoveFwdPkg(height uint64) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.RemovePkg(tx, height) - }) -} - -// RevocationLogTail returns the "tail", or the end of the current revocation -// log. This entry represents the last previous state for the remote node's -// commitment chain. The ChannelDelta returned by this method will always lag -// one state behind the most current (unrevoked) state of the remote node's -// commitment chain. -func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { - c.RLock() - defer c.RUnlock() - - // If we haven't created any state updates yet, then we'll exit early as - // there's nothing to be found on disk in the revocation bucket. - if c.RemoteCommitment.CommitHeight == 0 { - return nil, nil - } - - var commit ChannelCommitment - if err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket == nil { - return ErrNoPastDeltas - } - - // Once we have the bucket that stores the revocation log from - // this channel, we'll jump to the _last_ key in bucket. As we - // store the update number on disk in a big-endian format, - // this will retrieve the latest entry. - cursor := logBucket.Cursor() - _, tailLogEntry := cursor.Last() - logEntryReader := bytes.NewReader(tailLogEntry) - - // Once we have the entry, we'll decode it into the channel - // delta pointer we created above. - var dbErr error - commit, dbErr = deserializeChanCommit(logEntryReader) - if dbErr != nil { - return dbErr - } - - return nil - }); err != nil { - return nil, err - } - - return &commit, nil -} - -// CommitmentHeight returns the current commitment height. The commitment -// height represents the number of updates to the commitment state to date. -// This value is always monotonically increasing. This method is provided in -// order to allow multiple instances of a particular open channel to obtain a -// consistent view of the number of channel updates to date. -func (c *OpenChannel) CommitmentHeight() (uint64, error) { - c.RLock() - defer c.RUnlock() - - var height uint64 - err := c.Db.View(func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - commit, err := fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - - height = commit.CommitHeight - return nil - }) - if err != nil { - return 0, err - } - - return height, nil -} - -// FindPreviousState scans through the append-only log in an attempt to recover -// the previous channel state indicated by the update number. This method is -// intended to be used for obtaining the relevant data needed to claim all -// funds rightfully spendable in the case of an on-chain broadcast of the -// commitment transaction. -func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) { - c.RLock() - defer c.RUnlock() - - var commit ChannelCommitment - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket == nil { - return ErrNoPastDeltas - } - - c, err := fetchChannelLogEntry(logBucket, updateNum) - if err != nil { - return err - } - - commit = c - return nil - }) - if err != nil { - return nil, err - } - - return &commit, nil -} - // ClosureType is an enum like structure that details exactly _how_ a channel // was closed. Three closure types are currently possible: none, cooperative, // local force close, remote force close, and (remote) breach. type ClosureType uint8 const ( - // CooperativeClose indicates that a channel has been closed - // cooperatively. This means that both channel peers were online and - // signed a new transaction paying out the settled balance of the - // contract. - CooperativeClose ClosureType = 0 - - // LocalForceClose indicates that we have unilaterally broadcast our - // current commitment state on-chain. - LocalForceClose ClosureType = 1 - // RemoteForceClose indicates that the remote peer has unilaterally // broadcast their current commitment state on-chain. RemoteForceClose ClosureType = 4 - - // BreachClose indicates that the remote peer attempted to broadcast a - // prior _revoked_ channel state. - BreachClose ClosureType = 2 - - // FundingCanceled indicates that the channel never was fully opened - // before it was marked as closed in the database. This can happen if - // we or the remote fail at some point during the opening workflow, or - // we timeout waiting for the funding transaction to be confirmed. - FundingCanceled ClosureType = 3 - - // Abandoned indicates that the channel state was removed without - // any further actions. This is intended to clean up unusable - // channels during development. - Abandoned ClosureType = 5 ) // ChannelCloseSummary contains the final state of a channel at the point it @@ -2160,214 +591,6 @@ type ChannelCloseSummary struct { LastChanSyncMsg *lnwire.ChannelReestablish } -// CloseChannel closes a previously active Lightning channel. Closing a channel -// entails deleting all saved state within the database concerning this -// channel. This method also takes a struct that summarizes the state of the -// channel at closing, this compact representation will be the only component -// of a channel left over after a full closing. -func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoChanDBExists - } - - nodePub := c.IdentityPub.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return ErrNoActiveChannels - } - - chainBucket := nodeChanBucket.Bucket(c.ChainHash[:]) - if chainBucket == nil { - return ErrNoActiveChannels - } - - var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) - if err != nil { - return err - } - chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) - if chanBucket == nil { - return ErrNoActiveChannels - } - - // Before we delete the channel state, we'll read out the full - // details, as we'll also store portions of this information - // for record keeping. - chanState, err := fetchOpenChannel( - chanBucket, &c.FundingOutpoint, - ) - if err != nil { - return err - } - - // Now that the index to this channel has been deleted, purge - // the remaining channel metadata from the database. - err = deleteOpenChannel(chanBucket, chanPointBuf.Bytes()) - if err != nil { - return err - } - - // With the base channel data deleted, attempt to delete the - // information stored within the revocation log. - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket != nil { - err = chanBucket.DeleteBucket(revocationLogBucket) - if err != nil { - return err - } - } - - err = chainBucket.DeleteBucket(chanPointBuf.Bytes()) - if err != nil { - return err - } - - // Finally, create a summary of this channel in the closed - // channel bucket for this node. - return putChannelCloseSummary( - tx, chanPointBuf.Bytes(), summary, chanState, - ) - }) -} - -// ChannelSnapshot is a frozen snapshot of the current channel state. A -// snapshot is detached from the original channel that generated it, providing -// read-only access to the current or prior state of an active channel. -// -// TODO(roasbeef): remove all together? pretty much just commitment -type ChannelSnapshot struct { - // RemoteIdentity is the identity public key of the remote node that we - // are maintaining the open channel with. - RemoteIdentity btcec.PublicKey - - // ChanPoint is the outpoint that created the channel. This output is - // found within the funding transaction and uniquely identified the - // channel on the resident chain. - ChannelPoint wire.OutPoint - - // ChainHash is the genesis hash of the chain that the channel resides - // within. - ChainHash chainhash.Hash - - // Capacity is the total capacity of the channel. - Capacity btcutil.Amount - - // TotalMSatSent is the total number of milli-satoshis we've sent - // within this channel. - TotalMSatSent lnwire.MilliSatoshi - - // TotalMSatReceived is the total number of milli-satoshis we've - // received within this channel. - TotalMSatReceived lnwire.MilliSatoshi - - // ChannelCommitment is the current up-to-date commitment for the - // target channel. - ChannelCommitment -} - -// Snapshot returns a read-only snapshot of the current channel state. This -// snapshot includes information concerning the current settled balance within -// the channel, metadata detailing total flows, and any outstanding HTLCs. -func (c *OpenChannel) Snapshot() *ChannelSnapshot { - c.RLock() - defer c.RUnlock() - - localCommit := c.LocalCommitment - snapshot := &ChannelSnapshot{ - RemoteIdentity: *c.IdentityPub, - ChannelPoint: c.FundingOutpoint, - Capacity: c.Capacity, - TotalMSatSent: c.TotalMSatSent, - TotalMSatReceived: c.TotalMSatReceived, - ChainHash: c.ChainHash, - ChannelCommitment: ChannelCommitment{ - LocalBalance: localCommit.LocalBalance, - RemoteBalance: localCommit.RemoteBalance, - CommitHeight: localCommit.CommitHeight, - CommitFee: localCommit.CommitFee, - }, - } - - // Copy over the current set of HTLCs to ensure the caller can't mutate - // our internal state. - snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) - for i, h := range localCommit.Htlcs { - snapshot.Htlcs[i] = h.Copy() - } - - return snapshot -} - -// LatestCommitments returns the two latest commitments for both the local and -// remote party. These commitments are read from disk to ensure that only the -// latest fully committed state is returned. The first commitment returned is -// the local commitment, and the second returned is the remote commitment. -func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanCommitments(chanBucket, c) - }) - if err != nil { - return nil, nil, err - } - - return &c.LocalCommitment, &c.RemoteCommitment, nil -} - -// RemoteRevocationStore returns the most up to date commitment version of the -// revocation storage tree for the remote party. This method can be used when -// acting on a possible contract breach to ensure, that the caller has the most -// up to date information required to deliver justice. -func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanRevocationState(chanBucket, c) - }) - if err != nil { - return nil, err - } - - return c.RevocationStore, nil -} - -func putChannelCloseSummary(tx *bbolt.Tx, chanID []byte, - summary *ChannelCloseSummary, lastChanState *OpenChannel) error { - - closedChanBucket, err := tx.CreateBucketIfNotExists(closedChannelBucket) - if err != nil { - return err - } - - summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation - summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation - summary.LocalChanConfig = lastChanState.LocalChanCfg - - var b bytes.Buffer - if err := serializeChannelCloseSummary(&b, summary); err != nil { - return err - } - - return closedChanBucket.Put(chanID, b.Bytes()) -} - func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { err := WriteElements(w, cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, @@ -2517,113 +740,6 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error { ) } -func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - var w bytes.Buffer - if err := WriteElements(&w, - channel.ChanType, channel.ChainHash, channel.FundingOutpoint, - channel.ShortChannelID, channel.IsPending, channel.IsInitiator, - channel.chanStatus, channel.FundingBroadcastHeight, - channel.NumConfsRequired, channel.ChannelFlags, - channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, - channel.TotalMSatReceived, - ); err != nil { - return err - } - - // For single funder channels that we initiated, write the funding txn. - if channel.ChanType.IsSingleFunder() && channel.IsInitiator && - !channel.hasChanStatus(ChanStatusRestored) { - - if err := WriteElement(&w, channel.FundingTxn); err != nil { - return err - } - } - - if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil { - return err - } - if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil { - return err - } - - return chanBucket.Put(chanInfoKey, w.Bytes()) -} - -func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { - if err := WriteElements(w, - c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, - c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, - c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, - c.CommitSig, - ); err != nil { - return err - } - - return SerializeHtlcs(w, c.Htlcs...) -} - -func putChanCommitment(chanBucket *bbolt.Bucket, c *ChannelCommitment, - local bool) error { - - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - var b bytes.Buffer - if err := serializeChanCommit(&b, c); err != nil { - return err - } - - return chanBucket.Put(commitKey, b.Bytes()) -} - -func putChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - // If this is a restored channel, then we don't have any commitments to - // write. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - err := putChanCommitment( - chanBucket, &channel.LocalCommitment, true, - ) - if err != nil { - return err - } - - return putChanCommitment( - chanBucket, &channel.RemoteCommitment, false, - ) -} - -func putChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - - var b bytes.Buffer - err := WriteElements( - &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, - channel.RevocationStore, - ) - if err != nil { - return err - } - - // TODO(roasbeef): don't keep producer on disk - - // If the next revocation is present, which is only the case after the - // FundingLocked message has been sent, then we'll write it to disk. - if channel.RemoteNextRevocation != nil { - err = WriteElements(&b, channel.RemoteNextRevocation) - if err != nil { - return err - } - } - - return chanBucket.Put(revocationStateKey, b.Bytes()) -} - func readChanConfig(b io.Reader, c *ChannelConfig) error { return ReadElements(b, &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, @@ -2633,185 +749,3 @@ func readChanConfig(b io.Reader, c *ChannelConfig) error { &c.HtlcBasePoint, ) } - -func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - infoBytes := chanBucket.Get(chanInfoKey) - if infoBytes == nil { - return ErrNoChanInfoFound - } - r := bytes.NewReader(infoBytes) - - if err := ReadElements(r, - &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, - &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, - &channel.chanStatus, &channel.FundingBroadcastHeight, - &channel.NumConfsRequired, &channel.ChannelFlags, - &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, - &channel.TotalMSatReceived, - ); err != nil { - return err - } - - // For single funder channels that we initiated, read the funding txn. - if channel.ChanType.IsSingleFunder() && channel.IsInitiator && - !channel.hasChanStatus(ChanStatusRestored) { - - if err := ReadElement(r, &channel.FundingTxn); err != nil { - return err - } - } - - if err := readChanConfig(r, &channel.LocalChanCfg); err != nil { - return err - } - if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil { - return err - } - - channel.Packager = NewChannelPackager(channel.ShortChannelID) - - return nil -} - -func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { - var c ChannelCommitment - - err := ReadElements(r, - &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, - &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, - &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, - ) - if err != nil { - return c, err - } - - c.Htlcs, err = DeserializeHtlcs(r) - if err != nil { - return c, err - } - - return c, nil -} - -func fetchChanCommitment(chanBucket *bbolt.Bucket, local bool) (ChannelCommitment, error) { - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - commitBytes := chanBucket.Get(commitKey) - if commitBytes == nil { - return ChannelCommitment{}, ErrNoCommitmentsFound - } - - r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) -} - -func fetchChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - var err error - - // If this is a restored channel, then we don't have any commitments to - // read. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false) - if err != nil { - return err - } - - return nil -} - -func fetchChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - revBytes := chanBucket.Get(revocationStateKey) - if revBytes == nil { - return ErrNoRevocationsFound - } - r := bytes.NewReader(revBytes) - - err := ReadElements( - r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, - &channel.RevocationStore, - ) - if err != nil { - return err - } - - // If there aren't any bytes left in the buffer, then we don't yet have - // the next remote revocation, so we can exit early here. - if r.Len() == 0 { - return nil - } - - // Otherwise we'll read the next revocation for the remote party which - // is always the last item within the buffer. - return ReadElements(r, &channel.RemoteNextRevocation) -} - -func deleteOpenChannel(chanBucket *bbolt.Bucket, chanPointBytes []byte) error { - - if err := chanBucket.Delete(chanInfoKey); err != nil { - return err - } - - err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00))) - if err != nil { - return err - } - err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01))) - if err != nil { - return err - } - - if err := chanBucket.Delete(revocationStateKey); err != nil { - return err - } - - if diff := chanBucket.Get(commitDiffKey); diff != nil { - return chanBucket.Delete(commitDiffKey) - } - - return nil - -} - -// makeLogKey converts a uint64 into an 8 byte array. -func makeLogKey(updateNum uint64) [8]byte { - var key [8]byte - byteOrder.PutUint64(key[:], updateNum) - return key -} - -func appendChannelLogEntry(log *bbolt.Bucket, - commit *ChannelCommitment) error { - - var b bytes.Buffer - if err := serializeChanCommit(&b, commit); err != nil { - return err - } - - logEntrykey := makeLogKey(commit.CommitHeight) - return log.Put(logEntrykey[:], b.Bytes()) -} - -func fetchChannelLogEntry(log *bbolt.Bucket, - updateNum uint64) (ChannelCommitment, error) { - - logEntrykey := makeLogKey(updateNum) - commitBytes := log.Get(logEntrykey[:]) - if commitBytes == nil { - return ChannelCommitment{}, fmt.Errorf("log entry not found") - } - - commitReader := bytes.NewReader(commitBytes) - return deserializeChanCommit(commitReader) -} diff --git a/channeldb/migration_01_to_11/channel_cache.go b/channeldb/migration_01_to_11/channel_cache.go deleted file mode 100644 index 5d391e00..00000000 --- a/channeldb/migration_01_to_11/channel_cache.go +++ /dev/null @@ -1,50 +0,0 @@ -package migration_01_to_11 - -// channelCache is an in-memory cache used to improve the performance of -// ChanUpdatesInHorizon. It caches the chan info and edge policies for a -// particular channel. -type channelCache struct { - n int - channels map[uint64]ChannelEdge -} - -// newChannelCache creates a new channelCache with maximum capacity of n -// channels. -func newChannelCache(n int) *channelCache { - return &channelCache{ - n: n, - channels: make(map[uint64]ChannelEdge), - } -} - -// get returns the channel from the cache, if it exists. -func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) { - channel, ok := c.channels[chanid] - return channel, ok -} - -// insert adds the entry to the channel cache. If an entry for chanid already -// exists, it will be replaced with the new entry. If the entry doesn't exist, -// it will be inserted to the cache, performing a random eviction if the cache -// is at capacity. -func (c *channelCache) insert(chanid uint64, channel ChannelEdge) { - // If entry exists, replace it. - if _, ok := c.channels[chanid]; ok { - c.channels[chanid] = channel - return - } - - // Otherwise, evict an entry at random and insert. - if len(c.channels) == c.n { - for id := range c.channels { - delete(c.channels, id) - break - } - } - c.channels[chanid] = channel -} - -// remove deletes an edge for chanid from the cache, if it exists. -func (c *channelCache) remove(chanid uint64) { - delete(c.channels, chanid) -} diff --git a/channeldb/migration_01_to_11/channel_cache_test.go b/channeldb/migration_01_to_11/channel_cache_test.go deleted file mode 100644 index b2929635..00000000 --- a/channeldb/migration_01_to_11/channel_cache_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package migration_01_to_11 - -import ( - "reflect" - "testing" -) - -// TestChannelCache checks the behavior of the channelCache with respect to -// insertion, eviction, and removal of cache entries. -func TestChannelCache(t *testing.T) { - const cacheSize = 100 - - // Create a new channel cache with the configured max size. - c := newChannelCache(cacheSize) - - // As a sanity check, assert that querying the empty cache does not - // return an entry. - _, ok := c.get(0) - if ok { - t.Fatalf("channel cache should be empty") - } - - // Now, fill up the cache entirely. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, channelForInt(i)) - } - - // Assert that the cache has all of the entries just inserted, since no - // eviction should occur until we try to surpass the max size. - assertHasChanEntries(t, c, 0, cacheSize) - - // Now, insert a new element that causes the cache to evict an element. - c.insert(cacheSize, channelForInt(cacheSize)) - - // Assert that the cache has this last entry, as the cache should evict - // some prior element and not the newly inserted one. - assertHasChanEntries(t, c, cacheSize, cacheSize) - - // Iterate over all inserted elements and construct a set of the evicted - // elements. - evicted := make(map[uint64]struct{}) - for i := uint64(0); i < cacheSize+1; i++ { - _, ok := c.get(i) - if !ok { - evicted[i] = struct{}{} - } - } - - // Assert that exactly one element has been evicted. - numEvicted := len(evicted) - if numEvicted != 1 { - t.Fatalf("expected one evicted entry, got: %d", numEvicted) - } - - // Remove the highest item which initially caused the eviction and - // reinsert the element that was evicted prior. - c.remove(cacheSize) - for i := range evicted { - c.insert(i, channelForInt(i)) - } - - // Since the removal created an extra slot, the last insertion should - // not have caused an eviction and the entries for all channels in the - // original set that filled the cache should be present. - assertHasChanEntries(t, c, 0, cacheSize) - - // Finally, reinsert the existing set back into the cache and test that - // the cache still has all the entries. If the randomized eviction were - // happening on inserts for existing cache items, we expect this to fail - // with high probability. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, channelForInt(i)) - } - assertHasChanEntries(t, c, 0, cacheSize) - -} - -// assertHasEntries queries the edge cache for all channels in the range [start, -// end), asserting that they exist and their value matches the entry produced by -// entryForInt. -func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { - t.Helper() - - for i := start; i < end; i++ { - entry, ok := c.get(i) - if !ok { - t.Fatalf("channel cache should contain chan %d", i) - } - - expEntry := channelForInt(i) - if !reflect.DeepEqual(entry, expEntry) { - t.Fatalf("entry mismatch, want: %v, got: %v", - expEntry, entry) - } - } -} - -// channelForInt generates a unique ChannelEdge given an integer. -func channelForInt(i uint64) ChannelEdge { - return ChannelEdge{ - Info: &ChannelEdgeInfo{ - ChannelID: i, - }, - } -} diff --git a/channeldb/migration_01_to_11/channel_test.go b/channeldb/migration_01_to_11/channel_test.go index 53fb39d7..1380828e 100644 --- a/channeldb/migration_01_to_11/channel_test.go +++ b/channeldb/migration_01_to_11/channel_test.go @@ -4,18 +4,13 @@ import ( "bytes" "io/ioutil" "math/rand" - "net" "os" - "reflect" - "runtime" - "testing" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" _ "github.com/btcsuite/btcwallet/walletdb/bdb" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" @@ -66,8 +61,6 @@ var ( LockTime: 5, } privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - - wireSig, _ = lnwire.NewSigFromSignature(testSig) ) // makeTestDB creates a new instance of the ChannelDB for testing purposes. A @@ -223,819 +216,6 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) { RevocationProducer: producer, RevocationStore: store, Db: cdb, - Packager: NewChannelPackager(chanID), FundingTxn: testTx, }, nil } - -func TestOpenChannelPutGetDelete(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state, then add an additional fake HTLC - // before syncing to disk. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - state.LocalCommitment.Htlcs = []HTLC{ - { - Signature: testSig.Serialize(), - Incoming: true, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: []byte("onionblob"), - }, - } - state.RemoteCommitment.Htlcs = []HTLC{ - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: []byte("onionblob"), - }, - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channel: %v", err) - } - - newState := openChannels[0] - - // The decoded channel state should be identical to what we stored - // above. - if !reflect.DeepEqual(state, newState) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(state), spew.Sdump(newState)) - } - - // We'll also test that the channel is properly able to hot swap the - // next revocation for the state machine. This tests the initial - // post-funding revocation exchange. - nextRevKey, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to create new private key: %v", err) - } - if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil { - t.Fatalf("unable to update revocation: %v", err) - } - - openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channel: %v", err) - } - updatedChan := openChannels[0] - - // Ensure that the revocation was set properly. - if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) { - t.Fatalf("next revocation wasn't updated") - } - - // Finally to wrap up the test, delete the state of the channel within - // the database. This involves "closing" the channel which removes all - // written state, and creates a small "summary" elsewhere within the - // database. - closeSummary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - RemotePub: state.IdentityPub, - SettledBalance: btcutil.Amount(500), - TimeLockedBalance: btcutil.Amount(10000), - IsPending: false, - CloseType: CooperativeClose, - } - if err := state.CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - - // As the channel is now closed, attempting to fetch all open channels - // for our fake node ID should return an empty slice. - openChans, err := cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channels: %v", err) - } - if len(openChans) != 0 { - t.Fatalf("all channels not deleted, found %v", len(openChans)) - } - - // Additionally, attempting to fetch all the open channels globally - // should yield no results. - openChans, err = cdb.FetchAllChannels() - if err != nil { - t.Fatal("unable to fetch all open chans") - } - if len(openChans) != 0 { - t.Fatalf("all channels not deleted, found %v", len(openChans)) - } -} - -func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { - if !reflect.DeepEqual(a, b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: commitments don't match: %v vs %v", - line, spew.Sdump(a), spew.Sdump(b)) - } -} - -func TestChannelStateTransition(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a minimal channel, then perform a full sync in order to - // persist the data. - channel, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := channel.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Add some HTLCs which were added during this new state transition. - // Half of the HTLCs are incoming, while the other half are outgoing. - var ( - htlcs []HTLC - htlcAmt lnwire.MilliSatoshi - ) - for i := uint32(0); i < 10; i++ { - var incoming bool - if i > 5 { - incoming = true - } - htlc := HTLC{ - Signature: testSig.Serialize(), - Incoming: incoming, - Amt: 10, - RHash: key, - RefundTimeout: i, - OutputIndex: int32(i * 3), - LogIndex: uint64(i * 2), - HtlcIndex: uint64(i), - } - htlc.OnionBlob = make([]byte, 10) - copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10)) - htlcs = append(htlcs, htlc) - htlcAmt += htlc.Amt - } - - // Create a new channel delta which includes the above HTLCs, some - // balance updates, and an increment of the current commitment height. - // Additionally, modify the signature and commitment transaction. - newSequence := uint32(129498) - newSig := bytes.Repeat([]byte{3}, 71) - newTx := channel.LocalCommitment.CommitTx.Copy() - newTx.TxIn[0].Sequence = newSequence - commitment := ChannelCommitment{ - CommitHeight: 1, - LocalLogIndex: 2, - LocalHtlcIndex: 1, - RemoteLogIndex: 2, - RemoteHtlcIndex: 1, - LocalBalance: lnwire.MilliSatoshi(1e8), - RemoteBalance: lnwire.MilliSatoshi(1e8), - CommitFee: 55, - FeePerKw: 99, - CommitTx: newTx, - CommitSig: newSig, - Htlcs: htlcs, - } - - // First update the local node's broadcastable state and also add a - // CommitDiff remote node's as well in order to simulate a proper state - // transition. - if err := channel.UpdateCommitment(&commitment); err != nil { - t.Fatalf("unable to update commitment: %v", err) - } - - // The balances, new update, the HTLCs and the changes to the fake - // commitment transaction along with the modified signature should all - // have been updated. - updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channel: %v", err) - } - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) - numDiskUpdates, err := updatedChannel[0].CommitmentHeight() - if err != nil { - t.Fatalf("unable to read commitment height from disk: %v", err) - } - if numDiskUpdates != uint64(commitment.CommitHeight) { - t.Fatalf("num disk updates doesn't match: %v vs %v", - numDiskUpdates, commitment.CommitHeight) - } - - // Attempting to query for a commitment diff should return - // ErrNoPendingCommit as we haven't yet created a new state for them. - _, err = channel.RemoteCommitChainTip() - if err != ErrNoPendingCommit { - t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) - } - - // To simulate us extending a new state to the remote party, we'll also - // create a new commit diff for them. - remoteCommit := commitment - remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) - remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) - remoteCommit.CommitHeight = 1 - commitDiff := &CommitDiff{ - Commitment: remoteCommit, - CommitSig: &lnwire.CommitSig{ - ChanID: lnwire.ChannelID(key), - CommitSig: wireSig, - HtlcSigs: []lnwire.Sig{ - wireSig, - wireSig, - }, - }, - LogUpdates: []LogUpdate{ - { - LogIndex: 1, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, - }, - }, - { - LogIndex: 2, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, - }, - }, - }, - OpenedCircuitKeys: []CircuitKey{}, - ClosedCircuitKeys: []CircuitKey{}, - } - copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], - bytes.Repeat([]byte{1}, 32)) - copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], - bytes.Repeat([]byte{2}, 32)) - if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { - t.Fatalf("unable to add to commit chain: %v", err) - } - - // The commitment tip should now match the commitment that we just - // inserted. - diskCommitDiff, err := channel.RemoteCommitChainTip() - if err != nil { - t.Fatalf("unable to fetch commit diff: %v", err) - } - if !reflect.DeepEqual(commitDiff, diskCommitDiff) { - t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), - spew.Sdump(diskCommitDiff)) - } - - // We'll save the old remote commitment as this will be added to the - // revocation log shortly. - oldRemoteCommit := channel.RemoteCommitment - - // Next, write to the log which tracks the necessary revocation state - // needed to rectify any fishy behavior by the remote party. Modify the - // current uncollapsed revocation state to simulate a state transition - // by the remote party. - channel.RemoteCurrentRevocation = channel.RemoteNextRevocation - newPriv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to generate key: %v", err) - } - channel.RemoteNextRevocation = newPriv.PubKey() - - fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, - diskCommitDiff.LogUpdates, nil) - - err = channel.AdvanceCommitChainTail(fwdPkg) - if err != nil { - t.Fatalf("unable to append to revocation log: %v", err) - } - - // At this point, the remote commit chain should be nil, and the posted - // remote commitment should match the one we added as a diff above. - if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { - t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) - } - - // We should be able to fetch the channel delta created above by its - // update number with all the state properly reconstructed. - diskPrevCommit, err := channel.FindPreviousState( - oldRemoteCommit.CommitHeight, - ) - if err != nil { - t.Fatalf("unable to fetch past delta: %v", err) - } - - // The two deltas (the original vs the on-disk version) should - // identical, and all HTLC data should properly be retained. - assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) - - // The state number recovered from the tail of the revocation log - // should be identical to this current state. - logTail, err := channel.RevocationLogTail() - if err != nil { - t.Fatalf("unable to retrieve log: %v", err) - } - if logTail.CommitHeight != oldRemoteCommit.CommitHeight { - t.Fatal("update number doesn't match") - } - - oldRemoteCommit = channel.RemoteCommitment - - // Next modify the posted diff commitment slightly, then create a new - // commitment diff and advance the tail. - commitDiff.Commitment.CommitHeight = 2 - commitDiff.Commitment.LocalBalance -= htlcAmt - commitDiff.Commitment.RemoteBalance += htlcAmt - commitDiff.LogUpdates = []LogUpdate{} - if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { - t.Fatalf("unable to add to commit chain: %v", err) - } - - fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) - - err = channel.AdvanceCommitChainTail(fwdPkg) - if err != nil { - t.Fatalf("unable to append to revocation log: %v", err) - } - - // Once again, fetch the state and ensure it has been properly updated. - prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) - if err != nil { - t.Fatalf("unable to fetch past delta: %v", err) - } - assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) - - // Once again, state number recovered from the tail of the revocation - // log should be identical to this current state. - logTail, err = channel.RevocationLogTail() - if err != nil { - t.Fatalf("unable to retrieve log: %v", err) - } - if logTail.CommitHeight != oldRemoteCommit.CommitHeight { - t.Fatal("update number doesn't match") - } - - // The revocation state stored on-disk should now also be identical. - updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channel: %v", err) - } - if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { - t.Fatalf("revocation state was not synced") - } - if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { - t.Fatalf("revocation state was not synced") - } - - // Now attempt to delete the channel from the database. - closeSummary := &ChannelCloseSummary{ - ChanPoint: channel.FundingOutpoint, - RemotePub: channel.IdentityPub, - SettledBalance: btcutil.Amount(500), - TimeLockedBalance: btcutil.Amount(10000), - IsPending: false, - CloseType: RemoteForceClose, - } - if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to delete updated channel: %v", err) - } - - // If we attempt to fetch the target channel again, it shouldn't be - // found. - channels, err := cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channels: %v", err) - } - if len(channels) != 0 { - t.Fatalf("%v channels, found, but none should be", - len(channels)) - } - - // Attempting to find previous states on the channel should fail as the - // revocation log has been deleted. - _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) - if err == nil { - t.Fatal("revocation log search should have failed") - } -} - -func TestFetchPendingChannels(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create first test channel state - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - pendingChannels, err := cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to list pending channels: %v", err) - } - - if len(pendingChannels) != 1 { - t.Fatalf("incorrect number of pending channels: expecting %v,"+ - "got %v", 1, len(pendingChannels)) - } - - // The broadcast height of the pending channel should have been set - // properly. - if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { - t.Fatalf("broadcast height mismatch: expected %v, got %v", - pendingChannels[0].FundingBroadcastHeight, - broadcastHeight) - } - - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 5, - TxIndex: 10, - TxPosition: 15, - } - err = pendingChannels[0].MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - if pendingChannels[0].IsPending { - t.Fatalf("channel marked open should no longer be pending") - } - - if pendingChannels[0].ShortChanID() != chanOpenLoc { - t.Fatalf("channel opening height not updated: expected %v, "+ - "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), - chanOpenLoc) - } - - // Next, we'll re-fetch the channel to ensure that the open height was - // properly set. - openChans, err := cdb.FetchAllChannels() - if err != nil { - t.Fatalf("unable to fetch channels: %v", err) - } - if openChans[0].ShortChanID() != chanOpenLoc { - t.Fatalf("channel opening heights don't match: expected %v, "+ - "got %v", spew.Sdump(openChans[0].ShortChanID()), - chanOpenLoc) - } - if openChans[0].FundingBroadcastHeight != broadcastHeight { - t.Fatalf("broadcast height mismatch: expected %v, got %v", - openChans[0].FundingBroadcastHeight, - broadcastHeight) - } - - pendingChannels, err = cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to list pending channels: %v", err) - } - - if len(pendingChannels) != 0 { - t.Fatalf("incorrect number of pending channels: expecting %v,"+ - "got %v", 0, len(pendingChannels)) - } -} - -func TestFetchClosedChannels(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a test channel, that we'll be closing within this pull - // request. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Next sync the channel to disk, marking it as being in a pending open - // state. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Next, simulate the confirmation of the channel by marking it as - // pending within the database. - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 5, - TxIndex: 10, - TxPosition: 15, - } - err = state.MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - // Next, close the channel by including a close channel summary in the - // database. - summary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - ClosingTXID: rev, - RemotePub: state.IdentityPub, - Capacity: state.Capacity, - SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(), - TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000, - CloseType: RemoteForceClose, - IsPending: true, - LocalChanConfig: state.LocalChanCfg, - } - if err := state.CloseChannel(summary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - - // Query the database to ensure that the channel has now been properly - // closed. We should get the same result whether querying for pending - // channels only, or not. - pendingClosed, err := cdb.FetchClosedChannels(true) - if err != nil { - t.Fatalf("failed fetching closed channels: %v", err) - } - if len(pendingClosed) != 1 { - t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ - "got %v", 1, len(pendingClosed)) - } - if !reflect.DeepEqual(summary, pendingClosed[0]) { - t.Fatalf("database summaries don't match: expected %v got %v", - spew.Sdump(summary), spew.Sdump(pendingClosed[0])) - } - closed, err := cdb.FetchClosedChannels(false) - if err != nil { - t.Fatalf("failed fetching all closed channels: %v", err) - } - if len(closed) != 1 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 1, len(closed)) - } - if !reflect.DeepEqual(summary, closed[0]) { - t.Fatalf("database summaries don't match: expected %v got %v", - spew.Sdump(summary), spew.Sdump(closed[0])) - } - - // Mark the channel as fully closed. - err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) - if err != nil { - t.Fatalf("failed fully closing channel: %v", err) - } - - // The channel should no longer be considered pending, but should still - // be retrieved when fetching all the closed channels. - closed, err = cdb.FetchClosedChannels(false) - if err != nil { - t.Fatalf("failed fetching closed channels: %v", err) - } - if len(closed) != 1 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 1, len(closed)) - } - pendingClose, err := cdb.FetchClosedChannels(true) - if err != nil { - t.Fatalf("failed fetching channels pending close: %v", err) - } - if len(pendingClose) != 0 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 0, len(closed)) - } -} - -// TestFetchWaitingCloseChannels ensures that the correct channels that are -// waiting to be closed are returned. -func TestFetchWaitingCloseChannels(t *testing.T) { - t.Parallel() - - const numChannels = 2 - const broadcastHeight = 99 - addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555} - - // We'll start by creating two channels within our test database. One of - // them will have their funding transaction confirmed on-chain, while - // the other one will remain unconfirmed. - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - channels := make([]*OpenChannel, numChannels) - for i := 0; i < numChannels; i++ { - channel, err := createTestChannelState(db) - if err != nil { - t.Fatalf("unable to create channel: %v", err) - } - err = channel.SyncPending(addr, broadcastHeight) - if err != nil { - t.Fatalf("unable to sync channel: %v", err) - } - channels[i] = channel - } - - // We'll only confirm the first one. - channelConf := lnwire.ShortChannelID{ - BlockHeight: broadcastHeight + 1, - TxIndex: 10, - TxPosition: 15, - } - if err := channels[0].MarkAsOpen(channelConf); err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - // Then, we'll mark the channels as if their commitments were broadcast. - // This would happen in the event of a force close and should make the - // channels enter a state of waiting close. - for _, channel := range channels { - closeTx := wire.NewMsgTx(2) - closeTx.AddTxIn( - &wire.TxIn{ - PreviousOutPoint: channel.FundingOutpoint, - }, - ) - if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil { - t.Fatalf("unable to mark commitment broadcast: %v", err) - } - } - - // Now, we'll fetch all the channels waiting to be closed from the - // database. We should expect to see both channels above, even if any of - // them haven't had their funding transaction confirm on-chain. - waitingCloseChannels, err := db.FetchWaitingCloseChannels() - if err != nil { - t.Fatalf("unable to fetch all waiting close channels: %v", err) - } - if len(waitingCloseChannels) != 2 { - t.Fatalf("expected %d channels waiting to be closed, got %d", 2, - len(waitingCloseChannels)) - } - expectedChannels := make(map[wire.OutPoint]struct{}) - for _, channel := range channels { - expectedChannels[channel.FundingOutpoint] = struct{}{} - } - for _, channel := range waitingCloseChannels { - if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { - t.Fatalf("expected channel %v to be waiting close", - channel.FundingOutpoint) - } - - // Finally, make sure we can retrieve the closing tx for the - // channel. - closeTx, err := channel.BroadcastedCommitment() - if err != nil { - t.Fatalf("Unable to retrieve commitment: %v", err) - } - - if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint { - t.Fatalf("expected outpoint %v, got %v", - channel.FundingOutpoint, - closeTx.TxIn[0].PreviousOutPoint) - } - } -} - -// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory -// short channel ID of another OpenChannel to reflect a preceding call to -// MarkOpen on a different OpenChannel. -func TestRefreshShortChanID(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a test channel. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - - // Mark the channel as pending within the channeldb. - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Next, locate the pending channel with the database. - pendingChannels, err := cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to load pending channels; %v", err) - } - - var pendingChannel *OpenChannel - for _, channel := range pendingChannels { - if channel.FundingOutpoint == state.FundingOutpoint { - pendingChannel = channel - break - } - } - if pendingChannel == nil { - t.Fatalf("unable to find pending channel with funding "+ - "outpoint=%v: %v", state.FundingOutpoint, err) - } - - // Next, simulate the confirmation of the channel by marking it as - // pending within the database. - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 105, - TxIndex: 10, - TxPosition: 15, - } - - err = state.MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel open: %v", err) - } - - // The short_chan_id of the receiver to MarkAsOpen should reflect the - // open location, but the other pending channel should remain unchanged. - if state.ShortChanID() == pendingChannel.ShortChanID() { - t.Fatalf("pending channel short_chan_ID should not have been " + - "updated before refreshing short_chan_id") - } - - // Now that the receiver's short channel id has been updated, check to - // ensure that the channel packager's source has been updated as well. - // This ensures that the packager will read and write to buckets - // corresponding to the new short chan id, instead of the prior. - if state.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - state.Packager.(*ChannelPackager).source) - } - - // Now, refresh the short channel ID of the pending channel. - err = pendingChannel.RefreshShortChanID() - if err != nil { - t.Fatalf("unable to refresh short_chan_id: %v", err) - } - - // This should result in both OpenChannel's now having the same - // ShortChanID. - if state.ShortChanID() != pendingChannel.ShortChanID() { - t.Fatalf("expected pending channel short_chan_id to be "+ - "refreshed: want %v, got %v", state.ShortChanID(), - pendingChannel.ShortChanID()) - } - - // Check to ensure that the _other_ OpenChannel channel packager's - // source has also been updated after the refresh. This ensures that the - // other packagers will read and write to buckets corresponding to the - // updated short chan id. - if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - pendingChannel.Packager.(*ChannelPackager).source) - } -} diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go index cfef35e0..1727c8c9 100644 --- a/channeldb/migration_01_to_11/codec.go +++ b/channeldb/migration_01_to_11/codec.go @@ -48,12 +48,6 @@ type UnknownElementType struct { element interface{} } -// NewUnknownElementType creates a new UnknownElementType error from the passed -// method name and element. -func NewUnknownElementType(method string, el interface{}) UnknownElementType { - return UnknownElementType{method: method, element: el} -} - // Error returns the name of the method that encountered the error, as well as // the type that was unsupported. func (e UnknownElementType) Error() string { diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index e1057d65..623b33bc 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -4,16 +4,11 @@ import ( "bytes" "encoding/binary" "fmt" - "net" "os" "path/filepath" "time" - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" ) const ( @@ -87,57 +82,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { return chanDB, nil } -// Path returns the file path to the channel database. -func (d *DB) Path() string { - return d.dbPath -} - -// Wipe completely deletes all saved state within all used buckets within the -// database. The deletion is done in a single transaction, therefore this -// operation is fully atomic. -func (d *DB) Wipe() error { - return d.Update(func(tx *bbolt.Tx) error { - err := tx.DeleteBucket(openChannelBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(closedChannelBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(invoiceBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(nodeInfoBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(nodeBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(edgeBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(edgeIndexBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(graphMetaBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - return nil - }) -} - // createChannelDB creates and initializes a fresh version of channeldb. In // the case that the target path has not yet been created or doesn't yet exist, // then the path is created. Additionally, all required top-level buckets used @@ -163,14 +107,6 @@ func createChannelDB(dbPath string) error { return err } - if _, err := tx.CreateBucket(forwardingLogBucket); err != nil { - return err - } - - if _, err := tx.CreateBucket(fwdPackagesKey); err != nil { - return err - } - if _, err := tx.CreateBucket(invoiceBucket); err != nil { return err } @@ -179,10 +115,6 @@ func createChannelDB(dbPath string) error { return err } - if _, err := tx.CreateBucket(nodeInfoBucket); err != nil { - return err - } - nodes, err := tx.CreateBucket(nodeBucket) if err != nil { return err @@ -249,359 +181,6 @@ func fileExists(path string) bool { return true } -// FetchOpenChannels starts a new database transaction and returns all stored -// currently active/open channels associated with the target nodeID. In the case -// that no active channels are known to have been created with this node, then a -// zero-length slice is returned. -func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { - var channels []*OpenChannel - err := d.View(func(tx *bbolt.Tx) error { - var err error - channels, err = d.fetchOpenChannels(tx, nodeID) - return err - }) - - return channels, err -} - -// fetchOpenChannels uses and existing database transaction and returns all -// stored currently active/open channels associated with the target nodeID. In -// the case that no active channels are known to have been created with this -// node, then a zero-length slice is returned. -func (d *DB) fetchOpenChannels(tx *bbolt.Tx, - nodeID *btcec.PublicKey) ([]*OpenChannel, error) { - - // Get the bucket dedicated to storing the metadata for open channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return nil, nil - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - pub := nodeID.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(pub) - if nodeChanBucket == nil { - return nil, nil - } - - // Next, we'll need to go down an additional layer in order to retrieve - // the channels for each chain the node knows of. - var channels []*OpenChannel - err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, then we'll retrieve - // that so we can extract all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read bucket for chain=%x", - chainHash[:]) - } - - // Finally, we both of the necessary buckets retrieved, fetch - // all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read channel for "+ - "chain_hash=%x, node_key=%x: %v", - chainHash[:], pub, err) - } - - channels = append(channels, nodeChannels...) - return nil - }) - - return channels, err -} - -// fetchNodeChannels retrieves all active channels from the target chainBucket -// which is under a node's dedicated channel bucket. This function is typically -// used to fetch all the active channels related to a particular node. -func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) { - - var channels []*OpenChannel - - // A node may have channels on several chains, so for each known chain, - // we'll extract all the channels. - err := chainBucket.ForEach(func(chanPoint, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } - - // Once we've found a valid channel bucket, we'll extract it - // from the node's chain bucket. - chanBucket := chainBucket.Bucket(chanPoint) - - var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) - if err != nil { - return err - } - oChannel, err := fetchOpenChannel(chanBucket, &outPoint) - if err != nil { - return fmt.Errorf("unable to read channel data for "+ - "chan_point=%v: %v", outPoint, err) - } - oChannel.Db = d - - channels = append(channels, oChannel) - - return nil - }) - if err != nil { - return nil, err - } - - return channels, nil -} - -// FetchChannel attempts to locate a channel specified by the passed channel -// point. If the channel cannot be found, then an error will be returned. -func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { - var ( - targetChan *OpenChannel - targetChanPoint bytes.Buffer - ) - - if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { - return nil, err - } - - // chanScan will traverse the following bucket structure: - // * nodePub => chainHash => chanPoint - // - // At each level we go one further, ensuring that we're traversing the - // proper key (that's actually a bucket). By only reading the bucket - // structure and skipping fully decoding each channel, we save a good - // bit of CPU as we don't need to do things like decompress public - // keys. - chanScan := func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Within the node channel bucket, are the set of node pubkeys - // we have channels with, we don't know the entire set, so - // we'll check them all. - return openChanBucket.ForEach(func(nodePub, v []byte) error { - // Ensure that this is a key the same size as a pubkey, - // and also that it leads directly to a bucket. - if len(nodePub) != 33 || v != nil { - return nil - } - - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return nil - } - - // The next layer down is all the chains that this node - // has channels on with us. - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) - } - - // Finally we reach the leaf bucket that stores - // all the chanPoints for this node. - chanBucket := chainBucket.Bucket( - targetChanPoint.Bytes(), - ) - if chanBucket == nil { - return nil - } - - channel, err := fetchOpenChannel( - chanBucket, &chanPoint, - ) - if err != nil { - return err - } - - targetChan = channel - targetChan.Db = d - - return nil - }) - }) - } - - err := d.View(chanScan) - if err != nil { - return nil, err - } - - if targetChan != nil { - return targetChan, nil - } - - // If we can't find the channel, then we return with an error, as we - // have nothing to backup. - return nil, ErrChannelNotFound -} - -// FetchAllChannels attempts to retrieve all open channels currently stored -// within the database, including pending open, fully open and channels waiting -// for a closing transaction to confirm. -func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - var channels []*OpenChannel - - // TODO(halseth): fetch all in one db tx. - openChannels, err := d.FetchAllOpenChannels() - if err != nil { - return nil, err - } - channels = append(channels, openChannels...) - - pendingChannels, err := d.FetchPendingChannels() - if err != nil { - return nil, err - } - channels = append(channels, pendingChannels...) - - waitingClose, err := d.FetchWaitingCloseChannels() - if err != nil { - return nil, err - } - channels = append(channels, waitingClose...) - - return channels, nil -} - -// FetchAllOpenChannels will return all channels that have the funding -// transaction confirmed, and is not waiting for a closing transaction to be -// confirmed. -func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { - return fetchChannels(d, false, false) -} - -// FetchPendingChannels will return channels that have completed the process of -// generating and broadcasting funding transactions, but whose funding -// transactions have yet to be confirmed on the blockchain. -func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, true, false) -} - -// FetchWaitingCloseChannels will return all channels that have been opened, -// but are now waiting for a closing transaction to be confirmed. -// -// NOTE: This includes channels that are also pending to be opened. -func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { - waitingClose, err := fetchChannels(d, false, true) - if err != nil { - return nil, err - } - pendingWaitingClose, err := fetchChannels(d, true, true) - if err != nil { - return nil, err - } - - return append(waitingClose, pendingWaitingClose...), nil -} - -// fetchChannels attempts to retrieve channels currently stored in the -// database. The pending parameter determines whether only pending channels -// will be returned, or only open channels will be returned. The waitingClose -// parameter determines whether only channels waiting for a closing transaction -// to be confirmed should be returned. If no active channels exist within the -// network, then ErrNoActiveChannels is returned. -func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { - var channels []*OpenChannel - - err := d.View(func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Next, fetch the bucket dedicated to storing metadata related - // to all nodes. All keys within this bucket are the serialized - // public keys of all our direct counterparties. - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return fmt.Errorf("node bucket not created") - } - - // Finally for each node public key in the bucket, fetch all - // the channels related to this particular node. - return nodeMetaBucket.ForEach(func(k, v []byte) error { - nodeChanBucket := openChanBucket.Bucket(k) - if nodeChanBucket == nil { - return nil - } - - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, - // then we'll retrieve that so we can extract - // all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) - } - - nodeChans, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read "+ - "channel for chain_hash=%x, "+ - "node_key=%x: %v", chainHash[:], k, err) - } - for _, channel := range nodeChans { - if channel.IsPending != pending { - continue - } - - // If the channel is in any other state - // than Default, then it means it is - // waiting to be closed. - channelWaitingClose := - channel.ChanStatus() != ChanStatusDefault - - // Only include it if we requested - // channels with the same waitingClose - // status. - if channelWaitingClose != waitingClose { - continue - } - - channels = append(channels, channel) - } - return nil - }) - - }) - }) - if err != nil { - return nil, err - } - - return channels, nil -} - // FetchClosedChannels attempts to fetch all closed channels from the database. // The pendingOnly bool toggles if channels that aren't yet fully closed should // be returned in the response or not. When a channel was cooperatively closed, @@ -641,371 +220,6 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro return chanSummaries, nil } -// ErrClosedChannelNotFound signals that a closed channel could not be found in -// the channeldb. -var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary") - -// FetchClosedChannel queries for a channel close summary using the channel -// point of the channel in question. -func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { - var chanSummary *ChannelCloseSummary - if err := d.View(func(tx *bbolt.Tx) error { - closeBucket := tx.Bucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - var b bytes.Buffer - var err error - if err = writeOutpoint(&b, chanID); err != nil { - return err - } - - summaryBytes := closeBucket.Get(b.Bytes()) - if summaryBytes == nil { - return ErrClosedChannelNotFound - } - - summaryReader := bytes.NewReader(summaryBytes) - chanSummary, err = deserializeCloseChannelSummary(summaryReader) - - return err - }); err != nil { - return nil, err - } - - return chanSummary, nil -} - -// FetchClosedChannelForID queries for a channel close summary using the -// channel ID of the channel in question. -func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( - *ChannelCloseSummary, error) { - - var chanSummary *ChannelCloseSummary - if err := d.View(func(tx *bbolt.Tx) error { - closeBucket := tx.Bucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - // The first 30 bytes of the channel ID and outpoint will be - // equal. - cursor := closeBucket.Cursor() - op, c := cursor.Seek(cid[:30]) - - // We scan over all possible candidates for this channel ID. - for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { - var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(op), &outPoint) - if err != nil { - return err - } - - // If the found outpoint does not correspond to this - // channel ID, we continue. - if !cid.IsChanPoint(&outPoint) { - continue - } - - // Deserialize the close summary and return. - r := bytes.NewReader(c) - chanSummary, err = deserializeCloseChannelSummary(r) - if err != nil { - return err - } - - return nil - } - return ErrClosedChannelNotFound - }); err != nil { - return nil, err - } - - return chanSummary, nil -} - -// MarkChanFullyClosed marks a channel as fully closed within the database. A -// channel should be marked as fully closed if the channel was initially -// cooperatively closed and it's reached a single confirmation, or after all -// the pending funds in a channel that has been forcibly closed have been -// swept. -func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { - return d.Update(func(tx *bbolt.Tx) error { - var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { - return err - } - - chanID := b.Bytes() - - closedChanBucket, err := tx.CreateBucketIfNotExists( - closedChannelBucket, - ) - if err != nil { - return err - } - - chanSummaryBytes := closedChanBucket.Get(chanID) - if chanSummaryBytes == nil { - return fmt.Errorf("no closed channel for "+ - "chan_point=%v found", chanPoint) - } - - chanSummaryReader := bytes.NewReader(chanSummaryBytes) - chanSummary, err := deserializeCloseChannelSummary( - chanSummaryReader, - ) - if err != nil { - return err - } - - chanSummary.IsPending = false - - var newSummary bytes.Buffer - err = serializeChannelCloseSummary(&newSummary, chanSummary) - if err != nil { - return err - } - - err = closedChanBucket.Put(chanID, newSummary.Bytes()) - if err != nil { - return err - } - - // Now that the channel is closed, we'll check if we have any - // other open channels with this peer. If we don't we'll - // garbage collect it to ensure we don't establish persistent - // connections to peers without open channels. - return d.pruneLinkNode(tx, chanSummary.RemotePub) - }) -} - -// pruneLinkNode determines whether we should garbage collect a link node from -// the database due to no longer having any open channels with it. If there are -// any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error { - openChannels, err := d.fetchOpenChannels(tx, remotePub) - if err != nil { - return fmt.Errorf("unable to fetch open channels for peer %x: "+ - "%v", remotePub.SerializeCompressed(), err) - } - - if len(openChannels) > 0 { - return nil - } - - log.Infof("Pruning link node %x with zero open channels from database", - remotePub.SerializeCompressed()) - - return d.deleteLinkNode(tx, remotePub) -} - -// PruneLinkNodes attempts to prune all link nodes found within the databse with -// whom we no longer have any open channels with. -func (d *DB) PruneLinkNodes() error { - return d.Update(func(tx *bbolt.Tx) error { - linkNodes, err := d.fetchAllLinkNodes(tx) - if err != nil { - return err - } - - for _, linkNode := range linkNodes { - err := d.pruneLinkNode(tx, linkNode.IdentityPub) - if err != nil { - return err - } - } - - return nil - }) -} - -// ChannelShell is a shell of a channel that is meant to be used for channel -// recovery purposes. It contains a minimal OpenChannel instance along with -// addresses for that target node. -type ChannelShell struct { - // NodeAddrs the set of addresses that this node has known to be - // reachable at in the past. - NodeAddrs []net.Addr - - // Chan is a shell of an OpenChannel, it contains only the items - // required to restore the channel on disk. - Chan *OpenChannel -} - -// RestoreChannelShells is a method that allows the caller to reconstruct the -// state of an OpenChannel from the ChannelShell. We'll attempt to write the -// new channel to disk, create a LinkNode instance with the passed node -// addresses, and finally create an edge within the graph for the channel as -// well. This method is idempotent, so repeated calls with the same set of -// channel shells won't modify the database after the initial call. -func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { - chanGraph := d.ChannelGraph() - - // TODO(conner): find way to do this w/o accessing internal members? - chanGraph.cacheMu.Lock() - defer chanGraph.cacheMu.Unlock() - - var chansRestored []uint64 - err := d.Update(func(tx *bbolt.Tx) error { - for _, channelShell := range channelShells { - channel := channelShell.Chan - - // When we make a channel, we mark that the channel has - // been restored, this will signal to other sub-systems - // to not attempt to use the channel as if it was a - // regular one. - channel.chanStatus |= ChanStatusRestored - - // First, we'll attempt to create a new open channel - // and link node for this channel. If the channel - // already exists, then in order to ensure this method - // is idempotent, we'll continue to the next step. - channel.Db = d - err := syncNewChannel( - tx, channel, channelShell.NodeAddrs, - ) - if err != nil { - return err - } - - // Next, we'll create an active edge in the graph - // database for this channel in order to restore our - // partial view of the network. - // - // TODO(roasbeef): if we restore *after* the channel - // has been closed on chain, then need to inform the - // router that it should try and prune these values as - // we can detect them - edgeInfo := ChannelEdgeInfo{ - ChannelID: channel.ShortChannelID.ToUint64(), - ChainHash: channel.ChainHash, - ChannelPoint: channel.FundingOutpoint, - Capacity: channel.Capacity, - } - - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - selfNode, err := chanGraph.sourceNode(nodes) - if err != nil { - return err - } - - // Depending on which pub key is smaller, we'll assign - // our roles as "node1" and "node2". - chanPeer := channel.IdentityPub.SerializeCompressed() - selfIsSmaller := bytes.Compare( - selfNode.PubKeyBytes[:], chanPeer, - ) == -1 - if selfIsSmaller { - copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], chanPeer) - } else { - copy(edgeInfo.NodeKey1Bytes[:], chanPeer) - copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) - } - - // With the edge info shell constructed, we'll now add - // it to the graph. - err = chanGraph.addChannelEdge(tx, &edgeInfo) - if err != nil && err != ErrEdgeAlreadyExist { - return err - } - - // Similarly, we'll construct a channel edge shell and - // add that itself to the graph. - chanEdge := ChannelEdgePolicy{ - ChannelID: edgeInfo.ChannelID, - LastUpdate: time.Now(), - } - - // If their pubkey is larger, then we'll flip the - // direction bit to indicate that us, the "second" node - // is updating their policy. - if !selfIsSmaller { - chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection - } - - _, err = updateEdgePolicy(tx, &chanEdge) - if err != nil { - return err - } - - chansRestored = append(chansRestored, edgeInfo.ChannelID) - } - - return nil - }) - if err != nil { - return err - } - - for _, chanid := range chansRestored { - chanGraph.rejectCache.remove(chanid) - chanGraph.chanCache.remove(chanid) - } - - return nil -} - -// AddrsForNode consults the graph and channel database for all addresses known -// to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - var ( - linkNode *LinkNode - graphNode LightningNode - ) - - dbErr := d.View(func(tx *bbolt.Tx) error { - var err error - - linkNode, err = fetchLinkNode(tx, nodePub) - if err != nil { - return err - } - - // We'll also query the graph for this peer to see if they have - // any addresses that we don't currently have stored within the - // link node database. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - compressedPubKey := nodePub.SerializeCompressed() - graphNode, err = fetchLightningNode(nodes, compressedPubKey) - if err != nil && err != ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still - // have the link node data. - return err - } - - return nil - }) - if dbErr != nil { - return nil, dbErr - } - - // Now that we have both sources of addrs for this node, we'll use a - // map to de-duplicate any addresses between the two sources, and - // produce a final list of the combined addrs. - addrs := make(map[string]net.Addr) - for _, addr := range linkNode.Addresses { - addrs[addr.String()] = addr - } - for _, addr := range graphNode.Addresses { - addrs[addr.String()] = addr - } - dedupedAddrs := make([]net.Addr, 0, len(addrs)) - for _, addr := range addrs { - dedupedAddrs = append(dedupedAddrs, addr) - } - - return dedupedAddrs, nil -} - // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. diff --git a/channeldb/migration_01_to_11/db_test.go b/channeldb/migration_01_to_11/db_test.go deleted file mode 100644 index 721546e7..00000000 --- a/channeldb/migration_01_to_11/db_test.go +++ /dev/null @@ -1,471 +0,0 @@ -package migration_01_to_11 - -import ( - "io/ioutil" - "math" - "math/rand" - "net" - "os" - "path/filepath" - "reflect" - "testing" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcutil" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/shachain" -) - -func TestOpenWithCreate(t *testing.T) { - t.Parallel() - - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - defer os.RemoveAll(tempDirName) - - // Next, open thereby creating channeldb for the first time. - dbPath := filepath.Join(tempDirName, "cdb") - cdb, err := Open(dbPath) - if err != nil { - t.Fatalf("unable to create channeldb: %v", err) - } - if err := cdb.Close(); err != nil { - t.Fatalf("unable to close channeldb: %v", err) - } - - // The path should have been successfully created. - if !fileExists(dbPath) { - t.Fatalf("channeldb failed to create data directory") - } -} - -// TestWipe tests that the database wipe operation completes successfully -// and that the buckets are deleted. It also checks that attempts to fetch -// information while the buckets are not set return the correct errors. -func TestWipe(t *testing.T) { - t.Parallel() - - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - defer os.RemoveAll(tempDirName) - - // Next, open thereby creating channeldb for the first time. - dbPath := filepath.Join(tempDirName, "cdb") - cdb, err := Open(dbPath) - if err != nil { - t.Fatalf("unable to create channeldb: %v", err) - } - defer cdb.Close() - - if err := cdb.Wipe(); err != nil { - t.Fatalf("unable to wipe channeldb: %v", err) - } - // Check correct errors are returned - _, err = cdb.FetchAllOpenChannels() - if err != ErrNoActiveChannels { - t.Fatalf("fetching open channels: expected '%v' instead got '%v'", - ErrNoActiveChannels, err) - } - _, err = cdb.FetchClosedChannels(false) - if err != ErrNoClosedChannels { - t.Fatalf("fetching closed channels: expected '%v' instead got '%v'", - ErrNoClosedChannels, err) - } -} - -// TestFetchClosedChannelForID tests that we are able to properly retrieve a -// ChannelCloseSummary from the DB given a ChannelID. -func TestFetchClosedChannelForID(t *testing.T) { - t.Parallel() - - const numChans = 101 - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state, that we will mutate the index of the - // funding point. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Now run through the number of channels, and modify the outpoint index - // to create new channel IDs. - for i := uint32(0); i < numChans; i++ { - // Save the open channel to disk. - state.FundingOutpoint.Index = i - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel "+ - "state: %v", err) - } - - // Close the channel. To make sure we retrieve the correct - // summary later, we make them differ in the SettledBalance. - closeSummary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - RemotePub: state.IdentityPub, - SettledBalance: btcutil.Amount(500 + i), - } - if err := state.CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - } - - // Now run though them all again and make sure we are able to retrieve - // summaries from the DB. - for i := uint32(0); i < numChans; i++ { - state.FundingOutpoint.Index = i - - // We calculate the ChannelID and use it to fetch the summary. - cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) - fetchedSummary, err := cdb.FetchClosedChannelForID(cid) - if err != nil { - t.Fatalf("unable to fetch close summary: %v", err) - } - - // Make sure we retrieved the correct one by checking the - // SettledBalance. - if fetchedSummary.SettledBalance != btcutil.Amount(500+i) { - t.Fatalf("summaries don't match: expected %v got %v", - btcutil.Amount(500+i), - fetchedSummary.SettledBalance) - } - } - - // As a final test we make sure that we get ErrClosedChannelNotFound - // for a ChannelID we didn't add to the DB. - state.FundingOutpoint.Index++ - cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) - _, err = cdb.FetchClosedChannelForID(cid) - if err != ErrClosedChannelNotFound { - t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) - } -} - -// TestAddrsForNode tests the we're able to properly obtain all the addresses -// for a target node. -func TestAddrsForNode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - graph := cdb.ChannelGraph() - - // We'll make a test vertex to insert into the database, as the source - // node, but this node will only have half the number of addresses it - // usually does. - testNode, err := createTestVertex(cdb) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - testNode.Addresses = []net.Addr{testAddr} - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // Next, we'll make a link node with the same pubkey, but with an - // additional address. - nodePub, err := testNode.PubKey() - if err != nil { - t.Fatalf("unable to recv node pub: %v", err) - } - linkNode := cdb.NewLinkNode( - wire.MainNet, nodePub, anotherAddr, - ) - if err := linkNode.Sync(); err != nil { - t.Fatalf("unable to sync link node: %v", err) - } - - // Now that we've created a link node, as well as a vertex for the - // node, we'll query for all its addresses. - nodeAddrs, err := cdb.AddrsForNode(nodePub) - if err != nil { - t.Fatalf("unable to obtain node addrs: %v", err) - } - - expectedAddrs := make(map[string]struct{}) - expectedAddrs[testAddr.String()] = struct{}{} - expectedAddrs[anotherAddr.String()] = struct{}{} - - // Finally, ensure that all the expected addresses are found. - if len(nodeAddrs) != len(expectedAddrs) { - t.Fatalf("expected %v addrs, got %v", - len(expectedAddrs), len(nodeAddrs)) - } - for _, addr := range nodeAddrs { - if _, ok := expectedAddrs[addr.String()]; !ok { - t.Fatalf("unexpected addr: %v", addr) - } - } -} - -// TestFetchChannel tests that we're able to fetch an arbitrary channel from -// disk. -func TestFetchChannel(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state that we'll sync to the database - // shortly. - channelState, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Mark the channel as pending, then immediately mark it as open to it - // can be fully visible. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - if err := channelState.SyncPending(addr, 9); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99)) - if err != nil { - t.Fatalf("unable to mark channel open: %v", err) - } - - // Next, attempt to fetch the channel by its chan point. - dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) - if err != nil { - t.Fatalf("unable to fetch channel: %v", err) - } - - // The decoded channel state should be identical to what we stored - // above. - if !reflect.DeepEqual(channelState, dbChannel) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(channelState), spew.Sdump(dbChannel)) - } - - // If we attempt to query for a non-exist ante channel, then we should - // get an error. - channelState2, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - channelState2.FundingOutpoint.Index ^= 1 - - _, err = cdb.FetchChannel(channelState2.FundingOutpoint) - if err == nil { - t.Fatalf("expected query to fail") - } -} - -func genRandomChannelShell() (*ChannelShell, error) { - var testPriv [32]byte - if _, err := rand.Read(testPriv[:]); err != nil { - return nil, err - } - - _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) - - var chanPoint wire.OutPoint - if _, err := rand.Read(chanPoint.Hash[:]); err != nil { - return nil, err - } - - pub.Curve = nil - - chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) - - chanStatus := ChanStatusDefault | ChanStatusRestored - - var shaChainPriv [32]byte - if _, err := rand.Read(testPriv[:]); err != nil { - return nil, err - } - revRoot, err := chainhash.NewHash(shaChainPriv[:]) - if err != nil { - return nil, err - } - shaChainProducer := shachain.NewRevocationProducer(*revRoot) - - return &ChannelShell{ - NodeAddrs: []net.Addr{&net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }}, - Chan: &OpenChannel{ - chanStatus: chanStatus, - ChainHash: rev, - FundingOutpoint: chanPoint, - ShortChannelID: lnwire.NewShortChanIDFromInt( - uint64(rand.Int63()), - ), - IdentityPub: pub, - LocalChanCfg: ChannelConfig{ - ChannelConstraints: ChannelConstraints{ - CsvDelay: uint16(rand.Int63()), - }, - PaymentBasePoint: keychain.KeyDescriptor{ - KeyLocator: keychain.KeyLocator{ - Family: keychain.KeyFamily(rand.Int63()), - Index: uint32(rand.Int63()), - }, - }, - }, - RemoteCurrentRevocation: pub, - IsPending: false, - RevocationStore: shachain.NewRevocationStore(), - RevocationProducer: shaChainProducer, - }, - }, nil -} - -// TestRestoreChannelShells tests that we're able to insert a partially channel -// populated to disk. This is useful for channel recovery purposes. We should -// find the new channel shell on disk, and also the db should be populated with -// an edge for that channel. -func TestRestoreChannelShells(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First, we'll make our channel shell, it will only have the minimal - // amount of information required for us to initiate the data loss - // protection feature. - channelShell, err := genRandomChannelShell() - if err != nil { - t.Fatalf("unable to gen channel shell: %v", err) - } - - graph := cdb.ChannelGraph() - - // Before we can restore the channel, we'll need to make a source node - // in the graph as the channel edge we create will need to have a - // origin. - testNode, err := createTestVertex(cdb) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // With the channel shell constructed, we'll now insert it into the - // database with the restoration method. - if err := cdb.RestoreChannelShells(channelShell); err != nil { - t.Fatalf("unable to restore channel shell: %v", err) - } - - // Now that the channel has been inserted, we'll attempt to query for - // it to ensure we can properly locate it via various means. - // - // First, we'll attempt to query for all channels that we have with the - // node public key that was restored. - nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub) - if err != nil { - t.Fatalf("unable find channel: %v", err) - } - - // We should now find a single channel from the database. - if len(nodeChans) != 1 { - t.Fatalf("unable to find restored channel by node "+ - "pubkey: %v", err) - } - - // Ensure that it isn't possible to modify the commitment state machine - // of this restored channel. - channel := nodeChans[0] - err = channel.UpdateCommitment(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - err = channel.AppendRemoteCommitChain(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - err = channel.AdvanceCommitChainTail(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - - // That single channel should have the proper channel point, and also - // the expected set of flags to indicate that it was a restored - // channel. - if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint { - t.Fatalf("wrong funding outpoint: expected %v, got %v", - nodeChans[0].FundingOutpoint, - channelShell.Chan.FundingOutpoint) - } - if !nodeChans[0].HasChanStatus(ChanStatusRestored) { - t.Fatalf("node has wrong status flags: %v", - nodeChans[0].chanStatus) - } - - // We should also be able to find the channel if we query for it - // directly. - _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) - if err != nil { - t.Fatalf("unable to fetch channel: %v", err) - } - - // We should also be able to find the link node that was inserted by - // its public key. - linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch link node: %v", err) - } - - // The node should have the same address, as specified in the channel - // shell. - if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) { - t.Fatalf("addr mismach: expected %v, got %v", - linkNode.Addresses, channelShell.NodeAddrs) - } - - // Finally, we'll ensure that the edge for the channel was properly - // inserted. - chanInfos, err := graph.FetchChanInfos( - []uint64{channelShell.Chan.ShortChannelID.ToUint64()}, - ) - if err != nil { - t.Fatalf("unable to find edges: %v", err) - } - - if len(chanInfos) != 1 { - t.Fatalf("wrong amount of chan infos: expected %v got %v", - len(chanInfos), 1) - } - - // We should only find a single edge. - if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil { - t.Fatalf("only a single edge should be inserted: %v", err) - } -} diff --git a/channeldb/migration_01_to_11/doc.go b/channeldb/migration_01_to_11/doc.go deleted file mode 100644 index c90412f2..00000000 --- a/channeldb/migration_01_to_11/doc.go +++ /dev/null @@ -1 +0,0 @@ -package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/error.go b/channeldb/migration_01_to_11/error.go index f264fb70..232aaa2b 100644 --- a/channeldb/migration_01_to_11/error.go +++ b/channeldb/migration_01_to_11/error.go @@ -1,55 +1,23 @@ package migration_01_to_11 import ( - "errors" "fmt" ) var ( - // ErrNoChanDBExists is returned when a channel bucket hasn't been - // created. - ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") // ErrDBReversion is returned when detecting an attempt to revert to a // prior database version. ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version") - // ErrLinkNodesNotFound is returned when node info bucket hasn't been - // created. - ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist") - - // ErrNoActiveChannels is returned when there is no active (open) - // channels within the database. - ErrNoActiveChannels = fmt.Errorf("no active channels exist") - - // ErrNoPastDeltas is returned when the channel delta bucket hasn't been - // created. - ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") - - // ErrInvoiceNotFound is returned when a targeted invoice can't be - // found. - ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice") - // ErrNoInvoicesCreated is returned when we don't have invoices in // our database to return. ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") - // ErrDuplicateInvoice is returned when an invoice with the target - // payment hash already exists. - ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") - // ErrNoPaymentsCreated is returned when bucket of payments hasn't been // created. ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") - // ErrNodeNotFound is returned when node bucket exists, but node with - // specific identity can't be found. - ErrNodeNotFound = fmt.Errorf("link node with target identity not found") - - // ErrChannelNotFound is returned when we attempt to locate a channel - // for a specific chain, but it is not found. - ErrChannelNotFound = fmt.Errorf("channel not found") - // ErrMetaNotFound is returned when meta bucket hasn't been // created. ErrMetaNotFound = fmt.Errorf("unable to locate meta information") @@ -58,22 +26,11 @@ var ( // graph doesn't exist. ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") - // ErrGraphNeverPruned is returned when graph was never pruned. - ErrGraphNeverPruned = fmt.Errorf("graph never pruned") - // ErrSourceNodeNotSet is returned if the source node of the graph // hasn't been added The source node is the center node within a // star-graph. ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") - // ErrGraphNodesNotFound is returned in case none of the nodes has - // been added in graph node bucket. - ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") - - // ErrGraphNoEdgesFound is returned in case of none of the channel/edges - // has been added in graph edge bucket. - ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") - // ErrGraphNodeNotFound is returned when we're unable to find the target // node. ErrGraphNodeNotFound = fmt.Errorf("unable to find node") @@ -82,17 +39,6 @@ var ( // can't be found. ErrEdgeNotFound = fmt.Errorf("edge not found") - // ErrZombieEdge is an error returned when we attempt to look up an edge - // but it is marked as a zombie within the zombie index. - ErrZombieEdge = errors.New("edge marked as zombie") - - // ErrEdgeAlreadyExist is returned when edge with specific - // channel id can't be added because it already exist. - ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") - - // ErrNodeAliasNotFound is returned when alias for node can't be found. - ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") - // ErrUnknownAddressType is returned when a node's addressType is not // an expected value. ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") @@ -101,20 +47,11 @@ var ( // channels it has closed, but it hasn't yet closed any channels. ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") - // ErrNoForwardingEvents is returned in the case that a query fails due - // to the log not having any recorded events. - ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events") - // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel // policy field is not found in the db even though its message flags // indicate it should be. ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + "present") - - // ErrChanAlreadyExists is return when the caller attempts to create a - // channel with a channel point that is already present in the - // database. - ErrChanAlreadyExists = fmt.Errorf("channel already exists") ) // ErrTooManyExtraOpaqueBytes creates an error which should be returned if the diff --git a/channeldb/migration_01_to_11/fees.go b/channeldb/migration_01_to_11/fees.go deleted file mode 100644 index c90412f2..00000000 --- a/channeldb/migration_01_to_11/fees.go +++ /dev/null @@ -1 +0,0 @@ -package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/forwarding_log.go b/channeldb/migration_01_to_11/forwarding_log.go deleted file mode 100644 index 6b9f8f5d..00000000 --- a/channeldb/migration_01_to_11/forwarding_log.go +++ /dev/null @@ -1,274 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "io" - "sort" - "time" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" -) - -var ( - // forwardingLogBucket is the bucket that we'll use to store the - // forwarding log. The forwarding log contains a time series database - // of the forwarding history of a lightning daemon. Each key within the - // bucket is a timestamp (in nano seconds since the unix epoch), and - // the value a slice of a forwarding event for that timestamp. - forwardingLogBucket = []byte("circuit-fwd-log") -) - -const ( - // forwardingEventSize is the size of a forwarding event. The breakdown - // is as follows: - // - // * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in - // || 8 byte value out - // - // From the value in and value out, callers can easily compute the - // total fee extract from a forwarding event. - forwardingEventSize = 32 - - // MaxResponseEvents is the max number of forwarding events that will - // be returned by a single query response. This size was selected to - // safely remain under gRPC's 4MiB message size response limit. As each - // full forwarding event (including the timestamp) is 40 bytes, we can - // safely return 50k entries in a single response. - MaxResponseEvents = 50000 -) - -// ForwardingLog returns an instance of the ForwardingLog object backed by the -// target database instance. -func (d *DB) ForwardingLog() *ForwardingLog { - return &ForwardingLog{ - db: d, - } -} - -// ForwardingLog is a time series database that logs the fulfilment of payment -// circuits by a lightning network daemon. The log contains a series of -// forwarding events which map a timestamp to a forwarding event. A forwarding -// event describes which channels were used to create+settle a circuit, and the -// amount involved. Subtracting the outgoing amount from the incoming amount -// reveals the fee charged for the forwarding service. -type ForwardingLog struct { - db *DB -} - -// ForwardingEvent is an event in the forwarding log's time series. Each -// forwarding event logs the creation and tear-down of a payment circuit. A -// circuit is created once an incoming HTLC has been fully forwarded, and -// destroyed once the payment has been settled. -type ForwardingEvent struct { - // Timestamp is the settlement time of this payment circuit. - Timestamp time.Time - - // IncomingChanID is the incoming channel ID of the payment circuit. - IncomingChanID lnwire.ShortChannelID - - // OutgoingChanID is the outgoing channel ID of the payment circuit. - OutgoingChanID lnwire.ShortChannelID - - // AmtIn is the amount of the incoming HTLC. Subtracting this from the - // outgoing amount gives the total fees of this payment circuit. - AmtIn lnwire.MilliSatoshi - - // AmtOut is the amount of the outgoing HTLC. Subtracting the incoming - // amount from this gives the total fees for this payment circuit. - AmtOut lnwire.MilliSatoshi -} - -// encodeForwardingEvent writes out the target forwarding event to the passed -// io.Writer, using the expected DB format. Note that the timestamp isn't -// serialized as this will be the key value within the bucket. -func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { - return WriteElements( - w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, - ) -} - -// decodeForwardingEvent attempts to decode the raw bytes of a serialized -// forwarding event into the target ForwardingEvent. Note that the timestamp -// won't be decoded, as the caller is expected to set this due to the bucket -// structure of the forwarding log. -func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { - return ReadElements( - r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, - ) -} - -// AddForwardingEvents adds a series of forwarding events to the database. -// Before inserting, the set of events will be sorted according to their -// timestamp. This ensures that all writes to disk are sequential. -func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { - // Before we create the database transaction, we'll ensure that the set - // of forwarding events are properly sorted according to their - // timestamp. - sort.Slice(events, func(i, j int) bool { - return events[i].Timestamp.Before(events[j].Timestamp) - }) - - var timestamp [8]byte - - return f.db.Batch(func(tx *bbolt.Tx) error { - // First, we'll fetch the bucket that stores our time series - // log. - logBucket, err := tx.CreateBucketIfNotExists( - forwardingLogBucket, - ) - if err != nil { - return err - } - - // With the bucket obtained, we can now begin to write out the - // series of events. - for _, event := range events { - var eventBytes [forwardingEventSize]byte - eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) - - // First, we'll serialize this timestamp into our - // timestamp buffer. - byteOrder.PutUint64( - timestamp[:], uint64(event.Timestamp.UnixNano()), - ) - - // With the key encoded, we'll then encode the event - // into our buffer, then write it out to disk. - err := encodeForwardingEvent(eventBuf, &event) - if err != nil { - return err - } - err = logBucket.Put(timestamp[:], eventBuf.Bytes()) - if err != nil { - return err - } - } - - return nil - }) -} - -// ForwardingEventQuery represents a query to the forwarding log payment -// circuit time series database. The query allows a caller to retrieve all -// records for a particular time slice, offset in that time slice, limiting the -// total number of responses returned. -type ForwardingEventQuery struct { - // StartTime is the start time of the time slice. - StartTime time.Time - - // EndTime is the end time of the time slice. - EndTime time.Time - - // IndexOffset is the offset within the time slice to start at. This - // can be used to start the response at a particular record. - IndexOffset uint32 - - // NumMaxEvents is the max number of events to return. - NumMaxEvents uint32 -} - -// ForwardingLogTimeSlice is the response to a forwarding query. It includes -// the original query, the set events that match the query, and an integer -// which represents the offset index of the last item in the set of retuned -// events. This integer allows callers to resume their query using this offset -// in the event that the query's response exceeds the max number of returnable -// events. -type ForwardingLogTimeSlice struct { - ForwardingEventQuery - - // ForwardingEvents is the set of events in our time series that answer - // the query embedded above. - ForwardingEvents []ForwardingEvent - - // LastIndexOffset is the index of the last element in the set of - // returned ForwardingEvents above. Callers can use this to resume - // their query in the event that the time slice has too many events to - // fit into a single response. - LastIndexOffset uint32 -} - -// Query allows a caller to query the forwarding event time series for a -// particular time slice. The caller can control the precise time as well as -// the number of events to be returned. -// -// TODO(roasbeef): rename? -func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { - resp := ForwardingLogTimeSlice{ - ForwardingEventQuery: q, - } - - // If the user provided an index offset, then we'll not know how many - // records we need to skip. We'll also keep track of the record offset - // as that's part of the final return value. - recordsToSkip := q.IndexOffset - recordOffset := q.IndexOffset - - err := f.db.View(func(tx *bbolt.Tx) error { - // If the bucket wasn't found, then there aren't any events to - // be returned. - logBucket := tx.Bucket(forwardingLogBucket) - if logBucket == nil { - return ErrNoForwardingEvents - } - - // We'll be using a cursor to seek into the database, so we'll - // populate byte slices that represent the start of the key - // space we're interested in, and the end. - var startTime, endTime [8]byte - byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano())) - byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano())) - - // If we know that a set of log events exists, then we'll begin - // our seek through the log in order to satisfy the query. - // We'll continue until either we reach the end of the range, - // or reach our max number of events. - logCursor := logBucket.Cursor() - timestamp, events := logCursor.Seek(startTime[:]) - for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() { - // If our current return payload exceeds the max number - // of events, then we'll exit now. - if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents { - return nil - } - - // If we're not yet past the user defined offset, then - // we'll continue to seek forward. - if recordsToSkip > 0 { - recordsToSkip-- - continue - } - - currentTime := time.Unix( - 0, int64(byteOrder.Uint64(timestamp)), - ) - - // At this point, we've skipped enough records to start - // to collate our query. For each record, we'll - // increment the final record offset so the querier can - // utilize pagination to seek further. - readBuf := bytes.NewReader(events) - for readBuf.Len() != 0 { - var event ForwardingEvent - err := decodeForwardingEvent(readBuf, &event) - if err != nil { - return err - } - - event.Timestamp = currentTime - resp.ForwardingEvents = append(resp.ForwardingEvents, event) - - recordOffset++ - } - } - - return nil - }) - if err != nil && err != ErrNoForwardingEvents { - return ForwardingLogTimeSlice{}, err - } - - resp.LastIndexOffset = recordOffset - - return resp, nil -} diff --git a/channeldb/migration_01_to_11/forwarding_log_test.go b/channeldb/migration_01_to_11/forwarding_log_test.go deleted file mode 100644 index 9e0de7c4..00000000 --- a/channeldb/migration_01_to_11/forwarding_log_test.go +++ /dev/null @@ -1,265 +0,0 @@ -package migration_01_to_11 - -import ( - "math/rand" - "reflect" - "testing" - - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lnwire" - - "time" -) - -// TestForwardingLogBasicStorageAndQuery tests that we're able to store and -// then query for items that have previously been added to the event log. -func TestForwardingLogBasicStorageAndQuery(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - timestamp := time.Unix(1234, 0) - - // We'll create 100 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 100 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: timestamp, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - timestamp = timestamp.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // With our events added we'll now construct a basic query to retrieve - // all of the events. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: timestamp, - IndexOffset: 0, - NumMaxEvents: 1000, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // The set of returned events should match identically, as they should - // be returned in sorted order. - if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) { - t.Fatalf("event mismatch: expected %v vs %v", - spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The offset index of the final entry should be numEvents, so the - // number of total events we've written. - if timeSlice.LastIndexOffset != uint32(numEvents) { - t.Fatalf("wrong final offset: expected %v, got %v", - timeSlice.LastIndexOffset, numEvents) - } -} - -// TestForwardingLogQueryOptions tests that the query offset works properly. So -// if we add a series of events, then we should be able to seek within the -// timeslice accordingly. This exercises the index offset and num max event -// field in the query, and also the last index offset field int he response. -func TestForwardingLogQueryOptions(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - endTime := time.Unix(1234, 0) - - // We'll create 20 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 20 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: endTime, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - endTime = endTime.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // With all of our events added, we should be able to query for the - // first 10 events using the max event query field. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: endTime, - IndexOffset: 0, - NumMaxEvents: 10, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 10 events back. - if len(timeSlice.ForwardingEvents) != 10 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The set of events returned should be the first 10 events that we - // added. - if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[:10]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The final offset should be the exact number of events returned. - if timeSlice.LastIndexOffset != 10 { - t.Fatalf("wrong index offset: expected %v, got %v", 10, - timeSlice.LastIndexOffset) - } - - // If we use the final offset to query again, then we should get 10 - // more events, that are the last 10 events we wrote. - eventQuery.IndexOffset = 10 - timeSlice, err = log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 10 events back once again. - if len(timeSlice.ForwardingEvents) != 10 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The events that we got back should be the last 10 events that we - // wrote out. - if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[10:]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // Finally, the last index offset should be 20, or the number of - // records we've written out. - if timeSlice.LastIndexOffset != 20 { - t.Fatalf("wrong index offset: expected %v, got %v", 20, - timeSlice.LastIndexOffset) - } -} - -// TestForwardingLogQueryLimit tests that we're able to properly limit the -// number of events that are returned as part of a query. -func TestForwardingLogQueryLimit(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - endTime := time.Unix(1234, 0) - - // We'll create 200 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 200 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: endTime, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - endTime = endTime.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // Once the events have been written out, we'll issue a query over the - // entire range, but restrict the number of events to the first 100. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: endTime, - IndexOffset: 0, - NumMaxEvents: 100, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 100 events back. - if len(timeSlice.ForwardingEvents) != 100 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The set of events returned should be the first 100 events that we - // added. - if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[:100]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The final offset should be the exact number of events returned. - if timeSlice.LastIndexOffset != 100 { - t.Fatalf("wrong index offset: expected %v, got %v", 100, - timeSlice.LastIndexOffset) - } -} diff --git a/channeldb/migration_01_to_11/forwarding_package.go b/channeldb/migration_01_to_11/forwarding_package.go deleted file mode 100644 index cbbf90cf..00000000 --- a/channeldb/migration_01_to_11/forwarding_package.go +++ /dev/null @@ -1,928 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" -) - -// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding -// package has potentially been mangled. -var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") - -// FwdState is an enum used to describe the lifecycle of a FwdPkg. -type FwdState byte - -const ( - // FwdStateLockedIn is the starting state for all forwarding packages. - // Packages in this state have not yet committed to the exact set of - // Adds to forward to the switch. - FwdStateLockedIn FwdState = iota - - // FwdStateProcessed marks the state in which all Adds have been - // locally processed and the forwarding decision to the switch has been - // persisted. - FwdStateProcessed - - // FwdStateCompleted signals that all Adds have been acked, and that all - // settles and fails have been delivered to their sources. Packages in - // this state can be removed permanently. - FwdStateCompleted -) - -var ( - // fwdPackagesKey is the root-level bucket that all forwarding packages - // are written. This bucket is further subdivided based on the short - // channel ID of each channel. - fwdPackagesKey = []byte("fwd-packages") - - // addBucketKey is the bucket to which all Add log updates are written. - addBucketKey = []byte("add-updates") - - // failSettleBucketKey is the bucket to which all Settle/Fail log - // updates are written. - failSettleBucketKey = []byte("fail-settle-updates") - - // fwdFilterKey is a key used to write the set of Adds that passed - // validation and are to be forwarded to the switch. - // NOTE: The presence of this key within a forwarding package indicates - // that the package has reached FwdStateProcessed. - fwdFilterKey = []byte("fwd-filter-key") - - // ackFilterKey is a key used to access the PkgFilter indicating which - // Adds have received a Settle/Fail. This response may come from a - // number of sources, including: exitHop settle/fails, switch failures, - // chain arbiter interjections, as well as settle/fails from the - // next hop in the route. - ackFilterKey = []byte("ack-filter-key") - - // settleFailFilterKey is a key used to access the PkgFilter indicating - // which Settles/Fails in have been received and processed by the link - // that originally received the Add. - settleFailFilterKey = []byte("settle-fail-filter-key") -) - -// PkgFilter is used to compactly represent a particular subset of the Adds in a -// forwarding package. Each filter is represented as a simple, statically-sized -// bitvector, where the elements are intended to be the indices of the Adds as -// they are written in the FwdPkg. -type PkgFilter struct { - count uint16 - filter []byte -} - -// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. -func NewPkgFilter(count uint16) *PkgFilter { - // We add 7 to ensure that the integer division yields properly rounded - // values. - filterLen := (count + 7) / 8 - - return &PkgFilter{ - count: count, - filter: make([]byte, filterLen), - } -} - -// Count returns the number of elements represented by this PkgFilter. -func (f *PkgFilter) Count() uint16 { - return f.count -} - -// Set marks the `i`-th element as included by this filter. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Set(i uint16) { - byt := i / 8 - bit := i % 8 - - // Set the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - f.filter[byt] |= byte(1 << (7 - bit)) -} - -// Contains queries the filter for membership of index `i`. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Contains(i uint16) bool { - byt := i / 8 - bit := i % 8 - - // Read the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - return f.filter[byt]&(1<<(7-bit)) != 0 -} - -// Equal checks two PkgFilters for equality. -func (f *PkgFilter) Equal(f2 *PkgFilter) bool { - if f == f2 { - return true - } - if f.count != f2.count { - return false - } - - return bytes.Equal(f.filter, f2.filter) -} - -// IsFull returns true if every element in the filter has been Set, and false -// otherwise. -func (f *PkgFilter) IsFull() bool { - // Batch validate bytes that are fully used. - for i := uint16(0); i < f.count/8; i++ { - if f.filter[i] != 0xFF { - return false - } - } - - // If the count is not a multiple of 8, check that the filter contains - // all remaining bits. - rem := f.count % 8 - for idx := f.count - rem; idx < f.count; idx++ { - if !f.Contains(idx) { - return false - } - } - - return true -} - -// Size returns number of bytes produced when the PkgFilter is serialized. -func (f *PkgFilter) Size() uint16 { - // 2 bytes for uint16 `count`, then round up number of bytes required to - // represent `count` bits. - return 2 + (f.count+7)/8 -} - -// Encode writes the filter to the provided io.Writer. -func (f *PkgFilter) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, f.count); err != nil { - return err - } - - _, err := w.Write(f.filter) - - return err -} - -// Decode reads the filter from the provided io.Reader. -func (f *PkgFilter) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { - return err - } - - f.filter = make([]byte, f.Size()-2) - _, err := io.ReadFull(r, f.filter) - - return err -} - -// FwdPkg records all adds, settles, and fails that were locked in as a result -// of the remote peer sending us a revocation. Each package is identified by -// the short chanid and remote commitment height corresponding to the revocation -// that locked in the HTLCs. For everything except a locally initiated payment, -// settles and fails in a forwarding package must have a corresponding Add in -// another package, and can be removed individually once the source link has -// received the fail/settle. -// -// Adds cannot be removed, as we need to present the same batch of Adds to -// properly handle replay protection. Instead, we use a PkgFilter to mark that -// we have finished processing a particular Add. A FwdPkg should only be deleted -// after the AckFilter is full and all settles and fails have been persistently -// removed. -type FwdPkg struct { - // Source identifies the channel that wrote this forwarding package. - Source lnwire.ShortChannelID - - // Height is the height of the remote commitment chain that locked in - // this forwarding package. - Height uint64 - - // State signals the persistent condition of the package and directs how - // to reprocess the package in the event of failures. - State FwdState - - // Adds contains all add messages which need to be processed and - // forwarded to the switch. Adds does not change over the life of a - // forwarding package. - Adds []LogUpdate - - // FwdFilter is a filter containing the indices of all Adds that were - // forwarded to the switch. - FwdFilter *PkgFilter - - // AckFilter is a filter containing the indices of all Adds for which - // the source has received a settle or fail and is reflected in the next - // commitment txn. A package should not be removed until IsFull() - // returns true. - AckFilter *PkgFilter - - // SettleFails contains all settle and fail messages that should be - // forwarded to the switch. - SettleFails []LogUpdate - - // SettleFailFilter is a filter containing the indices of all Settle or - // Fails originating in this package that have been received and locked - // into the incoming link's commitment state. - SettleFailFilter *PkgFilter -} - -// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This -// should be used to create a package at the time we receive a revocation. -func NewFwdPkg(source lnwire.ShortChannelID, height uint64, - addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { - - nAddUpdates := uint16(len(addUpdates)) - nSettleFailUpdates := uint16(len(settleFailUpdates)) - - return &FwdPkg{ - Source: source, - Height: height, - State: FwdStateLockedIn, - Adds: addUpdates, - FwdFilter: NewPkgFilter(nAddUpdates), - AckFilter: NewPkgFilter(nAddUpdates), - SettleFails: settleFailUpdates, - SettleFailFilter: NewPkgFilter(nSettleFailUpdates), - } -} - -// ID returns an unique identifier for this package, used to ensure that sphinx -// replay processing of this batch is idempotent. -func (f *FwdPkg) ID() []byte { - var id = make([]byte, 16) - byteOrder.PutUint64(id[:8], f.Source.ToUint64()) - byteOrder.PutUint64(id[8:], f.Height) - return id -} - -// String returns a human-readable description of the forwarding package. -func (f *FwdPkg) String() string { - return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", - f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) -} - -// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID -// is assumed to be that of the packager. -type AddRef struct { - // Height is the remote commitment height that locked in the Add. - Height uint64 - - // Index is the index of the Add within the fwd pkg's Adds. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// Encode serializes the AddRef to the given io.Writer. -func (a *AddRef) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { - return err - } - - return binary.Write(w, binary.BigEndian, a.Index) -} - -// Decode deserializes the AddRef from the given io.Reader. -func (a *AddRef) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { - return err - } - - return binary.Read(r, binary.BigEndian, &a.Index) -} - -// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A -// channel does not remove its own Settle/Fail htlcs, so the source is provided -// to locate a db bucket belonging to another channel. -type SettleFailRef struct { - // Source identifies the outgoing link that locked in the settle or - // fail. This is then used by the *incoming* link to find the settle - // fail in another link's forwarding packages. - Source lnwire.ShortChannelID - - // Height is the remote commitment height that locked in this - // Settle/Fail. - Height uint64 - - // Index is the index of the Add with the fwd pkg's SettleFails. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// SettleFailAcker is a generic interface providing the ability to acknowledge -// settle/fail HTLCs stored in forwarding packages. -type SettleFailAcker interface { - // AckSettleFails atomically updates the settle-fail filters in *other* - // channels' forwarding packages. - AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error -} - -// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages -// of any active channel. -type GlobalFwdPkgReader interface { - // LoadChannelFwdPkgs loads all known forwarding packages for the given - // channel. - LoadChannelFwdPkgs(tx *bbolt.Tx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) -} - -// FwdOperator defines the interfaces for managing forwarding packages that are -// external to a particular channel. This interface is used by the switch to -// read forwarding packages from arbitrary channels, and acknowledge settles and -// fails for locally-sourced payments. -type FwdOperator interface { - // GlobalFwdPkgReader provides read access to all known forwarding - // packages - GlobalFwdPkgReader - - // SettleFailAcker grants the ability to acknowledge settles or fails - // residing in arbitrary forwarding packages. - SettleFailAcker -} - -// SwitchPackager is a concrete implementation of the FwdOperator interface. -// A SwitchPackager offers the ability to read any forwarding package, and ack -// arbitrary settle and fail HTLCs. -type SwitchPackager struct{} - -// NewSwitchPackager instantiates a new SwitchPackager. -func NewSwitchPackager() *SwitchPackager { - return &SwitchPackager{} -} - -// AckSettleFails atomically updates the settle-fail filters in *other* -// channels' forwarding packages, to mark that the switch has received a settle -// or fail residing in the forwarding package of a link. -func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx, - settleFailRefs ...SettleFailRef) error { - - return ackSettleFails(tx, settleFailRefs) -} - -// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. -func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) { - - return loadChannelFwdPkgs(tx, source) -} - -// FwdPackager supports all operations required to modify fwd packages, such as -// creation, updates, reading, and removal. The interfaces are broken down in -// this way to support future delegation of the subinterfaces. -type FwdPackager interface { - // AddFwdPkg serializes and writes a FwdPkg for this channel at the - // remote commitment height included in the forwarding package. - AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error - - // SetFwdFilter looks up the forwarding package at the remote `height` - // and sets the `fwdFilter`, marking the Adds for which: - // 1) We are not the exit node - // 2) Passed all validation - // 3) Should be forwarded to the switch immediately after a failure - SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error - - // AckAddHtlcs atomically updates the add filters in this channel's - // forwarding packages to mark the resolution of an Add that was - // received from the remote party. - AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error - - // SettleFailAcker allows a link to acknowledge settle/fail HTLCs - // belonging to other channels. - SettleFailAcker - - // LoadFwdPkgs loads all known forwarding packages owned by this - // channel. - LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) - - // RemovePkg deletes a forwarding package owned by this channel at - // the provided remote `height`. - RemovePkg(tx *bbolt.Tx, height uint64) error -} - -// ChannelPackager is used by a channel to manage the lifecycle of its forwarding -// packages. The packager is tied to a particular source channel ID, allowing it -// to create and edit its own packages. Each packager also has the ability to -// remove fail/settle htlcs that correspond to an add contained in one of -// source's packages. -type ChannelPackager struct { - source lnwire.ShortChannelID -} - -// NewChannelPackager creates a new packager for a single channel. -func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { - return &ChannelPackager{ - source: source, - } -} - -// AddFwdPkg writes a newly locked in forwarding package to disk. -func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error { - fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey) - if err != nil { - return err - } - - source := makeLogKey(fwdPkg.Source.ToUint64()) - sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) - if err != nil { - return err - } - - heightKey := makeLogKey(fwdPkg.Height) - heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) - if err != nil { - return err - } - - // Write ADD updates we received at this commit height. - addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) - if err != nil { - return err - } - - // Write SETTLE/FAIL updates we received at this commit height. - failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) - if err != nil { - return err - } - - for i := range fwdPkg.Adds { - err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) - if err != nil { - return err - } - } - - // Persist the initialized pkg filter, which will be used to determine - // when we can remove this forwarding package from disk. - var ackFilterBuf bytes.Buffer - if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { - return err - } - - for i := range fwdPkg.SettleFails { - err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) - if err != nil { - return err - } - } - - var settleFailFilterBuf bytes.Buffer - err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) - if err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. -func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error { - var b bytes.Buffer - if err := htlc.Encode(&b); err != nil { - return err - } - - return bkt.Put(uint16Key(idx), b.Bytes()) -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in a map indexed by the -// remote commitment height at which the updates were locked in. -func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) { - return loadChannelFwdPkgs(tx, p.source) -} - -// loadChannelFwdPkgs loads all forwarding packages owned by `source`. -func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) { - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil, nil - } - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return nil, nil - } - - var heights []uint64 - if err := sourceBkt.ForEach(func(k, _ []byte) error { - if len(k) != 8 { - return ErrCorruptedFwdPkg - } - - heights = append(heights, byteOrder.Uint64(k)) - - return nil - }); err != nil { - return nil, err - } - - // Load the forwarding package for each retrieved height. - fwdPkgs := make([]*FwdPkg, 0, len(heights)) - for _, height := range heights { - fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) - if err != nil { - return nil, err - } - - fwdPkgs = append(fwdPkgs, fwdPkg) - } - - return fwdPkgs, nil -} - -// loadFwPkg reads the packager's fwd pkg at a given height, and determines the -// appropriate FwdState. -func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID, - height uint64) (*FwdPkg, error) { - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - // Load ADDs from disk. - addBkt := heightBkt.Bucket(addBucketKey) - if addBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - adds, err := loadHtlcs(addBkt) - if err != nil { - return nil, err - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - ackFilterReader := bytes.NewReader(ackFilterBytes) - - ackFilter := &PkgFilter{} - if err := ackFilter.Decode(ackFilterReader); err != nil { - return nil, err - } - - // Load SETTLE/FAILs from disk. - failSettleBkt := heightBkt.Bucket(failSettleBucketKey) - if failSettleBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - failSettles, err := loadHtlcs(failSettleBkt) - if err != nil { - return nil, err - } - - // Load settle fail filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - - settleFailFilter := &PkgFilter{} - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return nil, err - } - - // Initialize the fwding package, which always starts in the - // FwdStateLockedIn. We can determine what state the package was left in - // by examining constraints on the information loaded from disk. - fwdPkg := &FwdPkg{ - Source: source, - State: FwdStateLockedIn, - Height: height, - Adds: adds, - AckFilter: ackFilter, - SettleFails: failSettles, - SettleFailFilter: settleFailFilter, - } - - // Check to see if we have written the set exported filter adds to - // disk. If we haven't, processing of this package was never started, or - // failed during the last attempt. - fwdFilterBytes := heightBkt.Get(fwdFilterKey) - if fwdFilterBytes == nil { - nAdds := uint16(len(adds)) - fwdPkg.FwdFilter = NewPkgFilter(nAdds) - return fwdPkg, nil - } - - fwdFilterReader := bytes.NewReader(fwdFilterBytes) - fwdPkg.FwdFilter = &PkgFilter{} - if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { - return nil, err - } - - // Otherwise, a complete round of processing was completed, and we - // advance the package to FwdStateProcessed. - fwdPkg.State = FwdStateProcessed - - // If every add, settle, and fail has been fully acknowledged, we can - // safely set the package's state to FwdStateCompleted, signalling that - // it can be garbage collected. - if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { - fwdPkg.State = FwdStateCompleted - } - - return fwdPkg, nil -} - -// loadHtlcs retrieves all serialized htlcs in a bucket, returning -// them in order of the indexes they were written under. -func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) { - var htlcs []LogUpdate - if err := bkt.ForEach(func(_, v []byte) error { - var htlc LogUpdate - if err := htlc.Decode(bytes.NewReader(v)); err != nil { - return err - } - - htlcs = append(htlcs, htlc) - - return nil - }); err != nil { - return nil, err - } - - return htlcs, nil -} - -// SetFwdFilter writes the set of indexes corresponding to Adds at the -// `height` that are to be forwarded to the switch. Calling this method causes -// the forwarding package at `height` to be in FwdStateProcessed. We write this -// forwarding decision so that we always arrive at the same behavior for HTLCs -// leaving this channel. After a restart, we skip validation of these Adds, -// since they are assumed to have already been validated, and make the switch or -// outgoing link responsible for handling replays. -func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64, - fwdFilter *PkgFilter) error { - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - source := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(source[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - return ErrCorruptedFwdPkg - } - - // If the fwd filter has already been written, we return early to avoid - // modifying the persistent state. - forwardedAddsBytes := heightBkt.Get(fwdFilterKey) - if forwardedAddsBytes != nil { - return nil - } - - // Otherwise we serialize and write the provided fwd filter. - var b bytes.Buffer - if err := fwdFilter.Encode(&b); err != nil { - return err - } - - return heightBkt.Put(fwdFilterKey, b.Bytes()) -} - -// AckAddHtlcs accepts a list of references to add htlcs, and updates the -// AckAddFilter of those forwarding packages to indicate that a settle or fail -// has been received in response to the add. -func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error { - if len(addRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - sourceKey := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique height. - heightDiffs := make(map[uint64][]uint16) - for _, addRef := range addRefs { - heightDiffs[addRef.Height] = append( - heightDiffs[addRef.Height], - addRef.Index, - ) - } - - // Load each height bucket once and remove all acked htlcs at that - // height. - for height, indexes := range heightDiffs { - err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) - if err != nil { - return err - } - } - - return nil -} - -// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package -// with a list of indexes, writing the resulting filter back in its place. -func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is successful, as there is nothing - // to ack. - return nil - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - ackFilter := &PkgFilter{} - ackFilterReader := bytes.NewReader(ackFilterBytes) - if err := ackFilter.Decode(ackFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - ackFilter.Set(index) - } - - // Write the resulting filter to disk. - var ackFilterBuf bytes.Buffer - if err := ackFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) -} - -// AckSettleFails persistently acknowledges settles or fails from a remote forwarding -// package. This should only be called after the source of the Add has locked in -// the settle/fail, or it becomes otherwise safe to forgo retransmitting the -// settle/fail after a restart. -func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error { - return ackSettleFails(tx, settleFailRefs) -} - -// ackSettleFails persistently acknowledges a batch of settle fail references. -func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error { - if len(settleFailRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique destination-height pair. - destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) - for _, settleFailRef := range settleFailRefs { - destHeights, ok := destHeightDiffs[settleFailRef.Source] - if !ok { - destHeights = make(map[uint64][]uint16) - destHeightDiffs[settleFailRef.Source] = destHeights - } - - destHeights[settleFailRef.Height] = append( - destHeights[settleFailRef.Height], - settleFailRef.Index, - ) - } - - // With the references organized by destination and height, we now load - // each remote bucket, and update the settle fail filter for any - // settle/fail htlcs. - for dest, destHeights := range destHeightDiffs { - destKey := makeLogKey(dest.ToUint64()) - destBkt := fwdPkgBkt.Bucket(destKey[:]) - if destBkt == nil { - // If the destination bucket is not found, this is - // likely the result of the destination channel being - // closed and having it's forwarding packages wiped. We - // won't treat this as an error, because the response - // will no longer be retransmitted internally. - continue - } - - for height, indexes := range destHeights { - err := ackSettleFailsAtHeight(destBkt, height, indexes) - if err != nil { - return err - } - } - } - - return nil -} - -// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes -// at particular a height by updating the settle fail filter. -func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := destBkt.Bucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is as there is nothing to ack. - return nil - } - - // Load ack filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - settleFailFilter := &PkgFilter{} - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - settleFailFilter.Set(index) - } - - // Write the resulting filter to disk. - var settleFailFilterBuf bytes.Buffer - if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// RemovePkg deletes the forwarding package at the given height from the -// packager's source bucket. -func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error { - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil - } - - sourceBytes := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - - return sourceBkt.DeleteBucket(heightKey[:]) -} - -// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. -func uint16Key(i uint16) []byte { - key := make([]byte, 2) - byteOrder.PutUint16(key, i) - return key -} - -// Compile-time constraint to ensure that ChannelPackager implements the public -// FwdPackager interface. -var _ FwdPackager = (*ChannelPackager)(nil) - -// Compile-time constraint to ensure that SwitchPackager implements the public -// FwdOperator interface. -var _ FwdOperator = (*SwitchPackager)(nil) diff --git a/channeldb/migration_01_to_11/forwarding_package_test.go b/channeldb/migration_01_to_11/forwarding_package_test.go deleted file mode 100644 index 1128aad3..00000000 --- a/channeldb/migration_01_to_11/forwarding_package_test.go +++ /dev/null @@ -1,815 +0,0 @@ -package migration_01_to_11_test - -import ( - "bytes" - "io/ioutil" - "path/filepath" - "runtime" - "testing" - - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/lnwire" -) - -// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000, -// which is greater than the number of HTLCs we permit on a commitment txn. -// This should encapsulate every potential filter used in practice. -func TestPkgFilterBruteForce(t *testing.T) { - t.Parallel() - - checkPkgFilterRange(t, 1000) -} - -// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear -// insertion of `high` elements. This is primarily to test that IsFull functions -// properly for all relevant sizes of `high`. -func checkPkgFilterRange(t *testing.T, high int) { - for i := uint16(0); i < uint16(high); i++ { - f := channeldb.NewPkgFilter(i) - - if f.Count() != i { - t.Fatalf("pkg filter count=%d is actually %d", - i, f.Count()) - } - checkPkgFilterEncodeDecode(t, i, f) - - for j := uint16(0); j < i; j++ { - if f.Contains(j) { - t.Fatalf("pkg filter count=%d contains %d "+ - "before being added", i, j) - } - - f.Set(j) - checkPkgFilterEncodeDecode(t, i, f) - - if !f.Contains(j) { - t.Fatalf("pkg filter count=%d missing %d "+ - "after being added", i, j) - } - - if j < i-1 && f.IsFull() { - t.Fatalf("pkg filter count=%d already full", i) - } - } - - if !f.IsFull() { - t.Fatalf("pkg filter count=%d not full", i) - } - checkPkgFilterEncodeDecode(t, i, f) - } -} - -// TestPkgFilterRand uses a random permutation to verify the proper behavior of -// the pkg filter if the entries are not inserted in-order. -func TestPkgFilterRand(t *testing.T) { - t.Parallel() - - checkPkgFilterRand(t, 3, 17) -} - -// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting -// indices and asserting the invariants. The order in which indices are inserted -// is parameterized by a base `b` coprime to `p`, and using modular -// exponentiation to generate all elements in [1,p). -func checkPkgFilterRand(t *testing.T, b, p uint16) { - f := channeldb.NewPkgFilter(p) - var j = b - for i := uint16(1); i < p; i++ { - if f.Contains(j) { - t.Fatalf("pkg filter contains %d-%d "+ - "before being added", i, j) - } - - f.Set(j) - checkPkgFilterEncodeDecode(t, i, f) - - if !f.Contains(j) { - t.Fatalf("pkg filter missing %d-%d "+ - "after being added", i, j) - } - - if i < p-1 && f.IsFull() { - t.Fatalf("pkg filter %d already full", i) - } - checkPkgFilterEncodeDecode(t, i, f) - - j = (b * j) % p - } - - // Set 0 independently, since it will never be emitted by the generator. - f.Set(0) - checkPkgFilterEncodeDecode(t, p, f) - - if !f.IsFull() { - t.Fatalf("pkg filter count=%d not full", p) - } - checkPkgFilterEncodeDecode(t, p, f) -} - -// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by: -// 1) writing it to a buffer -// 2) verifying the number of bytes written matches the filter's Size() -// 3) reconstructing the filter decoding the bytes -// 4) checking that the two filters are the same according to Equal -func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) { - var b bytes.Buffer - if err := f.Encode(&b); err != nil { - t.Fatalf("unable to serialize pkg filter: %v", err) - } - - // +2 for uint16 length - size := uint16(len(b.Bytes())) - if size != f.Size() { - t.Fatalf("pkg filter count=%d serialized size differs, "+ - "Size(): %d, len(bytes): %v", i, f.Size(), size) - } - - reader := bytes.NewReader(b.Bytes()) - - f2 := &channeldb.PkgFilter{} - if err := f2.Decode(reader); err != nil { - t.Fatalf("unable to deserialize pkg filter: %v", err) - } - - if !f.Equal(f2) { - t.Fatalf("pkg filter count=%v does is not equal "+ - "after deserialization, want: %v, got %v", - i, f, f2) - } -} - -var ( - chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{}) - - adds = []channeldb.LogUpdate{ - { - LogIndex: 0, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: 1, - Amount: 100, - Expiry: 1000, - PaymentHash: [32]byte{0}, - }, - }, - { - LogIndex: 1, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: 1, - Amount: 101, - Expiry: 1001, - PaymentHash: [32]byte{1}, - }, - }, - } - - settleFails = []channeldb.LogUpdate{ - { - LogIndex: 2, - UpdateMsg: &lnwire.UpdateFulfillHTLC{ - ChanID: chanID, - ID: 0, - PaymentPreimage: [32]byte{0}, - }, - }, - { - LogIndex: 3, - UpdateMsg: &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: []byte{}, - }, - }, - } -) - -// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a -// forwarding package that contains no adds, fails or settles. We expect that -// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding -// decision via SetFwdFilter. -func TestPackagerEmptyFwdPkg(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package with no htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. With no HTLCs, - // the ack filter will have no elements, and should always return true. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Now, write the forwarding decision. In this case, its just an empty - // fwd filter. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However with no htlcs, it should leap frog to FwdStateCompleted. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted -// as soon as all the adds in the package have been acked using AckAddHtlcs. -func TestPackagerOnlyAdds(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil) - - nAdds := len(adds) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range adds { - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However not allf of the HTLCs have been acked, so should not - // have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all adds have been - // acked, the ack filter should return true and the package should be - // FwdStateCompleted since there are no other settle/fail packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerOnlySettleFails asserts that the fwdpkg remains in -// FwdStateProcessed after writing the forwarding decision when there are no -// adds in the fwdpkg. We expect this because an empty FwdFilter will always -// return true, but we are still waiting for the remaining fails and settles to -// be deleted. -func TestPackagerOnlySettleFails(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails) - - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However, not all of the HTLCs have been - // acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and -// settle/fails, then checks the behavior when the adds are acked before any of -// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the -// remainder of the fail/settles are being deleted. -func TestPackagerAddsThenSettleFails(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) - - nAdds := len(adds) - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range adds { - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However not allf of the HTLCs have been acked, so should not - // have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However not allf of the HTLCs have been - // acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to remove settle/fail htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and -// settle/fails, then checks the behavior when the settle/fails are removed -// before any of the adds have been acked. This should cause the fwdpkg to -// remain in FwdStateProcessed until the final ack is recorded, at which point -// it should be promoted directly to FwdStateCompleted.since all adds have been -// removed. -func TestPackagerSettleFailsThenAdds(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that has both add - // and settle/fail htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) - - nAdds := len(adds) - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - // Simulate another channel deleting the settle/fails it received from - // the original fwd pkg. - // TODO(conner): use different packager/s? - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However none all of the add HTLCs have - // been acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to remove settle/fail htlc: %v", err) - } - } - - // Now simulate this channel receiving a fail/settle for the adds in the - // fwdpkg. - for i := range adds { - // Again, we should still have one package on disk and be in - // FwdStateProcessed. This should not change until all of the - // add htlcs have been acked. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// assertFwdPkgState checks the current state of a fwdpkg meets our -// expectations. -func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, - state channeldb.FwdState) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.State != state { - t.Fatalf("line %d: expected fwdpkg in state %v, found %v", - line, state, fwdPkg.State) - } -} - -// assertFwdPkgNumAddsSettleFails checks that the number of adds and -// settle/fail log updates are correct. -func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, - expectedNumAdds, expectedNumSettleFails int) { - _, _, line, _ := runtime.Caller(1) - if len(fwdPkg.Adds) != expectedNumAdds { - t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d", - line, expectedNumAdds, len(fwdPkg.Adds)) - } - - if len(fwdPkg.SettleFails) != expectedNumSettleFails { - t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d", - line, expectedNumSettleFails, len(fwdPkg.SettleFails)) - } -} - -// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our -// expected full-ness. -func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.AckFilter.IsFull() != expected { - t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+ - "found %v", line, expected, fwdPkg.AckFilter.IsFull()) - } -} - -// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail -// filter matches our expected full-ness. -func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.SettleFailFilter.IsFull() != expected { - t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+ - "found %v", line, expected, fwdPkg.SettleFailFilter.IsFull()) - } -} - -// loadFwdPkgs is a helper method that reads all forwarding packages for a -// particular packager. -func loadFwdPkgs(t *testing.T, db *bbolt.DB, - packager channeldb.FwdPackager) []*channeldb.FwdPkg { - - var fwdPkgs []*channeldb.FwdPkg - if err := db.View(func(tx *bbolt.Tx) error { - var err error - fwdPkgs, err = packager.LoadFwdPkgs(tx) - return err - }); err != nil { - t.Fatalf("unable to load fwd pkgs: %v", err) - } - - return fwdPkgs -} - -// makeFwdPkgDB initializes a test database for forwarding packages. If the -// provided path is an empty, it will create a temp dir/file to use. -func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB { - if path == "" { - var err error - path, err = ioutil.TempDir("", "fwdpkgdb") - if err != nil { - t.Fatalf("unable to create temp path: %v", err) - } - - path = filepath.Join(path, "fwdpkg.db") - } - - db, err := bbolt.Open(path, 0600, nil) - if err != nil { - t.Fatalf("unable to open boltdb: %v", err) - } - - return db -} diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go index d90863c6..8e8f4a4a 100644 --- a/channeldb/migration_01_to_11/graph.go +++ b/channeldb/migration_01_to_11/graph.go @@ -2,20 +2,15 @@ package migration_01_to_11 import ( "bytes" - "crypto/sha256" "encoding/binary" - "errors" "fmt" "image/color" "io" - "math" "net" - "sync" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/coreos/bbolt" @@ -74,11 +69,6 @@ var ( // lookup of incoming channel edges. unknownPolicy = []byte{} - // chanStart is an array of all zero bytes which is used to perform - // range scans within the edgeBucket to obtain all of the outgoing - // edges for a particular node. - chanStart [8]byte - // edgeIndexBucket is an index which can be used to iterate all edges // in the bucket, grouping them according to their in/out nodes. // Additionally, the items in this bucket also contain the complete @@ -155,9 +145,6 @@ const ( // would be possible for a node to create a ton of updates and slowly // fill our disk, and also waste bandwidth due to relaying. MaxAllowedExtraOpaqueBytes = 10000 - - // feeRateParts is the total number of parts used to express fee rates. - feeRateParts = 1e6 ) // ChannelGraph is a persistent, on-disk graph representation of the Lightning @@ -172,200 +159,16 @@ const ( // for that edge. type ChannelGraph struct { db *DB - - cacheMu sync.RWMutex - rejectCache *rejectCache - chanCache *channelCache } // newChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph { return &ChannelGraph{ - db: db, - rejectCache: newRejectCache(rejectCacheSize), - chanCache: newChannelCache(chanCacheSize), + db: db, } } -// Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() *DB { - return c.db -} - -// ForEachChannel iterates through all the channel edges stored within the -// graph and invokes the passed callback for each edge. The callback takes two -// edges as since this is a directed graph, both the in/out edges are visited. -// If the callback returns an error, then the transaction is aborted and the -// iteration stops early. -// -// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer -// for that particular channel edge routing policy will be passed into the -// callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates - - return c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // For each edge pair within the edge index, we fetch each edge - // itself and also the node information in order to fully - // populated the object. - return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - infoReader := bytes.NewReader(edgeInfoBytes) - edgeInfo, err := deserializeChanEdgeInfo(infoReader) - if err != nil { - return err - } - edgeInfo.db = c.db - - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - return err - } - - // With both edges read, execute the call back. IF this - // function returns an error then the transaction will - // be aborted. - return cb(&edgeInfo, edge1, edge2) - }) - }) -} - -// ForEachNodeChannel iterates through all channels of a given node, executing the -// passed callback with an edge info structure and the policies of each end -// of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the -// connecting node. If the callback returns an error, then the iteration is -// halted with the error propagated back up to the caller. -// -// Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { - - db := c.db - - return nodeTraversal(tx, nodePub, db, cb) -} - -// DisabledChannelIDs returns the channel ids of disabled channels. -// A channel is disabled when two of the associated ChanelEdgePolicies -// have their disabled bit on. -func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { - var disabledChanIDs []uint64 - chanEdgeFound := make(map[uint64]struct{}) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - - disabledEdgePolicyIndex := edges.Bucket(disabledEdgePolicyBucket) - if disabledEdgePolicyIndex == nil { - return nil - } - - // We iterate over all disabled policies and we add each channel that - // has more than one disabled policy to disabledChanIDs array. - return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error { - chanID := byteOrder.Uint64(k[:8]) - _, edgeFound := chanEdgeFound[chanID] - if edgeFound { - delete(chanEdgeFound, chanID) - disabledChanIDs = append(disabledChanIDs, chanID) - return nil - } - - chanEdgeFound[chanID] = struct{}{} - return nil - }) - }) - if err != nil { - return nil, err - } - - return disabledChanIDs, nil -} - -// ForEachNode iterates through all the stored vertices/nodes in the graph, -// executing the passed callback with each node encountered. If the callback -// returns an error, then the transaction is aborted and the iteration stops -// early. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal -// -// TODO(roasbeef): add iterator interface to allow for memory efficient graph -// traversal when graph gets mega -func (c *ChannelGraph) ForEachNode(tx *bbolt.Tx, cb func(*bbolt.Tx, *LightningNode) error) error { - traversal := func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - return nodes.ForEach(func(pubKey, nodeBytes []byte) error { - // If this is the source key, then we skip this - // iteration as the value for this key is a pubKey - // rather than raw node information. - if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { - return nil - } - - nodeReader := bytes.NewReader(nodeBytes) - node, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - node.db = c.db - - // Execute the callback, the transaction will abort if - // this returns an error. - return cb(tx, &node) - }) - } - - // If no transaction was provided, then we'll create a new transaction - // to execute the transaction within. - if tx == nil { - return c.db.View(traversal) - } - - // Otherwise, we re-use the existing transaction to execute the graph - // traversal. - return traversal(tx) -} - // SourceNode returns the source node of the graph. The source node is treated // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another @@ -442,20 +245,6 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { }) } -// AddLightningNode adds a vertex/node to the graph database. If the node is not -// in the database from before, this will add a new, unconnected one to the -// graph. If it is present from before, this will update that node's -// information. Note that this method is expected to only be called to update -// an already present node from a node announcement, or to insert a node found -// in a channel update. -// -// TODO(roasbeef): also need sig of announcement -func (c *ChannelGraph) AddLightningNode(node *LightningNode) error { - return c.db.Update(func(tx *bbolt.Tx) error { - return addLightningNode(tx, node) - }) -} - func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { nodes, err := tx.CreateBucketIfNotExists(nodeBucket) if err != nil { @@ -477,1487 +266,6 @@ func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { return putLightningNode(nodes, aliases, updateIndex, node) } -// LookupAlias attempts to return the alias as advertised by the target node. -// TODO(roasbeef): currently assumes that aliases are unique... -func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { - var alias string - - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - aliases := nodes.Bucket(aliasIndexBucket) - if aliases == nil { - return ErrGraphNodesNotFound - } - - nodePub := pub.SerializeCompressed() - a := aliases.Get(nodePub) - if a == nil { - return ErrNodeAliasNotFound - } - - // TODO(roasbeef): should actually be using the utf-8 - // package... - alias = string(a) - return nil - }) - if err != nil { - return "", err - } - - return alias, nil -} - -// DeleteLightningNode starts a new database transaction to remove a vertex/node -// from the database according to the node's public key. -func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { - // TODO(roasbeef): ensure dangling edges are removed... - return c.db.Update(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - - return c.deleteLightningNode( - nodes, nodePub.SerializeCompressed(), - ) - }) -} - -// deleteLightningNode uses an existing database transaction to remove a -// vertex/node from the database according to the node's public key. -func (c *ChannelGraph) deleteLightningNode(nodes *bbolt.Bucket, - compressedPubKey []byte) error { - - aliases := nodes.Bucket(aliasIndexBucket) - if aliases == nil { - return ErrGraphNodesNotFound - } - - if err := aliases.Delete(compressedPubKey); err != nil { - return err - } - - // Before we delete the node, we'll fetch its current state so we can - // determine when its last update was to clear out the node update - // index. - node, err := fetchLightningNode(nodes, compressedPubKey) - if err != nil { - return err - } - - if err := nodes.Delete(compressedPubKey); err != nil { - - return err - } - - // Finally, we'll delete the index entry for the node within the - // nodeUpdateIndexBucket as this node is no longer active, so we don't - // need to track its last update. - nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) - if nodeUpdateIndex == nil { - return ErrGraphNodesNotFound - } - - // In order to delete the entry, we'll need to reconstruct the key for - // its last update. - updateUnix := uint64(node.LastUpdate.Unix()) - var indexKey [8 + 33]byte - byteOrder.PutUint64(indexKey[:8], updateUnix) - copy(indexKey[8:], compressedPubKey) - - return nodeUpdateIndex.Delete(indexKey[:]) -} - -// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An -// undirected edge from the two target nodes are created. The information -// stored denotes the static attributes of the channel, such as the channelID, -// the keys involved in creation of the channel, and the set of features that -// the channel supports. The chanPoint and chanID are used to uniquely identify -// the edge globally within the database. -func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - return c.addChannelEdge(tx, edge) - }) - if err != nil { - return err - } - - c.rejectCache.remove(edge.ChannelID) - c.chanCache.remove(edge.ChannelID) - - return nil -} - -// addChannelEdge is the private form of AddChannelEdge that allows callers to -// utilize an existing db transaction. -func (c *ChannelGraph) addChannelEdge(tx *bbolt.Tx, edge *ChannelEdgeInfo) error { - // Construct the channel's primary key which is the 8-byte channel ID. - var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) - - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - - // First, attempt to check if this edge has already been created. If - // so, then we can exit early as this method is meant to be idempotent. - if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil { - return ErrEdgeAlreadyExist - } - - // Before we insert the channel into the database, we'll ensure that - // both nodes already exist in the channel graph. If either node - // doesn't, then we'll insert a "shell" node that just includes its - // public key, so subsequent validation and queries can work properly. - _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) - switch { - case node1Err == ErrGraphNodeNotFound: - node1Shell := LightningNode{ - PubKeyBytes: edge.NodeKey1Bytes, - HaveNodeAnnouncement: false, - } - err := addLightningNode(tx, &node1Shell) - if err != nil { - return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey1Bytes) - - } - case node1Err != nil: - return err - } - - _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) - switch { - case node2Err == ErrGraphNodeNotFound: - node2Shell := LightningNode{ - PubKeyBytes: edge.NodeKey2Bytes, - HaveNodeAnnouncement: false, - } - err := addLightningNode(tx, &node2Shell) - if err != nil { - return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey2Bytes) - - } - case node2Err != nil: - return err - } - - // If the edge hasn't been created yet, then we'll first add it to the - // edge index in order to associate the edge between two nodes and also - // store the static components of the channel. - if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { - return err - } - - // Mark edge policies for both sides as unknown. This is to enable - // efficient incoming channel lookup for a node. - for _, key := range []*[33]byte{&edge.NodeKey1Bytes, - &edge.NodeKey2Bytes} { - - err := putChanEdgePolicyUnknown(edges, edge.ChannelID, - key[:]) - if err != nil { - return err - } - } - - // Finally we add it to the channel index which maps channel points - // (outpoints) to the shorter channel ID's. - var b bytes.Buffer - if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { - return err - } - return chanIndex.Put(b.Bytes(), chanKey[:]) -} - -// HasChannelEdge returns true if the database knows of a channel edge with the -// passed channel ID, and false otherwise. If an edge with that ID is found -// within the graph, then two time stamps representing the last time the edge -// was updated for both directed edges are returned along with the boolean. If -// it is not found, then the zombie index is checked and its result is returned -// as the second boolean. -func (c *ChannelGraph) HasChannelEdge( - chanID uint64) (time.Time, time.Time, bool, bool, error) { - - var ( - upd1Time time.Time - upd2Time time.Time - exists bool - isZombie bool - ) - - // We'll query the cache with the shared lock held to allow multiple - // readers to access values in the cache concurrently if they exist. - c.cacheMu.RLock() - if entry, ok := c.rejectCache.get(chanID); ok { - c.cacheMu.RUnlock() - upd1Time = time.Unix(entry.upd1Time, 0) - upd2Time = time.Unix(entry.upd2Time, 0) - exists, isZombie = entry.flags.unpack() - return upd1Time, upd2Time, exists, isZombie, nil - } - c.cacheMu.RUnlock() - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - // The item was not found with the shared lock, so we'll acquire the - // exclusive lock and check the cache again in case another method added - // the entry to the cache while no lock was held. - if entry, ok := c.rejectCache.get(chanID); ok { - upd1Time = time.Unix(entry.upd1Time, 0) - upd2Time = time.Unix(entry.upd2Time, 0) - exists, isZombie = entry.flags.unpack() - return upd1Time, upd2Time, exists, isZombie, nil - } - - if err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - var channelID [8]byte - byteOrder.PutUint64(channelID[:], chanID) - - // If the edge doesn't exist, then we'll also check our zombie - // index. - if edgeIndex.Get(channelID[:]) == nil { - exists = false - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex != nil { - isZombie, _, _ = isZombieEdge( - zombieIndex, chanID, - ) - } - - return nil - } - - exists = true - isZombie = false - - // If the channel has been found in the graph, then retrieve - // the edges itself so we can return the last updated - // timestamps. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - - e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes, - channelID[:], c.db) - if err != nil { - return err - } - - // As we may have only one of the edges populated, only set the - // update time if the edge was found in the database. - if e1 != nil { - upd1Time = e1.LastUpdate - } - if e2 != nil { - upd2Time = e2.LastUpdate - } - - return nil - }); err != nil { - return time.Time{}, time.Time{}, exists, isZombie, err - } - - c.rejectCache.insert(chanID, rejectCacheEntry{ - upd1Time: upd1Time.Unix(), - upd2Time: upd2Time.Unix(), - flags: packRejectFlags(exists, isZombie), - }) - - return upd1Time, upd2Time, exists, isZombie, nil -} - -// UpdateChannelEdge retrieves and update edge of the graph database. Method -// only reserved for updating an edge info after its already been created. -// In order to maintain this constraints, we return an error in the scenario -// that an edge info hasn't yet been created yet, but someone attempts to update -// it. -func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { - // Construct the channel's primary key which is the 8-byte channel ID. - var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) - - return c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edge == nil { - return ErrEdgeNotFound - } - - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - - if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo == nil { - return ErrEdgeNotFound - } - - return putChanEdgeInfo(edgeIndex, edge, chanKey) - }) -} - -const ( - // pruneTipBytes is the total size of the value which stores a prune - // entry of the graph in the prune log. The "prune tip" is the last - // entry in the prune log, and indicates if the channel graph is in - // sync with the current UTXO state. The structure of the value - // is: blockHash, taking 32 bytes total. - pruneTipBytes = 32 -) - -// PruneGraph prunes newly closed channels from the channel graph in response -// to a new block being solved on the network. Any transactions which spend the -// funding output of any known channels within he graph will be deleted. -// Additionally, the "prune tip", or the last block which has been used to -// prune the graph is stored so callers can ensure the graph is fully in sync -// with the current UTXO state. A slice of channels that have been closed by -// the target block are returned if the function succeeds without error. -func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, - blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) { - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var chansClosed []*ChannelEdgeInfo - - err := c.db.Update(func(tx *bbolt.Tx) error { - // First grab the edges bucket which houses the information - // we'd like to delete - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - - // Next grab the two edge indexes which will also need to be updated. - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrSourceNodeNotSet - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - - // For each of the outpoints that have been spent within the - // block, we attempt to delete them from the graph as if that - // outpoint was a channel, then it has now been closed. - for _, chanPoint := range spentOutputs { - // TODO(roasbeef): load channel bloom filter, continue - // if NOT if filter - - var opBytes bytes.Buffer - if err := writeOutpoint(&opBytes, chanPoint); err != nil { - return err - } - - // First attempt to see if the channel exists within - // the database, if not, then we can exit early. - chanID := chanIndex.Get(opBytes.Bytes()) - if chanID == nil { - continue - } - - // However, if it does, then we'll read out the full - // version so we can add it to the set of deleted - // channels. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - - // Attempt to delete the channel, an ErrEdgeNotFound - // will be returned if that outpoint isn't known to be - // a channel. If no error is returned, then a channel - // was successfully pruned. - err = delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - chanID, false, - ) - if err != nil && err != ErrEdgeNotFound { - return err - } - - chansClosed = append(chansClosed, &edgeInfo) - } - - metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) - if err != nil { - return err - } - - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) - if err != nil { - return err - } - - // With the graph pruned, add a new entry to the prune log, - // which can be used to check if the graph is fully synced with - // the current UTXO state. - var blockHeightBytes [4]byte - byteOrder.PutUint32(blockHeightBytes[:], blockHeight) - - var newTip [pruneTipBytes]byte - copy(newTip[:], blockHash[:]) - - err = pruneBucket.Put(blockHeightBytes[:], newTip[:]) - if err != nil { - return err - } - - // Now that the graph has been pruned, we'll also attempt to - // prune any nodes that have had a channel closed within the - // latest block. - return c.pruneGraphNodes(nodes, edgeIndex) - }) - if err != nil { - return nil, err - } - - for _, channel := range chansClosed { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) - } - - return chansClosed, nil -} - -// PruneGraphNodes is a garbage collection method which attempts to prune out -// any nodes from the channel graph that are currently unconnected. This ensure -// that we only maintain a graph of reachable nodes. In the event that a pruned -// node gains more channels, it will be re-added back to the graph. -func (c *ChannelGraph) PruneGraphNodes() error { - return c.db.Update(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - return c.pruneGraphNodes(nodes, edgeIndex) - }) -} - -// pruneGraphNodes attempts to remove any nodes from the graph who have had a -// channel closed within the current block. If the node still has existing -// channels in the graph, this will act as a no-op. -func (c *ChannelGraph) pruneGraphNodes(nodes *bbolt.Bucket, - edgeIndex *bbolt.Bucket) error { - - log.Trace("Pruning nodes from graph with no open channels") - - // We'll retrieve the graph's source node to ensure we don't remove it - // even if it no longer has any open channels. - sourceNode, err := c.sourceNode(nodes) - if err != nil { - return err - } - - // We'll use this map to keep count the number of references to a node - // in the graph. A node should only be removed once it has no more - // references in the graph. - nodeRefCounts := make(map[[33]byte]int) - err = nodes.ForEach(func(pubKey, nodeBytes []byte) error { - // If this is the source key, then we skip this - // iteration as the value for this key is a pubKey - // rather than raw node information. - if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { - return nil - } - - var nodePub [33]byte - copy(nodePub[:], pubKey) - nodeRefCounts[nodePub] = 0 - - return nil - }) - if err != nil { - return err - } - - // To ensure we never delete the source node, we'll start off by - // bumping its ref count to 1. - nodeRefCounts[sourceNode.PubKeyBytes] = 1 - - // Next, we'll run through the edgeIndex which maps a channel ID to the - // edge info. We'll use this scan to populate our reference count map - // above. - err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - // The first 66 bytes of the edge info contain the pubkeys of - // the nodes that this edge attaches. We'll extract them, and - // add them to the ref count map. - var node1, node2 [33]byte - copy(node1[:], edgeInfoBytes[:33]) - copy(node2[:], edgeInfoBytes[33:]) - - // With the nodes extracted, we'll increase the ref count of - // each of the nodes. - nodeRefCounts[node1]++ - nodeRefCounts[node2]++ - - return nil - }) - if err != nil { - return err - } - - // Finally, we'll make a second pass over the set of nodes, and delete - // any nodes that have a ref count of zero. - var numNodesPruned int - for nodePubKey, refCount := range nodeRefCounts { - // If the ref count of the node isn't zero, then we can safely - // skip it as it still has edges to or from it within the - // graph. - if refCount != 0 { - continue - } - - // If we reach this point, then there are no longer any edges - // that connect this node, so we can delete it. - if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { - log.Warnf("Unable to prune node %x from the "+ - "graph: %v", nodePubKey, err) - continue - } - - log.Infof("Pruned unconnected node %x from channel graph", - nodePubKey[:]) - - numNodesPruned++ - } - - if numNodesPruned > 0 { - log.Infof("Pruned %v unconnected nodes from the channel graph", - numNodesPruned) - } - - return nil -} - -// DisconnectBlockAtHeight is used to indicate that the block specified -// by the passed height has been disconnected from the main chain. This -// will "rewind" the graph back to the height below, deleting channels -// that are no longer confirmed from the graph. The prune log will be -// set to the last prune height valid for the remaining chain. -// Channels that were removed from the graph resulting from the -// disconnected block are returned. -func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, - error) { - - // Every channel having a ShortChannelID starting at 'height' - // will no longer be confirmed. - startShortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - } - - // Delete everything after this height from the db. - endShortChanID := lnwire.ShortChannelID{ - BlockHeight: math.MaxUint32 & 0x00ffffff, - TxIndex: math.MaxUint32 & 0x00ffffff, - TxPosition: math.MaxUint16, - } - // The block height will be the 3 first bytes of the channel IDs. - var chanIDStart [8]byte - byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) - var chanIDEnd [8]byte - byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - // Keep track of the channels that are removed from the graph. - var removedChans []*ChannelEdgeInfo - - if err := c.db.Update(func(tx *bbolt.Tx) error { - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - - // Scan from chanIDStart to chanIDEnd, deleting every - // found edge. - // NOTE: we must delete the edges after the cursor loop, since - // modifying the bucket while traversing is not safe. - var keys [][]byte - cursor := edgeIndex.Cursor() - for k, v := cursor.Seek(chanIDStart[:]); k != nil && - bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { - - edgeInfoReader := bytes.NewReader(v) - edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader) - if err != nil { - return err - } - - keys = append(keys, k) - removedChans = append(removedChans, &edgeInfo) - } - - for _, k := range keys { - err = delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - k, false, - ) - if err != nil && err != ErrEdgeNotFound { - return err - } - } - - // Delete all the entries in the prune log having a height - // greater or equal to the block disconnected. - metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) - if err != nil { - return err - } - - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) - if err != nil { - return err - } - - var pruneKeyStart [4]byte - byteOrder.PutUint32(pruneKeyStart[:], height) - - var pruneKeyEnd [4]byte - byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32) - - // To avoid modifying the bucket while traversing, we delete - // the keys in a second loop. - var pruneKeys [][]byte - pruneCursor := pruneBucket.Cursor() - for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && - bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { - - pruneKeys = append(pruneKeys, k) - } - - for _, k := range pruneKeys { - if err := pruneBucket.Delete(k); err != nil { - return err - } - } - - return nil - }); err != nil { - return nil, err - } - - for _, channel := range removedChans { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) - } - - return removedChans, nil -} - -// PruneTip returns the block height and hash of the latest block that has been -// used to prune channels in the graph. Knowing the "prune tip" allows callers -// to tell if the graph is currently in sync with the current best known UTXO -// state. -func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { - var ( - tipHash chainhash.Hash - tipHeight uint32 - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - graphMeta := tx.Bucket(graphMetaBucket) - if graphMeta == nil { - return ErrGraphNotFound - } - pruneBucket := graphMeta.Bucket(pruneLogBucket) - if pruneBucket == nil { - return ErrGraphNeverPruned - } - - pruneCursor := pruneBucket.Cursor() - - // The prune key with the largest block height will be our - // prune tip. - k, v := pruneCursor.Last() - if k == nil { - return ErrGraphNeverPruned - } - - // Once we have the prune tip, the value will be the block hash, - // and the key the block height. - copy(tipHash[:], v[:]) - tipHeight = byteOrder.Uint32(k[:]) - - return nil - }) - if err != nil { - return nil, 0, err - } - - return &tipHash, tipHeight, nil -} - -// DeleteChannelEdges removes edges with the given channel IDs from the database -// and marks them as zombies. This ensures that we're unable to re-add it to our -// database once again. If an edge does not exist within the database, then -// ErrEdgeNotFound will be returned. -func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error { - // TODO(roasbeef): possibly delete from node bucket if node has no more - // channels - // TODO(roasbeef): don't delete both edges? - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrEdgeNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrEdgeNotFound - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - - var rawChanID [8]byte - for _, chanID := range chanIDs { - byteOrder.PutUint64(rawChanID[:], chanID) - err := delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - rawChanID[:], true, - ) - if err != nil { - return err - } - } - - return nil - }) - if err != nil { - return err - } - - for _, chanID := range chanIDs { - c.rejectCache.remove(chanID) - c.chanCache.remove(chanID) - } - - return nil -} - -// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the -// passed channel point (outpoint). If the passed channel doesn't exist within -// the database, then ErrEdgeNotFound is returned. -func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { - var chanID uint64 - if err := c.db.View(func(tx *bbolt.Tx) error { - var err error - chanID, err = getChanID(tx, chanPoint) - return err - }); err != nil { - return 0, err - } - - return chanID, nil -} - -// getChanID returns the assigned channel ID for a given channel point. -func getChanID(tx *bbolt.Tx, chanPoint *wire.OutPoint) (uint64, error) { - var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { - return 0, err - } - - edges := tx.Bucket(edgeBucket) - if edges == nil { - return 0, ErrGraphNoEdgesFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return 0, ErrGraphNoEdgesFound - } - - chanIDBytes := chanIndex.Get(b.Bytes()) - if chanIDBytes == nil { - return 0, ErrEdgeNotFound - } - - chanID := byteOrder.Uint64(chanIDBytes) - - return chanID, nil -} - -// TODO(roasbeef): allow updates to use Batch? - -// HighestChanID returns the "highest" known channel ID in the channel graph. -// This represents the "newest" channel from the PoV of the chain. This method -// can be used by peers to quickly determine if they're graphs are in sync. -func (c *ChannelGraph) HighestChanID() (uint64, error) { - var cid uint64 - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // In order to find the highest chan ID, we'll fetch a cursor - // and use that to seek to the "end" of our known rage. - cidCursor := edgeIndex.Cursor() - - lastChanID, _ := cidCursor.Last() - - // If there's no key, then this means that we don't actually - // know of any channels, so we'll return a predicable error. - if lastChanID == nil { - return ErrGraphNoEdgesFound - } - - // Otherwise, we'll de serialize the channel ID and return it - // to the caller. - cid = byteOrder.Uint64(lastChanID) - return nil - }) - if err != nil && err != ErrGraphNoEdgesFound { - return 0, err - } - - return cid, nil -} - -// ChannelEdge represents the complete set of information for a channel edge in -// the known channel graph. This struct couples the core information of the -// edge as well as each of the known advertised edge policies. -type ChannelEdge struct { - // Info contains all the static information describing the channel. - Info *ChannelEdgeInfo - - // Policy1 points to the "first" edge policy of the channel containing - // the dynamic information required to properly route through the edge. - Policy1 *ChannelEdgePolicy - - // Policy2 points to the "second" edge policy of the channel containing - // the dynamic information required to properly route through the edge. - Policy2 *ChannelEdgePolicy -} - -// ChanUpdatesInHorizon returns all the known channel edges which have at least -// one edge that has an update timestamp within the specified horizon. -func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { - // To ensure we don't return duplicate ChannelEdges, we'll use an - // additional map to keep track of the edges already seen to prevent - // re-adding it. - edgesSeen := make(map[uint64]struct{}) - edgesToCache := make(map[uint64]ChannelEdge) - var edgesInHorizon []ChannelEdge - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var hits int - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) - if edgeUpdateIndex == nil { - return ErrGraphNoEdgesFound - } - - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - // We'll now obtain a cursor to perform a range query within - // the index to find all channels within the horizon. - updateCursor := edgeUpdateIndex.Cursor() - - var startTimeBytes, endTimeBytes [8 + 8]byte - byteOrder.PutUint64( - startTimeBytes[:8], uint64(startTime.Unix()), - ) - byteOrder.PutUint64( - endTimeBytes[:8], uint64(endTime.Unix()), - ) - - // With our start and end times constructed, we'll step through - // the index collecting the info and policy of each update of - // each channel that has a last update within the time range. - for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && - bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { - - // We have a new eligible entry, so we'll slice of the - // chan ID so we can query it in the DB. - chanID := indexKey[8:] - - // If we've already retrieved the info and policies for - // this edge, then we can skip it as we don't need to do - // so again. - chanIDInt := byteOrder.Uint64(chanID) - if _, ok := edgesSeen[chanIDInt]; ok { - continue - } - - if channel, ok := c.chanCache.get(chanIDInt); ok { - hits++ - edgesSeen[chanIDInt] = struct{}{} - edgesInHorizon = append(edgesInHorizon, channel) - continue - } - - // First, we'll fetch the static edge information. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - chanID := byteOrder.Uint64(chanID) - return fmt.Errorf("unable to fetch info for "+ - "edge with chan_id=%v: %v", chanID, err) - } - edgeInfo.db = c.db - - // With the static information obtained, we'll now - // fetch the dynamic policy info. - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - chanID := byteOrder.Uint64(chanID) - return fmt.Errorf("unable to fetch policies "+ - "for edge with chan_id=%v: %v", chanID, - err) - } - - // Finally, we'll collate this edge with the rest of - // edges to be returned. - edgesSeen[chanIDInt] = struct{}{} - channel := ChannelEdge{ - Info: &edgeInfo, - Policy1: edge1, - Policy2: edge2, - } - edgesInHorizon = append(edgesInHorizon, channel) - edgesToCache[chanIDInt] = channel - } - - return nil - }) - switch { - case err == ErrGraphNoEdgesFound: - fallthrough - case err == ErrGraphNodesNotFound: - break - - case err != nil: - return nil, err - } - - // Insert any edges loaded from disk into the cache. - for chanid, channel := range edgesToCache { - c.chanCache.insert(chanid, channel) - } - - log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)", - float64(hits)/float64(len(edgesInHorizon)), hits, - len(edgesInHorizon)) - - return edgesInHorizon, nil -} - -// NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. This method can be used by two -// nodes to quickly determine if they have the same set of up to date node -// announcements. -func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { - var nodesInHorizon []LightningNode - - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) - if nodeUpdateIndex == nil { - return ErrGraphNodesNotFound - } - - // We'll now obtain a cursor to perform a range query within - // the index to find all node announcements within the horizon. - updateCursor := nodeUpdateIndex.Cursor() - - var startTimeBytes, endTimeBytes [8 + 33]byte - byteOrder.PutUint64( - startTimeBytes[:8], uint64(startTime.Unix()), - ) - byteOrder.PutUint64( - endTimeBytes[:8], uint64(endTime.Unix()), - ) - - // With our start and end times constructed, we'll step through - // the index collecting info for each node within the time - // range. - for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && - bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { - - nodePub := indexKey[8:] - node, err := fetchLightningNode(nodes, nodePub) - if err != nil { - return err - } - node.db = c.db - - nodesInHorizon = append(nodesInHorizon, node) - } - - return nil - }) - switch { - case err == ErrGraphNoEdgesFound: - fallthrough - case err == ErrGraphNodesNotFound: - break - - case err != nil: - return nil, err - } - - return nodesInHorizon, nil -} - -// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan -// ID's that we don't know and are not known zombies of the passed set. In other -// words, we perform a set difference of our set of chan ID's and the ones -// passed in. This method can be used by callers to determine the set of -// channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { - var newChanIDs []uint64 - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // Fetch the zombie index, it may not exist if no edges have - // ever been marked as zombies. If the index has been - // initialized, we will use it later to skip known zombie edges. - zombieIndex := edges.Bucket(zombieBucket) - - // We'll run through the set of chanIDs and collate only the - // set of channel that are unable to be found within our db. - var cidBytes [8]byte - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) - - // If the edge is already known, skip it. - if v := edgeIndex.Get(cidBytes[:]); v != nil { - continue - } - - // If the edge is a known zombie, skip it. - if zombieIndex != nil { - isZombie, _, _ := isZombieEdge(zombieIndex, cid) - if isZombie { - continue - } - } - - newChanIDs = append(newChanIDs, cid) - } - - return nil - }) - switch { - // If we don't know of any edges yet, then we'll return the entire set - // of chan IDs specified. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil - - case err != nil: - return nil, err - } - - return newChanIDs, nil -} - -// FilterChannelRange returns the channel ID's of all known channels which were -// mined in a block height within the passed range. This method can be used to -// quickly share with a peer the set of channels we know of within a particular -// range to catch them up after a period of time offline. -func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { - var chanIDs []uint64 - - startChanID := &lnwire.ShortChannelID{ - BlockHeight: startHeight, - } - - endChanID := lnwire.ShortChannelID{ - BlockHeight: endHeight, - TxIndex: math.MaxUint32 & 0x00ffffff, - TxPosition: math.MaxUint16, - } - - // As we need to perform a range scan, we'll convert the starting and - // ending height to their corresponding values when encoded using short - // channel ID's. - var chanIDStart, chanIDEnd [8]byte - byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) - byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - cursor := edgeIndex.Cursor() - - // We'll now iterate through the database, and find each - // channel ID that resides within the specified range. - var cid uint64 - for k, _ := cursor.Seek(chanIDStart[:]); k != nil && - bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { - - // This channel ID rests within the target range, so - // we'll convert it into an integer and add it to our - // returned set. - cid = byteOrder.Uint64(k) - chanIDs = append(chanIDs, cid) - } - - return nil - }) - switch { - // If we don't know of any channels yet, then there's nothing to - // filter, so we'll return an empty slice. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil - - case err != nil: - return nil, err - } - - return chanIDs, nil -} - -// FetchChanInfos returns the set of channel edges that correspond to the passed -// channel ID's. If an edge is the query is unknown to the database, it will -// skipped and the result will contain only those edges that exist at the time -// of the query. This can be used to respond to peer queries that are seeking to -// fill in gaps in their view of the channel graph. -func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { - // TODO(roasbeef): sort cids? - - var ( - chanEdges []ChannelEdge - cidBytes [8]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) - - // First, we'll fetch the static edge information. If - // the edge is unknown, we will skip the edge and - // continue gathering all known edges. - edgeInfo, err := fetchChanEdgeInfo( - edgeIndex, cidBytes[:], - ) - switch { - case err == ErrEdgeNotFound: - continue - case err != nil: - return err - } - edgeInfo.db = c.db - - // With the static information obtained, we'll now - // fetch the dynamic policy info. - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, cidBytes[:], c.db, - ) - if err != nil { - return err - } - - chanEdges = append(chanEdges, ChannelEdge{ - Info: &edgeInfo, - Policy1: edge1, - Policy2: edge2, - }) - } - return nil - }) - if err != nil { - return nil, err - } - - return chanEdges, nil -} - -func delEdgeUpdateIndexEntry(edgesBucket *bbolt.Bucket, chanID uint64, - edge1, edge2 *ChannelEdgePolicy) error { - - // First, we'll fetch the edge update index bucket which currently - // stores an entry for the channel we're about to delete. - updateIndex := edgesBucket.Bucket(edgeUpdateIndexBucket) - if updateIndex == nil { - // No edges in bucket, return early. - return nil - } - - // Now that we have the bucket, we'll attempt to construct a template - // for the index key: updateTime || chanid. - var indexKey [8 + 8]byte - byteOrder.PutUint64(indexKey[8:], chanID) - - // With the template constructed, we'll attempt to delete an entry that - // would have been created by both edges: we'll alternate the update - // times, as one may had overridden the other. - if edge1 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) - if err := updateIndex.Delete(indexKey[:]); err != nil { - return err - } - } - - // We'll also attempt to delete the entry that may have been created by - // the second edge. - if edge2 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) - if err := updateIndex.Delete(indexKey[:]); err != nil { - return err - } - } - - return nil -} - -func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, - nodes *bbolt.Bucket, chanID []byte, isZombie bool) error { - - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - - // We'll also remove the entry in the edge update index bucket before - // we delete the edges themselves so we can access their last update - // times. - cid := byteOrder.Uint64(chanID) - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, nil, - ) - if err != nil { - return err - } - err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2) - if err != nil { - return err - } - - // The edge key is of the format pubKey || chanID. First we construct - // the latter half, populating the channel ID. - var edgeKey [33 + 8]byte - copy(edgeKey[33:], chanID) - - // With the latter half constructed, copy over the first public key to - // delete the edge in this direction, then the second to delete the - // edge in the opposite direction. - copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:]) - if edges.Get(edgeKey[:]) != nil { - if err := edges.Delete(edgeKey[:]); err != nil { - return err - } - } - copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:]) - if edges.Get(edgeKey[:]) != nil { - if err := edges.Delete(edgeKey[:]); err != nil { - return err - } - } - - // As part of deleting the edge we also remove all disabled entries - // from the edgePolicyDisabledIndex bucket. We do that for both directions. - updateEdgePolicyDisabledIndex(edges, cid, false, false) - updateEdgePolicyDisabledIndex(edges, cid, true, false) - - // With the edge data deleted, we can purge the information from the two - // edge indexes. - if err := edgeIndex.Delete(chanID); err != nil { - return err - } - var b bytes.Buffer - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { - return err - } - if err := chanIndex.Delete(b.Bytes()); err != nil { - return err - } - - // Finally, we'll mark the edge as a zombie within our index if it's - // being removed due to the channel becoming a zombie. We do this to - // ensure we don't store unnecessary data for spent channels. - if !isZombie { - return nil - } - - return markEdgeZombie( - zombieIndex, byteOrder.Uint64(chanID), edgeInfo.NodeKey1Bytes, - edgeInfo.NodeKey2Bytes, - ) -} - -// UpdateEdgePolicy updates the edge routing policy for a single directed edge -// within the database for the referenced channel. The `flags` attribute within -// the ChannelEdgePolicy determines which of the directed edges are being -// updated. If the flag is 1, then the first node's information is being -// updated, otherwise it's the second node's information. The node ordering is -// determined by the lexicographical ordering of the identity public keys of -// the nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var isUpdate1 bool - err := c.db.Update(func(tx *bbolt.Tx) error { - var err error - isUpdate1, err = updateEdgePolicy(tx, edge) - return err - }) - if err != nil { - return err - } - - // If an entry for this channel is found in reject cache, we'll modify - // the entry with the updated timestamp for the direction that was just - // written. If the edge doesn't exist, we'll load the cache entry lazily - // during the next query for this edge. - if entry, ok := c.rejectCache.get(edge.ChannelID); ok { - if isUpdate1 { - entry.upd1Time = edge.LastUpdate.Unix() - } else { - entry.upd2Time = edge.LastUpdate.Unix() - } - c.rejectCache.insert(edge.ChannelID, entry) - } - - // If an entry for this channel is found in channel cache, we'll modify - // the entry with the updated policy for the direction that was just - // written. If the edge doesn't exist, we'll defer loading the info and - // policies and lazily read from disk during the next query. - if channel, ok := c.chanCache.get(edge.ChannelID); ok { - if isUpdate1 { - channel.Policy1 = edge - } else { - channel.Policy2 = edge - } - c.chanCache.insert(edge.ChannelID, channel) - } - - return nil -} - // updateEdgePolicy attempts to update an edge's policy within the relevant // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged @@ -2083,297 +391,6 @@ func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { return key, nil } -// AuthSig is a signature under the advertised public key which serves to -// authenticate the attributes announced by this node. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (l *LightningNode) AuthSig() (*btcec.Signature, error) { - return btcec.ParseSignature(l.AuthSigBytes, btcec.S256()) -} - -// AddPubKey is a setter-link method that can be used to swap out the public -// key for a node. -func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { - l.pubKey = key - copy(l.PubKeyBytes[:], key.SerializeCompressed()) -} - -// NodeAnnouncement retrieves the latest node announcement of the node. -func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, - error) { - - if !l.HaveNodeAnnouncement { - return nil, fmt.Errorf("node does not have node announcement") - } - - alias, err := lnwire.NewNodeAlias(l.Alias) - if err != nil { - return nil, err - } - - nodeAnn := &lnwire.NodeAnnouncement{ - Features: l.Features.RawFeatureVector, - NodeID: l.PubKeyBytes, - RGBColor: l.Color, - Alias: alias, - Addresses: l.Addresses, - Timestamp: uint32(l.LastUpdate.Unix()), - ExtraOpaqueData: l.ExtraOpaqueData, - } - - if !signed { - return nodeAnn, nil - } - - sig, err := lnwire.NewSigFromRawSignature(l.AuthSigBytes) - if err != nil { - return nil, err - } - - nodeAnn.Signature = sig - - return nodeAnn, nil -} - -// isPublic determines whether the node is seen as public within the graph from -// the source node's point of view. An existing database transaction can also be -// specified. -func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error) { - // In order to determine whether this node is publicly advertised within - // the graph, we'll need to look at all of its edges and check whether - // they extend to any other node than the source node. errDone will be - // used to terminate the check early. - nodeIsPublic := false - errDone := errors.New("done") - err := l.ForEachChannel(tx, func(_ *bbolt.Tx, info *ChannelEdgeInfo, - _, _ *ChannelEdgePolicy) error { - - // If this edge doesn't extend to the source node, we'll - // terminate our search as we can now conclude that the node is - // publicly advertised within the graph due to the local node - // knowing of the current edge. - if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) && - !bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) { - - nodeIsPublic = true - return errDone - } - - // Since the edge _does_ extend to the source node, we'll also - // need to ensure that this is a public edge. - if info.AuthProof != nil { - nodeIsPublic = true - return errDone - } - - // Otherwise, we'll continue our search. - return nil - }) - if err != nil && err != errDone { - return false, err - } - - return nodeIsPublic, nil -} - -// FetchLightningNode attempts to look up a target node by its identity public -// key. If the node isn't found in the database, then ErrGraphNodeNotFound is -// returned. -func (c *ChannelGraph) FetchLightningNode(pub *btcec.PublicKey) (*LightningNode, error) { - var node *LightningNode - nodePub := pub.SerializeCompressed() - err := c.db.View(func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // If a key for this serialized public key isn't found, then - // the target node doesn't exist within the database. - nodeBytes := nodes.Get(nodePub) - if nodeBytes == nil { - return ErrGraphNodeNotFound - } - - // If the node is found, then we can de deserialize the node - // information to return to the user. - nodeReader := bytes.NewReader(nodeBytes) - n, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - n.db = c.db - - node = &n - - return nil - }) - if err != nil { - return nil, err - } - - return node, nil -} - -// HasLightningNode determines if the graph has a vertex identified by the -// target node identity public key. If the node exists in the database, a -// timestamp of when the data for the node was lasted updated is returned along -// with a true boolean. Otherwise, an empty time.Time is returned with a false -// boolean. -func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) { - var ( - updateTime time.Time - exists bool - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // If a key for this serialized public key isn't found, we can - // exit early. - nodeBytes := nodes.Get(nodePub[:]) - if nodeBytes == nil { - exists = false - return nil - } - - // Otherwise we continue on to obtain the time stamp - // representing the last time the data for this node was - // updated. - nodeReader := bytes.NewReader(nodeBytes) - node, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - - exists = true - updateTime = node.LastUpdate - return nil - }) - if err != nil { - return time.Time{}, exists, err - } - - return updateTime, exists, nil -} - -// nodeTraversal is used to traverse all channels of a node given by its -// public key and passes channel information into the specified callback. -func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - - traversal := func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // In order to reach all the edges for this node, we take - // advantage of the construction of the key-space within the - // edge bucket. The keys are stored in the form: pubKey || - // chanID. Therefore, starting from a chanID of zero, we can - // scan forward in the bucket, grabbing all the edges for the - // node. Once the prefix no longer matches, then we know we're - // done. - var nodeStart [33 + 8]byte - copy(nodeStart[:], nodePub) - copy(nodeStart[33:], chanStart[:]) - - // Starting from the key pubKey || 0, we seek forward in the - // bucket until the retrieved key no longer has the public key - // as its prefix. This indicates that we've stepped over into - // another node's edges, so we can terminate our scan. - edgeCursor := edges.Cursor() - for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { - // If the prefix still matches, the channel id is - // returned in nodeEdge. Channel id is used to lookup - // the node at the other end of the channel and both - // edge policies. - chanID := nodeEdge[33:] - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - edgeInfo.db = db - - outgoingPolicy, err := fetchChanEdgePolicy( - edges, chanID, nodePub, nodes, - ) - if err != nil { - return err - } - - otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub) - if err != nil { - return err - } - - incomingPolicy, err := fetchChanEdgePolicy( - edges, chanID, otherNode[:], nodes, - ) - if err != nil { - return err - } - - // Finally, we execute the callback. - err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) - if err != nil { - return err - } - } - - return nil - } - - // If no transaction was provided, then we'll create a new transaction - // to execute the transaction within. - if tx == nil { - return db.View(traversal) - } - - // Otherwise, we re-use the existing transaction to execute the graph - // traversal. - return traversal(tx) -} - -// ForEachChannel iterates through all channels of this node, executing the -// passed callback with an edge info structure and the policies of each end -// of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the -// connecting node. If the callback returns an error, then the iteration is -// halted with the error propagated back up to the caller. -// -// Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - - nodePub := l.PubKeyBytes[:] - db := l.db - - return nodeTraversal(tx, nodePub, db, cb) -} - // ChannelEdgeInfo represents a fully authenticated channel along with all its // unique attributes. Once an authenticated channel announcement has been // processed on the network, then an instance of ChannelEdgeInfo encapsulating @@ -2395,19 +412,15 @@ type ChannelEdgeInfo struct { // NodeKey1Bytes is the raw public key of the first node. NodeKey1Bytes [33]byte - nodeKey1 *btcec.PublicKey // NodeKey2Bytes is the raw public key of the first node. NodeKey2Bytes [33]byte - nodeKey2 *btcec.PublicKey // BitcoinKey1Bytes is the raw public key of the first node. BitcoinKey1Bytes [33]byte - bitcoinKey1 *btcec.PublicKey // BitcoinKey2Bytes is the raw public key of the first node. BitcoinKey2Bytes [33]byte - bitcoinKey2 *btcec.PublicKey // Features is an opaque byte slice that encodes the set of channel // specific features that this channel edge supports. @@ -2433,173 +446,6 @@ type ChannelEdgeInfo struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db *DB -} - -// AddNodeKeys is a setter-like method that can be used to replace the set of -// keys for the target ChannelEdgeInfo. -func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, - bitcoinKey2 *btcec.PublicKey) { - - c.nodeKey1 = nodeKey1 - copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed()) - - c.nodeKey2 = nodeKey2 - copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed()) - - c.bitcoinKey1 = bitcoinKey1 - copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed()) - - c.bitcoinKey2 = bitcoinKey2 - copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) -} - -// NodeKey1 is the identity public key of the "first" node that was involved in -// the creation of this channel. A node is considered "first" if the -// lexicographical ordering the its serialized public key is "smaller" than -// that of the other node involved in channel creation. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { - if c.nodeKey1 != nil { - return c.nodeKey1, nil - } - - key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.nodeKey1 = key - - return key, nil -} - -// NodeKey2 is the identity public key of the "second" node that was -// involved in the creation of this channel. A node is considered -// "second" if the lexicographical ordering the its serialized public -// key is "larger" than that of the other node involved in channel -// creation. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { - if c.nodeKey2 != nil { - return c.nodeKey2, nil - } - - key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.nodeKey2 = key - - return key, nil -} - -// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first -// node, that was involved in the funding transaction that originally -// created the channel that this struct represents. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { - if c.bitcoinKey1 != nil { - return c.bitcoinKey1, nil - } - - key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.bitcoinKey1 = key - - return key, nil -} - -// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second -// node, that was involved in the funding transaction that originally -// created the channel that this struct represents. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { - if c.bitcoinKey2 != nil { - return c.bitcoinKey2, nil - } - - key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.bitcoinKey2 = key - - return key, nil -} - -// OtherNodeKeyBytes returns the node key bytes of the other end of -// the channel. -func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( - [33]byte, error) { - - switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - return c.NodeKey2Bytes, nil - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - return c.NodeKey1Bytes, nil - default: - return [33]byte{}, fmt.Errorf("node not participating in this channel") - } -} - -// FetchOtherNode attempts to fetch the full LightningNode that's opposite of -// the target node in the channel. This is useful when one knows the pubkey of -// one of the nodes, and wishes to obtain the full LightningNode for the other -// end of the channel. -func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*LightningNode, error) { - - // Ensure that the node passed in is actually a member of the channel. - var targetNodeBytes [33]byte - switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey2Bytes - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey1Bytes - default: - return nil, fmt.Errorf("node not participating in this channel") - } - - var targetNode *LightningNode - fetchNodeFunc := func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - node, err := fetchLightningNode(nodes, targetNodeBytes[:]) - if err != nil { - return err - } - node.db = c.db - - targetNode = &node - - return nil - } - - // If the transaction is nil, then we'll need to create a new one, - // otherwise we can use the existing db transaction. - var err error - if tx == nil { - err = c.db.View(fetchNodeFunc) - } else { - err = fetchNodeFunc(tx) - } - - return targetNode, err } // ChannelAuthProof is the authentication proof (the signature portion) for a @@ -2610,117 +456,23 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*Lig // nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || // features. type ChannelAuthProof struct { - // nodeSig1 is a cached instance of the first node signature. - nodeSig1 *btcec.Signature - // NodeSig1Bytes are the raw bytes of the first node signature encoded // in DER format. NodeSig1Bytes []byte - // nodeSig2 is a cached instance of the second node signature. - nodeSig2 *btcec.Signature - // NodeSig2Bytes are the raw bytes of the second node signature // encoded in DER format. NodeSig2Bytes []byte - // bitcoinSig1 is a cached instance of the first bitcoin signature. - bitcoinSig1 *btcec.Signature - // BitcoinSig1Bytes are the raw bytes of the first bitcoin signature // encoded in DER format. BitcoinSig1Bytes []byte - // bitcoinSig2 is a cached instance of the second bitcoin signature. - bitcoinSig2 *btcec.Signature - // BitcoinSig2Bytes are the raw bytes of the second bitcoin signature // encoded in DER format. BitcoinSig2Bytes []byte } -// Node1Sig is the signature using the identity key of the node that is first -// in a lexicographical ordering of the serialized public keys of the two nodes -// that created the channel. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node1Sig() (*btcec.Signature, error) { - if c.nodeSig1 != nil { - return c.nodeSig1, nil - } - - sig, err := btcec.ParseSignature(c.NodeSig1Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.nodeSig1 = sig - - return sig, nil -} - -// Node2Sig is the signature using the identity key of the node that is second -// in a lexicographical ordering of the serialized public keys of the two nodes -// that created the channel. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node2Sig() (*btcec.Signature, error) { - if c.nodeSig2 != nil { - return c.nodeSig2, nil - } - - sig, err := btcec.ParseSignature(c.NodeSig2Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.nodeSig2 = sig - - return sig, nil -} - -// BitcoinSig1 is the signature using the public key of the first node that was -// used in the channel's multi-sig output. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig1() (*btcec.Signature, error) { - if c.bitcoinSig1 != nil { - return c.bitcoinSig1, nil - } - - sig, err := btcec.ParseSignature(c.BitcoinSig1Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.bitcoinSig1 = sig - - return sig, nil -} - -// BitcoinSig2 is the signature using the public key of the second node that -// was used in the channel's multi-sig output. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig2() (*btcec.Signature, error) { - if c.bitcoinSig2 != nil { - return c.bitcoinSig2, nil - } - - sig, err := btcec.ParseSignature(c.BitcoinSig2Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.bitcoinSig2 = sig - - return sig, nil -} - // IsEmpty check is the authentication proof is empty Proof is empty if at // least one of the signatures are equal to nil. func (c *ChannelAuthProof) IsEmpty() bool { @@ -2742,9 +494,6 @@ type ChannelEdgePolicy struct { // use SetSigBytes instead to make sure that the cache is invalidated. SigBytes []byte - // sig is a cached fully parsed signature. - sig *btcec.Signature - // ChannelID is the unique channel ID for the channel. The first 3 // bytes are the block height, the next 3 the index within the block, // and the last 2 bytes are the output index for the channel. @@ -2794,35 +543,6 @@ type ChannelEdgePolicy struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db *DB -} - -// Signature is a channel announcement signature, which is needed for proper -// edge policy announcement. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelEdgePolicy) Signature() (*btcec.Signature, error) { - if c.sig != nil { - return c.sig, nil - } - - sig, err := btcec.ParseSignature(c.SigBytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.sig = sig - - return sig, nil -} - -// SetSigBytes updates the signature and invalidates the cached parsed -// signature. -func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { - c.SigBytes = sig - c.sig = nil } // IsDisabled determines whether the edge has the disabled bit set. @@ -2831,488 +551,6 @@ func (c *ChannelEdgePolicy) IsDisabled() bool { lnwire.ChanUpdateDisabled } -// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over -// the passed active payment channel. This value is currently computed as -// specified in BOLT07, but will likely change in the near future. -func (c *ChannelEdgePolicy) ComputeFee( - amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { - - return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts -} - -// divideCeil divides dividend by factor and rounds the result up. -func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { - return (dividend + factor - 1) / factor -} - -// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming -// amount. -func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( - incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { - - return incomingAmt - divideCeil( - feeRateParts*(incomingAmt-c.FeeBaseMSat), - feeRateParts+c.FeeProportionalMillionths, - ) -} - -// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for -// the channel identified by the funding outpoint. If the channel can't be -// found, then ErrEdgeNotFound is returned. A struct which houses the general -// information for the channel itself is returned as well as two structs that -// contain the routing policies for the channel in either direction. -func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - var ( - edgeInfo *ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // If the channel's outpoint doesn't exist within the outpoint - // index, then the edge does not exist. - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrGraphNoEdgesFound - } - var b bytes.Buffer - if err := writeOutpoint(&b, op); err != nil { - return err - } - chanID := chanIndex.Get(b.Bytes()) - if chanID == nil { - return ErrEdgeNotFound - } - - // If the channel is found to exists, then we'll first retrieve - // the general information for the channel. - edge, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - edgeInfo = &edge - edgeInfo.db = c.db - - // Once we have the information about the channels' parameters, - // we'll fetch the routing policies for each for the directed - // edges. - e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - return err - } - - policy1 = e1 - policy2 = e2 - return nil - }) - if err != nil { - return nil, nil, nil, err - } - - return edgeInfo, policy1, policy2, nil -} - -// FetchChannelEdgesByID attempts to lookup the two directed edges for the -// channel identified by the channel ID. If the channel can't be found, then -// ErrEdgeNotFound is returned. A struct which houses the general information -// for the channel itself is returned as well as two structs that contain the -// routing policies for the channel in either direction. -// -// ErrZombieEdge an be returned if the edge is currently marked as a zombie -// within the database. In this case, the ChannelEdgePolicy's will be nil, and -// the ChannelEdgeInfo will only include the public keys of each node. -func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - var ( - edgeInfo *ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy - channelID [8]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - byteOrder.PutUint64(channelID[:], chanID) - - // Now, attempt to fetch edge. - edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:]) - - // If it doesn't exist, we'll quickly check our zombie index to - // see if we've previously marked it as so. - if err == ErrEdgeNotFound { - // If the zombie index doesn't exist, or the edge is not - // marked as a zombie within it, then we'll return the - // original ErrEdgeNotFound error. - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return ErrEdgeNotFound - } - - isZombie, pubKey1, pubKey2 := isZombieEdge( - zombieIndex, chanID, - ) - if !isZombie { - return ErrEdgeNotFound - } - - // Otherwise, the edge is marked as a zombie, so we'll - // populate the edge info with the public keys of each - // party as this is the only information we have about - // it and return an error signaling so. - edgeInfo = &ChannelEdgeInfo{ - NodeKey1Bytes: pubKey1, - NodeKey2Bytes: pubKey2, - } - return ErrZombieEdge - } - - // Otherwise, we'll just return the error if any. - if err != nil { - return err - } - - edgeInfo = &edge - edgeInfo.db = c.db - - // Then we'll attempt to fetch the accompanying policies of this - // edge. - e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, channelID[:], c.db, - ) - if err != nil { - return err - } - - policy1 = e1 - policy2 = e2 - return nil - }) - if err == ErrZombieEdge { - return edgeInfo, nil, nil, err - } - if err != nil { - return nil, nil, nil, err - } - - return edgeInfo, policy1, policy2, nil -} - -// IsPublicNode is a helper method that determines whether the node with the -// given public key is seen as a public node in the graph from the graph's -// source node's point of view. -func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { - var nodeIsPublic bool - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - ourPubKey := nodes.Get(sourceKey) - if ourPubKey == nil { - return ErrSourceNodeNotSet - } - node, err := fetchLightningNode(nodes, pubKey[:]) - if err != nil { - return err - } - - nodeIsPublic, err = node.isPublic(tx, ourPubKey) - return err - }) - if err != nil { - return false, err - } - - return nodeIsPublic, nil -} - -// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. -func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { - if len(aPub) != 33 || len(bPub) != 33 { - return nil, fmt.Errorf("Pubkey size error. Compressed " + - "pubkeys only") - } - - // Swap to sort pubkeys if needed. Keys are sorted in lexicographical - // order. The signatures within the scriptSig must also adhere to the - // order, ensuring that the signatures for each public key appears in - // the proper order on the stack. - if bytes.Compare(aPub, bPub) == 1 { - aPub, bPub = bPub, aPub - } - - // First, we'll generate the witness script for the multi-sig. - bldr := txscript.NewScriptBuilder() - bldr.AddOp(txscript.OP_2) - bldr.AddData(aPub) // Add both pubkeys (sorted). - bldr.AddData(bPub) - bldr.AddOp(txscript.OP_2) - bldr.AddOp(txscript.OP_CHECKMULTISIG) - witnessScript, err := bldr.Script() - if err != nil { - return nil, err - } - - // With the witness script generated, we'll now turn it into a p2sh - // script: - // * OP_0 - bldr = txscript.NewScriptBuilder() - bldr.AddOp(txscript.OP_0) - scriptHash := sha256.Sum256(witnessScript) - bldr.AddData(scriptHash[:]) - - return bldr.Script() -} - -// EdgePoint couples the outpoint of a channel with the funding script that it -// creates. The FilteredChainView will use this to watch for spends of this -// edge point on chain. We require both of these values as depending on the -// concrete implementation, either the pkScript, or the out point will be used. -type EdgePoint struct { - // FundingPkScript is the p2wsh multi-sig script of the target channel. - FundingPkScript []byte - - // OutPoint is the outpoint of the target channel. - OutPoint wire.OutPoint -} - -// String returns a human readable version of the target EdgePoint. We return -// the outpoint directly as it is enough to uniquely identify the edge point. -func (e *EdgePoint) String() string { - return e.OutPoint.String() -} - -// ChannelView returns the verifiable edge information for each active channel -// within the known channel graph. The set of UTXO's (along with their scripts) -// returned are the ones that need to be watched on chain to detect channel -// closes on the resident blockchain. -func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { - var edgePoints []EdgePoint - if err := c.db.View(func(tx *bbolt.Tx) error { - // We're going to iterate over the entire channel index, so - // we'll need to fetch the edgeBucket to get to the index as - // it's a sub-bucket. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // Once we have the proper bucket, we'll range over each key - // (which is the channel point for the channel) and decode it, - // accumulating each entry. - return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { - chanPointReader := bytes.NewReader(chanPointBytes) - - var chanPoint wire.OutPoint - err := readOutpoint(chanPointReader, &chanPoint) - if err != nil { - return err - } - - edgeInfo, err := fetchChanEdgeInfo( - edgeIndex, chanID, - ) - if err != nil { - return err - } - - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) - if err != nil { - return err - } - - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: chanPoint, - }) - - return nil - }) - }); err != nil { - return nil, err - } - - return edgePoints, nil -} - -// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. -func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { - return &ChannelEdgePolicy{db: c.db} -} - -// markEdgeZombie marks an edge as a zombie within our zombie index. The public -// keys should represent the node public keys of the two parties involved in the -// edge. -func markEdgeZombie(zombieIndex *bbolt.Bucket, chanID uint64, pubKey1, - pubKey2 [33]byte) error { - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - - var v [66]byte - copy(v[:33], pubKey1[:]) - copy(v[33:], pubKey2[:]) - - return zombieIndex.Put(k[:], v[:]) -} - -// MarkEdgeLive clears an edge from our zombie index, deeming it as live. -func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - return zombieIndex.Delete(k[:]) - }) - if err != nil { - return err - } - - c.rejectCache.remove(chanID) - c.chanCache.remove(chanID) - - return nil -} - -// IsZombieEdge returns whether the edge is considered zombie. If it is a -// zombie, then the two node public keys corresponding to this edge are also -// returned. -func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { - var ( - isZombie bool - pubKey1, pubKey2 [33]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) - return nil - }) - if err != nil { - return false, [33]byte{}, [33]byte{} - } - - return isZombie, pubKey1, pubKey2 -} - -// isZombieEdge returns whether an entry exists for the given channel in the -// zombie index. If an entry exists, then the two node public keys corresponding -// to this edge are also returned. -func isZombieEdge(zombieIndex *bbolt.Bucket, - chanID uint64) (bool, [33]byte, [33]byte) { - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - - v := zombieIndex.Get(k[:]) - if v == nil { - return false, [33]byte{}, [33]byte{} - } - - var pubKey1, pubKey2 [33]byte - copy(pubKey1[:], v[:33]) - copy(pubKey2[:], v[33:]) - - return true, pubKey1, pubKey2 -} - -// NumZombies returns the current number of zombie channels in the graph. -func (c *ChannelGraph) NumZombies() (uint64, error) { - var numZombies uint64 - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return nil - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - return zombieIndex.ForEach(func(_, _ []byte) error { - numZombies++ - return nil - }) - }) - if err != nil { - return 0, err - } - - return numZombies, nil -} - func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket, updateIndex *bbolt.Bucket, node *LightningNode) error { @@ -3548,84 +786,6 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } -func putChanEdgeInfo(edgeIndex *bbolt.Bucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error { - var b bytes.Buffer - - if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return err - } - - if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { - return err - } - - authProof := edgeInfo.AuthProof - var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte - if authProof != nil { - nodeSig1 = authProof.NodeSig1Bytes - nodeSig2 = authProof.NodeSig2Bytes - bitcoinSig1 = authProof.BitcoinSig1Bytes - bitcoinSig2 = authProof.BitcoinSig2Bytes - } - - if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { - return err - } - - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { - return err - } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { - return err - } - if _, err := b.Write(chanID[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { - return err - } - - if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { - return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) - } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) - if err != nil { - return err - } - - return edgeIndex.Put(chanID[:], b.Bytes()) -} - -func fetchChanEdgeInfo(edgeIndex *bbolt.Bucket, - chanID []byte) (ChannelEdgeInfo, error) { - - edgeInfoBytes := edgeIndex.Get(chanID) - if edgeInfoBytes == nil { - return ChannelEdgeInfo{}, ErrEdgeNotFound - } - - edgeInfoReader := bytes.NewReader(edgeInfoBytes) - return deserializeChanEdgeInfo(edgeInfoReader) -} - func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { var ( err error @@ -3856,47 +1016,6 @@ func fetchChanEdgePolicy(edges *bbolt.Bucket, chanID []byte, return ep, nil } -func fetchChanEdgePolicies(edgeIndex *bbolt.Bucket, edges *bbolt.Bucket, - nodes *bbolt.Bucket, chanID []byte, - db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - edgeInfo := edgeIndex.Get(chanID) - if edgeInfo == nil { - return nil, nil, ErrEdgeNotFound - } - - // The first node is contained within the first half of the edge - // information. We only propagate the error here and below if it's - // something other than edge non-existence. - node1Pub := edgeInfo[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) - if err != nil { - return nil, nil, err - } - - // As we may have a single direction of the edge but not the other, - // only fill in the database pointers if the edge is found. - if edge1 != nil { - edge1.db = db - edge1.Node.db = db - } - - // Similarly, the second node is contained within the latter - // half of the edge information. - node2Pub := edgeInfo[33:66] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) - if err != nil { - return nil, nil, err - } - - if edge2 != nil { - edge2.db = db - edge2.Node.db = db - } - - return edge1, edge2, nil -} - func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, to []byte) error { diff --git a/channeldb/migration_01_to_11/graph_test.go b/channeldb/migration_01_to_11/graph_test.go index 00a8a000..a65f0046 100644 --- a/channeldb/migration_01_to_11/graph_test.go +++ b/channeldb/migration_01_to_11/graph_test.go @@ -1,24 +1,13 @@ package migration_01_to_11 import ( - "bytes" - "crypto/sha256" - "fmt" "image/color" - "math" "math/big" prand "math/rand" "net" - "reflect" - "runtime" - "testing" "time" "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" ) @@ -66,3132 +55,3 @@ func createTestVertex(db *DB) (*LightningNode, error) { return createLightningNode(db, priv) } - -func TestNodeInsertionAndDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test basic insertion/deletion for vertexes from the - // graph, so we'll create a test vertex to start with. - _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - node := &LightningNode{ - HaveNodeAnnouncement: true, - AuthSigBytes: testSig.Serialize(), - LastUpdate: time.Unix(1232342, 0), - Color: color.RGBA{1, 2, 3, 0}, - Alias: "kek", - Features: testFeatures, - Addresses: testAddrs, - ExtraOpaqueData: []byte("extra new data"), - db: db, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - // First, insert the node into the graph DB. This should succeed - // without any errors. - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, fetch the node from the database to ensure everything was - // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) - if err != nil { - t.Fatalf("unable to locate node: %v", err) - } - - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { - t.Fatalf("unable to query for node: %v", err) - } else if !exists { - t.Fatalf("node should be found but wasn't") - } - - // The two nodes should match exactly! - if err := compareNodes(node, dbNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // Next, delete the node from the graph, this should purge all data - // related to the node. - if err := graph.DeleteLightningNode(testPub); err != nil { - t.Fatalf("unable to delete node; %v", err) - } - - // Finally, attempt to fetch the node again. This should fail as the - // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) - if err != ErrGraphNodeNotFound { - t.Fatalf("fetch after delete should fail!") - } -} - -// TestPartialNode checks that we can add and retrieve a LightningNode where -// where only the pubkey is known to the database. -func TestPartialNode(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We want to be able to insert nodes into the graph that only has the - // PubKey set. - _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - node := &LightningNode{ - HaveNodeAnnouncement: false, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, fetch the node from the database to ensure everything was - // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) - if err != nil { - t.Fatalf("unable to locate node: %v", err) - } - - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { - t.Fatalf("unable to query for node: %v", err) - } else if !exists { - t.Fatalf("node should be found but wasn't") - } - - // The two nodes should match exactly! (with default values for - // LastUpdate and db set to satisfy compareNodes()) - node = &LightningNode{ - HaveNodeAnnouncement: false, - LastUpdate: time.Unix(0, 0), - db: db, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - if err := compareNodes(node, dbNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // Next, delete the node from the graph, this should purge all data - // related to the node. - if err := graph.DeleteLightningNode(testPub); err != nil { - t.Fatalf("unable to delete node: %v", err) - } - - // Finally, attempt to fetch the node again. This should fail as the - // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) - if err != ErrGraphNodeNotFound { - t.Fatalf("fetch after delete should fail!") - } -} - -func TestAliasLookup(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the alias index within the database, so first - // create a new test node. - testNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // Add the node to the graph's database, this should also insert an - // entry into the alias index for this node. - if err := graph.AddLightningNode(testNode); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, attempt to lookup the alias. The alias should exactly match - // the one which the test node was assigned. - nodePub, err := testNode.PubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - } - dbAlias, err := graph.LookupAlias(nodePub) - if err != nil { - t.Fatalf("unable to find alias: %v", err) - } - if dbAlias != testNode.Alias { - t.Fatalf("aliases don't match, expected %v got %v", - testNode.Alias, dbAlias) - } - - // Ensure that looking up a non-existent alias results in an error. - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - nodePub, err = node.PubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - } - _, err = graph.LookupAlias(nodePub) - if err != ErrNodeAliasNotFound { - t.Fatalf("alias lookup should fail for non-existent pubkey") - } -} - -func TestSourceNode(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the setting/getting of the source node, so we - // first create a fake node to use within the test. - testNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // Attempt to fetch the source node, this should return an error as the - // source node hasn't yet been set. - if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet { - t.Fatalf("source node shouldn't be set in new graph") - } - - // Set the source the source node, this should insert the node into the - // database in a special way indicating it's the source node. - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // Retrieve the source node from the database, it should exactly match - // the one we set above. - sourceNode, err := graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - if err := compareNodes(testNode, sourceNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } -} - -func TestEdgeInsertionDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the insertion/deletion of edges, so we create two - // vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - chanID := uint64(prand.Int63()) - outpoint := wire.OutPoint{ - Hash: rev, - Index: 9, - } - - // Add the new edge to the database, this should proceed without any - // errors. - node1Pub, err := node1.PubKey() - if err != nil { - t.Fatalf("unable to generate node key: %v", err) - } - node2Pub, err := node2.PubKey() - if err != nil { - t.Fatalf("unable to generate node key: %v", err) - } - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure that both policies are returned as unknown (nil). - _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel edge") - } - if e1 != nil || e2 != nil { - t.Fatalf("channel edges not unknown") - } - - // Next, attempt to delete the edge from the database, again this - // should proceed without any issues. - if err := graph.DeleteChannelEdges(chanID); err != nil { - t.Fatalf("unable to delete edge: %v", err) - } - - // Ensure that any query attempts to lookup the delete channel edge are - // properly deleted. - if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil { - t.Fatalf("channel edge not deleted") - } - if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil { - t.Fatalf("channel edge not deleted") - } - isZombie, _, _ := graph.IsZombieEdge(chanID) - if !isZombie { - t.Fatal("channel edge not marked as zombie") - } - - // Finally, attempt to delete a (now) non-existent edge within the - // database, this should result in an error. - err = graph.DeleteChannelEdges(chanID) - if err != ErrEdgeNotFound { - t.Fatalf("deleting a non-existent edge should fail!") - } -} - -func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { - - shortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - TxIndex: txIndex, - TxPosition: txPosition, - } - outpoint := wire.OutPoint{ - Hash: rev, - Index: outPointIndex, - } - - node1Pub, _ := node1.PubKey() - node2Pub, _ := node2.PubKey() - edgeInfo := ChannelEdgeInfo{ - ChannelID: shortChanID.ToUint64(), - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - return edgeInfo, shortChanID -} - -// TestDisconnectBlockAtHeight checks that the pruned state of the channel -// database is what we expect after calling DisconnectBlockAtHeight. -func TestDisconnectBlockAtHeight(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // We'd like to test the insertion/deletion of edges, so we create two - // vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - var spendOutputs []*wire.OutPoint - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) - - // Prune the graph a few times to make sure we have entries in the - // prune log. - _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - var blockHash2 chainhash.Hash - copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) - - _, err = graph.PruneGraph(spendOutputs, &blockHash2, 156) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // We'll create 3 almost identical edges, so first create a helper - // method containing all logic for doing so. - - // Create an edge which has its block height at 156. - height := uint32(156) - edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) - - // Create an edge with block height 157. We give it - // maximum values for tx index and position, to make - // sure our database range scan get edges from the - // entire range. - edgeInfo2, _ := createEdge( - height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, - node1, node2, - ) - - // Create a third edge, this with a block height of 155. - edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) - - // Now add all these new edges to the database. - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - if err := graph.AddChannelEdge(&edgeInfo2); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - if err := graph.AddChannelEdge(&edgeInfo3); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Call DisconnectBlockAtHeight, which should prune every channel - // that has a funding height of 'height' or greater. - removed, err := graph.DisconnectBlockAtHeight(uint32(height)) - if err != nil { - t.Fatalf("unable to prune %v", err) - } - - // The two edges should have been removed. - if len(removed) != 2 { - t.Fatalf("expected two edges to be removed from graph, "+ - "only %d were", len(removed)) - } - if removed[0].ChannelID != edgeInfo.ChannelID { - t.Fatalf("expected edge to be removed from graph") - } - if removed[1].ChannelID != edgeInfo2.ChannelID { - t.Fatalf("expected edge to be removed from graph") - } - - // The two first edges should be removed from the db. - _, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if has { - t.Fatalf("edge1 was not pruned from the graph") - } - if isZombie { - t.Fatal("reorged edge1 should not be marked as zombie") - } - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if has { - t.Fatalf("edge2 was not pruned from the graph") - } - if isZombie { - t.Fatal("reorged edge2 should not be marked as zombie") - } - - // Edge 3 should not be removed. - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if !has { - t.Fatalf("edge3 was pruned from the graph") - } - if isZombie { - t.Fatal("edge3 was marked as zombie") - } - - // PruneTip should be set to the blockHash we specified for the block - // at height 155. - hash, h, err := graph.PruneTip() - if err != nil { - t.Fatalf("unable to get prune tip: %v", err) - } - if !blockHash.IsEqual(hash) { - t.Fatalf("expected best block to be %x, was %x", blockHash, hash) - } - if h != height-1 { - t.Fatalf("expected best block height to be %d, was %d", height-1, h) - } -} - -func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, - e2 *ChannelEdgeInfo) { - - if e1.ChannelID != e2.ChannelID { - t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, - e2.ChannelID) - } - - if e1.ChainHash != e2.ChainHash { - t.Fatalf("chain hashes don't match: %v vs %v", e1.ChainHash, - e2.ChainHash) - } - - if !bytes.Equal(e1.NodeKey1Bytes[:], e2.NodeKey1Bytes[:]) { - t.Fatalf("nodekey1 doesn't match") - } - if !bytes.Equal(e1.NodeKey2Bytes[:], e2.NodeKey2Bytes[:]) { - t.Fatalf("nodekey2 doesn't match") - } - if !bytes.Equal(e1.BitcoinKey1Bytes[:], e2.BitcoinKey1Bytes[:]) { - t.Fatalf("bitcoinkey1 doesn't match") - } - if !bytes.Equal(e1.BitcoinKey2Bytes[:], e2.BitcoinKey2Bytes[:]) { - t.Fatalf("bitcoinkey2 doesn't match") - } - - if !bytes.Equal(e1.Features, e2.Features) { - t.Fatalf("features doesn't match: %x vs %x", e1.Features, - e2.Features) - } - - if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) { - t.Fatalf("nodesig1 doesn't match: %v vs %v", - spew.Sdump(e1.AuthProof.NodeSig1Bytes), - spew.Sdump(e2.AuthProof.NodeSig1Bytes)) - } - if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) { - t.Fatalf("nodesig2 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) { - t.Fatalf("bitcoinsig1 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) { - t.Fatalf("bitcoinsig2 doesn't match") - } - - if e1.ChannelPoint != e2.ChannelPoint { - t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint, - e2.ChannelPoint) - } - - if e1.Capacity != e2.Capacity { - t.Fatalf("capacity doesn't match: %v vs %v", e1.Capacity, - e2.Capacity) - } - - if !bytes.Equal(e1.ExtraOpaqueData, e2.ExtraOpaqueData) { - t.Fatalf("extra data doesn't match: %v vs %v", - e2.ExtraOpaqueData, e2.ExtraOpaqueData) - } -} - -func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) { - - var ( - firstNode *LightningNode - secondNode *LightningNode - ) - if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { - firstNode = node1 - secondNode = node2 - } else { - firstNode = node2 - secondNode = node1 - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - chanID := uint64(prand.Int63()) - outpoint := wire.OutPoint{ - Hash: rev, - Index: 9, - } - - // Add the new edge to the database, this should proceed without any - // errors. - edgeInfo := &ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 1000, - ExtraOpaqueData: []byte("new unknown feature"), - } - copy(edgeInfo.NodeKey1Bytes[:], firstNode.PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], secondNode.PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) - - edge1 := &ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(433453, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 3452352, - Node: secondNode, - ExtraOpaqueData: []byte("new unknown feature2"), - db: db, - } - edge2 := &ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(124234, 0), - MessageFlags: 1, - ChannelFlags: 1, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 90392423, - Node: firstNode, - ExtraOpaqueData: []byte("new unknown feature1"), - db: db, - } - - return edgeInfo, edge1, edge2 -} - -func TestEdgeInfoUpdates(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the update of edges inserted into the database, so - // we create two vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - - // Make sure inserting the policy at this point, before the edge info - // is added, will fail. - if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { - t.Fatalf("expected ErrEdgeNotFound, got: %v", err) - } - - // Add the edge info. - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanID := edgeInfo.ChannelID - outpoint := edgeInfo.ChannelPoint - - // Next, insert both edge policies into the database, they should both - // be inserted without any issues. - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Check for existence of the edge within the database, it should be - // found. - _, _, found, isZombie, err := graph.HasChannelEdge(chanID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if !found { - t.Fatalf("graph should have of inserted edge") - } - if isZombie { - t.Fatal("live edge should not be marked as zombie") - } - - // We should also be able to retrieve the channelID only knowing the - // channel point of the channel. - dbChanID, err := graph.ChannelID(&outpoint) - if err != nil { - t.Fatalf("unable to retrieve channel ID: %v", err) - } - if dbChanID != chanID { - t.Fatalf("chan ID's mismatch, expected %v got %v", dbChanID, - chanID) - } - - // With the edges inserted, perform some queries to ensure that they've - // been inserted properly. - dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) - - // Next, attempt to query the channel edges according to the outpoint - // of the channel. - dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) -} - -func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { - update := prand.Int63() - - return newEdgePolicy(chanID, op, db, update) -} - -func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, - updateTime int64) *ChannelEdgePolicy { - - return &ChannelEdgePolicy{ - ChannelID: chanID, - LastUpdate: time.Unix(updateTime, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: uint16(prand.Int63()), - MinHTLC: lnwire.MilliSatoshi(prand.Int63()), - MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), - FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), - FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), - db: db, - } -} - -func TestGraphTraversal(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test some of the graph traversal capabilities within - // the DB, so we'll create a series of fake nodes to insert into the - // graph. - const numNodes = 20 - nodes := make([]*LightningNode, numNodes) - nodeIndex := map[string]struct{}{} - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - nodes[i] = node - nodeIndex[node.Alias] = struct{}{} - } - - // Add each of the nodes into the graph, they should be inserted - // without error. - for _, node := range nodes { - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - - // Iterate over each node as returned by the graph, if all nodes are - // reached, then the map created above should be empty. - err = graph.ForEachNode(nil, func(_ *bbolt.Tx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(nodeIndex) != 0 { - t.Fatalf("all nodes not reached within ForEach") - } - - // Determine which node is "smaller", we'll need this in order to - // properly create the edges for the graph. - var firstNode, secondNode *LightningNode - if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { - firstNode = nodes[0] - secondNode = nodes[1] - } else { - firstNode = nodes[0] - secondNode = nodes[1] - } - - // Create 5 channels between the first two nodes we generated above. - const numChannels = 5 - chanIndex := map[uint64]struct{}{} - for i := 0; i < numChannels; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) - err := graph.AddChannelEdge(&edgeInfo) - if err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create and add an edge with random data that points from - // node1 -> node2. - edge := randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 0 - edge.Node = secondNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node2 -> node1 - // this time. - edge = randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 1 - edge.Node = firstNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - chanIndex[chanID] = struct{}{} - } - - // Iterate through all the known channels within the graph DB, once - // again if the map is empty that indicates that all edges have - // properly been reached. - err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy, - _ *ChannelEdgePolicy) error { - - delete(chanIndex, ei.ChannelID) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(chanIndex) != 0 { - t.Fatalf("all edges not reached within ForEach") - } - - // Finally, we want to test the ability to iterate over all the - // outgoing channels for a particular node. - numNodeChans := 0 - err = firstNode.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy) error { - - // All channels between first and second node should have fully - // (both sides) specified policies. - if inEdge == nil || outEdge == nil { - return fmt.Errorf("channel policy not present") - } - - // Each should indicate that it's outgoing (pointed - // towards the second node). - if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } - - // The incoming edge should also indicate that it's pointing to - // the origin node. - if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } - - numNodeChans++ - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if numNodeChans != numChannels { - t.Fatalf("all edges for node not reached within ForEach: "+ - "expected %v, got %v", numChannels, numNodeChans) - } -} - -func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, - blockHeight uint32) { - - pruneHash, pruneHeight, err := graph.PruneTip() - if err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to fetch prune tip: %v", line, err) - } - if !bytes.Equal(blockHash[:], pruneHash[:]) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line: %v, prune tips don't match, expected %x got %x", - line, blockHash, pruneHash) - } - if pruneHeight != blockHeight { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: prune heights don't match, expected %v "+ - "got %v", line, blockHeight, pruneHeight) - } -} - -func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { - numChans := 0 - if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error { - - numChans++ - return nil - }); err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to scan channels: %v", line, err) - } - if numChans != n { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v chans instead have %v", line, - n, numChans) - } -} - -func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { - numNodes := 0 - err := graph.ForEachNode(nil, func(_ *bbolt.Tx, _ *LightningNode) error { - numNodes++ - return nil - }) - if err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to scan nodes: %v", line, err) - } - - if numNodes != n { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) - } -} - -func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { - if len(a) != len(b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chan views don't match", line) - } - - chanViewSet := make(map[wire.OutPoint]struct{}) - for _, op := range a { - chanViewSet[op.OutPoint] = struct{}{} - } - - for _, op := range b { - if _, ok := chanViewSet[op.OutPoint]; !ok { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chanPoint(%v) not found in first "+ - "view", line, op) - } - } -} - -func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { - if len(a) != len(b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chan views don't match", line) - } - - chanViewSet := make(map[wire.OutPoint]struct{}) - for _, op := range a { - chanViewSet[op.OutPoint] = struct{}{} - } - - for _, op := range b { - if _, ok := chanViewSet[*op]; !ok { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chanPoint(%v) not found in first "+ - "view", line, op) - } - } -} - -func TestGraphPruning(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // As initial set up for the test, we'll create a graph with 5 vertexes - // and enough edges to create a fully connected graph. The graph will - // be rather simple, representing a straight line. - const numNodes = 5 - graphNodes := make([]*LightningNode, numNodes) - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - graphNodes[i] = node - } - - // With the vertexes created, we'll next create a series of channels - // between them. - channelPoints := make([]*wire.OutPoint, 0, numNodes-1) - edgePoints := make([]EdgePoint, 0, numNodes-1) - for i := 0; i < numNodes-1; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channelPoints = append(channelPoints, &op) - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], - ) - if err != nil { - t.Fatalf("unable to gen multi-sig p2wsh: %v", err) - } - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: op, - }) - - // Create and add an edge with random data that points from - // node_i -> node_i+1 - edge := randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 0 - edge.Node = graphNodes[i] - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node_i+1 -> - // node_i this time. - edge = randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 1 - edge.Node = graphNodes[i] - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - } - - // With all the channel points added, we'll consult the graph to ensure - // it has the same channel view as the one we just constructed. - channelView, err := graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - assertChanViewEqual(t, channelView, edgePoints) - - // Now with our test graph created, we can test the pruning - // capabilities of the channel graph. - - // First we create a mock block that ends up closing the first two - // channels. - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) - blockHeight := uint32(1) - block := channelPoints[:2] - prunedChans, err := graph.PruneGraph(block, &blockHash, blockHeight) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - if len(prunedChans) != 2 { - t.Fatalf("incorrect number of channels pruned: "+ - "expected %v, got %v", 2, prunedChans) - } - - // Now ensure that the prune tip has been updated. - assertPruneTip(t, graph, &blockHash, blockHeight) - - // Count up the number of channels known within the graph, only 2 - // should be remaining. - assertNumChans(t, graph, 2) - - // Those channels should also be missing from the channel view. - channelView, err = graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - assertChanViewEqualChanPoints(t, channelView, channelPoints[2:]) - - // Next we'll create a block that doesn't close any channels within the - // graph to test the negative error case. - fakeHash := sha256.Sum256([]byte("test prune")) - nonChannel := &wire.OutPoint{ - Hash: fakeHash, - Index: 9, - } - blockHash = sha256.Sum256(blockHash[:]) - blockHeight = 2 - prunedChans, err = graph.PruneGraph( - []*wire.OutPoint{nonChannel}, &blockHash, blockHeight, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // No channels should have been detected as pruned. - if len(prunedChans) != 0 { - t.Fatalf("channels were pruned but shouldn't have been") - } - - // Once again, the prune tip should have been updated. We should still - // see both channels and their participants, along with the source node. - assertPruneTip(t, graph, &blockHash, blockHeight) - assertNumChans(t, graph, 2) - assertNumNodes(t, graph, 4) - - // Finally, create a block that prunes the remainder of the channels - // from the graph. - blockHash = sha256.Sum256(blockHash[:]) - blockHeight = 3 - prunedChans, err = graph.PruneGraph( - channelPoints[2:], &blockHash, blockHeight, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // The remainder of the channels should have been pruned from the - // graph. - if len(prunedChans) != 2 { - t.Fatalf("incorrect number of channels pruned: "+ - "expected %v, got %v", 2, len(prunedChans)) - } - - // The prune tip should be updated, no channels should be found, and - // only the source node should remain within the current graph. - assertPruneTip(t, graph, &blockHash, blockHeight) - assertNumChans(t, graph, 0) - assertNumNodes(t, graph, 1) - - // Finally, the channel view at this point in the graph should now be - // completely empty. Those channels should also be missing from the - // channel view. - channelView, err = graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - if len(channelView) != 0 { - t.Fatalf("channel view should be empty, instead have: %v", - channelView) - } -} - -// TestHighestChanID tests that we're able to properly retrieve the highest -// known channel ID in the database. -func TestHighestChanID(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we don't yet have any channels in the database, then we should - // get a channel ID of zero if we ask for the highest channel ID. - bestID, err := graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - if bestID != 0 { - t.Fatalf("best ID w/ no chan should be zero, is instead: %v", - bestID) - } - - // Next, we'll insert two channels into the database, with each channel - // connecting the same two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // The first channel with be at height 10, while the other will be at - // height 100. - edge1, _ := createEdge(10, 0, 0, 0, node1, node2) - edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) - - if err := graph.AddChannelEdge(&edge1); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - if err := graph.AddChannelEdge(&edge2); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Now that the edges has been inserted, we'll query for the highest - // known channel ID in the database. - bestID, err = graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - - if bestID != chanID2.ToUint64() { - t.Fatalf("expected %v got %v for best chan ID: ", - chanID2.ToUint64(), bestID) - } - - // If we add another edge, then the current best chan ID should be - // updated as well. - edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edge3); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - bestID, err = graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - - if bestID != chanID3.ToUint64() { - t.Fatalf("expected %v got %v for best chan ID: ", - chanID3.ToUint64(), bestID) - } -} - -// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known -// channel updates within a specific time horizon. It also tests that upon -// insertion of a new edge, the edge update index is updated properly. -func TestChanUpdatesInHorizon(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we issue an arbitrary query before any channel updates are - // inserted in the database, we should get zero results. - chanUpdates, err := graph.ChanUpdatesInHorizon( - time.Unix(999, 0), time.Unix(9999, 0), - ) - if err != nil { - t.Fatalf("unable to updates for updates: %v", err) - } - if len(chanUpdates) != 0 { - t.Fatalf("expected 0 chan updates, instead got %v", - len(chanUpdates)) - } - - // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll now create 10 channels between the two nodes, with update - // times 10 seconds after each other. - const numChans = 10 - startTime := time.Unix(1234, 0) - endTime := startTime - edges := make([]ChannelEdge, 0, numChans) - for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - edge1UpdateTime := endTime - edge2UpdateTime := edge1UpdateTime.Add(time.Second) - endTime = endTime.Add(time.Second * 10) - - edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), - ) - edge1.ChannelFlags = 0 - edge1.Node = node2 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), - ) - edge2.ChannelFlags = 1 - edge2.Node = node1 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edges = append(edges, ChannelEdge{ - Info: &channel, - Policy1: edge1, - Policy2: edge2, - }) - } - - // With our channels loaded, we'll now start our series of queries. - queryCases := []struct { - start time.Time - end time.Time - - resp []ChannelEdge - }{ - // If we query for a time range that's strictly below our set - // of updates, then we'll get an empty result back. - { - start: time.Unix(100, 0), - end: time.Unix(200, 0), - }, - - // If we query for a time range that's well beyond our set of - // updates, we should get an empty set of results back. - { - start: time.Unix(99999, 0), - end: time.Unix(999999, 0), - }, - - // If we query for the start time, and 10 seconds directly - // after it, we should only get a single update, that first - // one. - { - start: time.Unix(1234, 0), - end: startTime.Add(time.Second * 10), - - resp: []ChannelEdge{edges[0]}, - }, - - // If we add 10 seconds past the first update, and then - // subtract 10 from the last update, then we should only get - // the 8 edges in the middle. - { - start: startTime.Add(time.Second * 10), - end: endTime.Add(-time.Second * 10), - - resp: edges[1:9], - }, - - // If we use the start and end time as is, we should get the - // entire range. - { - start: startTime, - end: endTime, - - resp: edges, - }, - } - for _, queryCase := range queryCases { - resp, err := graph.ChanUpdatesInHorizon( - queryCase.start, queryCase.end, - ) - if err != nil { - t.Fatalf("unable to query for updates: %v", err) - } - - if len(resp) != len(queryCase.resp) { - t.Fatalf("expected %v chans, got %v chans", - len(queryCase.resp), len(resp)) - - } - - for i := 0; i < len(resp); i++ { - chanExp := queryCase.resp[i] - chanRet := resp[i] - - assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) - - err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) - if err != nil { - t.Fatal(err) - } - compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) - if err != nil { - t.Fatal(err) - } - } - } -} - -// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve -// the most recent node updates within a particular time horizon. -func TestNodeUpdatesInHorizon(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - startTime := time.Unix(1234, 0) - endTime := startTime - - // If we issue an arbitrary query before we insert any nodes into the - // database, then we shouldn't get any results back. - nodeUpdates, err := graph.NodeUpdatesInHorizon( - time.Unix(999, 0), time.Unix(9999, 0), - ) - if err != nil { - t.Fatalf("unable to query for node updates: %v", err) - } - if len(nodeUpdates) != 0 { - t.Fatalf("expected 0 node updates, instead got %v", - len(nodeUpdates)) - } - - // We'll create 10 node announcements, each with an update timestamp 10 - // seconds after the other. - const numNodes = 10 - nodeAnns := make([]LightningNode, 0, numNodes) - for i := 0; i < numNodes; i++ { - nodeAnn, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - - // The node ann will use the current end time as its last - // update them, then we'll add 10 seconds in order to create - // the proper update time for the next node announcement. - updateTime := endTime - endTime = updateTime.Add(time.Second * 10) - - nodeAnn.LastUpdate = updateTime - - nodeAnns = append(nodeAnns, *nodeAnn) - - if err := graph.AddLightningNode(nodeAnn); err != nil { - t.Fatalf("unable to add lightning node: %v", err) - } - } - - queryCases := []struct { - start time.Time - end time.Time - - resp []LightningNode - }{ - // If we query for a time range that's strictly below our set - // of updates, then we'll get an empty result back. - { - start: time.Unix(100, 0), - end: time.Unix(200, 0), - }, - - // If we query for a time range that's well beyond our set of - // updates, we should get an empty set of results back. - { - start: time.Unix(99999, 0), - end: time.Unix(999999, 0), - }, - - // If we skip he first time epoch with out start time, then we - // should get back every now but the first. - { - start: startTime.Add(time.Second * 10), - end: endTime, - - resp: nodeAnns[1:], - }, - - // If we query for the range as is, we should get all 10 - // announcements back. - { - start: startTime, - end: endTime, - - resp: nodeAnns, - }, - - // If we reduce the ending time by 10 seconds, then we should - // get all but the last node we inserted. - { - start: startTime, - end: endTime.Add(-time.Second * 10), - - resp: nodeAnns[:9], - }, - } - for _, queryCase := range queryCases { - resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) - if err != nil { - t.Fatalf("unable to query for nodes: %v", err) - } - - if len(resp) != len(queryCase.resp) { - t.Fatalf("expected %v nodes, got %v nodes", - len(queryCase.resp), len(resp)) - - } - - for i := 0; i < len(resp); i++ { - err := compareNodes(&queryCase.resp[i], &resp[i]) - if err != nil { - t.Fatal(err) - } - } - } -} - -// TestFilterKnownChanIDs tests that we're able to properly perform the set -// differences of an incoming set of channel ID's, and those that we already -// know of on disk. -func TestFilterKnownChanIDs(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we try to filter out a set of channel ID's before we even know of - // any channels, then we should get the entire set back. - preChanIDs := []uint64{1, 2, 3, 4} - filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) - } - if !reflect.DeepEqual(preChanIDs, filteredIDs) { - t.Fatalf("chan IDs shouldn't have been filtered!") - } - - // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, we'll add 5 channel ID's to the graph, each of them having a - // block height 10 blocks after the previous. - const numChans = 5 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanIDs = append(chanIDs, chanID.ToUint64()) - } - - const numZombies = 5 - zombieIDs := make([]uint64, 0, numZombies) - for i := 0; i < numZombies; i++ { - channel, chanID := createEdge( - uint32(i*10+1), 0, 0, 0, node1, node2, - ) - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - err := graph.DeleteChannelEdges(channel.ChannelID) - if err != nil { - t.Fatalf("unable to mark edge zombie: %v", err) - } - - zombieIDs = append(zombieIDs, chanID.ToUint64()) - } - - queryCases := []struct { - queryIDs []uint64 - - resp []uint64 - }{ - // If we attempt to filter out all chanIDs we know of, the - // response should be the empty set. - { - queryIDs: chanIDs, - }, - // If we attempt to filter out all zombies that we know of, the - // response should be the empty set. - { - queryIDs: zombieIDs, - }, - - // If we query for a set of ID's that we didn't insert, we - // should get the same set back. - { - queryIDs: []uint64{99, 100}, - resp: []uint64{99, 100}, - }, - - // If we query for a super-set of our the chan ID's inserted, - // we should only get those new chanIDs back. - { - queryIDs: append(chanIDs, []uint64{99, 101}...), - resp: []uint64{99, 101}, - }, - } - - for _, queryCase := range queryCases { - resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) - } - - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), - spew.Sdump(resp)) - } - } -} - -// TestFilterChannelRange tests that we're able to properly retrieve the full -// set of short channel ID's for a given block range. -func TestFilterChannelRange(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // If we try to filter a channel range before we have any channels - // inserted, we should get an empty slice of results. - resp, err := graph.FilterChannelRange(10, 100) - if err != nil { - t.Fatalf("unable to filter channels: %v", err) - } - if len(resp) != 0 { - t.Fatalf("expected zero chans, instead got %v", len(resp)) - } - - // To start, we'll create a set of channels, each mined in a block 10 - // blocks after the prior one. - startHeight := uint32(100) - endHeight := startHeight - const numChans = 10 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - chanHeight := endHeight - channel, chanID := createEdge( - uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanIDs = append(chanIDs, chanID.ToUint64()) - - endHeight += 10 - } - - // With our channels inserted, we'll construct a series of queries that - // we'll execute below in order to exercise the features of the - // FilterKnownChanIDs method. - queryCases := []struct { - startHeight uint32 - endHeight uint32 - - resp []uint64 - }{ - // If we query for the entire range, then we should get the same - // set of short channel IDs back. - { - startHeight: startHeight, - endHeight: endHeight, - - resp: chanIDs, - }, - - // If we query for a range of channels right before our range, we - // shouldn't get any results back. - { - startHeight: 0, - endHeight: 10, - }, - - // If we only query for the last height (range wise), we should - // only get that last channel. - { - startHeight: endHeight - 10, - endHeight: endHeight - 10, - - resp: chanIDs[9:], - }, - - // If we query for just the first height, we should only get a - // single channel back (the first one). - { - startHeight: startHeight, - endHeight: startHeight, - - resp: chanIDs[:1], - }, - } - for i, queryCase := range queryCases { - resp, err := graph.FilterChannelRange( - queryCase.startHeight, queryCase.endHeight, - ) - if err != nil { - t.Fatalf("unable to issue range query: %v", err) - } - - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("case #%v: expected %v, got %v", i, - queryCase.resp, resp) - } - } -} - -// TestFetchChanInfos tests that we're able to properly retrieve the full set -// of ChannelEdge structs for a given set of short channel ID's. -func TestFetchChanInfos(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll make 5 test channels, ensuring we keep track of which channel - // ID corresponds to a particular ChannelEdge. - const numChans = 5 - startTime := time.Unix(1234, 0) - endTime := startTime - edges := make([]ChannelEdge, 0, numChans) - edgeQuery := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - updateTime := endTime - endTime = updateTime.Add(time.Second * 10) - - edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edge1.ChannelFlags = 0 - edge1.Node = node2 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edge2.ChannelFlags = 1 - edge2.Node = node1 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edges = append(edges, ChannelEdge{ - Info: &channel, - Policy1: edge1, - Policy2: edge2, - }) - - edgeQuery = append(edgeQuery, chanID.ToUint64()) - } - - // Add an additional edge that does not exist. The query should skip - // this channel and return only infos for the edges that exist. - edgeQuery = append(edgeQuery, 500) - - // Add an another edge to the query that has been marked as a zombie - // edge. The query should also skip this channel. - zombieChan, zombieChanID := createEdge( - 666, 0, 0, 0, node1, node2, - ) - if err := graph.AddChannelEdge(&zombieChan); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - err = graph.DeleteChannelEdges(zombieChan.ChannelID) - if err != nil { - t.Fatalf("unable to delete and mark edge zombie: %v", err) - } - edgeQuery = append(edgeQuery, zombieChanID.ToUint64()) - - // We'll now attempt to query for the range of channel ID's we just - // inserted into the database. We should get the exact same set of - // edges back. - resp, err := graph.FetchChanInfos(edgeQuery) - if err != nil { - t.Fatalf("unable to fetch chan edges: %v", err) - } - if len(resp) != len(edges) { - t.Fatalf("expected %v edges, instead got %v", len(edges), - len(resp)) - } - - for i := 0; i < len(resp); i++ { - err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) - if err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) - if err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) - } -} - -// TestIncompleteChannelPolicies tests that a channel that only has a policy -// specified on one end is properly returned in ForEachChannel calls from -// both sides. -func TestIncompleteChannelPolicies(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // Create two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create channel between nodes. - txHash := sha256.Sum256([]byte{0}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(0), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure that channel is reported with unknown policies. - - checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { - calls := 0 - node.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy) error { - - if !expectedOut && outEdge != nil { - t.Fatalf("Expected no outgoing policy") - } - - if expectedOut && outEdge == nil { - t.Fatalf("Expected an outgoing policy") - } - - if !expectedIn && inEdge != nil { - t.Fatalf("Expected no incoming policy") - } - - if expectedIn && inEdge == nil { - t.Fatalf("Expected an incoming policy") - } - - calls++ - - return nil - }) - - if calls != 1 { - t.Fatalf("Expected only one callback call") - } - } - - checkPolicies(node2, false, false) - - // Only create an edge policy for node1 and leave the policy for node2 - // unknown. - updateTime := time.Unix(1234, 0) - - edgePolicy := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edgePolicy.ChannelFlags = 0 - edgePolicy.Node = node2 - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - checkPolicies(node1, false, true) - checkPolicies(node2, true, false) - - // Create second policy and assert that both policies are reported - // as present. - edgePolicy = newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edgePolicy.ChannelFlags = 1 - edgePolicy.Node = node1 - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - checkPolicies(node1, true, true) - checkPolicies(node2, true, true) -} - -// TestChannelEdgePruningUpdateIndexDeletion tests that once edges are deleted -// from the graph, then their entries within the update index are also cleaned -// up. -func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // With the two nodes created, we'll now create a random channel, as - // well as two edges in the database with distinct update times. - edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge1.ChannelFlags = 0 - edge1.Node = node1 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge2.ChannelFlags = 1 - edge2.Node = node2 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // checkIndexTimestamps is a helper function that checks the edge update - // index only includes the given timestamps. - checkIndexTimestamps := func(timestamps ...uint64) { - timestampSet := make(map[uint64]struct{}) - for _, t := range timestamps { - timestampSet[t] = struct{}{} - } - - err := db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) - if edgeUpdateIndex == nil { - return ErrGraphNoEdgesFound - } - - numEntries := edgeUpdateIndex.Stats().KeyN - expectedEntries := len(timestampSet) - if numEntries != expectedEntries { - return fmt.Errorf("expected %v entries in the "+ - "update index, got %v", expectedEntries, - numEntries) - } - - return edgeUpdateIndex.ForEach(func(k, _ []byte) error { - t := byteOrder.Uint64(k[:8]) - if _, ok := timestampSet[t]; !ok { - return fmt.Errorf("found unexpected "+ - "timestamp "+"%d", t) - } - - return nil - }) - }) - if err != nil { - t.Fatal(err) - } - } - - // With both edges policies added, we'll make sure to check they exist - // within the edge update index. - checkIndexTimestamps( - uint64(edge1.LastUpdate.Unix()), - uint64(edge2.LastUpdate.Unix()), - ) - - // Now, we'll update the edge policies to ensure the old timestamps are - // removed from the update index. - edge1.ChannelFlags = 2 - edge1.LastUpdate = time.Now() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - edge2.ChannelFlags = 3 - edge2.LastUpdate = edge1.LastUpdate.Add(time.Hour) - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // With the policies updated, we should now be able to find their - // updated entries within the update index. - checkIndexTimestamps( - uint64(edge1.LastUpdate.Unix()), - uint64(edge2.LastUpdate.Unix()), - ) - - // Now we'll prune the graph, removing the edges, and also the update - // index entries from the database all together. - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) - _, err = graph.PruneGraph( - []*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // Finally, we'll check the database state one last time to conclude - // that we should no longer be able to locate _any_ entries within the - // edge update index. - checkIndexTimestamps() -} - -// TestPruneGraphNodes tests that unconnected vertexes are pruned via the -// PruneSyncState method. -func TestPruneGraphNodes(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - // We'll start off by inserting our source node, to ensure that it's - // the only node left after we prune the graph. - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // With the source node inserted, we'll now add three nodes to the - // channel graph, at the end of the scenario, only two of these nodes - // should still be in the graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node3, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node3); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll now add a new edge to the graph, but only actually advertise - // the edge of *one* of the nodes. - edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // We'll now insert an advertised edge, but it'll only be the edge that - // points from the first to the second node. - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge1.ChannelFlags = 0 - edge1.Node = node1 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // We'll now initiate a around of graph pruning. - if err := graph.PruneGraphNodes(); err != nil { - t.Fatalf("unable to prune graph nodes: %v", err) - } - - // At this point, there should be 3 nodes left in the graph still: the - // source node (which can't be pruned), and node 1+2. Nodes 1 and two - // should still be left in the graph as there's half of an advertised - // edge between them. - assertNumNodes(t, graph, 3) - - // Finally, we'll ensure that node3, the only fully unconnected node as - // properly deleted from the graph and not another node in its place. - node3Pub, err := node3.PubKey() - if err != nil { - t.Fatalf("unable to fetch the pubkey of node3: %v", err) - } - if _, err := graph.FetchLightningNode(node3Pub); err == nil { - t.Fatalf("node 3 should have been deleted!") - } -} - -// TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge -// to the graph, one or both of the nodes the edge involves aren't found in the -// database, then shell edges are created for each node if needed. -func TestAddChannelEdgeShellNodes(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // To start, we'll create two nodes, and only add one of them to the - // channel graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // We'll now create an edge between the two nodes, as a result, node2 - // should be inserted into the database as a shell node. - edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - node1Pub, err := node1.PubKey() - if err != nil { - t.Fatalf("unable to parse node 1 pub: %v", err) - } - node2Pub, err := node2.PubKey() - if err != nil { - t.Fatalf("unable to parse node 2 pub: %v", err) - } - - // Ensure that node1 was inserted as a full node, while node2 only has - // a shell node present. - node1, err = graph.FetchLightningNode(node1Pub) - if err != nil { - t.Fatalf("unable to fetch node1: %v", err) - } - if !node1.HaveNodeAnnouncement { - t.Fatalf("have shell announcement for node1, shouldn't") - } - - node2, err = graph.FetchLightningNode(node2Pub) - if err != nil { - t.Fatalf("unable to fetch node2: %v", err) - } - if node2.HaveNodeAnnouncement { - t.Fatalf("should have shell announcement for node2, but is full") - } -} - -// TestNodePruningUpdateIndexDeletion tests that once a node has been removed -// from the channel graph, we also remove the entry from the update index as -// well. -func TestNodePruningUpdateIndexDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with a single node that will be - // removed shortly. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll confirm that we can retrieve the node using - // NodeUpdatesInHorizon, using a time that's slightly beyond the last - // update time of our test node. - startTime := time.Unix(9, 0) - endTime := node1.LastUpdate.Add(time.Minute) - nodesInHorizon, err := graph.NodeUpdatesInHorizon(startTime, endTime) - if err != nil { - t.Fatalf("unable to fetch nodes in horizon: %v", err) - } - - // We should only have a single node, and that node should exactly - // match the node we just inserted. - if len(nodesInHorizon) != 1 { - t.Fatalf("should have 1 nodes instead have: %v", - len(nodesInHorizon)) - } - if err := compareNodes(node1, &nodesInHorizon[0]); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // We'll now delete the node from the graph, this should result in it - // being removed from the update index as well. - nodePub, _ := node1.PubKey() - if err := graph.DeleteLightningNode(nodePub); err != nil { - t.Fatalf("unable to delete node: %v", err) - } - - // Now that the node has been deleted, we'll again query the nodes in - // the horizon. This time we should have no nodes at all. - nodesInHorizon, err = graph.NodeUpdatesInHorizon(startTime, endTime) - if err != nil { - t.Fatalf("unable to fetch nodes in horizon: %v", err) - } - - if len(nodesInHorizon) != 0 { - t.Fatalf("should have zero nodes instead have: %v", - len(nodesInHorizon)) - } -} - -// TestNodeIsPublic ensures that we properly detect nodes that are seen as -// public within the network graph. -func TestNodeIsPublic(t *testing.T) { - t.Parallel() - - // We'll start off the test by creating a small network of 3 - // participants with the following graph: - // - // Alice <-> Bob <-> Carol - // - // We'll need to create a separate database and channel graph for each - // participant to replicate real-world scenarios (private edges being in - // some graphs but not others, etc.). - aliceDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - aliceNode, err := createTestVertex(aliceDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - aliceGraph := aliceDB.ChannelGraph() - if err := aliceGraph.SetSourceNode(aliceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - bobDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - bobNode, err := createTestVertex(bobDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - bobGraph := bobDB.ChannelGraph() - if err := bobGraph.SetSourceNode(bobNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - carolDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - carolNode, err := createTestVertex(carolDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - carolGraph := carolDB.ChannelGraph() - if err := carolGraph.SetSourceNode(carolNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - aliceBobEdge, _ := createEdge(10, 0, 0, 0, aliceNode, bobNode) - bobCarolEdge, _ := createEdge(10, 1, 0, 1, bobNode, carolNode) - - // After creating all of our nodes and edges, we'll add them to each - // participant's graph. - nodes := []*LightningNode{aliceNode, bobNode, carolNode} - edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} - dbs := []*DB{aliceDB, bobDB, carolDB} - graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} - for i, graph := range graphs { - for _, node := range nodes { - node.db = dbs[i] - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - for _, edge := range edges { - edge.db = dbs[i] - if err := graph.AddChannelEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - } - } - - // checkNodes is a helper closure that will be used to assert that the - // given nodes are seen as public/private within the given graphs. - checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph, - public bool) { - - t.Helper() - - for _, node := range nodes { - for _, graph := range graphs { - isPublic, err := graph.IsPublicNode(node.PubKeyBytes) - if err != nil { - t.Fatalf("unable to determine if pivot "+ - "is public: %v", err) - } - - switch { - case isPublic && !public: - t.Fatalf("expected %x to be private", - node.PubKeyBytes) - case !isPublic && public: - t.Fatalf("expected %x to be public", - node.PubKeyBytes) - } - } - } - } - - // Due to the way the edges were set up above, we'll make sure each node - // can correctly determine that every other node is public. - checkNodes(nodes, graphs, true) - - // Now, we'll remove the edge between Alice and Bob from everyone's - // graph. This will make Alice be seen as a private node as it no longer - // has any advertised edges. - for _, graph := range graphs { - err := graph.DeleteChannelEdges(aliceBobEdge.ChannelID) - if err != nil { - t.Fatalf("unable to remove edge: %v", err) - } - } - checkNodes( - []*LightningNode{aliceNode}, - []*ChannelGraph{bobGraph, carolGraph}, - false, - ) - - // We'll also make the edge between Bob and Carol private. Within Bob's - // and Carol's graph, the edge will exist, but it will not have a proof - // that allows it to be advertised. Within Alice's graph, we'll - // completely remove the edge as it is not possible for her to know of - // it without it being advertised. - for i, graph := range graphs { - err := graph.DeleteChannelEdges(bobCarolEdge.ChannelID) - if err != nil { - t.Fatalf("unable to remove edge: %v", err) - } - - if graph == aliceGraph { - continue - } - - bobCarolEdge.AuthProof = nil - bobCarolEdge.db = dbs[i] - if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - } - - // With the modifications above, Bob should now be seen as a private - // node from both Alice's and Carol's perspective. - checkNodes( - []*LightningNode{bobNode}, - []*ChannelGraph{aliceGraph, carolGraph}, - false, - ) -} - -// TestDisabledChannelIDs ensures that the disabled channels within the -// disabledEdgePolicyBucket are managed properly and the list returned from -// DisabledChannelIDs is correct. -func TestDisabledChannelIDs(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - graph := db.ChannelGraph() - - // Create first node and add it to the graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create second node and add it to the graph. - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Adding a new channel edge to the graph. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure no disabled channels exist in the bucket on start. - disabledChanIds, err := graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } - - // Add one disabled policy and ensure the channel is still not in the - // disabled list. - edge1.ChannelFlags |= lnwire.ChanUpdateDisabled - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } - - // Add second disabled policy and ensure the channel is now in the - // disabled list. - edge2.ChannelFlags |= lnwire.ChanUpdateDisabled - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID { - t.Fatalf("expected disabled channel with id %v, "+ - "got %v", edgeInfo.ChannelID, disabledChanIds) - } - - // Delete the channel edge and ensure it is removed from the disabled list. - if err = graph.DeleteChannelEdges(edgeInfo.ChannelID); err != nil { - t.Fatalf("unable to delete channel edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } -} - -// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in -// the DB that indicates that it should support the htlc_maximum_value_msat -// field, but it is not part of the opaque data, then we'll handle it as it is -// unknown. It also checks that we are correctly able to overwrite it when we -// receive the proper update. -func TestEdgePolicyMissingMaxHtcl(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the update of edges inserted into the database, so - // we create two vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanID := edgeInfo.ChannelID - from := edge2.Node.PubKeyBytes[:] - to := edge1.Node.PubKeyBytes[:] - - // We'll remove the no max_htlc field from the first edge policy, and - // all other opaque data, and serialize it. - edge1.MessageFlags = 0 - edge1.ExtraOpaqueData = nil - - var b bytes.Buffer - err = serializeChanEdgePolicy(&b, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } - - // Set the max_htlc field. The extra bytes added to the serialization - // will be the opaque data containing the serialized field. - edge1.MessageFlags = lnwire.ChanUpdateOptionMaxHtlc - edge1.MaxHTLC = 13928598 - var b2 bytes.Buffer - err = serializeChanEdgePolicy(&b2, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } - - withMaxHtlc := b2.Bytes() - - // Remove the opaque data from the serialization. - stripped := withMaxHtlc[:len(b.Bytes())] - - // Attempting to deserialize these bytes should return an error. - r := bytes.NewReader(stripped) - err = db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - _, err = deserializeChanEdgePolicy(r, nodes) - if err != ErrEdgePolicyOptionalFieldNotFound { - t.Fatalf("expected "+ - "ErrEdgePolicyOptionalFieldNotFound, got %v", - err) - } - - return nil - }) - if err != nil { - t.Fatalf("error reading db: %v", err) - } - - // Put the stripped bytes in the DB. - err = db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrEdgeNotFound - } - - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - - var edgeKey [33 + 8]byte - copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID) - - var scratch [8]byte - var indexKey [8 + 8]byte - copy(indexKey[:], scratch[:]) - byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) - - updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) - if err != nil { - return err - } - - if err := updateIndex.Put(indexKey[:], nil); err != nil { - return err - } - - return edges.Put(edgeKey[:], stripped) - }) - if err != nil { - t.Fatalf("error writing db: %v", err) - } - - // And add the second, unmodified edge. - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Attempt to fetch the edge and policies from the DB. Since the policy - // we added is invalid according to the new format, it should be as we - // are not aware of the policy (indicated by the policy returned being - // nil) - dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - - // The first edge should have a nil-policy returned - if dbEdge1 != nil { - t.Fatalf("expected db edge to be nil") - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) - - // Now add the original, unmodified edge policy, and make sure the edge - // policies then become fully populated. - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) -} - -// assertNumZombies queries the provided ChannelGraph for NumZombies, and -// asserts that the returned number is equal to expZombies. -func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { - t.Helper() - - numZombies, err := graph.NumZombies() - if err != nil { - t.Fatalf("unable to query number of zombies: %v", err) - } - - if numZombies != expZombies { - t.Fatalf("expected %d zombies, found %d", - expZombies, numZombies) - } -} - -// TestGraphZombieIndex ensures that we can mark edges correctly as zombie/live. -func TestGraphZombieIndex(t *testing.T) { - t.Parallel() - - // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to create test database: %v", err) - } - graph := db.ChannelGraph() - - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - - // Swap the nodes if the second's pubkey is smaller than the first. - // Without this, the comparisons at the end will fail probabilistically. - if bytes.Compare(node2.PubKeyBytes[:], node1.PubKeyBytes[:]) < 0 { - node1, node2 = node2, node1 - } - - edge, _, _ := createChannelEdge(db, node1, node2) - if err := graph.AddChannelEdge(edge); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Since the edge is known the graph and it isn't a zombie, IsZombieEdge - // should not report the channel as a zombie. - isZombie, _, _ := graph.IsZombieEdge(edge.ChannelID) - if isZombie { - t.Fatal("expected edge to not be marked as zombie") - } - assertNumZombies(t, graph, 0) - - // If we delete the edge and mark it as a zombie, then we should expect - // to see it within the index. - err = graph.DeleteChannelEdges(edge.ChannelID) - if err != nil { - t.Fatalf("unable to mark edge as zombie: %v", err) - } - isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID) - if !isZombie { - t.Fatal("expected edge to be marked as zombie") - } - if pubKey1 != node1.PubKeyBytes { - t.Fatalf("expected pubKey1 %x, got %x", node1.PubKeyBytes, - pubKey1) - } - if pubKey2 != node2.PubKeyBytes { - t.Fatalf("expected pubKey2 %x, got %x", node2.PubKeyBytes, - pubKey2) - } - assertNumZombies(t, graph, 1) - - // Similarly, if we mark the same edge as live, we should no longer see - // it within the index. - if err := graph.MarkEdgeLive(edge.ChannelID); err != nil { - t.Fatalf("unable to mark edge as live: %v", err) - } - isZombie, _, _ = graph.IsZombieEdge(edge.ChannelID) - if isZombie { - t.Fatal("expected edge to not be marked as zombie") - } - assertNumZombies(t, graph, 0) -} - -// compareNodes is used to compare two LightningNodes while excluding the -// Features struct, which cannot be compared as the semantics for reserializing -// the featuresMap have not been defined. -func compareNodes(a, b *LightningNode) error { - if a.LastUpdate != b.LastUpdate { - return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ - "got %v", a.LastUpdate, b.LastUpdate) - } - if !reflect.DeepEqual(a.Addresses, b.Addresses) { - return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ - "got %#v", a.Addresses, b.Addresses) - } - if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) { - return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+ - "got %#v", a.PubKeyBytes, b.PubKeyBytes) - } - if !reflect.DeepEqual(a.Color, b.Color) { - return fmt.Errorf("Color doesn't match: expected %#v, \n "+ - "got %#v", a.Color, b.Color) - } - if !reflect.DeepEqual(a.Alias, b.Alias) { - return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ - "got %#v", a.Alias, b.Alias) - } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } - if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { - return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ - "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) - } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } - - return nil -} - -// compareEdgePolicies is used to compare two ChannelEdgePolices using -// compareNodes, so as to exclude comparisons of the Nodes' Features struct. -func compareEdgePolicies(a, b *ChannelEdgePolicy) error { - if a.ChannelID != b.ChannelID { - return fmt.Errorf("ChannelID doesn't match: expected %v, "+ - "got %v", a.ChannelID, b.ChannelID) - } - if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ - "got %#v", a.LastUpdate, b.LastUpdate) - } - if a.MessageFlags != b.MessageFlags { - return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ - "got %v", a.MessageFlags, b.MessageFlags) - } - if a.ChannelFlags != b.ChannelFlags { - return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ - "got %v", a.ChannelFlags, b.ChannelFlags) - } - if a.TimeLockDelta != b.TimeLockDelta { - return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+ - "got %v", a.TimeLockDelta, b.TimeLockDelta) - } - if a.MinHTLC != b.MinHTLC { - return fmt.Errorf("MinHTLC doesn't match: expected %v, "+ - "got %v", a.MinHTLC, b.MinHTLC) - } - if a.MaxHTLC != b.MaxHTLC { - return fmt.Errorf("MaxHTLC doesn't match: expected %v, "+ - "got %v", a.MaxHTLC, b.MaxHTLC) - } - if a.FeeBaseMSat != b.FeeBaseMSat { - return fmt.Errorf("FeeBaseMSat doesn't match: expected %v, "+ - "got %v", a.FeeBaseMSat, b.FeeBaseMSat) - } - if a.FeeProportionalMillionths != b.FeeProportionalMillionths { - return fmt.Errorf("FeeProportionalMillionths doesn't match: "+ - "expected %v, got %v", a.FeeProportionalMillionths, - b.FeeProportionalMillionths) - } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } - if err := compareNodes(a.Node, b.Node); err != nil { - return err - } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } - return nil -} - -// TestLightningNodeSigVerifcation checks that we can use the LightningNode's -// pubkey to verify signatures. -func TestLightningNodeSigVerification(t *testing.T) { - t.Parallel() - - // Create some dummy data to sign. - var data [32]byte - if _, err := prand.Read(data[:]); err != nil { - t.Fatalf("unable to read prand: %v", err) - } - - // Create private key and sign the data with it. - priv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to crete priv key: %v", err) - } - - sign, err := priv.Sign(data[:]) - if err != nil { - t.Fatalf("unable to sign: %v", err) - } - - // Sanity check that the signature checks out. - if !sign.Verify(data[:], priv.PubKey()) { - t.Fatalf("signature doesn't check out") - } - - // Create a LightningNode from the same private key. - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - node, err := createLightningNode(db, priv) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - // And finally check that we can verify the same signature from the - // pubkey returned from the lightning node. - nodePub, err := node.PubKey() - if err != nil { - t.Fatalf("unable to get pubkey: %v", err) - } - - if !sign.Verify(data[:], nodePub) { - t.Fatalf("unable to verify sig") - } -} - -// TestComputeFee tests fee calculation based on both in- and outgoing amt. -func TestComputeFee(t *testing.T) { - var ( - policy = ChannelEdgePolicy{ - FeeBaseMSat: 10000, - FeeProportionalMillionths: 30000, - } - outgoingAmt = lnwire.MilliSatoshi(1000000) - expectedFee = lnwire.MilliSatoshi(40000) - ) - - fee := policy.ComputeFee(outgoingAmt) - if fee != expectedFee { - t.Fatalf("expected fee %v, got %v", expectedFee, fee) - } - - fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee) - if fwdFee != expectedFee { - t.Fatalf("expected fee %v, but got %v", fee, fwdFee) - } -} diff --git a/channeldb/migration_01_to_11/invoice_test.go b/channeldb/migration_01_to_11/invoice_test.go deleted file mode 100644 index 795fe493..00000000 --- a/channeldb/migration_01_to_11/invoice_test.go +++ /dev/null @@ -1,694 +0,0 @@ -package migration_01_to_11 - -import ( - "crypto/rand" - "reflect" - "testing" - "time" - - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lnwire" -) - -func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { - var pre [32]byte - if _, err := rand.Read(pre[:]); err != nil { - return nil, err - } - - i := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - Terms: ContractTerm{ - PaymentPreimage: pre, - Value: value, - }, - Htlcs: map[CircuitKey]*InvoiceHTLC{}, - Expiry: 4000, - } - i.Memo = []byte("memo") - i.Receipt = []byte("receipt") - - // Create a random byte slice of MaxPaymentRequestSize bytes to be used - // as a dummy paymentrequest, and determine if it should be set based - // on one of the random bytes. - var r [MaxPaymentRequestSize]byte - if _, err := rand.Read(r[:]); err != nil { - return nil, err - } - if r[0]&1 == 0 { - i.PaymentRequest = r[:] - } else { - i.PaymentRequest = []byte("") - } - - return i, nil -} - -func TestInvoiceWorkflow(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // Create a fake invoice which we'll use several times in the tests - // below. - fakeInvoice := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - Htlcs: map[CircuitKey]*InvoiceHTLC{}, - } - fakeInvoice.Memo = []byte("memo") - fakeInvoice.Receipt = []byte("receipt") - fakeInvoice.PaymentRequest = []byte("") - copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) - fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) - - paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() - - // Add the invoice to the database, this should succeed as there aren't - // any existing invoices within the database with the same payment - // hash. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { - t.Fatalf("unable to find invoice: %v", err) - } - - // Attempt to retrieve the invoice which was just added to the - // database. It should be found, and the invoice returned should be - // identical to the one created above. - dbInvoice, err := db.LookupInvoice(paymentHash) - if err != nil { - t.Fatalf("unable to find invoice: %v", err) - } - if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { - t.Fatalf("invoice fetched from db doesn't match original %v vs %v", - spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice)) - } - - // The add index of the invoice retrieved from the database should now - // be fully populated. As this is the first index written to the DB, - // the addIndex should be 1. - if dbInvoice.AddIndex != 1 { - t.Fatalf("wrong add index: expected %v, got %v", 1, - dbInvoice.AddIndex) - } - - // Settle the invoice, the version retrieved from the database should - // now have the settled bit toggle to true and a non-default - // SettledDate - payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - dbInvoice2, err := db.LookupInvoice(paymentHash) - if err != nil { - t.Fatalf("unable to fetch invoice: %v", err) - } - if dbInvoice2.Terms.State != ContractSettled { - t.Fatalf("invoice should now be settled but isn't") - } - if dbInvoice2.SettleDate.IsZero() { - t.Fatalf("invoice should have non-zero SettledDate but isn't") - } - - // Our 2x payment should be reflected, and also the settle index of 1 - // should also have been committed for this index. - if dbInvoice2.AmtPaid != payAmt { - t.Fatalf("wrong amt paid: expected %v, got %v", payAmt, - dbInvoice2.AmtPaid) - } - if dbInvoice2.SettleIndex != 1 { - t.Fatalf("wrong settle index: expected %v, got %v", 1, - dbInvoice2.SettleIndex) - } - - // Attempt to insert generated above again, this should fail as - // duplicates are rejected by the processing logic. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { - t.Fatalf("invoice insertion should fail due to duplication, "+ - "instead %v", err) - } - - // Attempt to look up a non-existent invoice, this should also fail but - // with a "not found" error. - var fakeHash [32]byte - if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { - t.Fatalf("lookup should have failed, instead %v", err) - } - - // Add 10 random invoices. - const numInvoices = 10 - amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]*Invoice, numInvoices+1) - invoices[0] = &dbInvoice2 - for i := 1; i < len(invoices)-1; i++ { - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - hash := invoice.Terms.PaymentPreimage.Hash() - if _, err := db.AddInvoice(invoice, hash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - invoices[i] = invoice - } - - // Perform a scan to collect all the active invoices. - dbInvoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch all invoices: %v", err) - } - - // The retrieve list of invoices should be identical as since we're - // using big endian, the invoices should be retrieved in ascending - // order (and the primary key should be incremented with each - // insertion). - for i := 0; i < len(invoices)-1; i++ { - if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { - t.Fatalf("retrieved invoices don't match %v vs %v", - spew.Sdump(invoices[i]), - spew.Sdump(dbInvoices[i])) - } - } -} - -// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as -// settled invoices are added to the database are properly placed in the add -// add or settle index which serves as an event time series. -func TestInvoiceAddTimeSeries(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // We'll start off by creating 20 random invoices, and inserting them - // into the database. - const numInvoices = 20 - amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]Invoice, numInvoices) - for i := 0; i < len(invoices); i++ { - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - invoices[i] = *invoice - } - - // With the invoices constructed, we'll now create a series of queries - // that we'll use to assert expected return values of - // InvoicesAddedSince. - addQueries := []struct { - sinceAddIndex uint64 - - resp []Invoice - }{ - // If we specify a value of zero, we shouldn't get any invoices - // back. - { - sinceAddIndex: 0, - }, - - // If we specify a value well beyond the number of inserted - // invoices, we shouldn't get any invoices back. - { - sinceAddIndex: 99999999, - }, - - // Using an index of 1 should result in all values, but the - // first one being returned. - { - sinceAddIndex: 1, - resp: invoices[1:], - }, - - // If we use an index of 10, then we should retrieve the - // reaming 10 invoices. - { - sinceAddIndex: 10, - resp: invoices[10:], - }, - } - - for i, query := range addQueries { - resp, err := db.InvoicesAddedSince(query.sinceAddIndex) - if err != nil { - t.Fatalf("unable to query: %v", err) - } - - if !reflect.DeepEqual(query.resp, resp) { - t.Fatalf("test #%v: expected %v, got %v", i, - spew.Sdump(query.resp), spew.Sdump(resp)) - } - } - - // We'll now only settle the latter half of each of those invoices. - for i := 10; i < len(invoices); i++ { - invoice := &invoices[i] - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(0), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - } - - invoices, err = db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch invoices: %v", err) - } - - // We'll slice off the first 10 invoices, as we only settled the last - // 10. - invoices = invoices[10:] - - // We'll now prepare an additional set of queries to ensure the settle - // time series has properly been maintained in the database. - settleQueries := []struct { - sinceSettleIndex uint64 - - resp []Invoice - }{ - // If we specify a value of zero, we shouldn't get any settled - // invoices back. - { - sinceSettleIndex: 0, - }, - - // If we specify a value well beyond the number of settled - // invoices, we shouldn't get any invoices back. - { - sinceSettleIndex: 99999999, - }, - - // Using an index of 1 should result in the final 10 invoices - // being returned, as we only settled those. - { - sinceSettleIndex: 1, - resp: invoices[1:], - }, - } - - for i, query := range settleQueries { - resp, err := db.InvoicesSettledSince(query.sinceSettleIndex) - if err != nil { - t.Fatalf("unable to query: %v", err) - } - - if !reflect.DeepEqual(query.resp, resp) { - t.Fatalf("test #%v: expected %v, got %v", i, - spew.Sdump(query.resp), spew.Sdump(resp)) - } - } -} - -// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it -// twice, then the second time we also receive the invoice that we settled as a -// return argument. -func TestDuplicateSettleInvoice(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - db.now = func() time.Time { return time.Unix(1, 0) } - - // We'll start out by creating an invoice and writing it to the DB. - amt := lnwire.NewMSatFromSatoshis(1000) - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - payHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, payHash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - - // We'll update what we expect the settle invoice to be so that our - // comparison below has the correct assumption. - invoice.SettleIndex = 1 - invoice.Terms.State = ContractSettled - invoice.AmtPaid = amt - invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ - {}: { - Amt: amt, - AcceptTime: time.Unix(1, 0), - ResolveTime: time.Unix(1, 0), - State: HtlcStateSettled, - }, - } - - // We should get back the exact same invoice that we just inserted. - if !reflect.DeepEqual(dbInvoice, invoice) { - t.Fatalf("wrong invoice after settle, expected %v got %v", - spew.Sdump(invoice), spew.Sdump(dbInvoice)) - } - - // If we try to settle the invoice again, then we should get the very - // same invoice back, but with an error this time. - dbInvoice, err = db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) - if err != ErrInvoiceAlreadySettled { - t.Fatalf("expected ErrInvoiceAlreadySettled") - } - - if dbInvoice == nil { - t.Fatalf("invoice from db is nil after settle!") - } - - invoice.SettleDate = dbInvoice.SettleDate - if !reflect.DeepEqual(dbInvoice, invoice) { - t.Fatalf("wrong invoice after second settle, expected %v got %v", - spew.Sdump(invoice), spew.Sdump(dbInvoice)) - } -} - -// TestQueryInvoices ensures that we can properly query the invoice database for -// invoices using different types of queries. -func TestQueryInvoices(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // To begin the test, we'll add 50 invoices to the database. We'll - // assume that the index of the invoice within the database is the same - // as the amount of the invoice itself. - const numInvoices = 50 - for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { - invoice, err := randInvoice(i) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - - // We'll only settle half of all invoices created. - if i%2 == 0 { - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(i), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - } - } - - // We'll then retrieve the set of all invoices and pending invoices. - // This will serve useful when comparing the expected responses of the - // query with the actual ones. - invoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to retrieve invoices: %v", err) - } - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to retrieve pending invoices: %v", err) - } - - // The test will consist of several queries along with their respective - // expected response. Each query response should match its expected one. - testCases := []struct { - query InvoiceQuery - expected []Invoice - }{ - // Fetch all invoices with a single query. - { - query: InvoiceQuery{ - NumMaxInvoices: numInvoices, - }, - expected: invoices, - }, - // Fetch all invoices with a single query, reversed. - { - query: InvoiceQuery{ - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: invoices, - }, - // Fetch the first 25 invoices. - { - query: InvoiceQuery{ - NumMaxInvoices: numInvoices / 2, - }, - expected: invoices[:numInvoices/2], - }, - // Fetch the first 10 invoices, but this time iterating - // backwards. - { - query: InvoiceQuery{ - IndexOffset: 11, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: invoices[:10], - }, - // Fetch the last 40 invoices. - { - query: InvoiceQuery{ - IndexOffset: 10, - NumMaxInvoices: numInvoices, - }, - expected: invoices[10:], - }, - // Fetch all but the first invoice. - { - query: InvoiceQuery{ - IndexOffset: 1, - NumMaxInvoices: numInvoices, - }, - expected: invoices[1:], - }, - // Fetch one invoice, reversed, with index offset 3. This - // should give us the second invoice in the array. - { - query: InvoiceQuery{ - IndexOffset: 3, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[1:2], - }, - // Same as above, at index 2. - { - query: InvoiceQuery{ - IndexOffset: 2, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[0:1], - }, - // Fetch one invoice, at index 1, reversed. Since invoice#1 is - // the very first, there won't be any left in a reverse search, - // so we expect no invoices to be returned. - { - query: InvoiceQuery{ - IndexOffset: 1, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: nil, - }, - // Same as above, but don't restrict the number of invoices to - // 1. - { - query: InvoiceQuery{ - IndexOffset: 1, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: nil, - }, - // Fetch one invoice, reversed, with no offset set. We expect - // the last invoice in the response. - { - query: InvoiceQuery{ - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Fetch one invoice, reversed, the offset set at numInvoices+1. - // We expect this to return the last invoice. - { - query: InvoiceQuery{ - IndexOffset: numInvoices + 1, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Same as above, at offset numInvoices. - { - query: InvoiceQuery{ - IndexOffset: numInvoices, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-2 : numInvoices-1], - }, - // Fetch one invoice, at no offset (same as offset 0). We - // expect the first invoice only in the response. - { - query: InvoiceQuery{ - NumMaxInvoices: 1, - }, - expected: invoices[:1], - }, - // Same as above, at offset 1. - { - query: InvoiceQuery{ - IndexOffset: 1, - NumMaxInvoices: 1, - }, - expected: invoices[1:2], - }, - // Same as above, at offset 2. - { - query: InvoiceQuery{ - IndexOffset: 2, - NumMaxInvoices: 1, - }, - expected: invoices[2:3], - }, - // Same as above, at offset numInvoices-1. Expect the last - // invoice to be returned. - { - query: InvoiceQuery{ - IndexOffset: numInvoices - 1, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Same as above, at offset numInvoices. No invoices should be - // returned, as there are no invoices after this offset. - { - query: InvoiceQuery{ - IndexOffset: numInvoices, - NumMaxInvoices: 1, - }, - expected: nil, - }, - // Fetch all pending invoices with a single query. - { - query: InvoiceQuery{ - PendingOnly: true, - NumMaxInvoices: numInvoices, - }, - expected: pendingInvoices, - }, - // Fetch the first 12 pending invoices. - { - query: InvoiceQuery{ - PendingOnly: true, - NumMaxInvoices: numInvoices / 4, - }, - expected: pendingInvoices[:len(pendingInvoices)/2], - }, - // Fetch the first 5 pending invoices, but this time iterating - // backwards. - { - query: InvoiceQuery{ - IndexOffset: 10, - PendingOnly: true, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - // Since we seek to the invoice with index 10 and - // iterate backwards, there should only be 5 pending - // invoices before it as every other invoice within the - // index is settled. - expected: pendingInvoices[:5], - }, - // Fetch the last 15 invoices. - { - query: InvoiceQuery{ - IndexOffset: 20, - PendingOnly: true, - NumMaxInvoices: numInvoices, - }, - // Since we seek to the invoice with index 20, there are - // 30 invoices left. From these 30, only 15 of them are - // still pending. - expected: pendingInvoices[len(pendingInvoices)-15:], - }, - } - - for i, testCase := range testCases { - response, err := db.QueryInvoices(testCase.query) - if err != nil { - t.Fatalf("unable to query invoice database: %v", err) - } - - if !reflect.DeepEqual(response.Invoices, testCase.expected) { - t.Fatalf("test #%d: query returned incorrect set of "+ - "invoices: expcted %v, got %v", i, - spew.Sdump(response.Invoices), - spew.Sdump(testCase.expected)) - } - } -} - -// getUpdateInvoice returns an invoice update callback that, when called, -// settles the invoice with the given amount. -func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { - return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - if invoice.Terms.State == ContractSettled { - return nil, ErrInvoiceAlreadySettled - } - - update := &InvoiceUpdateDesc{ - Preimage: invoice.Terms.PaymentPreimage, - State: ContractSettled, - Htlcs: map[CircuitKey]*HtlcAcceptDesc{ - {}: { - Amt: amt, - }, - }, - } - - return update, nil - } -} diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go index 5f40454a..f60457ff 100644 --- a/channeldb/migration_01_to_11/invoices.go +++ b/channeldb/migration_01_to_11/invoices.go @@ -3,7 +3,6 @@ package migration_01_to_11 import ( "bytes" "encoding/binary" - "errors" "fmt" "io" "time" @@ -16,9 +15,6 @@ import ( ) var ( - // UnknownPreimage is an all-zeroes preimage that indicates that the - // preimage for this invoice is not yet known. - UnknownPreimage lntypes.Preimage // invoiceBucket is the name of the bucket within the database that // stores all data related to invoices no matter their final state. @@ -26,23 +22,6 @@ var ( // which is a monotonically increasing uint32. invoiceBucket = []byte("invoices") - // paymentHashIndexBucket is the name of the sub-bucket within the - // invoiceBucket which indexes all invoices by their payment hash. The - // payment hash is the sha256 of the invoice's payment preimage. This - // index is used to detect duplicates, and also to provide a fast path - // for looking up incoming HTLCs to determine if we're able to settle - // them fully. - // - // maps: payHash => invoiceKey - invoiceIndexBucket = []byte("paymenthashes") - - // numInvoicesKey is the name of key which houses the auto-incrementing - // invoice ID which is essentially used as a primary key. With each - // invoice inserted, the primary key is incremented by one. This key is - // stored within the invoiceIndexBucket. Within the invoiceBucket - // invoices are uniquely identified by the invoice ID. - numInvoicesKey = []byte("nik") - // addIndexBucket is an index bucket that we'll use to create a // monotonically increasing set of add indexes. Each time we add a new // invoice, this sequence number will be incremented and then populated @@ -62,21 +41,6 @@ var ( // // settleIndexNo => invoiceKey settleIndexBucket = []byte("invoice-settle-index") - - // ErrInvoiceAlreadySettled is returned when the invoice is already - // settled. - ErrInvoiceAlreadySettled = errors.New("invoice already settled") - - // ErrInvoiceAlreadyCanceled is returned when the invoice is already - // canceled. - ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") - - // ErrInvoiceAlreadyAccepted is returned when the invoice is already - // accepted. - ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") - - // ErrInvoiceStillOpen is returned when the invoice is still open. - ErrInvoiceStillOpen = errors.New("invoice still open") ) const ( @@ -237,18 +201,6 @@ type Invoice struct { // HtlcState defines the states an htlc paying to an invoice can be in. type HtlcState uint8 -const ( - // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. - HtlcStateAccepted HtlcState = iota - - // HtlcStateCanceled indicates the htlc is canceled back to the - // sender. - HtlcStateCanceled - - // HtlcStateSettled indicates the htlc is settled. - HtlcStateSettled -) - // InvoiceHTLC contains details about an htlc paying to this invoice. type InvoiceHTLC struct { // Amt is the amount that is carried by this htlc. @@ -276,37 +228,6 @@ type InvoiceHTLC struct { State HtlcState } -// HtlcAcceptDesc describes the details of a newly accepted htlc. -type HtlcAcceptDesc struct { - // AcceptHeight is the block height at which this htlc was accepted. - AcceptHeight int32 - - // Amt is the amount that is carried by this htlc. - Amt lnwire.MilliSatoshi - - // Expiry is the expiry height of this htlc. - Expiry uint32 -} - -// InvoiceUpdateDesc describes the changes that should be applied to the -// invoice. -type InvoiceUpdateDesc struct { - // State is the new state that this invoice should progress to. - State ContractState - - // Htlcs describes the changes that need to be made to the invoice htlcs - // in the database. Htlc map entries with their value set should be - // added. If the map value is nil, the htlc should be canceled. - Htlcs map[CircuitKey]*HtlcAcceptDesc - - // Preimage must be set to the preimage when state is settled. - Preimage lntypes.Preimage -} - -// InvoiceUpdateCallback is a callback used in the db transaction to update the -// invoice. -type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) - func validateInvoice(i *Invoice) error { if len(i.Memo) > MaxMemoSize { return fmt.Errorf("max length a memo is %v, and invoice "+ @@ -325,186 +246,6 @@ func validateInvoice(i *Invoice) error { return nil } -// AddInvoice inserts the targeted invoice into the database. If the invoice has -// *any* payment hashes which already exists within the database, then the -// insertion will be aborted and rejected due to the strict policy banning any -// duplicate payment hashes. A side effect of this function is that it sets -// AddIndex on newInvoice. -func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( - uint64, error) { - - if err := validateInvoice(newInvoice); err != nil { - return 0, err - } - - var invoiceAddIndex uint64 - err := d.Update(func(tx *bbolt.Tx) error { - invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) - if err != nil { - return err - } - - invoiceIndex, err := invoices.CreateBucketIfNotExists( - invoiceIndexBucket, - ) - if err != nil { - return err - } - addIndex, err := invoices.CreateBucketIfNotExists( - addIndexBucket, - ) - if err != nil { - return err - } - - // Ensure that an invoice an identical payment hash doesn't - // already exist within the index. - if invoiceIndex.Get(paymentHash[:]) != nil { - return ErrDuplicateInvoice - } - - // If the current running payment ID counter hasn't yet been - // created, then create it now. - var invoiceNum uint32 - invoiceCounter := invoiceIndex.Get(numInvoicesKey) - if invoiceCounter == nil { - var scratch [4]byte - byteOrder.PutUint32(scratch[:], invoiceNum) - err := invoiceIndex.Put(numInvoicesKey, scratch[:]) - if err != nil { - return err - } - } else { - invoiceNum = byteOrder.Uint32(invoiceCounter) - } - - newIndex, err := putInvoice( - invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, - paymentHash, - ) - if err != nil { - return err - } - - invoiceAddIndex = newIndex - return nil - }) - if err != nil { - return 0, err - } - - return invoiceAddIndex, err -} - -// InvoicesAddedSince can be used by callers to seek into the event time series -// of all the invoices added in the database. The specified sinceAddIndex -// should be the highest add index that the caller knows of. This method will -// return all invoices with an add index greater than the specified -// sinceAddIndex. -// -// NOTE: The index starts from 1, as a result. We enforce that specifying a -// value below the starting index value is a noop. -func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { - var newInvoices []Invoice - - // If an index of zero was specified, then in order to maintain - // backwards compat, we won't send out any new invoices. - if sinceAddIndex == 0 { - return newInvoices, nil - } - - var startIndex [8]byte - byteOrder.PutUint64(startIndex[:], sinceAddIndex) - - err := d.DB.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - - addIndex := invoices.Bucket(addIndexBucket) - if addIndex == nil { - return ErrNoInvoicesCreated - } - - // We'll now run through each entry in the add index starting - // at our starting index. We'll continue until we reach the - // very end of the current key space. - invoiceCursor := addIndex.Cursor() - - // We'll seek to the starting index, then manually advance the - // cursor in order to skip the entry with the since add index. - invoiceCursor.Seek(startIndex[:]) - addSeqNo, invoiceKey := invoiceCursor.Next() - - for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() { - - // For each key found, we'll look up the actual - // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - newInvoices = append(newInvoices, invoice) - } - - return nil - }) - switch { - // If no invoices have been created, then we'll return the empty set of - // invoices. - case err == ErrNoInvoicesCreated: - - case err != nil: - return nil, err - } - - return newInvoices, nil -} - -// LookupInvoice attempts to look up an invoice according to its 32 byte -// payment hash. If an invoice which can settle the HTLC identified by the -// passed payment hash isn't found, then an error is returned. Otherwise, the -// full invoice is returned. Before setting the incoming HTLC, the values -// SHOULD be checked to ensure the payer meets the agreed upon contractual -// terms of the payment. -func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { - var invoice Invoice - err := d.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - invoiceIndex := invoices.Bucket(invoiceIndexBucket) - if invoiceIndex == nil { - return ErrNoInvoicesCreated - } - - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } - - // An invoice matching the payment hash has been found, so - // retrieve the record of the invoice itself. - i, err := fetchInvoice(invoiceNum, invoices) - if err != nil { - return err - } - invoice = i - - return nil - }) - if err != nil { - return invoice, err - } - - return invoice, nil -} - // FetchAllInvoices returns all invoices currently stored within the database. // If the pendingOnly param is true, then only unsettled invoices will be // returned, skipping all invoices that are fully settled. @@ -549,343 +290,6 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { return invoices, nil } -// InvoiceQuery represents a query to the invoice database. The query allows a -// caller to retrieve all invoices starting from a particular add index and -// limit the number of results returned. -type InvoiceQuery struct { - // IndexOffset is the offset within the add indices to start at. This - // can be used to start the response at a particular invoice. - IndexOffset uint64 - - // NumMaxInvoices is the maximum number of invoices that should be - // starting from the add index. - NumMaxInvoices uint64 - - // PendingOnly, if set, returns unsettled invoices starting from the - // add index. - PendingOnly bool - - // Reversed, if set, indicates that the invoices returned should start - // from the IndexOffset and go backwards. - Reversed bool -} - -// InvoiceSlice is the response to a invoice query. It includes the original -// query, the set of invoices that match the query, and an integer which -// represents the offset index of the last item in the set of returned invoices. -// This integer allows callers to resume their query using this offset in the -// event that the query's response exceeds the maximum number of returnable -// invoices. -type InvoiceSlice struct { - InvoiceQuery - - // Invoices is the set of invoices that matched the query above. - Invoices []Invoice - - // FirstIndexOffset is the index of the first element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - FirstIndexOffset uint64 - - // LastIndexOffset is the index of the last element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - LastIndexOffset uint64 -} - -// QueryInvoices allows a caller to query the invoice database for invoices -// within the specified add index range. -func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { - resp := InvoiceSlice{ - InvoiceQuery: q, - } - - err := d.View(func(tx *bbolt.Tx) error { - // If the bucket wasn't found, then there aren't any invoices - // within the database yet, so we can simply exit. - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - invoiceAddIndex := invoices.Bucket(addIndexBucket) - if invoiceAddIndex == nil { - return ErrNoInvoicesCreated - } - - // keyForIndex is a helper closure that retrieves the invoice - // key for the given add index of an invoice. - keyForIndex := func(c *bbolt.Cursor, index uint64) []byte { - var keyIndex [8]byte - byteOrder.PutUint64(keyIndex[:], index) - _, invoiceKey := c.Seek(keyIndex[:]) - return invoiceKey - } - - // nextKey is a helper closure to determine what the next - // invoice key is when iterating over the invoice add index. - nextKey := func(c *bbolt.Cursor) ([]byte, []byte) { - if q.Reversed { - return c.Prev() - } - return c.Next() - } - - // We'll be using a cursor to seek into the database and return - // a slice of invoices. We'll need to determine where to start - // our cursor depending on the parameters set within the query. - c := invoiceAddIndex.Cursor() - invoiceKey := keyForIndex(c, q.IndexOffset+1) - - // If the query is specifying reverse iteration, then we must - // handle a few offset cases. - if q.Reversed { - switch q.IndexOffset { - - // This indicates the default case, where no offset was - // specified. In that case we just start from the last - // invoice. - case 0: - _, invoiceKey = c.Last() - - // This indicates the offset being set to the very - // first invoice. Since there are no invoices before - // this offset, and the direction is reversed, we can - // return without adding any invoices to the response. - case 1: - return nil - - // Otherwise we start iteration at the invoice prior to - // the offset. - default: - invoiceKey = keyForIndex(c, q.IndexOffset-1) - } - } - - // If we know that a set of invoices exists, then we'll begin - // our seek through the bucket in order to satisfy the query. - // We'll continue until either we reach the end of the range, or - // reach our max number of invoices. - for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { - // If our current return payload exceeds the max number - // of invoices, then we'll exit now. - if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { - break - } - - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - // Skip any settled invoices if the caller is only - // interested in unsettled. - if q.PendingOnly && - invoice.Terms.State == ContractSettled { - - continue - } - - // At this point, we've exhausted the offset, so we'll - // begin collecting invoices found within the range. - resp.Invoices = append(resp.Invoices, invoice) - } - - // If we iterated through the add index in reverse order, then - // we'll need to reverse the slice of invoices to return them in - // forward order. - if q.Reversed { - numInvoices := len(resp.Invoices) - for i := 0; i < numInvoices/2; i++ { - opposite := numInvoices - i - 1 - resp.Invoices[i], resp.Invoices[opposite] = - resp.Invoices[opposite], resp.Invoices[i] - } - } - - return nil - }) - if err != nil && err != ErrNoInvoicesCreated { - return resp, err - } - - // Finally, record the indexes of the first and last invoices returned - // so that the caller can resume from this point later on. - if len(resp.Invoices) > 0 { - resp.FirstIndexOffset = resp.Invoices[0].AddIndex - resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex - } - - return resp, nil -} - -// UpdateInvoice attempts to update an invoice corresponding to the passed -// payment hash. If an invoice matching the passed payment hash doesn't exist -// within the database, then the action will fail with a "not found" error. -// -// The update is performed inside the same database transaction that fetches the -// invoice and is therefore atomic. The fields to update are controlled by the -// supplied callback. -func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, - callback InvoiceUpdateCallback) (*Invoice, error) { - - var updatedInvoice *Invoice - err := d.Update(func(tx *bbolt.Tx) error { - invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) - if err != nil { - return err - } - invoiceIndex, err := invoices.CreateBucketIfNotExists( - invoiceIndexBucket, - ) - if err != nil { - return err - } - settleIndex, err := invoices.CreateBucketIfNotExists( - settleIndexBucket, - ) - if err != nil { - return err - } - - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } - - updatedInvoice, err = d.updateInvoice( - paymentHash, invoices, settleIndex, invoiceNum, - callback, - ) - - return err - }) - - return updatedInvoice, err -} - -// InvoicesSettledSince can be used by callers to catch up any settled invoices -// they missed within the settled invoice time series. We'll return all known -// settled invoice that have a settle index higher than the passed -// sinceSettleIndex. -// -// NOTE: The index starts from 1, as a result. We enforce that specifying a -// value below the starting index value is a noop. -func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { - var settledInvoices []Invoice - - // If an index of zero was specified, then in order to maintain - // backwards compat, we won't send out any new invoices. - if sinceSettleIndex == 0 { - return settledInvoices, nil - } - - var startIndex [8]byte - byteOrder.PutUint64(startIndex[:], sinceSettleIndex) - - err := d.DB.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - - settleIndex := invoices.Bucket(settleIndexBucket) - if settleIndex == nil { - return ErrNoInvoicesCreated - } - - // We'll now run through each entry in the add index starting - // at our starting index. We'll continue until we reach the - // very end of the current key space. - invoiceCursor := settleIndex.Cursor() - - // We'll seek to the starting index, then manually advance the - // cursor in order to skip the entry with the since add index. - invoiceCursor.Seek(startIndex[:]) - seqNo, invoiceKey := invoiceCursor.Next() - - for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() { - - // For each key found, we'll look up the actual - // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - settledInvoices = append(settledInvoices, invoice) - } - - return nil - }) - if err != nil { - return nil, err - } - - return settledInvoices, nil -} - -func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, - i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( - uint64, error) { - - // Create the invoice key which is just the big-endian representation - // of the invoice number. - var invoiceKey [4]byte - byteOrder.PutUint32(invoiceKey[:], invoiceNum) - - // Increment the num invoice counter index so the next invoice bares - // the proper ID. - var scratch [4]byte - invoiceCounter := invoiceNum + 1 - byteOrder.PutUint32(scratch[:], invoiceCounter) - if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { - return 0, err - } - - // Add the payment hash to the invoice index. This will let us quickly - // identify if we can settle an incoming payment, and also to possibly - // allow a single invoice to have multiple payment installations. - err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) - if err != nil { - return 0, err - } - - // Next, we'll obtain the next add invoice index (sequence - // number), so we can properly place this invoice within this - // event stream. - nextAddSeqNo, err := addIndex.NextSequence() - if err != nil { - return 0, err - } - - // With the next sequence obtained, we'll updating the event series in - // the add index bucket to map this current add counter to the index of - // this new invoice. - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) - if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { - return 0, err - } - - i.AddIndex = nextAddSeqNo - - // Finally, serialize the invoice itself to be written to the disk. - var buf bytes.Buffer - if err := serializeInvoice(&buf, i); err != nil { - return 0, err - } - - if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil { - return 0, err - } - - return nextAddSeqNo, nil -} - // serializeInvoice serializes an invoice to a writer. // // Note: this function is in use for a migration. Before making changes that @@ -1006,17 +410,6 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { return nil } -func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) { - invoiceBytes := invoices.Get(invoiceNum) - if invoiceBytes == nil { - return Invoice{}, ErrInvoiceNotFound - } - - invoiceReader := bytes.NewReader(invoiceBytes) - - return deserializeInvoice(invoiceReader) -} - func deserializeInvoice(r io.Reader) (Invoice, error) { var err error invoice := Invoice{} @@ -1155,166 +548,3 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { return htlcs, nil } - -// copySlice allocates a new slice and copies the source into it. -func copySlice(src []byte) []byte { - dest := make([]byte, len(src)) - copy(dest, src) - return dest -} - -// copyInvoice makes a deep copy of the supplied invoice. -func copyInvoice(src *Invoice) *Invoice { - dest := Invoice{ - Memo: copySlice(src.Memo), - Receipt: copySlice(src.Receipt), - PaymentRequest: copySlice(src.PaymentRequest), - FinalCltvDelta: src.FinalCltvDelta, - CreationDate: src.CreationDate, - SettleDate: src.SettleDate, - Terms: src.Terms, - AddIndex: src.AddIndex, - SettleIndex: src.SettleIndex, - AmtPaid: src.AmtPaid, - Htlcs: make( - map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), - ), - } - - for k, v := range src.Htlcs { - dest.Htlcs[k] = v - } - - return &dest -} - -// updateInvoice fetches the invoice, obtains the update descriptor from the -// callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket, - invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { - - invoice, err := fetchInvoice(invoiceNum, invoices) - if err != nil { - return nil, err - } - - preUpdateState := invoice.Terms.State - - // Create deep copy to prevent any accidental modification in the - // callback. - copy := copyInvoice(&invoice) - - // Call the callback and obtain the update descriptor. - update, err := callback(copy) - if err != nil { - return &invoice, err - } - - // Update invoice state. - invoice.Terms.State = update.State - - now := d.now() - - // Update htlc set. - for key, htlcUpdate := range update.Htlcs { - htlc, ok := invoice.Htlcs[key] - - // No update means the htlc needs to be canceled. - if htlcUpdate == nil { - if !ok { - return nil, fmt.Errorf("unknown htlc %v", key) - } - if htlc.State != HtlcStateAccepted { - return nil, fmt.Errorf("can only cancel " + - "accepted htlcs") - } - - htlc.State = HtlcStateCanceled - htlc.ResolveTime = now - invoice.AmtPaid -= htlc.Amt - - continue - } - - // Add new htlc paying to the invoice. - if ok { - return nil, fmt.Errorf("htlc %v already exists", key) - } - htlc = &InvoiceHTLC{ - Amt: htlcUpdate.Amt, - Expiry: htlcUpdate.Expiry, - AcceptHeight: uint32(htlcUpdate.AcceptHeight), - AcceptTime: now, - } - if preUpdateState == ContractSettled { - htlc.State = HtlcStateSettled - htlc.ResolveTime = now - } else { - htlc.State = HtlcStateAccepted - } - - invoice.Htlcs[key] = htlc - invoice.AmtPaid += htlc.Amt - } - - // If invoice moved to the settled state, update settle index and settle - // time. - if preUpdateState != invoice.Terms.State && - invoice.Terms.State == ContractSettled { - - if update.Preimage.Hash() != hash { - return nil, fmt.Errorf("preimage does not match") - } - invoice.Terms.PaymentPreimage = update.Preimage - - // Settle all accepted htlcs. - for _, htlc := range invoice.Htlcs { - if htlc.State != HtlcStateAccepted { - continue - } - - htlc.State = HtlcStateSettled - htlc.ResolveTime = now - } - - err := setSettleFields(settleIndex, invoiceNum, &invoice, now) - if err != nil { - return nil, err - } - } - - var buf bytes.Buffer - if err := serializeInvoice(&buf, &invoice); err != nil { - return nil, err - } - - if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil { - return nil, err - } - - return &invoice, nil -} - -func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte, - invoice *Invoice, now time.Time) error { - - // Now that we know the invoice hasn't already been settled, we'll - // update the settle index so we can place this settle event in the - // proper location within our time series. - nextSettleSeqNo, err := settleIndex.NextSequence() - if err != nil { - return err - } - - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) - if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { - return err - } - - invoice.Terms.State = ContractSettled - invoice.SettleDate = now - invoice.SettleIndex = nextSettleSeqNo - - return nil -} diff --git a/channeldb/migration_01_to_11/nodes.go b/channeldb/migration_01_to_11/nodes.go deleted file mode 100644 index f40359e8..00000000 --- a/channeldb/migration_01_to_11/nodes.go +++ /dev/null @@ -1,316 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "io" - "net" - "time" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" -) - -var ( - // nodeInfoBucket stores metadata pertaining to nodes that we've had - // direct channel-based correspondence with. This bucket allows one to - // query for all open channels pertaining to the node by exploring each - // node's sub-bucket within the openChanBucket. - nodeInfoBucket = []byte("nib") -) - -// LinkNode stores metadata related to node's that we have/had a direct -// channel open with. Information such as the Bitcoin network the node -// advertised, and its identity public key are also stored. Additionally, this -// struct and the bucket its stored within have store data similar to that of -// Bitcoin's addrmanager. The TCP address information stored within the struct -// can be used to establish persistent connections will all channel -// counterparties on daemon startup. -// -// TODO(roasbeef): also add current OnionKey plus rotation schedule? -// TODO(roasbeef): add bitfield for supported services -// * possibly add a wire.NetAddress type, type -type LinkNode struct { - // Network indicates the Bitcoin network that the LinkNode advertises - // for incoming channel creation. - Network wire.BitcoinNet - - // IdentityPub is the node's current identity public key. Any - // channel/topology related information received by this node MUST be - // signed by this public key. - IdentityPub *btcec.PublicKey - - // LastSeen tracks the last time this node was seen within the network. - // A node should be marked as seen if the daemon either is able to - // establish an outgoing connection to the node or receives a new - // incoming connection from the node. This timestamp (stored in unix - // epoch) may be used within a heuristic which aims to determine when a - // channel should be unilaterally closed due to inactivity. - // - // TODO(roasbeef): replace with block hash/height? - // * possibly add a time-value metric into the heuristic? - LastSeen time.Time - - // Addresses is a list of IP address in which either we were able to - // reach the node over in the past, OR we received an incoming - // authenticated connection for the stored identity public key. - Addresses []net.Addr - - db *DB -} - -// NewLinkNode creates a new LinkNode from the provided parameters, which is -// backed by an instance of channeldb. -func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, - addrs ...net.Addr) *LinkNode { - - return &LinkNode{ - Network: bitNet, - IdentityPub: pub, - LastSeen: time.Now(), - Addresses: addrs, - db: db, - } -} - -// UpdateLastSeen updates the last time this node was directly encountered on -// the Lightning Network. -func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error { - l.LastSeen = lastSeen - - return l.Sync() -} - -// AddAddress appends the specified TCP address to the list of known addresses -// this node is/was known to be reachable at. -func (l *LinkNode) AddAddress(addr net.Addr) error { - for _, a := range l.Addresses { - if a.String() == addr.String() { - return nil - } - } - - l.Addresses = append(l.Addresses, addr) - - return l.Sync() -} - -// Sync performs a full database sync which writes the current up-to-date data -// within the struct to the database. -func (l *LinkNode) Sync() error { - - // Finally update the database by storing the link node and updating - // any relevant indexes. - return l.db.Update(func(tx *bbolt.Tx) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound - } - - return putLinkNode(nodeMetaBucket, l) - }) -} - -// putLinkNode serializes then writes the encoded version of the passed link -// node into the nodeMetaBucket. This function is provided in order to allow -// the ability to re-use a database transaction across many operations. -func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error { - // First serialize the LinkNode into its raw-bytes encoding. - var b bytes.Buffer - if err := serializeLinkNode(&b, l); err != nil { - return err - } - - // Finally insert the link-node into the node metadata bucket keyed - // according to the its pubkey serialized in compressed form. - nodePub := l.IdentityPub.SerializeCompressed() - return nodeMetaBucket.Put(nodePub, b.Bytes()) -} - -// DeleteLinkNode removes the link node with the given identity from the -// database. -func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error { - return db.Update(func(tx *bbolt.Tx) error { - return db.deleteLinkNode(tx, identity) - }) -} - -func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound - } - - pubKey := identity.SerializeCompressed() - return nodeMetaBucket.Delete(pubKey) -} - -// FetchLinkNode attempts to lookup the data for a LinkNode based on a target -// identity public key. If a particular LinkNode for the passed identity public -// key cannot be found, then ErrNodeNotFound if returned. -func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { - var linkNode *LinkNode - err := db.View(func(tx *bbolt.Tx) error { - node, err := fetchLinkNode(tx, identity) - if err != nil { - return err - } - - linkNode = node - return nil - }) - - return linkNode, err -} - -func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) { - // First fetch the bucket for storing node metadata, bailing out early - // if it hasn't been created yet. - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return nil, ErrLinkNodesNotFound - } - - // If a link node for that particular public key cannot be located, - // then exit early with an ErrNodeNotFound. - pubKey := targetPub.SerializeCompressed() - nodeBytes := nodeMetaBucket.Get(pubKey) - if nodeBytes == nil { - return nil, ErrNodeNotFound - } - - // Finally, decode and allocate a fresh LinkNode object to be returned - // to the caller. - nodeReader := bytes.NewReader(nodeBytes) - return deserializeLinkNode(nodeReader) -} - -// TODO(roasbeef): update link node addrs in server upon connection - -// FetchAllLinkNodes starts a new database transaction to fetch all nodes with -// whom we have active channels with. -func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { - var linkNodes []*LinkNode - err := db.View(func(tx *bbolt.Tx) error { - nodes, err := db.fetchAllLinkNodes(tx) - if err != nil { - return err - } - - linkNodes = nodes - return nil - }) - if err != nil { - return nil, err - } - - return linkNodes, nil -} - -// fetchAllLinkNodes uses an existing database transaction to fetch all nodes -// with whom we have active channels with. -func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return nil, ErrLinkNodesNotFound - } - - var linkNodes []*LinkNode - err := nodeMetaBucket.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } - - nodeReader := bytes.NewReader(v) - linkNode, err := deserializeLinkNode(nodeReader) - if err != nil { - return err - } - - linkNodes = append(linkNodes, linkNode) - return nil - }) - if err != nil { - return nil, err - } - - return linkNodes, nil -} - -func serializeLinkNode(w io.Writer, l *LinkNode) error { - var buf [8]byte - - byteOrder.PutUint32(buf[:4], uint32(l.Network)) - if _, err := w.Write(buf[:4]); err != nil { - return err - } - - serializedID := l.IdentityPub.SerializeCompressed() - if _, err := w.Write(serializedID); err != nil { - return err - } - - seenUnix := uint64(l.LastSeen.Unix()) - byteOrder.PutUint64(buf[:], seenUnix) - if _, err := w.Write(buf[:]); err != nil { - return err - } - - numAddrs := uint32(len(l.Addresses)) - byteOrder.PutUint32(buf[:4], numAddrs) - if _, err := w.Write(buf[:4]); err != nil { - return err - } - - for _, addr := range l.Addresses { - if err := serializeAddr(w, addr); err != nil { - return err - } - } - - return nil -} - -func deserializeLinkNode(r io.Reader) (*LinkNode, error) { - var ( - err error - buf [8]byte - ) - - node := &LinkNode{} - - if _, err := io.ReadFull(r, buf[:4]); err != nil { - return nil, err - } - node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4])) - - var pub [33]byte - if _, err := io.ReadFull(r, pub[:]); err != nil { - return nil, err - } - node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256()) - if err != nil { - return nil, err - } - - if _, err := io.ReadFull(r, buf[:]); err != nil { - return nil, err - } - node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0) - - if _, err := io.ReadFull(r, buf[:4]); err != nil { - return nil, err - } - numAddrs := byteOrder.Uint32(buf[:4]) - - node.Addresses = make([]net.Addr, numAddrs) - for i := uint32(0); i < numAddrs; i++ { - addr, err := deserializeAddr(r) - if err != nil { - return nil, err - } - node.Addresses[i] = addr - } - - return node, nil -} diff --git a/channeldb/migration_01_to_11/nodes_test.go b/channeldb/migration_01_to_11/nodes_test.go deleted file mode 100644 index 481dc5bd..00000000 --- a/channeldb/migration_01_to_11/nodes_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "net" - "testing" - "time" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" -) - -func TestLinkNodeEncodeDecode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First we'll create some initial data to use for populating our test - // LinkNode instances. - _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - _, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:]) - addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000") - if err != nil { - t.Fatalf("unable to create test addr: %v", err) - } - addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000") - if err != nil { - t.Fatalf("unable to create test addr: %v", err) - } - - // Create two fresh link node instances with the above dummy data, then - // fully sync both instances to disk. - node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) - node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) - if err := node1.Sync(); err != nil { - t.Fatalf("unable to sync node: %v", err) - } - if err := node2.Sync(); err != nil { - t.Fatalf("unable to sync node: %v", err) - } - - // Fetch all current link nodes from the database, they should exactly - // match the two created above. - originalNodes := []*LinkNode{node2, node1} - linkNodes, err := cdb.FetchAllLinkNodes() - if err != nil { - t.Fatalf("unable to fetch nodes: %v", err) - } - for i, node := range linkNodes { - if originalNodes[i].Network != node.Network { - t.Fatalf("node networks don't match: expected %v, got %v", - originalNodes[i].Network, node.Network) - } - - originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed() - dbPubkey := node.IdentityPub.SerializeCompressed() - if !bytes.Equal(originalPubkey, dbPubkey) { - t.Fatalf("node pubkeys don't match: expected %x, got %x", - originalPubkey, dbPubkey) - } - if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() { - t.Fatalf("last seen timestamps don't match: expected %v got %v", - originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix()) - } - if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() { - t.Fatalf("addresses don't match: expected %v, got %v", - originalNodes[i].Addresses, node.Addresses) - } - } - - // Next, we'll exercise the methods to append additional IP - // addresses, and also to update the last seen time. - if err := node1.UpdateLastSeen(time.Now()); err != nil { - t.Fatalf("unable to update last seen: %v", err) - } - if err := node1.AddAddress(addr2); err != nil { - t.Fatalf("unable to update addr: %v", err) - } - - // Fetch the same node from the database according to its public key. - node1DB, err := cdb.FetchLinkNode(pub1) - if err != nil { - t.Fatalf("unable to find node: %v", err) - } - - // Both the last seen timestamp and the list of reachable addresses for - // the node should be updated. - if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() { - t.Fatalf("last seen timestamps don't match: expected %v got %v", - node1.LastSeen.Unix(), node1DB.LastSeen.Unix()) - } - if len(node1DB.Addresses) != 2 { - t.Fatalf("wrong length for node1 addresses: expected %v, got %v", - 2, len(node1DB.Addresses)) - } - if node1DB.Addresses[0].String() != addr1.String() { - t.Fatalf("wrong address for node: expected %v, got %v", - addr1.String(), node1DB.Addresses[0].String()) - } - if node1DB.Addresses[1].String() != addr2.String() { - t.Fatalf("wrong address for node: expected %v, got %v", - addr2.String(), node1DB.Addresses[1].String()) - } -} - -func TestDeleteLinkNode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 1337, - } - linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) - if err := linkNode.Sync(); err != nil { - t.Fatalf("unable to write link node to db: %v", err) - } - - if _, err := cdb.FetchLinkNode(pubKey); err != nil { - t.Fatalf("unable to find link node: %v", err) - } - - if err := cdb.DeleteLinkNode(pubKey); err != nil { - t.Fatalf("unable to delete link node from db: %v", err) - } - - if _, err := cdb.FetchLinkNode(pubKey); err == nil { - t.Fatal("should not have found link node in db, but did") - } -} diff --git a/channeldb/migration_01_to_11/options.go b/channeldb/migration_01_to_11/options.go index c3cc2c4a..03b287e0 100644 --- a/channeldb/migration_01_to_11/options.go +++ b/channeldb/migration_01_to_11/options.go @@ -39,24 +39,3 @@ func DefaultOptions() Options { // OptionModifier is a function signature for modifying the default Options. type OptionModifier func(*Options) - -// OptionSetRejectCacheSize sets the RejectCacheSize to n. -func OptionSetRejectCacheSize(n int) OptionModifier { - return func(o *Options) { - o.RejectCacheSize = n - } -} - -// OptionSetChannelCacheSize sets the ChannelCacheSize to n. -func OptionSetChannelCacheSize(n int) OptionModifier { - return func(o *Options) { - o.ChannelCacheSize = n - } -} - -// OptionSetSyncFreelist allows the database to sync its freelist. -func OptionSetSyncFreelist(b bool) OptionModifier { - return func(o *Options) { - o.NoFreelistSync = !b - } -} diff --git a/channeldb/migration_01_to_11/payment_control.go b/channeldb/migration_01_to_11/payment_control.go index 83b1649a..7b069d24 100644 --- a/channeldb/migration_01_to_11/payment_control.go +++ b/channeldb/migration_01_to_11/payment_control.go @@ -1,373 +1,9 @@ package migration_01_to_11 import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" ) -var ( - // ErrAlreadyPaid signals we have already paid this payment hash. - ErrAlreadyPaid = errors.New("invoice is already paid") - - // ErrPaymentInFlight signals that payment for this payment hash is - // already "in flight" on the network. - ErrPaymentInFlight = errors.New("payment is in transition") - - // ErrPaymentNotInitiated is returned if payment wasn't initiated in - // switch. - ErrPaymentNotInitiated = errors.New("payment isn't initiated") - - // ErrPaymentAlreadySucceeded is returned in the event we attempt to - // change the status of a payment already succeeded. - ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded") - - // ErrPaymentAlreadyFailed is returned in the event we attempt to - // re-fail a failed payment. - ErrPaymentAlreadyFailed = errors.New("payment has already failed") - - // ErrUnknownPaymentStatus is returned when we do not recognize the - // existing state of a payment. - ErrUnknownPaymentStatus = errors.New("unknown payment status") - - // errNoAttemptInfo is returned when no attempt info is stored yet. - errNoAttemptInfo = errors.New("unable to find attempt info for " + - "inflight payment") -) - -// PaymentControl implements persistence for payments and payment attempts. -type PaymentControl struct { - db *DB -} - -// NewPaymentControl creates a new instance of the PaymentControl. -func NewPaymentControl(db *DB) *PaymentControl { - return &PaymentControl{ - db: db, - } -} - -// InitPayment checks or records the given PaymentCreationInfo with the DB, -// making sure it does not already exist as an in-flight payment. Then this -// method returns successfully, the payment is guranteeed to be in the InFlight -// state. -func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, - info *PaymentCreationInfo) error { - - var b bytes.Buffer - if err := serializePaymentCreationInfo(&b, info); err != nil { - return err - } - infoBytes := b.Bytes() - - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := createPaymentBucket(tx, paymentHash) - if err != nil { - return err - } - - // Get the existing status of this payment, if any. - paymentStatus := fetchPaymentStatus(bucket) - - switch paymentStatus { - - // We allow retrying failed payments. - case StatusFailed: - - // This is a new payment that is being initialized for the - // first time. - case StatusUnknown: - - // We already have an InFlight payment on the network. We will - // disallow any new payments. - case StatusInFlight: - updateErr = ErrPaymentInFlight - return nil - - // We've already succeeded a payment to this payment hash, - // forbid the switch from sending another. - case StatusSucceeded: - updateErr = ErrAlreadyPaid - return nil - - default: - updateErr = ErrUnknownPaymentStatus - return nil - } - - // Obtain a new sequence number for this payment. This is used - // to sort the payments in order of creation, and also acts as - // a unique identifier for each payment. - sequenceNum, err := nextPaymentSequence(tx) - if err != nil { - return err - } - - err = bucket.Put(paymentSequenceKey, sequenceNum) - if err != nil { - return err - } - - // Add the payment info to the bucket, which contains the - // static information for this payment - err = bucket.Put(paymentCreationInfoKey, infoBytes) - if err != nil { - return err - } - - // We'll delete any lingering attempt info to start with, in - // case we are initializing a payment that was attempted - // earlier, but left in a state where we could retry. - err = bucket.Delete(paymentAttemptInfoKey) - if err != nil { - return err - } - - // Also delete any lingering failure info now that we are - // re-attempting. - return bucket.Delete(paymentFailInfoKey) - }) - if err != nil { - return err - } - - return updateErr -} - -// RegisterAttempt atomically records the provided PaymentAttemptInfo to the -// DB. -func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, - attempt *PaymentAttemptInfo) error { - - // Serialize the information before opening the db transaction. - var a bytes.Buffer - if err := serializePaymentAttemptInfo(&a, attempt); err != nil { - return err - } - attemptBytes := a.Bytes() - - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only register attempts for payments that are - // in-flight. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Add the payment attempt to the payments bucket. - return bucket.Put(paymentAttemptInfoKey, attemptBytes) - }) - if err != nil { - return err - } - - return updateErr -} - -// Success transitions a payment into the Succeeded state. After invoking this -// method, InitPayment should always return an error to prevent us from making -// duplicate payments to the same payment hash. The provided preimage is -// atomically saved to the DB for record keeping. -func (p *PaymentControl) Success(paymentHash lntypes.Hash, - preimage lntypes.Preimage) (*route.Route, error) { - - var ( - updateErr error - route *route.Route - ) - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only mark in-flight payments as succeeded. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Record the successful payment info atomically to the - // payments record. - err = bucket.Put(paymentSettleInfoKey, preimage[:]) - if err != nil { - return err - } - - // Retrieve attempt info for the notification. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil { - return err - } - - route = &attempt.Route - - return nil - }) - if err != nil { - return nil, err - } - - return route, updateErr -} - -// Fail transitions a payment into the Failed state, and records the reason the -// payment failed. After invoking this method, InitPayment should return nil on -// its next call for this payment hash, allowing the switch to make a -// subsequent payment. -func (p *PaymentControl) Fail(paymentHash lntypes.Hash, - reason FailureReason) (*route.Route, error) { - - var ( - updateErr error - route *route.Route - ) - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only mark in-flight payments as failed. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Put the failure reason in the bucket for record keeping. - v := []byte{byte(reason)} - err = bucket.Put(paymentFailInfoKey, v) - if err != nil { - return err - } - - // Retrieve attempt info for the notification, if available. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil && err != errNoAttemptInfo { - return err - } - if err != errNoAttemptInfo { - route = &attempt.Route - } - - return nil - }) - if err != nil { - return nil, err - } - - return route, updateErr -} - -// FetchPayment returns information about a payment from the database. -func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( - *Payment, error) { - - var payment *Payment - err := p.db.View(func(tx *bbolt.Tx) error { - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err != nil { - return err - } - - payment, err = fetchPayment(bucket) - - return err - }) - if err != nil { - return nil, err - } - - return payment, nil -} - -// createPaymentBucket creates or fetches the sub-bucket assigned to this -// payment hash. -func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( - *bbolt.Bucket, error) { - - payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) - if err != nil { - return nil, err - } - - return payments.CreateBucketIfNotExists(paymentHash[:]) -} - -// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If -// the bucket does not exist, it returns ErrPaymentNotInitiated. -func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( - *bbolt.Bucket, error) { - - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil, ErrPaymentNotInitiated - } - - bucket := payments.Bucket(paymentHash[:]) - if bucket == nil { - return nil, ErrPaymentNotInitiated - } - - return bucket, nil - -} - -// nextPaymentSequence returns the next sequence number to store for a new -// payment. -func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) { - payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) - if err != nil { - return nil, err - } - - seq, err := payments.NextSequence() - if err != nil { - return nil, err - } - - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, seq) - return b, nil -} - // fetchPaymentStatus fetches the payment status of the payment. If the payment // isn't found, it will default to "StatusUnknown". func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { @@ -385,113 +21,3 @@ func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { return StatusUnknown } - -// ensureInFlight checks whether the payment found in the given bucket has -// status InFlight, and returns an error otherwise. This should be used to -// ensure we only mark in-flight payments as succeeded or failed. -func ensureInFlight(bucket *bbolt.Bucket) error { - paymentStatus := fetchPaymentStatus(bucket) - - switch { - - // The payment was indeed InFlight, return. - case paymentStatus == StatusInFlight: - return nil - - // Our records show the payment as unknown, meaning it never - // should have left the switch. - case paymentStatus == StatusUnknown: - return ErrPaymentNotInitiated - - // The payment succeeded previously. - case paymentStatus == StatusSucceeded: - return ErrPaymentAlreadySucceeded - - // The payment was already failed. - case paymentStatus == StatusFailed: - return ErrPaymentAlreadyFailed - - default: - return ErrUnknownPaymentStatus - } -} - -// fetchPaymentAttempt fetches the payment attempt from the bucket. -func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) { - attemptData := bucket.Get(paymentAttemptInfoKey) - if attemptData == nil { - return nil, errNoAttemptInfo - } - - r := bytes.NewReader(attemptData) - return deserializePaymentAttemptInfo(r) -} - -// InFlightPayment is a wrapper around a payment that has status InFlight. -type InFlightPayment struct { - // Info is the PaymentCreationInfo of the in-flight payment. - Info *PaymentCreationInfo - - // Attempt contains information about the last payment attempt that was - // made to this payment hash. - // - // NOTE: Might be nil. - Attempt *PaymentAttemptInfo -} - -// FetchInFlightPayments returns all payments with status InFlight. -func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { - var inFlights []*InFlightPayment - err := p.db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - return payments.ForEach(func(k, _ []byte) error { - bucket := payments.Bucket(k) - if bucket == nil { - return fmt.Errorf("non bucket element") - } - - // If the status is not InFlight, we can return early. - paymentStatus := fetchPaymentStatus(bucket) - if paymentStatus != StatusInFlight { - return nil - } - - var ( - inFlight = &InFlightPayment{} - err error - ) - - // Get the CreationInfo. - b := bucket.Get(paymentCreationInfoKey) - if b == nil { - return fmt.Errorf("unable to find creation " + - "info for inflight payment") - } - - r := bytes.NewReader(b) - inFlight.Info, err = deserializePaymentCreationInfo(r) - if err != nil { - return err - } - - // Now get the attempt info. It could be that there is - // no attempt info yet. - inFlight.Attempt, err = fetchPaymentAttempt(bucket) - if err != nil && err != errNoAttemptInfo { - return err - } - - inFlights = append(inFlights, inFlight) - return nil - }) - }) - if err != nil { - return nil, err - } - - return inFlights, nil -} diff --git a/channeldb/migration_01_to_11/payment_control_test.go b/channeldb/migration_01_to_11/payment_control_test.go deleted file mode 100644 index 9868475e..00000000 --- a/channeldb/migration_01_to_11/payment_control_test.go +++ /dev/null @@ -1,550 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "io/ioutil" - "reflect" - "testing" - "time" - - "github.com/btcsuite/fastsha256" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" -) - -func initDB() (*DB, error) { - tempPath, err := ioutil.TempDir("", "switchdb") - if err != nil { - return nil, err - } - - db, err := Open(tempPath) - if err != nil { - return nil, err - } - - return db, err -} - -func genPreimage() ([32]byte, error) { - var preimage [32]byte - if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { - return preimage, err - } - return preimage, nil -} - -func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, - lntypes.Preimage, error) { - - preimage, err := genPreimage() - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) - } - - rhash := fastsha256.Sum256(preimage[:]) - return &PaymentCreationInfo{ - PaymentHash: rhash, - Value: 1, - CreationDate: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, - &PaymentAttemptInfo{ - PaymentID: 1, - SessionKey: priv, - Route: testRoute, - }, preimage, nil -} - -// TestPaymentControlSwitchFail checks that payment status returns to Failed -// status after failing, and that InitPayment allows another HTLC for the -// same payment hash. -func TestPaymentControlSwitchFail(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = pControl.Fail(info.PaymentHash, failReason) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - &failReason, - ) - - // Sends the htlc again, which should succeed since the prior payment - // failed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Record a new attempt. - attempt.PaymentID = 2 - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, - nil, - ) - - // Verifies that status was changed to StatusSucceeded. - var route *route.Route - route, err = pControl.Success(info.PaymentHash, preimg) - if err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - err = assertRouteEqual(route, &attempt.Route) - if err != nil { - t.Fatalf("unexpected route returned: %v vs %v: %v", - spew.Sdump(attempt.Route), spew.Sdump(*route), err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) - - // Attempt a final payment, which should now fail since the prior - // payment succeed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrAlreadyPaid { - t.Fatalf("unable to send htlc message: %v", err) - } -} - -// TestPaymentControlSwitchDoubleSend checks the ability of payment control to -// prevent double sending of htlc message, when message is in StatusInFlight. -func TestPaymentControlSwitchDoubleSend(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate base status and move it to - // StatusInFlight and verifies that it was changed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Try to initiate double sending of htlc message with the same - // payment hash, should result in error indicating that payment has - // already been sent. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrPaymentInFlight { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") - } - - // Record an attempt. - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, - nil, - ) - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrPaymentInFlight { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") - } - - // After settling, the error should be ErrAlreadyPaid. - if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) - - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrAlreadyPaid { - t.Fatalf("unable to send htlc message: %v", err) - } -} - -// TestPaymentControlSuccessesWithoutInFlight checks that the payment -// control will disallow calls to Success when no payment is in flight. -func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, _, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Attempt to complete the payment should fail. - _, err = pControl.Success(info.PaymentHash, preimg) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, - nil, - ) -} - -// TestPaymentControlFailsWithoutInFlight checks that a strict payment -// control will disallow calls to Fail when no payment is in flight. -func TestPaymentControlFailsWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, _, _, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Calling Fail should return an error. - _, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, - ) -} - -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only -// deletes payments from the database that are not in-flight. -func TestPaymentControlDeleteNonInFligt(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - payments := []struct { - failed bool - success bool - }{ - { - failed: true, - success: false, - }, - { - failed: false, - success: true, - }, - { - failed: false, - success: false, - }, - } - - for _, p := range payments { - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - if p.failed { - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = pControl.Fail(info.PaymentHash, failReason) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, - lntypes.Preimage{}, &failReason, - ) - } else if p.success { - // Verifies that status was changed to StatusSucceeded. - _, err := pControl.Success(info.PaymentHash, preimg) - if err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, preimg, nil, - ) - } else { - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, - lntypes.Preimage{}, nil, - ) - } - } - - // Delete payments. - if err := db.DeletePayments(); err != nil { - t.Fatal(err) - } - - // This should leave the in-flight payment. - dbPayments, err := db.FetchPayments() - if err != nil { - t.Fatal(err) - } - - if len(dbPayments) != 1 { - t.Fatalf("expected one payment, got %d", len(dbPayments)) - } - - status := dbPayments[0].Status - if status != StatusInFlight { - t.Fatalf("expected in-fligth status, got %v", status) - } -} - -func assertPaymentStatus(t *testing.T, db *DB, - hash [32]byte, expStatus PaymentStatus) { - - t.Helper() - - var paymentStatus = StatusUnknown - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil { - return nil - } - - // Get the existing status of this payment, if any. - paymentStatus = fetchPaymentStatus(bucket) - return nil - }) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if paymentStatus != expStatus { - t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, paymentStatus) - } -} - -func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { - b := bucket.Get(paymentCreationInfoKey) - switch { - case b == nil && c == nil: - return nil - case b == nil: - return fmt.Errorf("expected creation info not found") - case c == nil: - return fmt.Errorf("unexpected creation info found") - } - - r := bytes.NewReader(b) - c2, err := deserializePaymentCreationInfo(r) - if err != nil { - return err - } - if !reflect.DeepEqual(c, c2) { - return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", - spew.Sdump(c), spew.Sdump(c2)) - } - - return nil -} - -func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { - b := bucket.Get(paymentAttemptInfoKey) - switch { - case b == nil && a == nil: - return nil - case b == nil: - return fmt.Errorf("expected attempt info not found") - case a == nil: - return fmt.Errorf("unexpected attempt info found") - } - - r := bytes.NewReader(b) - a2, err := deserializePaymentAttemptInfo(r) - if err != nil { - return err - } - - return assertRouteEqual(&a.Route, &a2.Route) -} - -func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { - zero := lntypes.Preimage{} - b := bucket.Get(paymentSettleInfoKey) - switch { - case b == nil && preimg == zero: - return nil - case b == nil: - return fmt.Errorf("expected preimage not found") - case preimg == zero: - return fmt.Errorf("unexpected preimage found") - } - - var pre2 lntypes.Preimage - copy(pre2[:], b[:]) - if preimg != pre2 { - return fmt.Errorf("Preimages don't match: %x vs %x", - preimg, pre2) - } - - return nil -} - -func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { - b := bucket.Get(paymentFailInfoKey) - switch { - case b == nil && failReason == nil: - return nil - case b == nil: - return fmt.Errorf("expected fail info not found") - case failReason == nil: - return fmt.Errorf("unexpected fail info found") - } - - failReason2 := FailureReason(b[0]) - if *failReason != failReason2 { - return fmt.Errorf("Failure infos don't match: %v vs %v", - *failReason, failReason2) - } - - return nil -} - -func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, - c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage, - f *FailureReason) { - - t.Helper() - - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil && c == nil { - return nil - } - if payments == nil { - return fmt.Errorf("sent payments not found") - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil && c == nil { - return nil - } - - if bucket == nil { - return fmt.Errorf("payment not found") - } - - if err := checkPaymentCreationInfo(bucket, c); err != nil { - return err - } - - if err := checkPaymentAttemptInfo(bucket, a); err != nil { - return err - } - - if err := checkSettleInfo(bucket, s); err != nil { - return err - } - - if err := checkFailInfo(bucket, f); err != nil { - return err - } - return nil - }) - if err != nil { - t.Fatalf("assert payment info failed: %v", err) - } - -} diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go index fd3db5a1..d34cd6e9 100644 --- a/channeldb/migration_01_to_11/payments.go +++ b/channeldb/migration_01_to_11/payments.go @@ -375,48 +375,6 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { return p, nil } -// DeletePayments deletes all completed and failed payments from the DB. -func (db *DB) DeletePayments() error { - return db.Update(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - var deleteBuckets [][]byte - err := payments.ForEach(func(k, _ []byte) error { - bucket := payments.Bucket(k) - if bucket == nil { - // We only expect sub-buckets to be found in - // this top-level bucket. - return fmt.Errorf("non bucket element in " + - "payments bucket") - } - - // If the status is InFlight, we cannot safely delete - // the payment information, so we return early. - paymentStatus := fetchPaymentStatus(bucket) - if paymentStatus == StatusInFlight { - return nil - } - - deleteBuckets = append(deleteBuckets, k) - return nil - }) - if err != nil { - return err - } - - for _, k := range deleteBuckets { - if err := payments.DeleteBucket(k); err != nil { - return err - } - } - - return nil - }) -} - func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/migration_01_to_11/payments_test.go b/channeldb/migration_01_to_11/payments_test.go index 07307941..c5584079 100644 --- a/channeldb/migration_01_to_11/payments_test.go +++ b/channeldb/migration_01_to_11/payments_test.go @@ -2,55 +2,17 @@ package migration_01_to_11 import ( "bytes" - "errors" "fmt" "math/rand" - "reflect" - "testing" "time" "github.com/btcsuite/btcd/btcec" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" ) var ( priv, _ = btcec.NewPrivateKey(btcec.S256()) pub = priv.PubKey() - - tlvBytes = []byte{1, 2, 3} - tlvEncoder = tlv.StubEncoder(tlvBytes) - testHop1 = &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - TLVRecords: []tlv.Record{ - tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), - tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), - }, - } - - testHop2 = &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - LegacyPayload: true, - } - - testRoute = route.Route{ - TotalTimeLock: 123, - TotalAmount: 1234567, - SourcePubKey: route.NewVertex(pub), - Hops: []*route.Hop{ - testHop1, - testHop2, - }, - } ) func makeFakePayment() *outgoingPayment { @@ -81,27 +43,6 @@ func makeFakePayment() *outgoingPayment { return fakePayment } -func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { - var preimg lntypes.Preimage - copy(preimg[:], rev[:]) - - c := &PaymentCreationInfo{ - PaymentHash: preimg.Hash(), - Value: 1000, - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte(""), - } - - a := &PaymentAttemptInfo{ - PaymentID: 44, - SessionKey: priv, - Route: testRoute, - } - return c, a -} - // randomBytes creates random []byte with length in range [minLen, maxLen) func randomBytes(minLen, maxLen int) ([]byte, error) { randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) @@ -165,160 +106,3 @@ func makeRandomFakePayment() (*outgoingPayment, error) { return fakePayment, nil } - -func TestSentPaymentSerialization(t *testing.T) { - t.Parallel() - - c, s := makeFakeInfo() - - var b bytes.Buffer - if err := serializePaymentCreationInfo(&b, c); err != nil { - t.Fatalf("unable to serialize creation info: %v", err) - } - - newCreationInfo, err := deserializePaymentCreationInfo(&b) - if err != nil { - t.Fatalf("unable to deserialize creation info: %v", err) - } - - if !reflect.DeepEqual(c, newCreationInfo) { - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(c), spew.Sdump(newCreationInfo), - ) - } - - b.Reset() - if err := serializePaymentAttemptInfo(&b, s); err != nil { - t.Fatalf("unable to serialize info: %v", err) - } - - newAttemptInfo, err := deserializePaymentAttemptInfo(&b) - if err != nil { - t.Fatalf("unable to deserialize info: %v", err) - } - - // First we verify all the records match up porperly, as they aren't - // able to be properly compared using reflect.DeepEqual. - err = assertRouteEqual(&s.Route, &newAttemptInfo.Route) - if err != nil { - t.Fatalf("Routes do not match after "+ - "serialization/deserialization: %v", err) - } - - // Clear routes to allow DeepEqual to compare the remaining fields. - newAttemptInfo.Route = route.Route{} - s.Route = route.Route{} - - if !reflect.DeepEqual(s, newAttemptInfo) { - s.SessionKey.Curve = nil - newAttemptInfo.SessionKey.Curve = nil - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(s), spew.Sdump(newAttemptInfo), - ) - } -} - -// assertRouteEquals compares to routes for equality and returns an error if -// they are not equal. -func assertRouteEqual(a, b *route.Route) error { - err := assertRouteHopRecordsEqual(a, b) - if err != nil { - return err - } - - // TLV records have already been compared and need to be cleared to - // properly compare the remaining fields using DeepEqual. - copyRouteNoHops := func(r *route.Route) *route.Route { - copy := *r - copy.Hops = make([]*route.Hop, len(r.Hops)) - for i, hop := range r.Hops { - hopCopy := *hop - hopCopy.TLVRecords = nil - copy.Hops[i] = &hopCopy - } - return © - } - - if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { - return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", - spew.Sdump(a), spew.Sdump(b)) - } - - return nil -} - -func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { - if len(r1.Hops) != len(r2.Hops) { - return errors.New("route hop count mismatch") - } - - for i := 0; i < len(r1.Hops); i++ { - records1 := r1.Hops[i].TLVRecords - records2 := r2.Hops[i].TLVRecords - if len(records1) != len(records2) { - return fmt.Errorf("route record count for hop %v "+ - "mismatch", i) - } - - for j := 0; j < len(records1); j++ { - expectedRecord := records1[j] - newRecord := records2[j] - - err := assertHopRecordsEqual(expectedRecord, newRecord) - if err != nil { - return fmt.Errorf("route record mismatch: %v", err) - } - } - } - - return nil -} - -func assertHopRecordsEqual(h1, h2 tlv.Record) error { - if h1.Type() != h2.Type() { - return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(), - h2.Type()) - } - - var b bytes.Buffer - if err := h2.Encode(&b); err != nil { - return fmt.Errorf("unable to encode record: %v", err) - } - - if !bytes.Equal(b.Bytes(), tlvBytes) { - return fmt.Errorf("wrong raw record: expected %x, got %x", - tlvBytes, b.Bytes()) - } - - if h1.Size() != h2.Size() { - return fmt.Errorf("wrong size: expected %v, "+ - "got %v", h1.Size(), h2.Size()) - } - - return nil -} - -func TestRouteSerialization(t *testing.T) { - t.Parallel() - - var b bytes.Buffer - if err := SerializeRoute(&b, testRoute); err != nil { - t.Fatal(err) - } - - r := bytes.NewReader(b.Bytes()) - route2, err := DeserializeRoute(r) - if err != nil { - t.Fatal(err) - } - - // First we verify all the records match up porperly, as they aren't - // able to be properly compared using reflect.DeepEqual. - err = assertRouteEqual(&testRoute, &route2) - if err != nil { - t.Fatalf("routes not equal: \n%v vs \n%v", - spew.Sdump(testRoute), spew.Sdump(route2)) - } -} diff --git a/channeldb/migration_01_to_11/reject_cache.go b/channeldb/migration_01_to_11/reject_cache.go deleted file mode 100644 index c54d78a8..00000000 --- a/channeldb/migration_01_to_11/reject_cache.go +++ /dev/null @@ -1,95 +0,0 @@ -package migration_01_to_11 - -// rejectFlags is a compact representation of various metadata stored by the -// reject cache about a particular channel. -type rejectFlags uint8 - -const ( - // rejectFlagExists is a flag indicating whether the channel exists, - // i.e. the channel is open and has a recent channel update. If this - // flag is not set, the channel is either a zombie or unknown. - rejectFlagExists rejectFlags = 1 << iota - - // rejectFlagZombie is a flag indicating whether the channel is a - // zombie, i.e. the channel is open but has no recent channel updates. - rejectFlagZombie -) - -// packRejectFlags computes the rejectFlags corresponding to the passed boolean -// values indicating whether the edge exists or is a zombie. -func packRejectFlags(exists, isZombie bool) rejectFlags { - var flags rejectFlags - if exists { - flags |= rejectFlagExists - } - if isZombie { - flags |= rejectFlagZombie - } - - return flags -} - -// unpack returns the booleans packed into the rejectFlags. The first indicates -// if the edge exists in our graph, the second indicates if the edge is a -// zombie. -func (f rejectFlags) unpack() (bool, bool) { - return f&rejectFlagExists == rejectFlagExists, - f&rejectFlagZombie == rejectFlagZombie -} - -// rejectCacheEntry caches frequently accessed information about a channel, -// including the timestamps of its latest edge policies and whether or not the -// channel exists in the graph. -type rejectCacheEntry struct { - upd1Time int64 - upd2Time int64 - flags rejectFlags -} - -// rejectCache is an in-memory cache used to improve the performance of -// HasChannelEdge. It caches information about the whether or channel exists, as -// well as the most recent timestamps for each policy (if they exists). -type rejectCache struct { - n int - edges map[uint64]rejectCacheEntry -} - -// newRejectCache creates a new rejectCache with maximum capacity of n entries. -func newRejectCache(n int) *rejectCache { - return &rejectCache{ - n: n, - edges: make(map[uint64]rejectCacheEntry, n), - } -} - -// get returns the entry from the cache for chanid, if it exists. -func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { - entry, ok := c.edges[chanid] - return entry, ok -} - -// insert adds the entry to the reject cache. If an entry for chanid already -// exists, it will be replaced with the new entry. If the entry doesn't exists, -// it will be inserted to the cache, performing a random eviction if the cache -// is at capacity. -func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { - // If entry exists, replace it. - if _, ok := c.edges[chanid]; ok { - c.edges[chanid] = entry - return - } - - // Otherwise, evict an entry at random and insert. - if len(c.edges) == c.n { - for id := range c.edges { - delete(c.edges, id) - break - } - } - c.edges[chanid] = entry -} - -// remove deletes an entry for chanid from the cache, if it exists. -func (c *rejectCache) remove(chanid uint64) { - delete(c.edges, chanid) -} diff --git a/channeldb/migration_01_to_11/reject_cache_test.go b/channeldb/migration_01_to_11/reject_cache_test.go deleted file mode 100644 index e15e0a10..00000000 --- a/channeldb/migration_01_to_11/reject_cache_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package migration_01_to_11 - -import ( - "reflect" - "testing" -) - -// TestRejectCache checks the behavior of the rejectCache with respect to insertion, -// eviction, and removal of cache entries. -func TestRejectCache(t *testing.T) { - const cacheSize = 100 - - // Create a new reject cache with the configured max size. - c := newRejectCache(cacheSize) - - // As a sanity check, assert that querying the empty cache does not - // return an entry. - _, ok := c.get(0) - if ok { - t.Fatalf("reject cache should be empty") - } - - // Now, fill up the cache entirely. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) - } - - // Assert that the cache has all of the entries just inserted, since no - // eviction should occur until we try to surpass the max size. - assertHasEntries(t, c, 0, cacheSize) - - // Now, insert a new element that causes the cache to evict an element. - c.insert(cacheSize, entryForInt(cacheSize)) - - // Assert that the cache has this last entry, as the cache should evict - // some prior element and not the newly inserted one. - assertHasEntries(t, c, cacheSize, cacheSize) - - // Iterate over all inserted elements and construct a set of the evicted - // elements. - evicted := make(map[uint64]struct{}) - for i := uint64(0); i < cacheSize+1; i++ { - _, ok := c.get(i) - if !ok { - evicted[i] = struct{}{} - } - } - - // Assert that exactly one element has been evicted. - numEvicted := len(evicted) - if numEvicted != 1 { - t.Fatalf("expected one evicted entry, got: %d", numEvicted) - } - - // Remove the highest item which initially caused the eviction and - // reinsert the element that was evicted prior. - c.remove(cacheSize) - for i := range evicted { - c.insert(i, entryForInt(i)) - } - - // Since the removal created an extra slot, the last insertion should - // not have caused an eviction and the entries for all channels in the - // original set that filled the cache should be present. - assertHasEntries(t, c, 0, cacheSize) - - // Finally, reinsert the existing set back into the cache and test that - // the cache still has all the entries. If the randomized eviction were - // happening on inserts for existing cache items, we expect this to fail - // with high probability. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) - } - assertHasEntries(t, c, 0, cacheSize) - -} - -// assertHasEntries queries the reject cache for all channels in the range [start, -// end), asserting that they exist and their value matches the entry produced by -// entryForInt. -func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) { - t.Helper() - - for i := start; i < end; i++ { - entry, ok := c.get(i) - if !ok { - t.Fatalf("reject cache should contain chan %d", i) - } - - expEntry := entryForInt(i) - if !reflect.DeepEqual(entry, expEntry) { - t.Fatalf("entry mismatch, want: %v, got: %v", - expEntry, entry) - } - } -} - -// entryForInt generates a unique rejectCacheEntry given an integer. -func entryForInt(i uint64) rejectCacheEntry { - exists := i%2 == 0 - isZombie := i%3 == 0 - return rejectCacheEntry{ - upd1Time: int64(2 * i), - upd2Time: int64(2*i + 1), - flags: packRejectFlags(exists, isZombie), - } -} diff --git a/channeldb/migration_01_to_11/waitingproof.go b/channeldb/migration_01_to_11/waitingproof.go deleted file mode 100644 index 64729116..00000000 --- a/channeldb/migration_01_to_11/waitingproof.go +++ /dev/null @@ -1,251 +0,0 @@ -package migration_01_to_11 - -import ( - "encoding/binary" - "sync" - - "io" - - "bytes" - - "github.com/coreos/bbolt" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" -) - -var ( - // waitingProofsBucketKey byte string name of the waiting proofs store. - waitingProofsBucketKey = []byte("waitingproofs") - - // ErrWaitingProofNotFound is returned if waiting proofs haven't been - // found by db. - ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " + - "found") - - // ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been - // found by db. - ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " + - "key already exist") -) - -// WaitingProofStore is the bold db map-like storage for half announcement -// signatures. The one responsibility of this storage is to be able to -// retrieve waiting proofs after client restart. -type WaitingProofStore struct { - // cache is used in order to reduce the number of redundant get - // calls, when object isn't stored in it. - cache map[WaitingProofKey]struct{} - db *DB - mu sync.RWMutex -} - -// NewWaitingProofStore creates new instance of proofs storage. -func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { - s := &WaitingProofStore{ - db: db, - cache: make(map[WaitingProofKey]struct{}), - } - - if err := s.ForAll(func(proof *WaitingProof) error { - s.cache[proof.Key()] = struct{}{} - return nil - }); err != nil && err != ErrWaitingProofNotFound { - return nil, err - } - - return s, nil -} - -// Add adds new waiting proof in the storage. -func (s *WaitingProofStore) Add(proof *WaitingProof) error { - s.mu.Lock() - defer s.mu.Unlock() - - err := s.db.Update(func(tx *bbolt.Tx) error { - var err error - var b bytes.Buffer - - // Get or create the bucket. - bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey) - if err != nil { - return err - } - - // Encode the objects and place it in the bucket. - if err := proof.Encode(&b); err != nil { - return err - } - - key := proof.Key() - - return bucket.Put(key[:], b.Bytes()) - }) - if err != nil { - return err - } - - // Knowing that the write succeeded, we can now update the in-memory - // cache with the proof's key. - s.cache[proof.Key()] = struct{}{} - - return nil -} - -// Remove removes the proof from storage by its key. -func (s *WaitingProofStore) Remove(key WaitingProofKey) error { - s.mu.Lock() - defer s.mu.Unlock() - - if _, ok := s.cache[key]; !ok { - return ErrWaitingProofNotFound - } - - err := s.db.Update(func(tx *bbolt.Tx) error { - // Get or create the top bucket. - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - return bucket.Delete(key[:]) - }) - if err != nil { - return err - } - - // Since the proof was successfully deleted from the store, we can now - // remove it from the in-memory cache. - delete(s.cache, key) - - return nil -} - -// ForAll iterates thought all waiting proofs and passing the waiting proof -// in the given callback. -func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { - return s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - // Iterate over objects buckets. - return bucket.ForEach(func(k, v []byte) error { - // Skip buckets fields. - if v == nil { - return nil - } - - r := bytes.NewReader(v) - proof := &WaitingProof{} - if err := proof.Decode(r); err != nil { - return err - } - - return cb(proof) - }) - }) -} - -// Get returns the object which corresponds to the given index. -func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { - proof := &WaitingProof{} - - s.mu.RLock() - defer s.mu.RUnlock() - - if _, ok := s.cache[key]; !ok { - return nil, ErrWaitingProofNotFound - } - - err := s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - // Iterate over objects buckets. - v := bucket.Get(key[:]) - if v == nil { - return ErrWaitingProofNotFound - } - - r := bytes.NewReader(v) - return proof.Decode(r) - }) - - return proof, err -} - -// WaitingProofKey is the proof key which uniquely identifies the waiting -// proof object. The goal of this key is distinguish the local and remote -// proof for the same channel id. -type WaitingProofKey [9]byte - -// WaitingProof is the storable object, which encapsulate the half proof and -// the information about from which side this proof came. This structure is -// needed to make channel proof exchange persistent, so that after client -// restart we may receive remote/local half proof and process it. -type WaitingProof struct { - *lnwire.AnnounceSignatures - isRemote bool -} - -// NewWaitingProof constructs a new waiting prof instance. -func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof { - return &WaitingProof{ - AnnounceSignatures: proof, - isRemote: isRemote, - } -} - -// OppositeKey returns the key which uniquely identifies opposite waiting proof. -func (p *WaitingProof) OppositeKey() WaitingProofKey { - var key [9]byte - binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) - - if !p.isRemote { - key[8] = 1 - } - return key -} - -// Key returns the key which uniquely identifies waiting proof. -func (p *WaitingProof) Key() WaitingProofKey { - var key [9]byte - binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) - - if p.isRemote { - key[8] = 1 - } - return key -} - -// Encode writes the internal representation of waiting proof in byte stream. -func (p *WaitingProof) Encode(w io.Writer) error { - if err := binary.Write(w, byteOrder, p.isRemote); err != nil { - return err - } - - if err := p.AnnounceSignatures.Encode(w, 0); err != nil { - return err - } - - return nil -} - -// Decode reads the data from the byte stream and initializes the -// waiting proof object with it. -func (p *WaitingProof) Decode(r io.Reader) error { - if err := binary.Read(r, byteOrder, &p.isRemote); err != nil { - return err - } - - msg := &lnwire.AnnounceSignatures{} - if err := msg.Decode(r, 0); err != nil { - return err - } - - (*p).AnnounceSignatures = msg - return nil -} diff --git a/channeldb/migration_01_to_11/waitingproof_test.go b/channeldb/migration_01_to_11/waitingproof_test.go deleted file mode 100644 index 968f1157..00000000 --- a/channeldb/migration_01_to_11/waitingproof_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package migration_01_to_11 - -import ( - "testing" - - "reflect" - - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" -) - -// TestWaitingProofStore tests add/get/remove functions of the waiting proof -// storage. -func TestWaitingProofStore(t *testing.T) { - t.Parallel() - - db, cleanup, err := makeTestDB() - if err != nil { - t.Fatalf("failed to make test database: %s", err) - } - defer cleanup() - - proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ - NodeSignature: wireSig, - BitcoinSignature: wireSig, - }) - - store, err := NewWaitingProofStore(db) - if err != nil { - t.Fatalf("unable to create the waiting proofs storage: %v", - err) - } - - if err := store.Add(proof1); err != nil { - t.Fatalf("unable add proof to storage: %v", err) - } - - proof2, err := store.Get(proof1.Key()) - if err != nil { - t.Fatalf("unable retrieve proof from storage: %v", err) - } - if !reflect.DeepEqual(proof1, proof2) { - t.Fatal("wrong proof retrieved") - } - - if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { - t.Fatalf("proof shouldn't be found: %v", err) - } - - if err := store.Remove(proof1.Key()); err != nil { - t.Fatalf("unable remove proof from storage: %v", err) - } - - if err := store.ForAll(func(proof *WaitingProof) error { - return errors.New("storage should be empty") - }); err != nil && err != ErrWaitingProofNotFound { - t.Fatal(err) - } -} diff --git a/channeldb/migration_01_to_11/witness_cache.go b/channeldb/migration_01_to_11/witness_cache.go deleted file mode 100644 index 69de1054..00000000 --- a/channeldb/migration_01_to_11/witness_cache.go +++ /dev/null @@ -1,229 +0,0 @@ -package migration_01_to_11 - -import ( - "fmt" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lntypes" -) - -var ( - // ErrNoWitnesses is an error that's returned when no new witnesses have - // been added to the WitnessCache. - ErrNoWitnesses = fmt.Errorf("no witnesses") - - // ErrUnknownWitnessType is returned if a caller attempts to - ErrUnknownWitnessType = fmt.Errorf("unknown witness type") -) - -// WitnessType is enum that denotes what "type" of witness is being -// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce -// any structure on added witnesses, we use this type to partition the -// witnesses on disk, and also to know how to map a witness to its look up key. -type WitnessType uint8 - -var ( - // Sha256HashWitness is a witness that is simply the pre image to a - // hash image. In order to map to its key, we'll use sha256. - Sha256HashWitness WitnessType = 1 -) - -// toDBKey is a helper method that maps a witness type to the key that we'll -// use to store it within the database. -func (w WitnessType) toDBKey() ([]byte, error) { - switch w { - - case Sha256HashWitness: - return []byte{byte(w)}, nil - - default: - return nil, ErrUnknownWitnessType - } -} - -var ( - // witnessBucketKey is the name of the bucket that we use to store all - // witnesses encountered. Within this bucket, we'll create a sub-bucket for - // each witness type. - witnessBucketKey = []byte("byte") -) - -// WitnessCache is a persistent cache of all witnesses we've encountered on the -// network. In the case of multi-hop, multi-step contracts, a cache of all -// witnesses can be useful in the case of partial contract resolution. If -// negotiations break down, we may be forced to locate the witness for a -// portion of the contract on-chain. In this case, we'll then add that witness -// to the cache so the incoming contract can fully resolve witness. -// Additionally, as one MUST always use a unique witness on the network, we may -// use this cache to detect duplicate witnesses. -// -// TODO(roasbeef): need expiry policy? -// * encrypt? -type WitnessCache struct { - db *DB -} - -// NewWitnessCache returns a new instance of the witness cache. -func (d *DB) NewWitnessCache() *WitnessCache { - return &WitnessCache{ - db: d, - } -} - -// witnessEntry is a key-value struct that holds each key -> witness pair, used -// when inserting records into the cache. -type witnessEntry struct { - key []byte - witness []byte -} - -// AddSha256Witnesses adds a batch of new sha256 preimages into the witness -// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the -// preimages' witness type. -func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error { - // Optimistically compute the preimages' hashes before attempting to - // start the db transaction. - entries := make([]witnessEntry, 0, len(preimages)) - for i := range preimages { - hash := preimages[i].Hash() - entries = append(entries, witnessEntry{ - key: hash[:], - witness: preimages[i][:], - }) - } - - return w.addWitnessEntries(Sha256HashWitness, entries) -} - -// addWitnessEntries inserts the witnessEntry key-value pairs into the cache, -// using the appropriate witness type to segment the namespace of possible -// witness types. -func (w *WitnessCache) addWitnessEntries(wType WitnessType, - entries []witnessEntry) error { - - // Exit early if there are no witnesses to add. - if len(entries) == 0 { - return nil - } - - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( - witnessTypeBucketKey, - ) - if err != nil { - return err - } - - for _, entry := range entries { - err = witnessTypeBucket.Put(entry.key, entry.witness) - if err != nil { - return err - } - } - - return nil - }) -} - -// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If -// the witness isn't found, ErrNoWitnesses will be returned. -func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) { - witness, err := w.lookupWitness(Sha256HashWitness, hash[:]) - if err != nil { - return lntypes.Preimage{}, err - } - - return lntypes.MakePreimage(witness) -} - -// lookupWitness attempts to lookup a witness according to its type and also -// its witness key. In the case that the witness isn't found, ErrNoWitnesses -// will be returned. -func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) { - var witness []byte - err := w.db.View(func(tx *bbolt.Tx) error { - witnessBucket := tx.Bucket(witnessBucketKey) - if witnessBucket == nil { - return ErrNoWitnesses - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey) - if witnessTypeBucket == nil { - return ErrNoWitnesses - } - - dbWitness := witnessTypeBucket.Get(witnessKey) - if dbWitness == nil { - return ErrNoWitnesses - } - - witness = make([]byte, len(dbWitness)) - copy(witness[:], dbWitness) - - return nil - }) - if err != nil { - return nil, err - } - - return witness, nil -} - -// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash. -func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error { - return w.deleteWitness(Sha256HashWitness, hash[:]) -} - -// deleteWitness attempts to delete a particular witness from the database. -func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error { - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( - witnessTypeBucketKey, - ) - if err != nil { - return err - } - - return witnessTypeBucket.Delete(witnessKey) - }) -} - -// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After -// this function return with a non-nil error, -func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error { - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - - return witnessBucket.DeleteBucket(witnessTypeBucketKey) - }) -} diff --git a/channeldb/migration_01_to_11/witness_cache_test.go b/channeldb/migration_01_to_11/witness_cache_test.go deleted file mode 100644 index 92836abe..00000000 --- a/channeldb/migration_01_to_11/witness_cache_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package migration_01_to_11 - -import ( - "crypto/sha256" - "testing" - - "github.com/lightningnetwork/lnd/lntypes" -) - -// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new -// sha256 preimages to the witness cache. -func TestWitnessCacheSha256Retrieval(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll be attempting to add then lookup two simple sha256 preimages - // within this test. - preimage1 := lntypes.Preimage(rev) - preimage2 := lntypes.Preimage(key) - - preimages := []lntypes.Preimage{preimage1, preimage2} - hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()} - - // First, we'll attempt to add the preimages to the database. - err = wCache.AddSha256Witnesses(preimages...) - if err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - // With the preimages stored, we'll now attempt to look them up. - for i, hash := range hashes { - preimage := preimages[i] - - // We should get back the *exact* same preimage as we originally - // stored. - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to look up witness: %v", err) - } - - if preimage != dbPreimage { - t.Fatalf("witnesses don't match: expected %x, got %x", - preimage[:], dbPreimage[:]) - } - } -} - -// TestWitnessCacheSha256Deletion tests that we're able to delete a single -// sha256 preimage, and also a class of witnesses from the cache. -func TestWitnessCacheSha256Deletion(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll start by adding two preimages to the cache. - preimage1 := lntypes.Preimage(key) - hash1 := preimage1.Hash() - - preimage2 := lntypes.Preimage(rev) - hash2 := preimage2.Hash() - - if err := wCache.AddSha256Witnesses(preimage1); err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - if err := wCache.AddSha256Witnesses(preimage2); err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - // We'll now delete the first preimage. If we attempt to look it up, we - // should get ErrNoWitnesses. - err = wCache.DeleteSha256Witness(hash1) - if err != nil { - t.Fatalf("unable to delete witness: %v", err) - } - _, err = wCache.LookupSha256Witness(hash1) - if err != ErrNoWitnesses { - t.Fatalf("expected ErrNoWitnesses instead got: %v", err) - } - - // Next, we'll attempt to delete the entire witness class itself. When - // we try to lookup the second preimage, we should again get - // ErrNoWitnesses. - if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil { - t.Fatalf("unable to delete witness class: %v", err) - } - _, err = wCache.LookupSha256Witness(hash2) - if err != ErrNoWitnesses { - t.Fatalf("expected ErrNoWitnesses instead got: %v", err) - } -} - -// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to -// query/add/delete an unknown witness. -func TestWitnessCacheUnknownWitness(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll attempt to add a new, undefined witness type to the database. - // We should get an error. - err = wCache.legacyAddWitnesses(234, key[:]) - if err != ErrUnknownWitnessType { - t.Fatalf("expected ErrUnknownWitnessType, got %v", err) - } -} - -// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves -// identically to the insertion via the generalized interface. -func TestAddSha256Witnesses(t *testing.T) { - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll start by adding a witnesses to the cache using the generic - // AddWitnesses method. - witness1 := rev[:] - preimage1 := lntypes.Preimage(rev) - hash1 := preimage1.Hash() - - witness2 := key[:] - preimage2 := lntypes.Preimage(key) - hash2 := preimage2.Hash() - - var ( - witnesses = [][]byte{witness1, witness2} - preimages = []lntypes.Preimage{preimage1, preimage2} - hashes = []lntypes.Hash{hash1, hash2} - ) - - err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...) - if err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - for i, hash := range hashes { - preimage := preimages[i] - - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to lookup witness: %v", err) - } - - // Assert that the retrieved witness matches the original. - if dbPreimage != preimage { - t.Fatalf("retrieved witness mismatch, want: %x, "+ - "got: %x", preimage, dbPreimage) - } - - // We'll now delete the witness, as we'll be reinserting it - // using the specialized AddSha256Witnesses method. - err = wCache.DeleteSha256Witness(hash) - if err != nil { - t.Fatalf("unable to delete witness: %v", err) - } - } - - // Now, add the same witnesses using the type-safe interface for - // lntypes.Preimages.. - err = wCache.AddSha256Witnesses(preimages...) - if err != nil { - t.Fatalf("unable to add sha256 preimage: %v", err) - } - - // Finally, iterate over the keys and assert that the returned witnesses - // match the original witnesses. This asserts that the specialized - // insertion method behaves identically to the generalized interface. - for i, hash := range hashes { - preimage := preimages[i] - - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to lookup witness: %v", err) - } - - // Assert that the retrieved witness matches the original. - if dbPreimage != preimage { - t.Fatalf("retrieved witness mismatch, want: %x, "+ - "got: %x", preimage, dbPreimage) - } - } -} - -// legacyAddWitnesses adds a batch of new witnesses of wType to the witness -// cache. The type of the witness will be used to map each witness to the key -// that will be used to look it up. All witnesses should be of the same -// WitnessType. -// -// NOTE: Previously this method exposed a generic interface for adding -// witnesses, which has since been deprecated in favor of a strongly typed -// interface for each witness class. We keep this method around to assert the -// correctness of specialized witness adding methods. -func (w *WitnessCache) legacyAddWitnesses(wType WitnessType, - witnesses ...[]byte) error { - - // Optimistically compute the witness keys before attempting to start - // the db transaction. - entries := make([]witnessEntry, 0, len(witnesses)) - for _, witness := range witnesses { - // Map each witness to its key by applying the appropriate - // transformation for the given witness type. - switch wType { - case Sha256HashWitness: - key := sha256.Sum256(witness) - entries = append(entries, witnessEntry{ - key: key[:], - witness: witness, - }) - default: - return ErrUnknownWitnessType - } - } - - return w.addWitnessEntries(wType, entries) -}