diff --git a/.golangci.yml b/.golangci.yml index ca3e6881..23b24810 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,6 +10,9 @@ run: skip-files: - "mobile\\/.*generated\\.go" + skip-dirs: + - channeldb/migration_01_to_11 + build-tags: - autopilotrpc - chainrpc diff --git a/channeldb/migration_01_to_11/README.md b/channeldb/migration_01_to_11/README.md new file mode 100644 index 00000000..7e3a81ef --- /dev/null +++ b/channeldb/migration_01_to_11/README.md @@ -0,0 +1,24 @@ +channeldb +========== + +[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE) +[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb) + +The channeldb implements the persistent storage engine for `lnd` and +generically a data storage layer for the required state within the Lightning +Network. The backing storage engine is +[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store +based off of LMDB. + +The package implements an object-oriented storage model with queries and +mutations flowing through a particular object instance rather than the database +itself. The storage implemented by the objects includes: open channels, past +commitment revocation states, the channel graph which includes authenticated +node and channel announcements, outgoing payments, and invoices + +## Installation and Updating + +```bash +$ go get -u github.com/lightningnetwork/lnd/channeldb +``` diff --git a/channeldb/migration_01_to_11/addr.go b/channeldb/migration_01_to_11/addr.go new file mode 100644 index 00000000..2e7def07 --- /dev/null +++ b/channeldb/migration_01_to_11/addr.go @@ -0,0 +1,221 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + + "github.com/lightningnetwork/lnd/tor" +) + +// addressType specifies the network protocol and version that should be used +// when connecting to a node at a particular address. +type addressType uint8 + +const ( + // tcp4Addr denotes an IPv4 TCP address. + tcp4Addr addressType = 0 + + // tcp6Addr denotes an IPv6 TCP address. + tcp6Addr addressType = 1 + + // v2OnionAddr denotes a version 2 Tor onion service address. + v2OnionAddr addressType = 2 + + // v3OnionAddr denotes a version 3 Tor (prop224) onion service address. + v3OnionAddr addressType = 3 +) + +// encodeTCPAddr serializes a TCP address into its compact raw bytes +// representation. +func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error { + var ( + addrType byte + ip []byte + ) + + if addr.IP.To4() != nil { + addrType = byte(tcp4Addr) + ip = addr.IP.To4() + } else { + addrType = byte(tcp6Addr) + ip = addr.IP.To16() + } + + if ip == nil { + return fmt.Errorf("unable to encode IP %v", addr.IP) + } + + if _, err := w.Write([]byte{addrType}); err != nil { + return err + } + + if _, err := w.Write(ip); err != nil { + return err + } + + var port [2]byte + byteOrder.PutUint16(port[:], uint16(addr.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + return nil +} + +// encodeOnionAddr serializes an onion address into its compact raw bytes +// representation. +func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error { + var suffixIndex int + hostLen := len(addr.OnionService) + switch hostLen { + case tor.V2Len: + if _, err := w.Write([]byte{byte(v2OnionAddr)}); err != nil { + return err + } + suffixIndex = tor.V2Len - tor.OnionSuffixLen + case tor.V3Len: + if _, err := w.Write([]byte{byte(v3OnionAddr)}); err != nil { + return err + } + suffixIndex = tor.V3Len - tor.OnionSuffixLen + default: + return errors.New("unknown onion service length") + } + + suffix := addr.OnionService[suffixIndex:] + if suffix != tor.OnionSuffix { + return fmt.Errorf("invalid suffix \"%v\"", suffix) + } + + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:suffixIndex], + ) + if err != nil { + return err + } + + // Sanity check the decoded length. + switch { + case hostLen == tor.V2Len && len(host) != tor.V2DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + + case hostLen == tor.V3Len && len(host) != tor.V3DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + } + + if _, err := w.Write(host); err != nil { + return err + } + + var port [2]byte + byteOrder.PutUint16(port[:], uint16(addr.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + return nil +} + +// deserializeAddr reads the serialized raw representation of an address and +// deserializes it into the actual address. This allows us to avoid address +// resolution within the channeldb package. +func deserializeAddr(r io.Reader) (net.Addr, error) { + var addrType [1]byte + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + + var address net.Addr + switch addressType(addrType[0]) { + case tcp4Addr: + var ip [4]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + case tcp6Addr: + var ip [16]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + case v2OnionAddr: + var h [tor.V2DecodedLen]byte + if _, err := r.Read(h[:]); err != nil { + return nil, err + } + + var p [2]byte + if _, err := r.Read(p[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + case v3OnionAddr: + var h [tor.V3DecodedLen]byte + if _, err := r.Read(h[:]); err != nil { + return nil, err + } + + var p [2]byte + if _, err := r.Read(p[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + default: + return nil, ErrUnknownAddressType + } + + return address, nil +} + +// serializeAddr serializes an address into its raw bytes representation so that +// it can be deserialized without requiring address resolution. +func serializeAddr(w io.Writer, address net.Addr) error { + switch addr := address.(type) { + case *net.TCPAddr: + return encodeTCPAddr(w, addr) + case *tor.OnionAddr: + return encodeOnionAddr(w, addr) + default: + return ErrUnknownAddressType + } +} diff --git a/channeldb/migration_01_to_11/addr_test.go b/channeldb/migration_01_to_11/addr_test.go new file mode 100644 index 00000000..8cdf99c3 --- /dev/null +++ b/channeldb/migration_01_to_11/addr_test.go @@ -0,0 +1,149 @@ +package migration_01_to_11 + +import ( + "bytes" + "net" + "strings" + "testing" + + "github.com/lightningnetwork/lnd/tor" +) + +type unknownAddrType struct{} + +func (t unknownAddrType) Network() string { return "unknown" } +func (t unknownAddrType) String() string { return "unknown" } + +var testIP4 = net.ParseIP("192.168.1.1") +var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") + +var addrTests = []struct { + expAddr net.Addr + serErr string +}{ + // Valid addresses. + { + expAddr: &net.TCPAddr{ + IP: testIP4, + Port: 12345, + }, + }, + { + expAddr: &net.TCPAddr{ + IP: testIP6, + Port: 65535, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "3g2upl4pq6kufc4m.onion", + Port: 9735, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", + Port: 80, + }, + }, + + // Invalid addresses. + { + expAddr: unknownAddrType{}, + serErr: ErrUnknownAddressType.Error(), + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv4 address. + IP: testIP4[:len(testIP4)-1], + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte of IPv4 address. + IP: append(testIP4, 0xff), + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv6 address. + IP: testIP6[:len(testIP6)-1], + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte to the IPv6 address. + IP: append(testIP6, 0xff), + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid suffix. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion", + Port: 80, + }, + serErr: "invalid suffix", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid length. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion", + Port: 80, + }, + serErr: "unknown onion service length", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid encoding. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion", + Port: 80, + }, + serErr: "illegal base32", + }, +} + +// TestAddrSerialization tests that the serialization method used by channeldb +// for net.Addr's works as intended. +func TestAddrSerialization(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + for _, test := range addrTests { + err := serializeAddr(&b, test.expAddr) + switch { + case err == nil && test.serErr != "": + t.Fatalf("expected serialization err for addr %v", + test.expAddr) + + case err != nil && test.serErr == "": + t.Fatalf("unexpected serialization err for addr %v: %v", + test.expAddr, err) + + case err != nil && !strings.Contains(err.Error(), test.serErr): + t.Fatalf("unexpected serialization err for addr %v, "+ + "want: %v, got %v", test.expAddr, test.serErr, + err) + + case err != nil: + continue + } + + addr, err := deserializeAddr(&b) + if err != nil { + t.Fatalf("unable to deserialize address: %v", err) + } + + if addr.String() != test.expAddr.String() { + t.Fatalf("expected address %v after serialization, "+ + "got %v", addr, test.expAddr) + } + } +} diff --git a/channeldb/migration_01_to_11/channel.go b/channeldb/migration_01_to_11/channel.go new file mode 100644 index 00000000..23d66852 --- /dev/null +++ b/channeldb/migration_01_to_11/channel.go @@ -0,0 +1,2817 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + // closedChannelBucket stores summarization information concerning + // previously open, but now closed channels. + closedChannelBucket = []byte("closed-chan-bucket") + + // openChanBucket stores all the currently open channels. This bucket + // has a second, nested bucket which is keyed by a node's ID. Within + // that node ID bucket, all attributes required to track, update, and + // close a channel are stored. + // + // openChan -> nodeID -> chanPoint + // + // TODO(roasbeef): flesh out comment + openChannelBucket = []byte("open-chan-bucket") + + // chanInfoKey can be accessed within the bucket for a channel + // (identified by its chanPoint). This key stores all the static + // information for a channel which is decided at the end of the + // funding flow. + chanInfoKey = []byte("chan-info-key") + + // chanCommitmentKey can be accessed within the sub-bucket for a + // particular channel. This key stores the up to date commitment state + // for a particular channel party. Appending a 0 to the end of this key + // indicates it's the commitment for the local party, and appending a 1 + // to the end of this key indicates it's the commitment for the remote + // party. + chanCommitmentKey = []byte("chan-commitment-key") + + // revocationStateKey stores their current revocation hash, our + // preimage producer and their preimage store. + revocationStateKey = []byte("revocation-state-key") + + // dataLossCommitPointKey stores the commitment point received from the + // remote peer during a channel sync in case we have lost channel state. + dataLossCommitPointKey = []byte("data-loss-commit-point-key") + + // closingTxKey points to a the closing tx that we broadcasted when + // moving the channel to state CommitBroadcasted. + closingTxKey = []byte("closing-tx-key") + + // commitDiffKey stores the current pending commitment state we've + // extended to the remote party (if any). Each time we propose a new + // state, we store the information necessary to reconstruct this state + // from the prior commitment. This allows us to resync the remote party + // to their expected state in the case of message loss. + // + // TODO(roasbeef): rename to commit chain? + commitDiffKey = []byte("commit-diff-key") + + // revocationLogBucket is dedicated for storing the necessary delta + // state between channel updates required to re-construct a past state + // in order to punish a counterparty attempting a non-cooperative + // channel closure. This key should be accessed from within the + // sub-bucket of a target channel, identified by its channel point. + revocationLogBucket = []byte("revocation-log-key") +) + +var ( + // ErrNoCommitmentsFound is returned when a channel has not set + // commitment states. + ErrNoCommitmentsFound = fmt.Errorf("no commitments found") + + // ErrNoChanInfoFound is returned when a particular channel does not + // have any channels state. + ErrNoChanInfoFound = fmt.Errorf("no chan info found") + + // ErrNoRevocationsFound is returned when revocation state for a + // particular channel cannot be found. + ErrNoRevocationsFound = fmt.Errorf("no revocations found") + + // ErrNoPendingCommit is returned when there is not a pending + // commitment for a remote party. A new commitment is written to disk + // each time we write a new state in order to be properly fault + // tolerant. + ErrNoPendingCommit = fmt.Errorf("no pending commits found") + + // ErrInvalidCircuitKeyLen signals that a circuit key could not be + // decoded because the byte slice is of an invalid length. + ErrInvalidCircuitKeyLen = fmt.Errorf( + "length of serialized circuit key must be 16 bytes") + + // ErrNoCommitPoint is returned when no data loss commit point is found + // in the database. + ErrNoCommitPoint = fmt.Errorf("no commit point found") + + // ErrNoCloseTx is returned when no closing tx is found for a channel + // in the state CommitBroadcasted. + ErrNoCloseTx = fmt.Errorf("no closing tx found") + + // ErrNoRestoredChannelMutation is returned when a caller attempts to + // mutate a channel that's been recovered. + ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + + "channel state") + + // ErrChanBorked is returned when a caller attempts to mutate a borked + // channel. + ErrChanBorked = fmt.Errorf("cannot mutate borked channel") +) + +// ChannelType is an enum-like type that describes one of several possible +// channel types. Each open channel is associated with a particular type as the +// channel type may determine how higher level operations are conducted such as +// fee negotiation, channel closing, the format of HTLCs, etc. +// TODO(roasbeef): split up per-chain? +type ChannelType uint8 + +const ( + // NOTE: iota isn't used here for this enum needs to be stable + // long-term as it will be persisted to the database. + + // SingleFunder represents a channel wherein one party solely funds the + // entire capacity of the channel. + SingleFunder ChannelType = 0 + + // DualFunder represents a channel wherein both parties contribute + // funds towards the total capacity of the channel. The channel may be + // funded symmetrically or asymmetrically. + DualFunder ChannelType = 1 + + // SingleFunderTweakless is similar to the basic SingleFunder channel + // type, but it omits the tweak for one's key in the commitment + // transaction of the remote party. + SingleFunderTweakless ChannelType = 2 +) + +// IsSingleFunder returns true if the channel type if one of the known single +// funder variants. +func (c ChannelType) IsSingleFunder() bool { + return c == SingleFunder || c == SingleFunderTweakless +} + +// IsTweakless returns true if the target channel uses a commitment that +// doesn't tweak the key for the remote party. +func (c ChannelType) IsTweakless() bool { + return c == SingleFunderTweakless +} + +// ChannelConstraints represents a set of constraints meant to allow a node to +// limit their exposure, enact flow control and ensure that all HTLCs are +// economically relevant. This struct will be mirrored for both sides of the +// channel, as each side will enforce various constraints that MUST be adhered +// to for the life time of the channel. The parameters for each of these +// constraints are static for the duration of the channel, meaning the channel +// must be torn down for them to change. +type ChannelConstraints struct { + // DustLimit is the threshold (in satoshis) below which any outputs + // should be trimmed. When an output is trimmed, it isn't materialized + // as an actual output, but is instead burned to miner's fees. + DustLimit btcutil.Amount + + // ChanReserve is an absolute reservation on the channel for the + // owner of this set of constraints. This means that the current + // settled balance for this node CANNOT dip below the reservation + // amount. This acts as a defense against costless attacks when + // either side no longer has any skin in the game. + ChanReserve btcutil.Amount + + // MaxPendingAmount is the maximum pending HTLC value that the + // owner of these constraints can offer the remote node at a + // particular time. + MaxPendingAmount lnwire.MilliSatoshi + + // MinHTLC is the minimum HTLC value that the owner of these + // constraints can offer the remote node. If any HTLCs below this + // amount are offered, then the HTLC will be rejected. This, in + // tandem with the dust limit allows a node to regulate the + // smallest HTLC that it deems economically relevant. + MinHTLC lnwire.MilliSatoshi + + // MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of + // this set of constraints can offer the remote node. This allows each + // node to limit their over all exposure to HTLCs that may need to be + // acted upon in the case of a unilateral channel closure or a contract + // breach. + MaxAcceptedHtlcs uint16 + + // CsvDelay is the relative time lock delay expressed in blocks. Any + // settled outputs that pay to the owner of this channel configuration + // MUST ensure that the delay branch uses this value as the relative + // time lock. Similarly, any HTLC's offered by this node should use + // this value as well. + CsvDelay uint16 +} + +// ChannelConfig is a struct that houses the various configuration opens for +// channels. Each side maintains an instance of this configuration file as it +// governs: how the funding and commitment transaction to be created, the +// nature of HTLC's allotted, the keys to be used for delivery, and relative +// time lock parameters. +type ChannelConfig struct { + // ChannelConstraints is the set of constraints that must be upheld for + // the duration of the channel for the owner of this channel + // configuration. Constraints govern a number of flow control related + // parameters, also including the smallest HTLC that will be accepted + // by a participant. + ChannelConstraints + + // MultiSigKey is the key to be used within the 2-of-2 output script + // for the owner of this channel config. + MultiSigKey keychain.KeyDescriptor + + // RevocationBasePoint is the base public key to be used when deriving + // revocation keys for the remote node's commitment transaction. This + // will be combined along with a per commitment secret to derive a + // unique revocation key for each state. + RevocationBasePoint keychain.KeyDescriptor + + // PaymentBasePoint is the base public key to be used when deriving + // the key used within the non-delayed pay-to-self output on the + // commitment transaction for a node. This will be combined with a + // tweak derived from the per-commitment point to ensure unique keys + // for each commitment transaction. + PaymentBasePoint keychain.KeyDescriptor + + // DelayBasePoint is the base public key to be used when deriving the + // key used within the delayed pay-to-self output on the commitment + // transaction for a node. This will be combined with a tweak derived + // from the per-commitment point to ensure unique keys for each + // commitment transaction. + DelayBasePoint keychain.KeyDescriptor + + // HtlcBasePoint is the base public key to be used when deriving the + // local HTLC key. The derived key (combined with the tweak derived + // from the per-commitment point) is used within the "to self" clause + // within any HTLC output scripts. + HtlcBasePoint keychain.KeyDescriptor +} + +// ChannelCommitment is a snapshot of the commitment state at a particular +// point in the commitment chain. With each state transition, a snapshot of the +// current state along with all non-settled HTLCs are recorded. These snapshots +// detail the state of the _remote_ party's commitment at a particular state +// number. For ourselves (the local node) we ONLY store our most recent +// (unrevoked) state for safety purposes. +type ChannelCommitment struct { + // CommitHeight is the update number that this ChannelDelta represents + // the total number of commitment updates to this point. This can be + // viewed as sort of a "commitment height" as this number is + // monotonically increasing. + CommitHeight uint64 + + // LocalLogIndex is the cumulative log index index of the local node at + // this point in the commitment chain. This value will be incremented + // for each _update_ added to the local update log. + LocalLogIndex uint64 + + // LocalHtlcIndex is the current local running HTLC index. This value + // will be incremented for each outgoing HTLC the local node offers. + LocalHtlcIndex uint64 + + // RemoteLogIndex is the cumulative log index index of the remote node + // at this point in the commitment chain. This value will be + // incremented for each _update_ added to the remote update log. + RemoteLogIndex uint64 + + // RemoteHtlcIndex is the current remote running HTLC index. This value + // will be incremented for each outgoing HTLC the remote node offers. + RemoteHtlcIndex uint64 + + // LocalBalance is the current available settled balance within the + // channel directly spendable by us. + LocalBalance lnwire.MilliSatoshi + + // RemoteBalance is the current available settled balance within the + // channel directly spendable by the remote node. + RemoteBalance lnwire.MilliSatoshi + + // CommitFee is the amount calculated to be paid in fees for the + // current set of commitment transactions. The fee amount is persisted + // with the channel in order to allow the fee amount to be removed and + // recalculated with each channel state update, including updates that + // happen after a system restart. + CommitFee btcutil.Amount + + // FeePerKw is the min satoshis/kilo-weight that should be paid within + // the commitment transaction for the entire duration of the channel's + // lifetime. This field may be updated during normal operation of the + // channel as on-chain conditions change. + // + // TODO(halseth): make this SatPerKWeight. Cannot be done atm because + // this will cause the import cycle lnwallet<->channeldb. Fee + // estimation stuff should be in its own package. + FeePerKw btcutil.Amount + + // CommitTx is the latest version of the commitment state, broadcast + // able by us. + CommitTx *wire.MsgTx + + // CommitSig is one half of the signature required to fully complete + // the script for the commitment transaction above. This is the + // signature signed by the remote party for our version of the + // commitment transactions. + CommitSig []byte + + // Htlcs is the set of HTLC's that are pending at this particular + // commitment height. + Htlcs []HTLC + + // TODO(roasbeef): pending commit pointer? + // * lets just walk through +} + +// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in +// the default usable state, or a state where it shouldn't be used. +type ChannelStatus uint8 + +var ( + // ChanStatusDefault is the normal state of an open channel. + ChanStatusDefault ChannelStatus + + // ChanStatusBorked indicates that the channel has entered an + // irreconcilable state, triggered by a state desynchronization or + // channel breach. Channels in this state should never be added to the + // htlc switch. + ChanStatusBorked ChannelStatus = 1 + + // ChanStatusCommitBroadcasted indicates that a commitment for this + // channel has been broadcasted. + ChanStatusCommitBroadcasted ChannelStatus = 1 << 1 + + // ChanStatusLocalDataLoss indicates that we have lost channel state + // for this channel, and broadcasting our latest commitment might be + // considered a breach. + // + // TODO(halseh): actually enforce that we are not force closing such a + // channel. + ChanStatusLocalDataLoss ChannelStatus = 1 << 2 + + // ChanStatusRestored is a status flag that signals that the channel + // has been restored, and doesn't have all the fields a typical channel + // will have. + ChanStatusRestored ChannelStatus = 1 << 3 +) + +// chanStatusStrings maps a ChannelStatus to a human friendly string that +// describes that status. +var chanStatusStrings = map[ChannelStatus]string{ + ChanStatusDefault: "ChanStatusDefault", + ChanStatusBorked: "ChanStatusBorked", + ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", + ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", + ChanStatusRestored: "ChanStatusRestored", +} + +// orderedChanStatusFlags is an in-order list of all that channel status flags. +var orderedChanStatusFlags = []ChannelStatus{ + ChanStatusDefault, + ChanStatusBorked, + ChanStatusCommitBroadcasted, + ChanStatusLocalDataLoss, + ChanStatusRestored, +} + +// String returns a human-readable representation of the ChannelStatus. +func (c ChannelStatus) String() string { + // If no flags are set, then this is the default case. + if c == 0 { + return chanStatusStrings[ChanStatusDefault] + } + + // Add individual bit flags. + statusStr := "" + for _, flag := range orderedChanStatusFlags { + if c&flag == flag { + statusStr += chanStatusStrings[flag] + "|" + c -= flag + } + } + + // Remove anything to the right of the final bar, including it as well. + statusStr = strings.TrimRight(statusStr, "|") + + // Add any remaining flags which aren't accounted for as hex. + if c != 0 { + statusStr += "|0x" + strconv.FormatUint(uint64(c), 16) + } + + // If this was purely an unknown flag, then remove the extra bar at the + // start of the string. + statusStr = strings.TrimLeft(statusStr, "|") + + return statusStr +} + +// OpenChannel encapsulates the persistent and dynamic state of an open channel +// with a remote node. An open channel supports several options for on-disk +// serialization depending on the exact context. Full (upon channel creation) +// state commitments, and partial (due to a commitment update) writes are +// supported. Each partial write due to a state update appends the new update +// to an on-disk log, which can then subsequently be queried in order to +// "time-travel" to a prior state. +type OpenChannel struct { + // ChanType denotes which type of channel this is. + ChanType ChannelType + + // ChainHash is a hash which represents the blockchain that this + // channel will be opened within. This value is typically the genesis + // hash. In the case that the original chain went through a contentious + // hard-fork, then this value will be tweaked using the unique fork + // point on each branch. + ChainHash chainhash.Hash + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identifies the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChannelID lnwire.ShortChannelID + + // IsPending indicates whether a channel's funding transaction has been + // confirmed. + IsPending bool + + // IsInitiator is a bool which indicates if we were the original + // initiator for the channel. This value may affect how higher levels + // negotiate fees, or close the channel. + IsInitiator bool + + // chanStatus is the current status of this channel. If it is not in + // the state Default, it should not be used for forwarding payments. + chanStatus ChannelStatus + + // FundingBroadcastHeight is the height in which the funding + // transaction was broadcast. This value can be used by higher level + // sub-systems to determine if a channel is stale and/or should have + // been confirmed before a certain height. + FundingBroadcastHeight uint32 + + // NumConfsRequired is the number of confirmations a channel's funding + // transaction must have received in order to be considered available + // for normal transactional use. + NumConfsRequired uint16 + + // ChannelFlags holds the flags that were sent as part of the + // open_channel message. + ChannelFlags lnwire.FundingFlag + + // IdentityPub is the identity public key of the remote node this + // channel has been established with. + IdentityPub *btcec.PublicKey + + // Capacity is the total capacity of this channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // LocalChanCfg is the channel configuration for the local node. + LocalChanCfg ChannelConfig + + // RemoteChanCfg is the channel configuration for the remote node. + RemoteChanCfg ChannelConfig + + // LocalCommitment is the current local commitment state for the local + // party. This is stored distinct from the state of the remote party + // as there are certain asymmetric parameters which affect the + // structure of each commitment. + LocalCommitment ChannelCommitment + + // RemoteCommitment is the current remote commitment state for the + // remote party. This is stored distinct from the state of the local + // party as there are certain asymmetric parameters which affect the + // structure of each commitment. + RemoteCommitment ChannelCommitment + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // RevocationProducer is used to generate the revocation in such a way + // that remote side might store it efficiently and have the ability to + // restore the revocation by index if needed. Current implementation of + // secret producer is shachain producer. + RevocationProducer shachain.Producer + + // RevocationStore is used to efficiently store the revocations for + // previous channels states sent to us by remote side. Current + // implementation of secret store is shachain store. + RevocationStore shachain.Store + + // Packager is used to create and update forwarding packages for this + // channel, which encodes all necessary information to recover from + // failures and reforward HTLCs that were not fully processed. + Packager FwdPackager + + // FundingTxn is the transaction containing this channel's funding + // outpoint. Upon restarts, this txn will be rebroadcast if the channel + // is found to be pending. + // + // NOTE: This value will only be populated for single-funder channels + // for which we are the initiator. + FundingTxn *wire.MsgTx + + // TODO(roasbeef): eww + Db *DB + + // TODO(roasbeef): just need to store local and remote HTLC's? + + sync.RWMutex +} + +// ShortChanID returns the current ShortChannelID of this channel. +func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { + c.RLock() + defer c.RUnlock() + + return c.ShortChannelID +} + +// ChanStatus returns the current ChannelStatus of this channel. +func (c *OpenChannel) ChanStatus() ChannelStatus { + c.RLock() + defer c.RUnlock() + + return c.chanStatus +} + +// ApplyChanStatus allows the caller to modify the internal channel state in a +// thead-safe manner. +func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.putChanStatus(status) +} + +// ClearChanStatus allows the caller to clear a particular channel status from +// the primary channel status bit field. After this method returns, a call to +// HasChanStatus(status) should return false. +func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.clearChanStatus(status) +} + +// HasChanStatus returns true if the internal bitfield channel status of the +// target channel has the specified status bit set. +func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { + c.RLock() + defer c.RUnlock() + + return c.hasChanStatus(status) +} + +func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { + return c.chanStatus&status == status +} + +// RefreshShortChanID updates the in-memory short channel ID using the latest +// value observed on disk. +func (c *OpenChannel) RefreshShortChanID() error { + c.Lock() + defer c.Unlock() + + var sid lnwire.ShortChannelID + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + sid = channel.ShortChannelID + + return nil + }) + if err != nil { + return err + } + + c.ShortChannelID = sid + c.Packager = NewChannelPackager(sid) + + return nil +} + +// fetchChanBucket is a helper function that returns the bucket where a +// channel's data resides in given: the public key for the node, the outpoint, +// and the chainhash that the channel resides on. +func fetchChanBucket(tx *bbolt.Tx, nodeKey *btcec.PublicKey, + outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bbolt.Bucket, error) { + + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.Bucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil +} + +// fullSync syncs the contents of an OpenChannel while re-using an existing +// database transaction. +func (c *OpenChannel) fullSync(tx *bbolt.Tx) error { + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket, err := tx.CreateBucketIfNotExists(openChannelBucket) + if err != nil { + return err + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := c.IdentityPub.SerializeCompressed() + nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) + if err != nil { + return err + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:]) + if err != nil { + return err + } + + // With the bucket for the node fetched, we can now go down another + // level, creating the bucket for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { + return err + } + chanBucket, err := chainBucket.CreateBucket( + chanPointBuf.Bytes(), + ) + switch { + case err == bbolt.ErrBucketExists: + // If this channel already exists, then in order to avoid + // overriding it, we'll return an error back up to the caller. + return ErrChanAlreadyExists + case err != nil: + return err + } + + return putOpenChannel(chanBucket, c) +} + +// MarkAsOpen marks a channel as fully open given a locator that uniquely +// describes its location within the chain. +func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { + c.Lock() + defer c.Unlock() + + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + channel.IsPending = false + channel.ShortChannelID = openLoc + + return putOpenChannel(chanBucket, channel) + }); err != nil { + return err + } + + c.IsPending = false + c.ShortChannelID = openLoc + c.Packager = NewChannelPackager(openLoc) + + return nil +} + +// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the +// passed commitPoint for use to retrieve funds in case the remote force closes +// the channel. +func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + var b bytes.Buffer + if err := WriteElement(&b, commitPoint); err != nil { + return err + } + + putCommitPoint := func(chanBucket *bbolt.Bucket) error { + return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) + } + + return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) +} + +// DataLossCommitPoint retrieves the stored commit point set during +// MarkDataLoss. If not found ErrNoCommitPoint is returned. +func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { + var commitPoint *btcec.PublicKey + + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoCommitPoint + default: + return err + } + + bs := chanBucket.Get(dataLossCommitPointKey) + if bs == nil { + return ErrNoCommitPoint + } + r := bytes.NewReader(bs) + if err := ReadElements(r, &commitPoint); err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + return commitPoint, nil +} + +// MarkBorked marks the event when the channel as reached an irreconcilable +// state, such as a channel breach or state desynchronization. Borked channels +// should never be added to the switch. +func (c *OpenChannel) MarkBorked() error { + c.Lock() + defer c.Unlock() + + return c.putChanStatus(ChanStatusBorked) +} + +// ChanSyncMsg returns the ChannelReestablish message that should be sent upon +// reconnection with the remote peer that we're maintaining this channel with. +// The information contained within this message is necessary to re-sync our +// commitment chains in the case of a last or only partially processed message. +// When the remote party receiver this message one of three things may happen: +// +// 1. We're fully synced and no messages need to be sent. +// 2. We didn't get the last CommitSig message they sent, to they'll re-send +// it. +// 3. We didn't get the last RevokeAndAck message they sent, so they'll +// re-send it. +// +// If this is a restored channel, having status ChanStatusRestored, then we'll +// modify our typical chan sync message to ensure they force close even if +// we're on the very first state. +func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { + c.Lock() + defer c.Unlock() + + // The remote commitment height that we'll send in the + // ChannelReestablish message is our current commitment height plus + // one. If the receiver thinks that our commitment height is actually + // *equal* to this value, then they'll re-send the last commitment that + // they sent but we never fully processed. + localHeight := c.LocalCommitment.CommitHeight + nextLocalCommitHeight := localHeight + 1 + + // The second value we'll send is the height of the remote commitment + // from our PoV. If the receiver thinks that their height is actually + // *one plus* this value, then they'll re-send their last revocation. + remoteChainTipHeight := c.RemoteCommitment.CommitHeight + + // If this channel has undergone a commitment update, then in order to + // prove to the remote party our knowledge of their prior commitment + // state, we'll also send over the last commitment secret that the + // remote party sent. + var lastCommitSecret [32]byte + if remoteChainTipHeight != 0 { + remoteSecret, err := c.RevocationStore.LookUp( + remoteChainTipHeight - 1, + ) + if err != nil { + return nil, err + } + lastCommitSecret = [32]byte(*remoteSecret) + } + + // Additionally, we'll send over the current unrevoked commitment on + // our local commitment transaction. + currentCommitSecret, err := c.RevocationProducer.AtIndex( + localHeight, + ) + if err != nil { + return nil, err + } + + // If we've restored this channel, then we'll purposefully give them an + // invalid LocalUnrevokedCommitPoint so they'll force close the channel + // allowing us to sweep our funds. + if c.hasChanStatus(ChanStatusRestored) { + currentCommitSecret[0] ^= 1 + + // If this is a tweakless channel, then we'll purposefully send + // a next local height taht's invalid to trigger a force close + // on their end. We do this as tweakless channels don't require + // that the commitment point is valid, only that it's present. + if c.ChanType.IsTweakless() { + nextLocalCommitHeight = 0 + } + } + + return &lnwire.ChannelReestablish{ + ChanID: lnwire.NewChanIDFromOutPoint( + &c.FundingOutpoint, + ), + NextLocalCommitHeight: nextLocalCommitHeight, + RemoteCommitTailHeight: remoteChainTipHeight, + LastRemoteCommitSecret: lastCommitSecret, + LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( + currentCommitSecret[:], + ), + }, nil +} + +// isBorked returns true if the channel has been marked as borked in the +// database. This requires an existing database transaction to already be +// active. +// +// NOTE: The primary mutex should already be held before this method is called. +func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) { + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return false, err + } + + return channel.chanStatus != ChanStatusDefault, nil +} + +// MarkCommitmentBroadcasted marks the channel as a commitment transaction has +// been broadcast, either our own or the remote, and we should watch the chain +// for it to confirm before taking any further action. It takes as argument the +// closing tx _we believe_ will appear in the chain. This is only used to +// republish this tx at startup to ensure propagation, and we should still +// handle the case where a different tx actually hits the chain. +func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error { + c.Lock() + defer c.Unlock() + + var b bytes.Buffer + if err := WriteElement(&b, closeTx); err != nil { + return err + } + + putClosingTx := func(chanBucket *bbolt.Bucket) error { + return chanBucket.Put(closingTxKey, b.Bytes()) + } + + return c.putChanStatus(ChanStatusCommitBroadcasted, putClosingTx) +} + +// BroadcastedCommitment retrieves the stored closing tx set during +// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. +func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { + var closeTx *wire.MsgTx + + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoCloseTx + default: + return err + } + + bs := chanBucket.Get(closingTxKey) + if bs == nil { + return ErrNoCloseTx + } + r := bytes.NewReader(bs) + return ReadElement(r, &closeTx) + }) + if err != nil { + return nil, err + } + + return closeTx, nil +} + +// putChanStatus appends the given status to the channel. fs is an optional +// list of closures that are given the chanBucket in order to atomically add +// extra information together with the new status. +func (c *OpenChannel) putChanStatus(status ChannelStatus, + fs ...func(*bbolt.Bucket) error) error { + + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + // Add this status to the existing bitvector found in the DB. + status = channel.chanStatus | status + channel.chanStatus = status + + if err := putOpenChannel(chanBucket, channel); err != nil { + return err + } + + for _, f := range fs { + if err := f(chanBucket); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + c.chanStatus = status + + return nil +} + +func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + // Unset this bit in the bitvector on disk. + status = channel.chanStatus & ^status + channel.chanStatus = status + + return putOpenChannel(chanBucket, channel) + }); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + c.chanStatus = status + + return nil +} + +// putChannel serializes, and stores the current state of the channel in its +// entirety. +func putOpenChannel(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + // First, we'll write out all the relatively static fields, that are + // decided upon initial channel creation. + if err := putChanInfo(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan info: %v", err) + } + + // With the static channel info written out, we'll now write out the + // current commitment state for both parties. + if err := putChanCommitments(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan commitments: %v", err) + } + + // Finally, we'll write out the revocation state for both parties + // within a distinct key space. + if err := putChanRevocationState(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan revocations: %v", err) + } + + return nil +} + +// fetchOpenChannel retrieves, and deserializes (including decrypting +// sensitive) the complete channel currently active with the passed nodeID. +func fetchOpenChannel(chanBucket *bbolt.Bucket, + chanPoint *wire.OutPoint) (*OpenChannel, error) { + + channel := &OpenChannel{ + FundingOutpoint: *chanPoint, + } + + // First, we'll read all the static information that changes less + // frequently from disk. + if err := fetchChanInfo(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan info: %v", err) + } + + // With the static information read, we'll now read the current + // commitment state for both sides of the channel. + if err := fetchChanCommitments(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan commitments: %v", err) + } + + // Finally, we'll retrieve the current revocation state so we can + // properly + if err := fetchChanRevocationState(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan revocations: %v", err) + } + + channel.Packager = NewChannelPackager(channel.ShortChannelID) + + return channel, nil +} + +// SyncPending writes the contents of the channel to the database while it's in +// the pending (waiting for funding confirmation) state. The IsPending flag +// will be set to true. When the channel's funding transaction is confirmed, +// the channel should be marked as "open" and the IsPending flag set to false. +// Note that this function also creates a LinkNode relationship between this +// newly created channel and a new LinkNode instance. This allows listing all +// channels in the database globally, or according to the LinkNode they were +// created with. +// +// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type +// that includes service bits. +func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { + c.Lock() + defer c.Unlock() + + c.FundingBroadcastHeight = pendingHeight + + return c.Db.Update(func(tx *bbolt.Tx) error { + return syncNewChannel(tx, c, []net.Addr{addr}) + }) +} + +// syncNewChannel will write the passed channel to disk, and also create a +// LinkNode (if needed) for the channel peer. +func syncNewChannel(tx *bbolt.Tx, c *OpenChannel, addrs []net.Addr) error { + // First, sync all the persistent channel state to disk. + if err := c.fullSync(tx); err != nil { + return err + } + + nodeInfoBucket, err := tx.CreateBucketIfNotExists(nodeInfoBucket) + if err != nil { + return err + } + + // If a LinkNode for this identity public key already exists, + // then we can exit early. + nodePub := c.IdentityPub.SerializeCompressed() + if nodeInfoBucket.Get(nodePub) != nil { + return nil + } + + // Next, we need to establish a (possibly) new LinkNode relationship + // for this channel. The LinkNode metadata contains reachability, + // up-time, and service bits related information. + linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) + + // TODO(roasbeef): do away with link node all together? + + return putLinkNode(nodeInfoBucket, linkNode) +} + +// UpdateCommitment updates the commitment state for the specified party +// (remote or local). The commitment stat completely describes the balance +// state at this point in the commitment chain. This method its to be called on +// two occasions: when we revoke our prior commitment state, and when the +// remote party revokes their prior commitment state. +func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state as all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + if err = putChanInfo(chanBucket, c); err != nil { + return fmt.Errorf("unable to store chan info: %v", err) + } + + // With the proper bucket fetched, we'll now write the latest + // commitment state to disk for the target party. + err = putChanCommitment( + chanBucket, newCommitment, true, + ) + if err != nil { + return fmt.Errorf("unable to store chan "+ + "revocations: %v", err) + } + + return nil + }) + if err != nil { + return err + } + + c.LocalCommitment = *newCommitment + + return nil +} + +// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are +// contained within ChannelDeltas which encode the current state of the +// commitment between state updates. +// +// TODO(roasbeef): save space by using smaller ints at tail end? +type HTLC struct { + // Signature is the signature for the second level covenant transaction + // for this HTLC. The second level transaction is a timeout tx in the + // case that this is an outgoing HTLC, and a success tx in the case + // that this is an incoming HTLC. + // + // TODO(roasbeef): make [64]byte instead? + Signature []byte + + // RHash is the payment hash of the HTLC. + RHash [32]byte + + // Amt is the amount of milli-satoshis this HTLC escrows. + Amt lnwire.MilliSatoshi + + // RefundTimeout is the absolute timeout on the HTLC that the sender + // must wait before reclaiming the funds in limbo. + RefundTimeout uint32 + + // OutputIndex is the output index for this particular HTLC output + // within the commitment transaction. + OutputIndex int32 + + // Incoming denotes whether we're the receiver or the sender of this + // HTLC. + Incoming bool + + // OnionBlob is an opaque blob which is used to complete multi-hop + // routing. + OnionBlob []byte + + // HtlcIndex is the HTLC counter index of this active, outstanding + // HTLC. This differs from the LogIndex, as the HtlcIndex is only + // incremented for each offered HTLC, while they LogIndex is + // incremented for each update (includes settle+fail). + HtlcIndex uint64 + + // LogIndex is the cumulative log index of this HTLC. This differs + // from the HtlcIndex as this will be incremented for each new log + // update added. + LogIndex uint64 +} + +// SerializeHtlcs writes out the passed set of HTLC's into the passed writer +// using the current default on-disk serialization format. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + if err := WriteElements(b, + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed +// io.Reader. The bytes within the passed reader MUST have been previously +// written to using the SerializeHtlcs function. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &htlcs[i].OnionBlob, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + } + + return htlcs, nil +} + +// Copy returns a full copy of the target HTLC. +func (h *HTLC) Copy() HTLC { + clone := HTLC{ + Incoming: h.Incoming, + Amt: h.Amt, + RefundTimeout: h.RefundTimeout, + OutputIndex: h.OutputIndex, + } + copy(clone.Signature[:], h.Signature) + copy(clone.RHash[:], h.RHash[:]) + + return clone +} + +// LogUpdate represents a pending update to the remote commitment chain. The +// log update may be an add, fail, or settle entry. We maintain this data in +// order to be able to properly retransmit our proposed +// state if necessary. +type LogUpdate struct { + // LogIndex is the log index of this proposed commitment update entry. + LogIndex uint64 + + // UpdateMsg is the update message that was included within the our + // local update log. The LogIndex value denotes the log index of this + // update which will be used when restoring our local update log if + // we're left with a dangling update on restart. + UpdateMsg lnwire.Message +} + +// Encode writes a log update to the provided io.Writer. +func (l *LogUpdate) Encode(w io.Writer) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// Decode reads a log update from the provided io.Reader. +func (l *LogUpdate) Decode(r io.Reader) error { + return ReadElements(r, &l.LogIndex, &l.UpdateMsg) +} + +// CircuitKey is used by a channel to uniquely identify the HTLCs it receives +// from the switch, and is used to purge our in-memory state of HTLCs that have +// already been processed by a link. Two list of CircuitKeys are included in +// each CommitDiff to allow a link to determine which in-memory htlcs directed +// the opening and closing of circuits in the switch's circuit map. +type CircuitKey struct { + // ChanID is the short chanid indicating the HTLC's origin. + // + // NOTE: It is fine for this value to be blank, as this indicates a + // locally-sourced payment. + ChanID lnwire.ShortChannelID + + // HtlcID is the unique htlc index predominately assigned by links, + // though can also be assigned by switch in the case of locally-sourced + // payments. + HtlcID uint64 +} + +// SetBytes deserializes the given bytes into this CircuitKey. +func (k *CircuitKey) SetBytes(bs []byte) error { + if len(bs) != 16 { + return ErrInvalidCircuitKeyLen + } + + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(bs[:8])) + k.HtlcID = binary.BigEndian.Uint64(bs[8:]) + + return nil +} + +// Bytes returns the serialized bytes for this circuit key. +func (k CircuitKey) Bytes() []byte { + var bs = make([]byte, 16) + binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(bs[8:], k.HtlcID) + return bs +} + +// Encode writes a CircuitKey to the provided io.Writer. +func (k *CircuitKey) Encode(w io.Writer) error { + var scratch [16]byte + binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(scratch[8:], k.HtlcID) + + _, err := w.Write(scratch[:]) + return err +} + +// Decode reads a CircuitKey from the provided io.Reader. +func (k *CircuitKey) Decode(r io.Reader) error { + var scratch [16]byte + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return err + } + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(scratch[:8])) + k.HtlcID = binary.BigEndian.Uint64(scratch[8:]) + + return nil +} + +// String returns a string representation of the CircuitKey. +func (k CircuitKey) String() string { + return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID) +} + +// CommitDiff represents the delta needed to apply the state transition between +// two subsequent commitment states. Given state N and state N+1, one is able +// to apply the set of messages contained within the CommitDiff to N to arrive +// at state N+1. Each time a new commitment is extended, we'll write a new +// commitment (along with the full commitment state) to disk so we can +// re-transmit the state in the case of a connection loss or message drop. +type CommitDiff struct { + // ChannelCommitment is the full commitment state that one would arrive + // at by applying the set of messages contained in the UpdateDiff to + // the prior accepted commitment. + Commitment ChannelCommitment + + // LogUpdates is the set of messages sent prior to the commitment state + // transition in question. Upon reconnection, if we detect that they + // don't have the commitment, then we re-send this along with the + // proper signature. + LogUpdates []LogUpdate + + // CommitSig is the exact CommitSig message that should be sent after + // the set of LogUpdates above has been retransmitted. The signatures + // within this message should properly cover the new commitment state + // and also the HTLC's within the new commitment state. + CommitSig *lnwire.CommitSig + + // OpenedCircuitKeys is a set of unique identifiers for any downstream + // Add packets included in this commitment txn. After a restart, this + // set of htlcs is acked from the link's incoming mailbox to ensure + // there isn't an attempt to re-add them to this commitment txn. + OpenedCircuitKeys []CircuitKey + + // ClosedCircuitKeys records the unique identifiers for any settle/fail + // packets that were resolved by this commitment txn. After a restart, + // this is used to ensure those circuits are removed from the circuit + // map, and the downstream packets in the link's mailbox are removed. + ClosedCircuitKeys []CircuitKey + + // AddAcks specifies the locations (commit height, pkg index) of any + // Adds that were failed/settled in this commit diff. This will ack + // entries in *this* channel's forwarding packages. + // + // NOTE: This value is not serialized, it is used to atomically mark the + // resolution of adds, such that they will not be reprocessed after a + // restart. + AddAcks []AddRef + + // SettleFailAcks specifies the locations (chan id, commit height, pkg + // index) of any Settles or Fails that were locked into this commit + // diff, and originate from *another* channel, i.e. the outgoing link. + // + // NOTE: This value is not serialized, it is used to atomically acks + // settles and fails from the forwarding packages of other channels, + // such that they will not be reforwarded internally after a restart. + SettleFailAcks []SettleFailRef +} + +func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { + if err := serializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := diff.CommitSig.Encode(w, 0); err != nil { + return err + } + + numUpdates := uint16(len(diff.LogUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range diff.LogUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + return nil +} + +func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { + var ( + d CommitDiff + err error + ) + + d.Commitment, err = deserializeChanCommit(r) + if err != nil { + return nil, err + } + + d.CommitSig = &lnwire.CommitSig{} + if err := d.CommitSig.Decode(r, 0); err != nil { + return nil, err + } + + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + d.LogUpdates = make([]LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + return &d, nil +} + +// AppendRemoteCommitChain appends a new CommitDiff to the end of the +// commitment chain for the remote party. This method is to be used once we +// have prepared a new commitment state for the remote party, but before we +// transmit it to the remote party. The contents of the argument should be +// sufficient to retransmit the updates and signature needed to reconstruct the +// state in full, in the case that we need to retransmit. +func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + return c.Db.Update(func(tx *bbolt.Tx) error { + // First, we'll grab the writable bucket where this channel's + // data resides. + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Any outgoing settles and fails necessarily have a + // corresponding adds in this channel's forwarding packages. + // Mark all of these as being fully processed in our forwarding + // package, which prevents us from reprocessing them after + // startup. + err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) + if err != nil { + return err + } + + // Additionally, we ack from any fails or settles that are + // persisted in another channel's forwarding package. This + // prevents the same fails and settles from being retransmitted + // after restarts. The actual fail or settle we need to + // propagate to the remote party is now in the commit diff. + err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) + if err != nil { + return err + } + + // TODO(roasbeef): use seqno to derive key for later LCP + + // With the bucket retrieved, we'll now serialize the commit + // diff itself, and write it to disk. + var b bytes.Buffer + if err := serializeCommitDiff(&b, diff); err != nil { + return err + } + return chanBucket.Put(commitDiffKey, b.Bytes()) + }) +} + +// RemoteCommitChainTip returns the "tip" of the current remote commitment +// chain. This value will be non-nil iff, we've created a new commitment for +// the remote party that they haven't yet ACK'd. In this case, their commitment +// chain will have a length of two: their current unrevoked commitment, and +// this new pending commitment. Once they revoked their prior state, we'll swap +// these pointers, causing the tip and the tail to point to the same entry. +func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { + var cd *CommitDiff + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoPendingCommit + default: + return err + } + + tipBytes := chanBucket.Get(commitDiffKey) + if tipBytes == nil { + return ErrNoPendingCommit + } + + tipReader := bytes.NewReader(tipBytes) + dcd, err := deserializeCommitDiff(tipReader) + if err != nil { + return err + } + + cd = dcd + return nil + }) + if err != nil { + return nil, err + } + + return cd, err +} + +// InsertNextRevocation inserts the _next_ commitment point (revocation) into +// the database, and also modifies the internal RemoteNextRevocation attribute +// to point to the passed key. This method is to be using during final channel +// set up, _after_ the channel has been fully confirmed. +// +// NOTE: If this method isn't called, then the target channel won't be able to +// propose new states for the commitment state of the remote party. +func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + c.RemoteNextRevocation = revKey + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return putChanRevocationState(chanBucket, c) + }) + if err != nil { + return err + } + + return nil +} + +// AdvanceCommitChainTail records the new state transition within an on-disk +// append-only log which records all state transitions by the remote peer. In +// the case of an uncooperative broadcast of a prior state by the remote peer, +// this log can be consulted in order to reconstruct the state needed to +// rectify the situation. This method will add the current commitment for the +// remote party to the revocation log, and promote the current pending +// commitment to the current remote commitment. +func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + var newRemoteCommit *ChannelCommitment + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Persist the latest preimage state to disk as the remote peer + // has just added to our local preimage store, and given us a + // new pending revocation key. + if err := putChanRevocationState(chanBucket, c); err != nil { + return err + } + + // With the current preimage producer/store state updated, + // append a new log entry recording this the delta of this + // state transition. + // + // TODO(roasbeef): could make the deltas relative, would save + // space, but then tradeoff for more disk-seeks to recover the + // full state. + logKey := revocationLogBucket + logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) + if err != nil { + return err + } + + // Before we append this revoked state to the revocation log, + // we'll swap out what's currently the tail of the commit tip, + // with the current locked-in commitment for the remote party. + tipBytes := chanBucket.Get(commitDiffKey) + tipReader := bytes.NewReader(tipBytes) + newCommit, err := deserializeCommitDiff(tipReader) + if err != nil { + return err + } + err = putChanCommitment( + chanBucket, &newCommit.Commitment, false, + ) + if err != nil { + return err + } + if err := chanBucket.Delete(commitDiffKey); err != nil { + return err + } + + // With the commitment pointer swapped, we can now add the + // revoked (prior) state to the revocation log. + // + // TODO(roasbeef): store less + err = appendChannelLogEntry(logBucket, &c.RemoteCommitment) + if err != nil { + return err + } + + // Lastly, we write the forwarding package to disk so that we + // can properly recover from failures and reforward HTLCs that + // have not received a corresponding settle/fail. + if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { + return err + } + + newRemoteCommit = &newCommit.Commitment + + return nil + }) + if err != nil { + return err + } + + // With the db transaction complete, we'll swap over the in-memory + // pointer of the new remote commitment, which was previously the tip + // of the commit chain. + c.RemoteCommitment = *newRemoteCommit + + return nil +} + +// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure +// this always returns the next index that has been not been allocated, this +// will first try to examine any pending commitments, before falling back to the +// last locked-in local commitment. +func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { + // First, load the most recent commit diff that we initiated for the + // remote party. If no pending commit is found, this is not treated as + // a critical error, since we can always fall back. + pendingRemoteCommit, err := c.RemoteCommitChainTip() + if err != nil && err != ErrNoPendingCommit { + return 0, err + } + + // If a pending commit was found, its local htlc index will be at least + // as large as the one on our local commitment. + if pendingRemoteCommit != nil { + return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil + } + + // Otherwise, fallback to using the local htlc index of our commitment. + return c.LocalCommitment.LocalHtlcIndex, nil +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + + var fwdPkgs []*FwdPkg + if err := c.Db.View(func(tx *bbolt.Tx) error { + var err error + fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) + return err + }); err != nil { + return nil, err + } + + return fwdPkgs, nil +} + +// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs +// indicating that a response to this Add has been committed to the remote party. +// Doing so will prevent these Add HTLCs from being reforwarded internally. +func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.AckAddHtlcs(tx, addRefs...) + }) +} + +// AckSettleFails updates the SettleFailFilter containing any of the provided +// SettleFailRefs, indicating that the response has been delivered to the +// incoming link, corresponding to a particular AddRef. Doing so will prevent +// the responses from being retransmitted internally. +func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.AckSettleFails(tx, settleFailRefs...) + }) +} + +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.SetFwdFilter(tx, height, fwdFilter) + }) +} + +// RemoveFwdPkg atomically removes a forwarding package specified by the remote +// commitment height. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.RemovePkg(tx, height) + }) +} + +// RevocationLogTail returns the "tail", or the end of the current revocation +// log. This entry represents the last previous state for the remote node's +// commitment chain. The ChannelDelta returned by this method will always lag +// one state behind the most current (unrevoked) state of the remote node's +// commitment chain. +func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { + c.RLock() + defer c.RUnlock() + + // If we haven't created any state updates yet, then we'll exit early as + // there's nothing to be found on disk in the revocation bucket. + if c.RemoteCommitment.CommitHeight == 0 { + return nil, nil + } + + var commit ChannelCommitment + if err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket == nil { + return ErrNoPastDeltas + } + + // Once we have the bucket that stores the revocation log from + // this channel, we'll jump to the _last_ key in bucket. As we + // store the update number on disk in a big-endian format, + // this will retrieve the latest entry. + cursor := logBucket.Cursor() + _, tailLogEntry := cursor.Last() + logEntryReader := bytes.NewReader(tailLogEntry) + + // Once we have the entry, we'll decode it into the channel + // delta pointer we created above. + var dbErr error + commit, dbErr = deserializeChanCommit(logEntryReader) + if dbErr != nil { + return dbErr + } + + return nil + }); err != nil { + return nil, err + } + + return &commit, nil +} + +// CommitmentHeight returns the current commitment height. The commitment +// height represents the number of updates to the commitment state to date. +// This value is always monotonically increasing. This method is provided in +// order to allow multiple instances of a particular open channel to obtain a +// consistent view of the number of channel updates to date. +func (c *OpenChannel) CommitmentHeight() (uint64, error) { + c.RLock() + defer c.RUnlock() + + var height uint64 + err := c.Db.View(func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + commit, err := fetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + + height = commit.CommitHeight + return nil + }) + if err != nil { + return 0, err + } + + return height, nil +} + +// FindPreviousState scans through the append-only log in an attempt to recover +// the previous channel state indicated by the update number. This method is +// intended to be used for obtaining the relevant data needed to claim all +// funds rightfully spendable in the case of an on-chain broadcast of the +// commitment transaction. +func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) { + c.RLock() + defer c.RUnlock() + + var commit ChannelCommitment + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket == nil { + return ErrNoPastDeltas + } + + c, err := fetchChannelLogEntry(logBucket, updateNum) + if err != nil { + return err + } + + commit = c + return nil + }) + if err != nil { + return nil, err + } + + return &commit, nil +} + +// ClosureType is an enum like structure that details exactly _how_ a channel +// was closed. Three closure types are currently possible: none, cooperative, +// local force close, remote force close, and (remote) breach. +type ClosureType uint8 + +const ( + // CooperativeClose indicates that a channel has been closed + // cooperatively. This means that both channel peers were online and + // signed a new transaction paying out the settled balance of the + // contract. + CooperativeClose ClosureType = 0 + + // LocalForceClose indicates that we have unilaterally broadcast our + // current commitment state on-chain. + LocalForceClose ClosureType = 1 + + // RemoteForceClose indicates that the remote peer has unilaterally + // broadcast their current commitment state on-chain. + RemoteForceClose ClosureType = 4 + + // BreachClose indicates that the remote peer attempted to broadcast a + // prior _revoked_ channel state. + BreachClose ClosureType = 2 + + // FundingCanceled indicates that the channel never was fully opened + // before it was marked as closed in the database. This can happen if + // we or the remote fail at some point during the opening workflow, or + // we timeout waiting for the funding transaction to be confirmed. + FundingCanceled ClosureType = 3 + + // Abandoned indicates that the channel state was removed without + // any further actions. This is intended to clean up unusable + // channels during development. + Abandoned ClosureType = 5 +) + +// ChannelCloseSummary contains the final state of a channel at the point it +// was closed. Once a channel is closed, all the information pertaining to that +// channel within the openChannelBucket is deleted, and a compact summary is +// put in place instead. +type ChannelCloseSummary struct { + // ChanPoint is the outpoint for this channel's funding transaction, + // and is used as a unique identifier for the channel. + ChanPoint wire.OutPoint + + // ShortChanID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChanID lnwire.ShortChannelID + + // ChainHash is the hash of the genesis block that this channel resides + // within. + ChainHash chainhash.Hash + + // ClosingTXID is the txid of the transaction which ultimately closed + // this channel. + ClosingTXID chainhash.Hash + + // RemotePub is the public key of the remote peer that we formerly had + // a channel with. + RemotePub *btcec.PublicKey + + // Capacity was the total capacity of the channel. + Capacity btcutil.Amount + + // CloseHeight is the height at which the funding transaction was + // spent. + CloseHeight uint32 + + // SettledBalance is our total balance settled balance at the time of + // channel closure. This _does not_ include the sum of any outputs that + // have been time-locked as a result of the unilateral channel closure. + SettledBalance btcutil.Amount + + // TimeLockedBalance is the sum of all the time-locked outputs at the + // time of channel closure. If we triggered the force closure of this + // channel, then this value will be non-zero if our settled output is + // above the dust limit. If we were on the receiving side of a channel + // force closure, then this value will be non-zero if we had any + // outstanding outgoing HTLC's at the time of channel closure. + TimeLockedBalance btcutil.Amount + + // CloseType details exactly _how_ the channel was closed. Five closure + // types are possible: cooperative, local force, remote force, breach + // and funding canceled. + CloseType ClosureType + + // IsPending indicates whether this channel is in the 'pending close' + // state, which means the channel closing transaction has been + // confirmed, but not yet been fully resolved. In the case of a channel + // that has been cooperatively closed, it will go straight into the + // fully resolved state as soon as the closing transaction has been + // confirmed. However, for channels that have been force closed, they'll + // stay marked as "pending" until _all_ the pending funds have been + // swept. + IsPending bool + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this is the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // LocalChanCfg is the channel configuration for the local node. + LocalChanConfig ChannelConfig + + // LastChanSyncMsg is the ChannelReestablish message for this channel + // for the state at the point where it was closed. + LastChanSyncMsg *lnwire.ChannelReestablish +} + +// CloseChannel closes a previously active Lightning channel. Closing a channel +// entails deleting all saved state within the database concerning this +// channel. This method also takes a struct that summarizes the state of the +// channel at closing, this compact representation will be the only component +// of a channel left over after a full closing. +func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoChanDBExists + } + + nodePub := c.IdentityPub.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return ErrNoActiveChannels + } + + chainBucket := nodeChanBucket.Bucket(c.ChainHash[:]) + if chainBucket == nil { + return ErrNoActiveChannels + } + + var chanPointBuf bytes.Buffer + err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) + if err != nil { + return err + } + chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return ErrNoActiveChannels + } + + // Before we delete the channel state, we'll read out the full + // details, as we'll also store portions of this information + // for record keeping. + chanState, err := fetchOpenChannel( + chanBucket, &c.FundingOutpoint, + ) + if err != nil { + return err + } + + // Now that the index to this channel has been deleted, purge + // the remaining channel metadata from the database. + err = deleteOpenChannel(chanBucket, chanPointBuf.Bytes()) + if err != nil { + return err + } + + // With the base channel data deleted, attempt to delete the + // information stored within the revocation log. + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket != nil { + err = chanBucket.DeleteBucket(revocationLogBucket) + if err != nil { + return err + } + } + + err = chainBucket.DeleteBucket(chanPointBuf.Bytes()) + if err != nil { + return err + } + + // Finally, create a summary of this channel in the closed + // channel bucket for this node. + return putChannelCloseSummary( + tx, chanPointBuf.Bytes(), summary, chanState, + ) + }) +} + +// ChannelSnapshot is a frozen snapshot of the current channel state. A +// snapshot is detached from the original channel that generated it, providing +// read-only access to the current or prior state of an active channel. +// +// TODO(roasbeef): remove all together? pretty much just commitment +type ChannelSnapshot struct { + // RemoteIdentity is the identity public key of the remote node that we + // are maintaining the open channel with. + RemoteIdentity btcec.PublicKey + + // ChanPoint is the outpoint that created the channel. This output is + // found within the funding transaction and uniquely identified the + // channel on the resident chain. + ChannelPoint wire.OutPoint + + // ChainHash is the genesis hash of the chain that the channel resides + // within. + ChainHash chainhash.Hash + + // Capacity is the total capacity of the channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // ChannelCommitment is the current up-to-date commitment for the + // target channel. + ChannelCommitment +} + +// Snapshot returns a read-only snapshot of the current channel state. This +// snapshot includes information concerning the current settled balance within +// the channel, metadata detailing total flows, and any outstanding HTLCs. +func (c *OpenChannel) Snapshot() *ChannelSnapshot { + c.RLock() + defer c.RUnlock() + + localCommit := c.LocalCommitment + snapshot := &ChannelSnapshot{ + RemoteIdentity: *c.IdentityPub, + ChannelPoint: c.FundingOutpoint, + Capacity: c.Capacity, + TotalMSatSent: c.TotalMSatSent, + TotalMSatReceived: c.TotalMSatReceived, + ChainHash: c.ChainHash, + ChannelCommitment: ChannelCommitment{ + LocalBalance: localCommit.LocalBalance, + RemoteBalance: localCommit.RemoteBalance, + CommitHeight: localCommit.CommitHeight, + CommitFee: localCommit.CommitFee, + }, + } + + // Copy over the current set of HTLCs to ensure the caller can't mutate + // our internal state. + snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) + for i, h := range localCommit.Htlcs { + snapshot.Htlcs[i] = h.Copy() + } + + return snapshot +} + +// LatestCommitments returns the two latest commitments for both the local and +// remote party. These commitments are read from disk to ensure that only the +// latest fully committed state is returned. The first commitment returned is +// the local commitment, and the second returned is the remote commitment. +func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return fetchChanCommitments(chanBucket, c) + }) + if err != nil { + return nil, nil, err + } + + return &c.LocalCommitment, &c.RemoteCommitment, nil +} + +// RemoteRevocationStore returns the most up to date commitment version of the +// revocation storage tree for the remote party. This method can be used when +// acting on a possible contract breach to ensure, that the caller has the most +// up to date information required to deliver justice. +func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return fetchChanRevocationState(chanBucket, c) + }) + if err != nil { + return nil, err + } + + return c.RevocationStore, nil +} + +func putChannelCloseSummary(tx *bbolt.Tx, chanID []byte, + summary *ChannelCloseSummary, lastChanState *OpenChannel) error { + + closedChanBucket, err := tx.CreateBucketIfNotExists(closedChannelBucket) + if err != nil { + return err + } + + summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation + summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation + summary.LocalChanConfig = lastChanState.LocalChanCfg + + var b bytes.Buffer + if err := serializeChannelCloseSummary(&b, summary); err != nil { + return err + } + + return closedChanBucket.Put(chanID, b.Bytes()) +} + +func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolen indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { + return err + } + } + + return nil +} + +func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { + c := &ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if err == io.EOF { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} + +func writeChanConfig(b io.Writer, c *ChannelConfig) error { + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + var w bytes.Buffer + if err := WriteElements(&w, + channel.ChanType, channel.ChainHash, channel.FundingOutpoint, + channel.ShortChannelID, channel.IsPending, channel.IsInitiator, + channel.chanStatus, channel.FundingBroadcastHeight, + channel.NumConfsRequired, channel.ChannelFlags, + channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, + channel.TotalMSatReceived, + ); err != nil { + return err + } + + // For single funder channels that we initiated, write the funding txn. + if channel.ChanType.IsSingleFunder() && channel.IsInitiator && + !channel.hasChanStatus(ChanStatusRestored) { + + if err := WriteElement(&w, channel.FundingTxn); err != nil { + return err + } + } + + if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil { + return err + } + if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil { + return err + } + + return chanBucket.Put(chanInfoKey, w.Bytes()) +} + +func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return SerializeHtlcs(w, c.Htlcs...) +} + +func putChanCommitment(chanBucket *bbolt.Bucket, c *ChannelCommitment, + local bool) error { + + var commitKey []byte + if local { + commitKey = append(chanCommitmentKey, byte(0x00)) + } else { + commitKey = append(chanCommitmentKey, byte(0x01)) + } + + var b bytes.Buffer + if err := serializeChanCommit(&b, c); err != nil { + return err + } + + return chanBucket.Put(commitKey, b.Bytes()) +} + +func putChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + // If this is a restored channel, then we don't have any commitments to + // write. + if channel.hasChanStatus(ChanStatusRestored) { + return nil + } + + err := putChanCommitment( + chanBucket, &channel.LocalCommitment, true, + ) + if err != nil { + return err + } + + return putChanCommitment( + chanBucket, &channel.RemoteCommitment, false, + ) +} + +func putChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + + var b bytes.Buffer + err := WriteElements( + &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, + channel.RevocationStore, + ) + if err != nil { + return err + } + + // TODO(roasbeef): don't keep producer on disk + + // If the next revocation is present, which is only the case after the + // FundingLocked message has been sent, then we'll write it to disk. + if channel.RemoteNextRevocation != nil { + err = WriteElements(&b, channel.RemoteNextRevocation) + if err != nil { + return err + } + } + + return chanBucket.Put(revocationStateKey, b.Bytes()) +} + +func readChanConfig(b io.Reader, c *ChannelConfig) error { + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + infoBytes := chanBucket.Get(chanInfoKey) + if infoBytes == nil { + return ErrNoChanInfoFound + } + r := bytes.NewReader(infoBytes) + + if err := ReadElements(r, + &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, + &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, + &channel.chanStatus, &channel.FundingBroadcastHeight, + &channel.NumConfsRequired, &channel.ChannelFlags, + &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, + &channel.TotalMSatReceived, + ); err != nil { + return err + } + + // For single funder channels that we initiated, read the funding txn. + if channel.ChanType.IsSingleFunder() && channel.IsInitiator && + !channel.hasChanStatus(ChanStatusRestored) { + + if err := ReadElement(r, &channel.FundingTxn); err != nil { + return err + } + } + + if err := readChanConfig(r, &channel.LocalChanCfg); err != nil { + return err + } + if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil { + return err + } + + channel.Packager = NewChannelPackager(channel.ShortChannelID) + + return nil +} + +func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { + var c ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, + &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, + &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = DeserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func fetchChanCommitment(chanBucket *bbolt.Bucket, local bool) (ChannelCommitment, error) { + var commitKey []byte + if local { + commitKey = append(chanCommitmentKey, byte(0x00)) + } else { + commitKey = append(chanCommitmentKey, byte(0x01)) + } + + commitBytes := chanBucket.Get(commitKey) + if commitBytes == nil { + return ChannelCommitment{}, ErrNoCommitmentsFound + } + + r := bytes.NewReader(commitBytes) + return deserializeChanCommit(r) +} + +func fetchChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + var err error + + // If this is a restored channel, then we don't have any commitments to + // read. + if channel.hasChanStatus(ChanStatusRestored) { + return nil + } + + channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false) + if err != nil { + return err + } + + return nil +} + +func fetchChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + revBytes := chanBucket.Get(revocationStateKey) + if revBytes == nil { + return ErrNoRevocationsFound + } + r := bytes.NewReader(revBytes) + + err := ReadElements( + r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, + &channel.RevocationStore, + ) + if err != nil { + return err + } + + // If there aren't any bytes left in the buffer, then we don't yet have + // the next remote revocation, so we can exit early here. + if r.Len() == 0 { + return nil + } + + // Otherwise we'll read the next revocation for the remote party which + // is always the last item within the buffer. + return ReadElements(r, &channel.RemoteNextRevocation) +} + +func deleteOpenChannel(chanBucket *bbolt.Bucket, chanPointBytes []byte) error { + + if err := chanBucket.Delete(chanInfoKey); err != nil { + return err + } + + err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00))) + if err != nil { + return err + } + err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01))) + if err != nil { + return err + } + + if err := chanBucket.Delete(revocationStateKey); err != nil { + return err + } + + if diff := chanBucket.Get(commitDiffKey); diff != nil { + return chanBucket.Delete(commitDiffKey) + } + + return nil + +} + +// makeLogKey converts a uint64 into an 8 byte array. +func makeLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +func appendChannelLogEntry(log *bbolt.Bucket, + commit *ChannelCommitment) error { + + var b bytes.Buffer + if err := serializeChanCommit(&b, commit); err != nil { + return err + } + + logEntrykey := makeLogKey(commit.CommitHeight) + return log.Put(logEntrykey[:], b.Bytes()) +} + +func fetchChannelLogEntry(log *bbolt.Bucket, + updateNum uint64) (ChannelCommitment, error) { + + logEntrykey := makeLogKey(updateNum) + commitBytes := log.Get(logEntrykey[:]) + if commitBytes == nil { + return ChannelCommitment{}, fmt.Errorf("log entry not found") + } + + commitReader := bytes.NewReader(commitBytes) + return deserializeChanCommit(commitReader) +} diff --git a/channeldb/migration_01_to_11/channel_cache.go b/channeldb/migration_01_to_11/channel_cache.go new file mode 100644 index 00000000..5d391e00 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_cache.go @@ -0,0 +1,50 @@ +package migration_01_to_11 + +// channelCache is an in-memory cache used to improve the performance of +// ChanUpdatesInHorizon. It caches the chan info and edge policies for a +// particular channel. +type channelCache struct { + n int + channels map[uint64]ChannelEdge +} + +// newChannelCache creates a new channelCache with maximum capacity of n +// channels. +func newChannelCache(n int) *channelCache { + return &channelCache{ + n: n, + channels: make(map[uint64]ChannelEdge), + } +} + +// get returns the channel from the cache, if it exists. +func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) { + channel, ok := c.channels[chanid] + return channel, ok +} + +// insert adds the entry to the channel cache. If an entry for chanid already +// exists, it will be replaced with the new entry. If the entry doesn't exist, +// it will be inserted to the cache, performing a random eviction if the cache +// is at capacity. +func (c *channelCache) insert(chanid uint64, channel ChannelEdge) { + // If entry exists, replace it. + if _, ok := c.channels[chanid]; ok { + c.channels[chanid] = channel + return + } + + // Otherwise, evict an entry at random and insert. + if len(c.channels) == c.n { + for id := range c.channels { + delete(c.channels, id) + break + } + } + c.channels[chanid] = channel +} + +// remove deletes an edge for chanid from the cache, if it exists. +func (c *channelCache) remove(chanid uint64) { + delete(c.channels, chanid) +} diff --git a/channeldb/migration_01_to_11/channel_cache_test.go b/channeldb/migration_01_to_11/channel_cache_test.go new file mode 100644 index 00000000..b2929635 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_cache_test.go @@ -0,0 +1,105 @@ +package migration_01_to_11 + +import ( + "reflect" + "testing" +) + +// TestChannelCache checks the behavior of the channelCache with respect to +// insertion, eviction, and removal of cache entries. +func TestChannelCache(t *testing.T) { + const cacheSize = 100 + + // Create a new channel cache with the configured max size. + c := newChannelCache(cacheSize) + + // As a sanity check, assert that querying the empty cache does not + // return an entry. + _, ok := c.get(0) + if ok { + t.Fatalf("channel cache should be empty") + } + + // Now, fill up the cache entirely. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, channelForInt(i)) + } + + // Assert that the cache has all of the entries just inserted, since no + // eviction should occur until we try to surpass the max size. + assertHasChanEntries(t, c, 0, cacheSize) + + // Now, insert a new element that causes the cache to evict an element. + c.insert(cacheSize, channelForInt(cacheSize)) + + // Assert that the cache has this last entry, as the cache should evict + // some prior element and not the newly inserted one. + assertHasChanEntries(t, c, cacheSize, cacheSize) + + // Iterate over all inserted elements and construct a set of the evicted + // elements. + evicted := make(map[uint64]struct{}) + for i := uint64(0); i < cacheSize+1; i++ { + _, ok := c.get(i) + if !ok { + evicted[i] = struct{}{} + } + } + + // Assert that exactly one element has been evicted. + numEvicted := len(evicted) + if numEvicted != 1 { + t.Fatalf("expected one evicted entry, got: %d", numEvicted) + } + + // Remove the highest item which initially caused the eviction and + // reinsert the element that was evicted prior. + c.remove(cacheSize) + for i := range evicted { + c.insert(i, channelForInt(i)) + } + + // Since the removal created an extra slot, the last insertion should + // not have caused an eviction and the entries for all channels in the + // original set that filled the cache should be present. + assertHasChanEntries(t, c, 0, cacheSize) + + // Finally, reinsert the existing set back into the cache and test that + // the cache still has all the entries. If the randomized eviction were + // happening on inserts for existing cache items, we expect this to fail + // with high probability. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, channelForInt(i)) + } + assertHasChanEntries(t, c, 0, cacheSize) + +} + +// assertHasEntries queries the edge cache for all channels in the range [start, +// end), asserting that they exist and their value matches the entry produced by +// entryForInt. +func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { + t.Helper() + + for i := start; i < end; i++ { + entry, ok := c.get(i) + if !ok { + t.Fatalf("channel cache should contain chan %d", i) + } + + expEntry := channelForInt(i) + if !reflect.DeepEqual(entry, expEntry) { + t.Fatalf("entry mismatch, want: %v, got: %v", + expEntry, entry) + } + } +} + +// channelForInt generates a unique ChannelEdge given an integer. +func channelForInt(i uint64) ChannelEdge { + return ChannelEdge{ + Info: &ChannelEdgeInfo{ + ChannelID: i, + }, + } +} diff --git a/channeldb/migration_01_to_11/channel_test.go b/channeldb/migration_01_to_11/channel_test.go new file mode 100644 index 00000000..53fb39d7 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_test.go @@ -0,0 +1,1041 @@ +package migration_01_to_11 + +import ( + "bytes" + "io/ioutil" + "math/rand" + "net" + "os" + "reflect" + "runtime" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + _ "github.com/btcsuite/btcwallet/walletdb/bdb" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + key = [chainhash.HashSize]byte{ + 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, + 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, + } + rev = [chainhash.HashSize]byte{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + } + testTx = &wire.MsgTx{ + Version: 1, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{}, + Index: 0xffffffff, + }, + SignatureScript: []byte{0x04, 0x31, 0xdc, 0x00, 0x1b, 0x01, 0x62}, + Sequence: 0xffffffff, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 5000000000, + PkScript: []byte{ + 0x41, // OP_DATA_65 + 0x04, 0xd6, 0x4b, 0xdf, 0xd0, 0x9e, 0xb1, 0xc5, + 0xfe, 0x29, 0x5a, 0xbd, 0xeb, 0x1d, 0xca, 0x42, + 0x81, 0xbe, 0x98, 0x8e, 0x2d, 0xa0, 0xb6, 0xc1, + 0xc6, 0xa5, 0x9d, 0xc2, 0x26, 0xc2, 0x86, 0x24, + 0xe1, 0x81, 0x75, 0xe8, 0x51, 0xc9, 0x6b, 0x97, + 0x3d, 0x81, 0xb0, 0x1c, 0xc3, 0x1f, 0x04, 0x78, + 0x34, 0xbc, 0x06, 0xd6, 0xd6, 0xed, 0xf6, 0x20, + 0xd1, 0x84, 0x24, 0x1a, 0x6a, 0xed, 0x8b, 0x63, + 0xa6, // 65-byte signature + 0xac, // OP_CHECKSIG + }, + }, + }, + LockTime: 5, + } + privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + + wireSig, _ = lnwire.NewSigFromSignature(testSig) +) + +// makeTestDB creates a new instance of the ChannelDB for testing purposes. A +// callback which cleans up the created temporary directories is also returned +// and intended to be executed after the test completes. +func makeTestDB() (*DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } + + // Next, create channeldb for the first time. + cdb, err := Open(tempDirName) + if err != nil { + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} + +func createTestChannelState(cdb *DB) (*OpenChannel, error) { + // Simulate 1000 channel updates. + producer, err := shachain.NewRevocationProducerFromBytes(key[:]) + if err != nil { + return nil, err + } + store := shachain.NewRevocationStore() + for i := 0; i < 1; i++ { + preImage, err := producer.AtIndex(uint64(i)) + if err != nil { + return nil, err + } + + if err := store.AddNextEntry(preImage); err != nil { + return nil, err + } + } + + localCfg := ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + DustLimit: btcutil.Amount(rand.Int63()), + MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), + ChanReserve: btcutil.Amount(rand.Int63()), + MinHTLC: lnwire.MilliSatoshi(rand.Int63()), + MaxAcceptedHtlcs: uint16(rand.Int31()), + CsvDelay: uint16(rand.Int31()), + }, + MultiSigKey: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + RevocationBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + DelayBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + HtlcBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + } + remoteCfg := ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + DustLimit: btcutil.Amount(rand.Int63()), + MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), + ChanReserve: btcutil.Amount(rand.Int63()), + MinHTLC: lnwire.MilliSatoshi(rand.Int63()), + MaxAcceptedHtlcs: uint16(rand.Int31()), + CsvDelay: uint16(rand.Int31()), + }, + MultiSigKey: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyMultiSig, + Index: 9, + }, + }, + RevocationBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyRevocationBase, + Index: 8, + }, + }, + PaymentBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyPaymentBase, + Index: 7, + }, + }, + DelayBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyDelayBase, + Index: 6, + }, + }, + HtlcBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyHtlcBase, + Index: 5, + }, + }, + } + + chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63())) + + return &OpenChannel{ + ChanType: SingleFunder, + ChainHash: key, + FundingOutpoint: wire.OutPoint{Hash: key, Index: rand.Uint32()}, + ShortChannelID: chanID, + IsInitiator: true, + IsPending: true, + IdentityPub: pubKey, + Capacity: btcutil.Amount(10000), + LocalChanCfg: localCfg, + RemoteChanCfg: remoteCfg, + TotalMSatSent: 8, + TotalMSatReceived: 2, + LocalCommitment: ChannelCommitment{ + CommitHeight: 0, + LocalBalance: lnwire.MilliSatoshi(9000), + RemoteBalance: lnwire.MilliSatoshi(3000), + CommitFee: btcutil.Amount(rand.Int63()), + FeePerKw: btcutil.Amount(5000), + CommitTx: testTx, + CommitSig: bytes.Repeat([]byte{1}, 71), + }, + RemoteCommitment: ChannelCommitment{ + CommitHeight: 0, + LocalBalance: lnwire.MilliSatoshi(3000), + RemoteBalance: lnwire.MilliSatoshi(9000), + CommitFee: btcutil.Amount(rand.Int63()), + FeePerKw: btcutil.Amount(5000), + CommitTx: testTx, + CommitSig: bytes.Repeat([]byte{1}, 71), + }, + NumConfsRequired: 4, + RemoteCurrentRevocation: privKey.PubKey(), + RemoteNextRevocation: privKey.PubKey(), + RevocationProducer: producer, + RevocationStore: store, + Db: cdb, + Packager: NewChannelPackager(chanID), + FundingTxn: testTx, + }, nil +} + +func TestOpenChannelPutGetDelete(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, then add an additional fake HTLC + // before syncing to disk. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + state.LocalCommitment.Htlcs = []HTLC{ + { + Signature: testSig.Serialize(), + Incoming: true, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: []byte("onionblob"), + }, + } + state.RemoteCommitment.Htlcs = []HTLC{ + { + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: []byte("onionblob"), + }, + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := state.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channel: %v", err) + } + + newState := openChannels[0] + + // The decoded channel state should be identical to what we stored + // above. + if !reflect.DeepEqual(state, newState) { + t.Fatalf("channel state doesn't match:: %v vs %v", + spew.Sdump(state), spew.Sdump(newState)) + } + + // We'll also test that the channel is properly able to hot swap the + // next revocation for the state machine. This tests the initial + // post-funding revocation exchange. + nextRevKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to create new private key: %v", err) + } + if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil { + t.Fatalf("unable to update revocation: %v", err) + } + + openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channel: %v", err) + } + updatedChan := openChannels[0] + + // Ensure that the revocation was set properly. + if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) { + t.Fatalf("next revocation wasn't updated") + } + + // Finally to wrap up the test, delete the state of the channel within + // the database. This involves "closing" the channel which removes all + // written state, and creates a small "summary" elsewhere within the + // database. + closeSummary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + RemotePub: state.IdentityPub, + SettledBalance: btcutil.Amount(500), + TimeLockedBalance: btcutil.Amount(10000), + IsPending: false, + CloseType: CooperativeClose, + } + if err := state.CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + + // As the channel is now closed, attempting to fetch all open channels + // for our fake node ID should return an empty slice. + openChans, err := cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channels: %v", err) + } + if len(openChans) != 0 { + t.Fatalf("all channels not deleted, found %v", len(openChans)) + } + + // Additionally, attempting to fetch all the open channels globally + // should yield no results. + openChans, err = cdb.FetchAllChannels() + if err != nil { + t.Fatal("unable to fetch all open chans") + } + if len(openChans) != 0 { + t.Fatalf("all channels not deleted, found %v", len(openChans)) + } +} + +func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { + if !reflect.DeepEqual(a, b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: commitments don't match: %v vs %v", + line, spew.Sdump(a), spew.Sdump(b)) + } +} + +func TestChannelStateTransition(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a minimal channel, then perform a full sync in order to + // persist the data. + channel, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := channel.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Add some HTLCs which were added during this new state transition. + // Half of the HTLCs are incoming, while the other half are outgoing. + var ( + htlcs []HTLC + htlcAmt lnwire.MilliSatoshi + ) + for i := uint32(0); i < 10; i++ { + var incoming bool + if i > 5 { + incoming = true + } + htlc := HTLC{ + Signature: testSig.Serialize(), + Incoming: incoming, + Amt: 10, + RHash: key, + RefundTimeout: i, + OutputIndex: int32(i * 3), + LogIndex: uint64(i * 2), + HtlcIndex: uint64(i), + } + htlc.OnionBlob = make([]byte, 10) + copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10)) + htlcs = append(htlcs, htlc) + htlcAmt += htlc.Amt + } + + // Create a new channel delta which includes the above HTLCs, some + // balance updates, and an increment of the current commitment height. + // Additionally, modify the signature and commitment transaction. + newSequence := uint32(129498) + newSig := bytes.Repeat([]byte{3}, 71) + newTx := channel.LocalCommitment.CommitTx.Copy() + newTx.TxIn[0].Sequence = newSequence + commitment := ChannelCommitment{ + CommitHeight: 1, + LocalLogIndex: 2, + LocalHtlcIndex: 1, + RemoteLogIndex: 2, + RemoteHtlcIndex: 1, + LocalBalance: lnwire.MilliSatoshi(1e8), + RemoteBalance: lnwire.MilliSatoshi(1e8), + CommitFee: 55, + FeePerKw: 99, + CommitTx: newTx, + CommitSig: newSig, + Htlcs: htlcs, + } + + // First update the local node's broadcastable state and also add a + // CommitDiff remote node's as well in order to simulate a proper state + // transition. + if err := channel.UpdateCommitment(&commitment); err != nil { + t.Fatalf("unable to update commitment: %v", err) + } + + // The balances, new update, the HTLCs and the changes to the fake + // commitment transaction along with the modified signature should all + // have been updated. + updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channel: %v", err) + } + assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() + if err != nil { + t.Fatalf("unable to read commitment height from disk: %v", err) + } + if numDiskUpdates != uint64(commitment.CommitHeight) { + t.Fatalf("num disk updates doesn't match: %v vs %v", + numDiskUpdates, commitment.CommitHeight) + } + + // Attempting to query for a commitment diff should return + // ErrNoPendingCommit as we haven't yet created a new state for them. + _, err = channel.RemoteCommitChainTip() + if err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) + } + + // To simulate us extending a new state to the remote party, we'll also + // create a new commit diff for them. + remoteCommit := commitment + remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) + remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) + remoteCommit.CommitHeight = 1 + commitDiff := &CommitDiff{ + Commitment: remoteCommit, + CommitSig: &lnwire.CommitSig{ + ChanID: lnwire.ChannelID(key), + CommitSig: wireSig, + HtlcSigs: []lnwire.Sig{ + wireSig, + wireSig, + }, + }, + LogUpdates: []LogUpdate{ + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + }, + }, + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + }, + }, + }, + OpenedCircuitKeys: []CircuitKey{}, + ClosedCircuitKeys: []CircuitKey{}, + } + copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{1}, 32)) + copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{2}, 32)) + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatalf("unable to add to commit chain: %v", err) + } + + // The commitment tip should now match the commitment that we just + // inserted. + diskCommitDiff, err := channel.RemoteCommitChainTip() + if err != nil { + t.Fatalf("unable to fetch commit diff: %v", err) + } + if !reflect.DeepEqual(commitDiff, diskCommitDiff) { + t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), + spew.Sdump(diskCommitDiff)) + } + + // We'll save the old remote commitment as this will be added to the + // revocation log shortly. + oldRemoteCommit := channel.RemoteCommitment + + // Next, write to the log which tracks the necessary revocation state + // needed to rectify any fishy behavior by the remote party. Modify the + // current uncollapsed revocation state to simulate a state transition + // by the remote party. + channel.RemoteCurrentRevocation = channel.RemoteNextRevocation + newPriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + channel.RemoteNextRevocation = newPriv.PubKey() + + fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, + diskCommitDiff.LogUpdates, nil) + + err = channel.AdvanceCommitChainTail(fwdPkg) + if err != nil { + t.Fatalf("unable to append to revocation log: %v", err) + } + + // At this point, the remote commit chain should be nil, and the posted + // remote commitment should match the one we added as a diff above. + if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) + } + + // We should be able to fetch the channel delta created above by its + // update number with all the state properly reconstructed. + diskPrevCommit, err := channel.FindPreviousState( + oldRemoteCommit.CommitHeight, + ) + if err != nil { + t.Fatalf("unable to fetch past delta: %v", err) + } + + // The two deltas (the original vs the on-disk version) should + // identical, and all HTLC data should properly be retained. + assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) + + // The state number recovered from the tail of the revocation log + // should be identical to this current state. + logTail, err := channel.RevocationLogTail() + if err != nil { + t.Fatalf("unable to retrieve log: %v", err) + } + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { + t.Fatal("update number doesn't match") + } + + oldRemoteCommit = channel.RemoteCommitment + + // Next modify the posted diff commitment slightly, then create a new + // commitment diff and advance the tail. + commitDiff.Commitment.CommitHeight = 2 + commitDiff.Commitment.LocalBalance -= htlcAmt + commitDiff.Commitment.RemoteBalance += htlcAmt + commitDiff.LogUpdates = []LogUpdate{} + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatalf("unable to add to commit chain: %v", err) + } + + fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) + + err = channel.AdvanceCommitChainTail(fwdPkg) + if err != nil { + t.Fatalf("unable to append to revocation log: %v", err) + } + + // Once again, fetch the state and ensure it has been properly updated. + prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) + if err != nil { + t.Fatalf("unable to fetch past delta: %v", err) + } + assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) + + // Once again, state number recovered from the tail of the revocation + // log should be identical to this current state. + logTail, err = channel.RevocationLogTail() + if err != nil { + t.Fatalf("unable to retrieve log: %v", err) + } + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { + t.Fatal("update number doesn't match") + } + + // The revocation state stored on-disk should now also be identical. + updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channel: %v", err) + } + if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { + t.Fatalf("revocation state was not synced") + } + if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { + t.Fatalf("revocation state was not synced") + } + + // Now attempt to delete the channel from the database. + closeSummary := &ChannelCloseSummary{ + ChanPoint: channel.FundingOutpoint, + RemotePub: channel.IdentityPub, + SettledBalance: btcutil.Amount(500), + TimeLockedBalance: btcutil.Amount(10000), + IsPending: false, + CloseType: RemoteForceClose, + } + if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to delete updated channel: %v", err) + } + + // If we attempt to fetch the target channel again, it shouldn't be + // found. + channels, err := cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channels: %v", err) + } + if len(channels) != 0 { + t.Fatalf("%v channels, found, but none should be", + len(channels)) + } + + // Attempting to find previous states on the channel should fail as the + // revocation log has been deleted. + _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) + if err == nil { + t.Fatal("revocation log search should have failed") + } +} + +func TestFetchPendingChannels(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create first test channel state + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + pendingChannels, err := cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to list pending channels: %v", err) + } + + if len(pendingChannels) != 1 { + t.Fatalf("incorrect number of pending channels: expecting %v,"+ + "got %v", 1, len(pendingChannels)) + } + + // The broadcast height of the pending channel should have been set + // properly. + if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { + t.Fatalf("broadcast height mismatch: expected %v, got %v", + pendingChannels[0].FundingBroadcastHeight, + broadcastHeight) + } + + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 5, + TxIndex: 10, + TxPosition: 15, + } + err = pendingChannels[0].MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + if pendingChannels[0].IsPending { + t.Fatalf("channel marked open should no longer be pending") + } + + if pendingChannels[0].ShortChanID() != chanOpenLoc { + t.Fatalf("channel opening height not updated: expected %v, "+ + "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), + chanOpenLoc) + } + + // Next, we'll re-fetch the channel to ensure that the open height was + // properly set. + openChans, err := cdb.FetchAllChannels() + if err != nil { + t.Fatalf("unable to fetch channels: %v", err) + } + if openChans[0].ShortChanID() != chanOpenLoc { + t.Fatalf("channel opening heights don't match: expected %v, "+ + "got %v", spew.Sdump(openChans[0].ShortChanID()), + chanOpenLoc) + } + if openChans[0].FundingBroadcastHeight != broadcastHeight { + t.Fatalf("broadcast height mismatch: expected %v, got %v", + openChans[0].FundingBroadcastHeight, + broadcastHeight) + } + + pendingChannels, err = cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to list pending channels: %v", err) + } + + if len(pendingChannels) != 0 { + t.Fatalf("incorrect number of pending channels: expecting %v,"+ + "got %v", 0, len(pendingChannels)) + } +} + +func TestFetchClosedChannels(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a test channel, that we'll be closing within this pull + // request. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Next sync the channel to disk, marking it as being in a pending open + // state. + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Next, simulate the confirmation of the channel by marking it as + // pending within the database. + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 5, + TxIndex: 10, + TxPosition: 15, + } + err = state.MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + // Next, close the channel by including a close channel summary in the + // database. + summary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + ClosingTXID: rev, + RemotePub: state.IdentityPub, + Capacity: state.Capacity, + SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(), + TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000, + CloseType: RemoteForceClose, + IsPending: true, + LocalChanConfig: state.LocalChanCfg, + } + if err := state.CloseChannel(summary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + + // Query the database to ensure that the channel has now been properly + // closed. We should get the same result whether querying for pending + // channels only, or not. + pendingClosed, err := cdb.FetchClosedChannels(true) + if err != nil { + t.Fatalf("failed fetching closed channels: %v", err) + } + if len(pendingClosed) != 1 { + t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ + "got %v", 1, len(pendingClosed)) + } + if !reflect.DeepEqual(summary, pendingClosed[0]) { + t.Fatalf("database summaries don't match: expected %v got %v", + spew.Sdump(summary), spew.Sdump(pendingClosed[0])) + } + closed, err := cdb.FetchClosedChannels(false) + if err != nil { + t.Fatalf("failed fetching all closed channels: %v", err) + } + if len(closed) != 1 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 1, len(closed)) + } + if !reflect.DeepEqual(summary, closed[0]) { + t.Fatalf("database summaries don't match: expected %v got %v", + spew.Sdump(summary), spew.Sdump(closed[0])) + } + + // Mark the channel as fully closed. + err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) + if err != nil { + t.Fatalf("failed fully closing channel: %v", err) + } + + // The channel should no longer be considered pending, but should still + // be retrieved when fetching all the closed channels. + closed, err = cdb.FetchClosedChannels(false) + if err != nil { + t.Fatalf("failed fetching closed channels: %v", err) + } + if len(closed) != 1 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 1, len(closed)) + } + pendingClose, err := cdb.FetchClosedChannels(true) + if err != nil { + t.Fatalf("failed fetching channels pending close: %v", err) + } + if len(pendingClose) != 0 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 0, len(closed)) + } +} + +// TestFetchWaitingCloseChannels ensures that the correct channels that are +// waiting to be closed are returned. +func TestFetchWaitingCloseChannels(t *testing.T) { + t.Parallel() + + const numChannels = 2 + const broadcastHeight = 99 + addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555} + + // We'll start by creating two channels within our test database. One of + // them will have their funding transaction confirmed on-chain, while + // the other one will remain unconfirmed. + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + channels := make([]*OpenChannel, numChannels) + for i := 0; i < numChannels; i++ { + channel, err := createTestChannelState(db) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + err = channel.SyncPending(addr, broadcastHeight) + if err != nil { + t.Fatalf("unable to sync channel: %v", err) + } + channels[i] = channel + } + + // We'll only confirm the first one. + channelConf := lnwire.ShortChannelID{ + BlockHeight: broadcastHeight + 1, + TxIndex: 10, + TxPosition: 15, + } + if err := channels[0].MarkAsOpen(channelConf); err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + // Then, we'll mark the channels as if their commitments were broadcast. + // This would happen in the event of a force close and should make the + // channels enter a state of waiting close. + for _, channel := range channels { + closeTx := wire.NewMsgTx(2) + closeTx.AddTxIn( + &wire.TxIn{ + PreviousOutPoint: channel.FundingOutpoint, + }, + ) + if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil { + t.Fatalf("unable to mark commitment broadcast: %v", err) + } + } + + // Now, we'll fetch all the channels waiting to be closed from the + // database. We should expect to see both channels above, even if any of + // them haven't had their funding transaction confirm on-chain. + waitingCloseChannels, err := db.FetchWaitingCloseChannels() + if err != nil { + t.Fatalf("unable to fetch all waiting close channels: %v", err) + } + if len(waitingCloseChannels) != 2 { + t.Fatalf("expected %d channels waiting to be closed, got %d", 2, + len(waitingCloseChannels)) + } + expectedChannels := make(map[wire.OutPoint]struct{}) + for _, channel := range channels { + expectedChannels[channel.FundingOutpoint] = struct{}{} + } + for _, channel := range waitingCloseChannels { + if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { + t.Fatalf("expected channel %v to be waiting close", + channel.FundingOutpoint) + } + + // Finally, make sure we can retrieve the closing tx for the + // channel. + closeTx, err := channel.BroadcastedCommitment() + if err != nil { + t.Fatalf("Unable to retrieve commitment: %v", err) + } + + if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint { + t.Fatalf("expected outpoint %v, got %v", + channel.FundingOutpoint, + closeTx.TxIn[0].PreviousOutPoint) + } + } +} + +// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory +// short channel ID of another OpenChannel to reflect a preceding call to +// MarkOpen on a different OpenChannel. +func TestRefreshShortChanID(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a test channel. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + + // Mark the channel as pending within the channeldb. + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Next, locate the pending channel with the database. + pendingChannels, err := cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to load pending channels; %v", err) + } + + var pendingChannel *OpenChannel + for _, channel := range pendingChannels { + if channel.FundingOutpoint == state.FundingOutpoint { + pendingChannel = channel + break + } + } + if pendingChannel == nil { + t.Fatalf("unable to find pending channel with funding "+ + "outpoint=%v: %v", state.FundingOutpoint, err) + } + + // Next, simulate the confirmation of the channel by marking it as + // pending within the database. + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 105, + TxIndex: 10, + TxPosition: 15, + } + + err = state.MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel open: %v", err) + } + + // The short_chan_id of the receiver to MarkAsOpen should reflect the + // open location, but the other pending channel should remain unchanged. + if state.ShortChanID() == pendingChannel.ShortChanID() { + t.Fatalf("pending channel short_chan_ID should not have been " + + "updated before refreshing short_chan_id") + } + + // Now that the receiver's short channel id has been updated, check to + // ensure that the channel packager's source has been updated as well. + // This ensures that the packager will read and write to buckets + // corresponding to the new short chan id, instead of the prior. + if state.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + state.Packager.(*ChannelPackager).source) + } + + // Now, refresh the short channel ID of the pending channel. + err = pendingChannel.RefreshShortChanID() + if err != nil { + t.Fatalf("unable to refresh short_chan_id: %v", err) + } + + // This should result in both OpenChannel's now having the same + // ShortChanID. + if state.ShortChanID() != pendingChannel.ShortChanID() { + t.Fatalf("expected pending channel short_chan_id to be "+ + "refreshed: want %v, got %v", state.ShortChanID(), + pendingChannel.ShortChanID()) + } + + // Check to ensure that the _other_ OpenChannel channel packager's + // source has also been updated after the refresh. This ensures that the + // other packagers will read and write to buckets corresponding to the + // updated short chan id. + if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + pendingChannel.Packager.(*ChannelPackager).source) + } +} diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go new file mode 100644 index 00000000..cfef35e0 --- /dev/null +++ b/channeldb/migration_01_to_11/codec.go @@ -0,0 +1,454 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized element: %s", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + case ChannelType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + return err + } + + case ChannelStatus: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case net.Addr: + if err := serializeAddr(w, e); err != nil { + return err + } + + case []net.Addr: + if err := WriteElement(w, uint32(len(e))); err != nil { + return err + } + + for _, addr := range e { + if err := serializeAddr(w, addr); err != nil { + return err + } + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + case *ChannelType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), b[:]) + *e = priv + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes(root[:]) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + msg, err := lnwire.ReadMessage(r, 0) + if err != nil { + return err + } + + *e = msg + + case *ChannelStatus: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *net.Addr: + addr, err := deserializeAddr(r) + if err != nil { + return err + } + *e = addr + + case *[]net.Addr: + var numAddrs uint32 + if err := ReadElement(r, &numAddrs); err != nil { + return err + } + + *e = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := deserializeAddr(r) + if err != nil { + return err + } + (*e)[i] = addr + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go new file mode 100644 index 00000000..c4306400 --- /dev/null +++ b/channeldb/migration_01_to_11/db.go @@ -0,0 +1,1185 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "os" + "path/filepath" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + dbName = "channel.db" + dbFilePermission = 0600 +) + +// migration is a function which takes a prior outdated version of the database +// instances and mutates the key/bucket structure to arrive at a more +// up-to-date version of the database. +type migration func(tx *bbolt.Tx) error + +type version struct { + number uint32 + migration migration +} + +var ( + // dbVersions is storing all versions of database. If current version + // of database don't match with latest version this list will be used + // for retrieving all migration function that are need to apply to the + // current db. + dbVersions = []version{ + { + // The base DB version requires no migration. + number: 0, + migration: nil, + }, + { + // The version of the database where two new indexes + // for the update time of node and channel updates were + // added. + number: 1, + migration: migrateNodeAndEdgeUpdateIndex, + }, + { + // The DB version that added the invoice event time + // series. + number: 2, + migration: migrateInvoiceTimeSeries, + }, + { + // The DB version that updated the embedded invoice in + // outgoing payments to match the new format. + number: 3, + migration: migrateInvoiceTimeSeriesOutgoingPayments, + }, + { + // The version of the database where every channel + // always has two entries in the edges bucket. If + // a policy is unknown, this will be represented + // by a special byte sequence. + number: 4, + migration: migrateEdgePolicies, + }, + { + // The DB version where we persist each attempt to send + // an HTLC to a payment hash, and track whether the + // payment is in-flight, succeeded, or failed. + number: 5, + migration: paymentStatusesMigration, + }, + { + // The DB version that properly prunes stale entries + // from the edge update index. + number: 6, + migration: migratePruneEdgeUpdateIndex, + }, + { + // The DB version that migrates the ChannelCloseSummary + // to a format where optional fields are indicated with + // boolean flags. + number: 7, + migration: migrateOptionalChannelCloseSummaryFields, + }, + { + // The DB version that changes the gossiper's message + // store keys to account for the message's type and + // ShortChannelID. + number: 8, + migration: migrateGossipMessageStoreKeys, + }, + { + // The DB version where the payments and payment + // statuses are moved to being stored in a combined + // bucket. + number: 9, + migration: migrateOutgoingPayments, + }, + { + // The DB version where we started to store legacy + // payload information for all routes, as well as the + // optional TLV records. + number: 10, + migration: migrateRouteSerialization, + }, + { + // Add invoice htlc and cltv delta fields. + number: 11, + migration: migrateInvoices, + }, + } + + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// DB is the primary datastore for the lnd daemon. The database stores +// information related to nodes, routing data, open/closed channels, fee +// schedules, and reputation data. +type DB struct { + *bbolt.DB + dbPath string + graph *ChannelGraph + now func() time.Time +} + +// Open opens an existing channeldb. Any necessary schemas migrations due to +// updates will take place as necessary. +func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { + path := filepath.Join(dbPath, dbName) + + if !fileExists(path) { + if err := createChannelDB(dbPath); err != nil { + return nil, err + } + } + + opts := DefaultOptions() + for _, modifier := range modifiers { + modifier(&opts) + } + + // Specify bbolt freelist options to reduce heap pressure in case the + // freelist grows to be very large. + options := &bbolt.Options{ + NoFreelistSync: opts.NoFreelistSync, + FreelistType: bbolt.FreelistMapType, + } + + bdb, err := bbolt.Open(path, dbFilePermission, options) + if err != nil { + return nil, err + } + + chanDB := &DB{ + DB: bdb, + dbPath: dbPath, + now: time.Now, + } + chanDB.graph = newChannelGraph( + chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, + ) + + // Synchronize the version of database and apply migrations if needed. + if err := chanDB.syncVersions(dbVersions); err != nil { + bdb.Close() + return nil, err + } + + return chanDB, nil +} + +// Path returns the file path to the channel database. +func (d *DB) Path() string { + return d.dbPath +} + +// Wipe completely deletes all saved state within all used buckets within the +// database. The deletion is done in a single transaction, therefore this +// operation is fully atomic. +func (d *DB) Wipe() error { + return d.Update(func(tx *bbolt.Tx) error { + err := tx.DeleteBucket(openChannelBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(closedChannelBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(invoiceBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(nodeInfoBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(nodeBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(edgeBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(edgeIndexBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(graphMetaBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + return nil + }) +} + +// createChannelDB creates and initializes a fresh version of channeldb. In +// the case that the target path has not yet been created or doesn't yet exist, +// then the path is created. Additionally, all required top-level buckets used +// within the database are created. +func createChannelDB(dbPath string) error { + if !fileExists(dbPath) { + if err := os.MkdirAll(dbPath, 0700); err != nil { + return err + } + } + + path := filepath.Join(dbPath, dbName) + bdb, err := bbolt.Open(path, dbFilePermission, nil) + if err != nil { + return err + } + + err = bdb.Update(func(tx *bbolt.Tx) error { + if _, err := tx.CreateBucket(openChannelBucket); err != nil { + return err + } + if _, err := tx.CreateBucket(closedChannelBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(forwardingLogBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(fwdPackagesKey); err != nil { + return err + } + + if _, err := tx.CreateBucket(invoiceBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(paymentBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(nodeInfoBucket); err != nil { + return err + } + + nodes, err := tx.CreateBucket(nodeBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucket(aliasIndexBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucket(nodeUpdateIndexBucket) + if err != nil { + return err + } + + edges, err := tx.CreateBucket(edgeBucket) + if err != nil { + return err + } + if _, err := edges.CreateBucket(edgeIndexBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(channelPointBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(zombieBucket); err != nil { + return err + } + + graphMeta, err := tx.CreateBucket(graphMetaBucket) + if err != nil { + return err + } + _, err = graphMeta.CreateBucket(pruneLogBucket) + if err != nil { + return err + } + + if _, err := tx.CreateBucket(metaBucket); err != nil { + return err + } + + meta := &Meta{ + DbVersionNumber: getLatestDBVersion(dbVersions), + } + return putMeta(meta, tx) + }) + if err != nil { + return fmt.Errorf("unable to create new channeldb") + } + + return bdb.Close() +} + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// FetchOpenChannels starts a new database transaction and returns all stored +// currently active/open channels associated with the target nodeID. In the case +// that no active channels are known to have been created with this node, then a +// zero-length slice is returned. +func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + var channels []*OpenChannel + err := d.View(func(tx *bbolt.Tx) error { + var err error + channels, err = d.fetchOpenChannels(tx, nodeID) + return err + }) + + return channels, err +} + +// fetchOpenChannels uses and existing database transaction and returns all +// stored currently active/open channels associated with the target nodeID. In +// the case that no active channels are known to have been created with this +// node, then a zero-length slice is returned. +func (d *DB) fetchOpenChannels(tx *bbolt.Tx, + nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + + // Get the bucket dedicated to storing the metadata for open channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + pub := nodeID.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(pub) + if nodeChanBucket == nil { + return nil, nil + } + + // Next, we'll need to go down an additional layer in order to retrieve + // the channels for each chain the node knows of. + var channels []*OpenChannel + err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, then we'll retrieve + // that so we can extract all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read bucket for chain=%x", + chainHash[:]) + } + + // Finally, we both of the necessary buckets retrieved, fetch + // all the active channels related to this node. + nodeChannels, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read channel for "+ + "chain_hash=%x, node_key=%x: %v", + chainHash[:], pub, err) + } + + channels = append(channels, nodeChannels...) + return nil + }) + + return channels, err +} + +// fetchNodeChannels retrieves all active channels from the target chainBucket +// which is under a node's dedicated channel bucket. This function is typically +// used to fetch all the active channels related to a particular node. +func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) { + + var channels []*OpenChannel + + // A node may have channels on several chains, so for each known chain, + // we'll extract all the channels. + err := chainBucket.ForEach(func(chanPoint, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // Once we've found a valid channel bucket, we'll extract it + // from the node's chain bucket. + chanBucket := chainBucket.Bucket(chanPoint) + + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) + if err != nil { + return err + } + oChannel, err := fetchOpenChannel(chanBucket, &outPoint) + if err != nil { + return fmt.Errorf("unable to read channel data for "+ + "chan_point=%v: %v", outPoint, err) + } + oChannel.Db = d + + channels = append(channels, oChannel) + + return nil + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchChannel attempts to locate a channel specified by the passed channel +// point. If the channel cannot be found, then an error will be returned. +func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { + var ( + targetChan *OpenChannel + targetChanPoint bytes.Buffer + ) + + if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { + return nil, err + } + + // chanScan will traverse the following bucket structure: + // * nodePub => chainHash => chanPoint + // + // At each level we go one further, ensuring that we're traversing the + // proper key (that's actually a bucket). By only reading the bucket + // structure and skipping fully decoding each channel, we save a good + // bit of CPU as we don't need to do things like decompress public + // keys. + chanScan := func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Within the node channel bucket, are the set of node pubkeys + // we have channels with, we don't know the entire set, so + // we'll check them all. + return openChanBucket.ForEach(func(nodePub, v []byte) error { + // Ensure that this is a key the same size as a pubkey, + // and also that it leads directly to a bucket. + if len(nodePub) != 33 || v != nil { + return nil + } + + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return nil + } + + // The next layer down is all the chains that this node + // has channels on with us. + return nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", chainHash[:]) + } + + // Finally we reach the leaf bucket that stores + // all the chanPoints for this node. + chanBucket := chainBucket.Bucket( + targetChanPoint.Bytes(), + ) + if chanBucket == nil { + return nil + } + + channel, err := fetchOpenChannel( + chanBucket, &chanPoint, + ) + if err != nil { + return err + } + + targetChan = channel + targetChan.Db = d + + return nil + }) + }) + } + + err := d.View(chanScan) + if err != nil { + return nil, err + } + + if targetChan != nil { + return targetChan, nil + } + + // If we can't find the channel, then we return with an error, as we + // have nothing to backup. + return nil, ErrChannelNotFound +} + +// FetchAllChannels attempts to retrieve all open channels currently stored +// within the database, including pending open, fully open and channels waiting +// for a closing transaction to confirm. +func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { + var channels []*OpenChannel + + // TODO(halseth): fetch all in one db tx. + openChannels, err := d.FetchAllOpenChannels() + if err != nil { + return nil, err + } + channels = append(channels, openChannels...) + + pendingChannels, err := d.FetchPendingChannels() + if err != nil { + return nil, err + } + channels = append(channels, pendingChannels...) + + waitingClose, err := d.FetchWaitingCloseChannels() + if err != nil { + return nil, err + } + channels = append(channels, waitingClose...) + + return channels, nil +} + +// FetchAllOpenChannels will return all channels that have the funding +// transaction confirmed, and is not waiting for a closing transaction to be +// confirmed. +func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { + return fetchChannels(d, false, false) +} + +// FetchPendingChannels will return channels that have completed the process of +// generating and broadcasting funding transactions, but whose funding +// transactions have yet to be confirmed on the blockchain. +func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { + return fetchChannels(d, true, false) +} + +// FetchWaitingCloseChannels will return all channels that have been opened, +// but are now waiting for a closing transaction to be confirmed. +// +// NOTE: This includes channels that are also pending to be opened. +func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { + waitingClose, err := fetchChannels(d, false, true) + if err != nil { + return nil, err + } + pendingWaitingClose, err := fetchChannels(d, true, true) + if err != nil { + return nil, err + } + + return append(waitingClose, pendingWaitingClose...), nil +} + +// fetchChannels attempts to retrieve channels currently stored in the +// database. The pending parameter determines whether only pending channels +// will be returned, or only open channels will be returned. The waitingClose +// parameter determines whether only channels waiting for a closing transaction +// to be confirmed should be returned. If no active channels exist within the +// network, then ErrNoActiveChannels is returned. +func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { + var channels []*OpenChannel + + err := d.View(func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Next, fetch the bucket dedicated to storing metadata related + // to all nodes. All keys within this bucket are the serialized + // public keys of all our direct counterparties. + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return fmt.Errorf("node bucket not created") + } + + // Finally for each node public key in the bucket, fetch all + // the channels related to this particular node. + return nodeMetaBucket.ForEach(func(k, v []byte) error { + nodeChanBucket := openChanBucket.Bucket(k) + if nodeChanBucket == nil { + return nil + } + + return nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, + // then we'll retrieve that so we can extract + // all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", chainHash[:]) + } + + nodeChans, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read "+ + "channel for chain_hash=%x, "+ + "node_key=%x: %v", chainHash[:], k, err) + } + for _, channel := range nodeChans { + if channel.IsPending != pending { + continue + } + + // If the channel is in any other state + // than Default, then it means it is + // waiting to be closed. + channelWaitingClose := + channel.ChanStatus() != ChanStatusDefault + + // Only include it if we requested + // channels with the same waitingClose + // status. + if channelWaitingClose != waitingClose { + continue + } + + channels = append(channels, channel) + } + return nil + }) + + }) + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchClosedChannels attempts to fetch all closed channels from the database. +// The pendingOnly bool toggles if channels that aren't yet fully closed should +// be returned in the response or not. When a channel was cooperatively closed, +// it becomes fully closed after a single confirmation. When a channel was +// forcibly closed, it will become fully closed after _all_ the pending funds +// (if any) have been swept. +func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { + var chanSummaries []*ChannelCloseSummary + + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrNoClosedChannels + } + + return closeBucket.ForEach(func(chanID []byte, summaryBytes []byte) error { + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err := deserializeCloseChannelSummary(summaryReader) + if err != nil { + return err + } + + // If the query specified to only include pending + // channels, then we'll skip any channels which aren't + // currently pending. + if !chanSummary.IsPending && pendingOnly { + return nil + } + + chanSummaries = append(chanSummaries, chanSummary) + return nil + }) + }); err != nil { + return nil, err + } + + return chanSummaries, nil +} + +// ErrClosedChannelNotFound signals that a closed channel could not be found in +// the channeldb. +var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary") + +// FetchClosedChannel queries for a channel close summary using the channel +// point of the channel in question. +func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { + var chanSummary *ChannelCloseSummary + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + var b bytes.Buffer + var err error + if err = writeOutpoint(&b, chanID); err != nil { + return err + } + + summaryBytes := closeBucket.Get(b.Bytes()) + if summaryBytes == nil { + return ErrClosedChannelNotFound + } + + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err = deserializeCloseChannelSummary(summaryReader) + + return err + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// FetchClosedChannelForID queries for a channel close summary using the +// channel ID of the channel in question. +func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( + *ChannelCloseSummary, error) { + + var chanSummary *ChannelCloseSummary + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + // The first 30 bytes of the channel ID and outpoint will be + // equal. + cursor := closeBucket.Cursor() + op, c := cursor.Seek(cid[:30]) + + // We scan over all possible candidates for this channel ID. + for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(op), &outPoint) + if err != nil { + return err + } + + // If the found outpoint does not correspond to this + // channel ID, we continue. + if !cid.IsChanPoint(&outPoint) { + continue + } + + // Deserialize the close summary and return. + r := bytes.NewReader(c) + chanSummary, err = deserializeCloseChannelSummary(r) + if err != nil { + return err + } + + return nil + } + return ErrClosedChannelNotFound + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// MarkChanFullyClosed marks a channel as fully closed within the database. A +// channel should be marked as fully closed if the channel was initially +// cooperatively closed and it's reached a single confirmation, or after all +// the pending funds in a channel that has been forcibly closed have been +// swept. +func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { + return d.Update(func(tx *bbolt.Tx) error { + var b bytes.Buffer + if err := writeOutpoint(&b, chanPoint); err != nil { + return err + } + + chanID := b.Bytes() + + closedChanBucket, err := tx.CreateBucketIfNotExists( + closedChannelBucket, + ) + if err != nil { + return err + } + + chanSummaryBytes := closedChanBucket.Get(chanID) + if chanSummaryBytes == nil { + return fmt.Errorf("no closed channel for "+ + "chan_point=%v found", chanPoint) + } + + chanSummaryReader := bytes.NewReader(chanSummaryBytes) + chanSummary, err := deserializeCloseChannelSummary( + chanSummaryReader, + ) + if err != nil { + return err + } + + chanSummary.IsPending = false + + var newSummary bytes.Buffer + err = serializeChannelCloseSummary(&newSummary, chanSummary) + if err != nil { + return err + } + + err = closedChanBucket.Put(chanID, newSummary.Bytes()) + if err != nil { + return err + } + + // Now that the channel is closed, we'll check if we have any + // other open channels with this peer. If we don't we'll + // garbage collect it to ensure we don't establish persistent + // connections to peers without open channels. + return d.pruneLinkNode(tx, chanSummary.RemotePub) + }) +} + +// pruneLinkNode determines whether we should garbage collect a link node from +// the database due to no longer having any open channels with it. If there are +// any left, then this acts as a no-op. +func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error { + openChannels, err := d.fetchOpenChannels(tx, remotePub) + if err != nil { + return fmt.Errorf("unable to fetch open channels for peer %x: "+ + "%v", remotePub.SerializeCompressed(), err) + } + + if len(openChannels) > 0 { + return nil + } + + log.Infof("Pruning link node %x with zero open channels from database", + remotePub.SerializeCompressed()) + + return d.deleteLinkNode(tx, remotePub) +} + +// PruneLinkNodes attempts to prune all link nodes found within the databse with +// whom we no longer have any open channels with. +func (d *DB) PruneLinkNodes() error { + return d.Update(func(tx *bbolt.Tx) error { + linkNodes, err := d.fetchAllLinkNodes(tx) + if err != nil { + return err + } + + for _, linkNode := range linkNodes { + err := d.pruneLinkNode(tx, linkNode.IdentityPub) + if err != nil { + return err + } + } + + return nil + }) +} + +// ChannelShell is a shell of a channel that is meant to be used for channel +// recovery purposes. It contains a minimal OpenChannel instance along with +// addresses for that target node. +type ChannelShell struct { + // NodeAddrs the set of addresses that this node has known to be + // reachable at in the past. + NodeAddrs []net.Addr + + // Chan is a shell of an OpenChannel, it contains only the items + // required to restore the channel on disk. + Chan *OpenChannel +} + +// RestoreChannelShells is a method that allows the caller to reconstruct the +// state of an OpenChannel from the ChannelShell. We'll attempt to write the +// new channel to disk, create a LinkNode instance with the passed node +// addresses, and finally create an edge within the graph for the channel as +// well. This method is idempotent, so repeated calls with the same set of +// channel shells won't modify the database after the initial call. +func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { + chanGraph := d.ChannelGraph() + + // TODO(conner): find way to do this w/o accessing internal members? + chanGraph.cacheMu.Lock() + defer chanGraph.cacheMu.Unlock() + + var chansRestored []uint64 + err := d.Update(func(tx *bbolt.Tx) error { + for _, channelShell := range channelShells { + channel := channelShell.Chan + + // When we make a channel, we mark that the channel has + // been restored, this will signal to other sub-systems + // to not attempt to use the channel as if it was a + // regular one. + channel.chanStatus |= ChanStatusRestored + + // First, we'll attempt to create a new open channel + // and link node for this channel. If the channel + // already exists, then in order to ensure this method + // is idempotent, we'll continue to the next step. + channel.Db = d + err := syncNewChannel( + tx, channel, channelShell.NodeAddrs, + ) + if err != nil { + return err + } + + // Next, we'll create an active edge in the graph + // database for this channel in order to restore our + // partial view of the network. + // + // TODO(roasbeef): if we restore *after* the channel + // has been closed on chain, then need to inform the + // router that it should try and prune these values as + // we can detect them + edgeInfo := ChannelEdgeInfo{ + ChannelID: channel.ShortChannelID.ToUint64(), + ChainHash: channel.ChainHash, + ChannelPoint: channel.FundingOutpoint, + Capacity: channel.Capacity, + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + selfNode, err := chanGraph.sourceNode(nodes) + if err != nil { + return err + } + + // Depending on which pub key is smaller, we'll assign + // our roles as "node1" and "node2". + chanPeer := channel.IdentityPub.SerializeCompressed() + selfIsSmaller := bytes.Compare( + selfNode.PubKeyBytes[:], chanPeer, + ) == -1 + if selfIsSmaller { + copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], chanPeer) + } else { + copy(edgeInfo.NodeKey1Bytes[:], chanPeer) + copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) + } + + // With the edge info shell constructed, we'll now add + // it to the graph. + err = chanGraph.addChannelEdge(tx, &edgeInfo) + if err != nil && err != ErrEdgeAlreadyExist { + return err + } + + // Similarly, we'll construct a channel edge shell and + // add that itself to the graph. + chanEdge := ChannelEdgePolicy{ + ChannelID: edgeInfo.ChannelID, + LastUpdate: time.Now(), + } + + // If their pubkey is larger, then we'll flip the + // direction bit to indicate that us, the "second" node + // is updating their policy. + if !selfIsSmaller { + chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection + } + + _, err = updateEdgePolicy(tx, &chanEdge) + if err != nil { + return err + } + + chansRestored = append(chansRestored, edgeInfo.ChannelID) + } + + return nil + }) + if err != nil { + return err + } + + for _, chanid := range chansRestored { + chanGraph.rejectCache.remove(chanid) + chanGraph.chanCache.remove(chanid) + } + + return nil +} + +// AddrsForNode consults the graph and channel database for all addresses known +// to the passed node public key. +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { + var ( + linkNode *LinkNode + graphNode LightningNode + ) + + dbErr := d.View(func(tx *bbolt.Tx) error { + var err error + + linkNode, err = fetchLinkNode(tx, nodePub) + if err != nil { + return err + } + + // We'll also query the graph for this peer to see if they have + // any addresses that we don't currently have stored within the + // link node database. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + compressedPubKey := nodePub.SerializeCompressed() + graphNode, err = fetchLightningNode(nodes, compressedPubKey) + if err != nil && err != ErrGraphNodeNotFound { + // If the node isn't found, then that's OK, as we still + // have the link node data. + return err + } + + return nil + }) + if dbErr != nil { + return nil, dbErr + } + + // Now that we have both sources of addrs for this node, we'll use a + // map to de-duplicate any addresses between the two sources, and + // produce a final list of the combined addrs. + addrs := make(map[string]net.Addr) + for _, addr := range linkNode.Addresses { + addrs[addr.String()] = addr + } + for _, addr := range graphNode.Addresses { + addrs[addr.String()] = addr + } + dedupedAddrs := make([]net.Addr, 0, len(addrs)) + for _, addr := range addrs { + dedupedAddrs = append(dedupedAddrs, addr) + } + + return dedupedAddrs, nil +} + +// syncVersions function is used for safe db version synchronization. It +// applies migration functions to the current database and recovers the +// previous state of db if at least one error/panic appeared during migration. +func (d *DB) syncVersions(versions []version) error { + meta, err := d.FetchMeta(nil) + if err != nil { + if err == ErrMetaNotFound { + meta = &Meta{} + } else { + return err + } + } + + latestVersion := getLatestDBVersion(versions) + log.Infof("Checking for schema update: latest_version=%v, "+ + "db_version=%v", latestVersion, meta.DbVersionNumber) + + switch { + + // If the database reports a higher version that we are aware of, the + // user is probably trying to revert to a prior version of lnd. We fail + // here to prevent reversions and unintended corruption. + case meta.DbVersionNumber > latestVersion: + log.Errorf("Refusing to revert from db_version=%d to "+ + "lower version=%d", meta.DbVersionNumber, + latestVersion) + return ErrDBReversion + + // If the current database version matches the latest version number, + // then we don't need to perform any migrations. + case meta.DbVersionNumber == latestVersion: + return nil + } + + log.Infof("Performing database schema migration") + + // Otherwise, we fetch the migrations which need to applied, and + // execute them serially within a single database transaction to ensure + // the migration is atomic. + migrations, migrationVersions := getMigrationsToApply( + versions, meta.DbVersionNumber, + ) + return d.Update(func(tx *bbolt.Tx) error { + for i, migration := range migrations { + if migration == nil { + continue + } + + log.Infof("Applying migration #%v", migrationVersions[i]) + + if err := migration(tx); err != nil { + log.Infof("Unable to apply migration #%v", + migrationVersions[i]) + return err + } + } + + meta.DbVersionNumber = latestVersion + return putMeta(meta, tx) + }) +} + +// ChannelGraph returns a new instance of the directed channel graph. +func (d *DB) ChannelGraph() *ChannelGraph { + return d.graph +} + +func getLatestDBVersion(versions []version) uint32 { + return versions[len(versions)-1].number +} + +// getMigrationsToApply retrieves the migration function that should be +// applied to the database. +func getMigrationsToApply(versions []version, version uint32) ([]migration, []uint32) { + migrations := make([]migration, 0, len(versions)) + migrationVersions := make([]uint32, 0, len(versions)) + + for _, v := range versions { + if v.number > version { + migrations = append(migrations, v.migration) + migrationVersions = append(migrationVersions, v.number) + } + } + + return migrations, migrationVersions +} diff --git a/channeldb/migration_01_to_11/db_test.go b/channeldb/migration_01_to_11/db_test.go new file mode 100644 index 00000000..721546e7 --- /dev/null +++ b/channeldb/migration_01_to_11/db_test.go @@ -0,0 +1,471 @@ +package migration_01_to_11 + +import ( + "io/ioutil" + "math" + "math/rand" + "net" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +func TestOpenWithCreate(t *testing.T) { + t.Parallel() + + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDirName) + + // Next, open thereby creating channeldb for the first time. + dbPath := filepath.Join(tempDirName, "cdb") + cdb, err := Open(dbPath) + if err != nil { + t.Fatalf("unable to create channeldb: %v", err) + } + if err := cdb.Close(); err != nil { + t.Fatalf("unable to close channeldb: %v", err) + } + + // The path should have been successfully created. + if !fileExists(dbPath) { + t.Fatalf("channeldb failed to create data directory") + } +} + +// TestWipe tests that the database wipe operation completes successfully +// and that the buckets are deleted. It also checks that attempts to fetch +// information while the buckets are not set return the correct errors. +func TestWipe(t *testing.T) { + t.Parallel() + + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDirName) + + // Next, open thereby creating channeldb for the first time. + dbPath := filepath.Join(tempDirName, "cdb") + cdb, err := Open(dbPath) + if err != nil { + t.Fatalf("unable to create channeldb: %v", err) + } + defer cdb.Close() + + if err := cdb.Wipe(); err != nil { + t.Fatalf("unable to wipe channeldb: %v", err) + } + // Check correct errors are returned + _, err = cdb.FetchAllOpenChannels() + if err != ErrNoActiveChannels { + t.Fatalf("fetching open channels: expected '%v' instead got '%v'", + ErrNoActiveChannels, err) + } + _, err = cdb.FetchClosedChannels(false) + if err != ErrNoClosedChannels { + t.Fatalf("fetching closed channels: expected '%v' instead got '%v'", + ErrNoClosedChannels, err) + } +} + +// TestFetchClosedChannelForID tests that we are able to properly retrieve a +// ChannelCloseSummary from the DB given a ChannelID. +func TestFetchClosedChannelForID(t *testing.T) { + t.Parallel() + + const numChans = 101 + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, that we will mutate the index of the + // funding point. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Now run through the number of channels, and modify the outpoint index + // to create new channel IDs. + for i := uint32(0); i < numChans; i++ { + // Save the open channel to disk. + state.FundingOutpoint.Index = i + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := state.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel "+ + "state: %v", err) + } + + // Close the channel. To make sure we retrieve the correct + // summary later, we make them differ in the SettledBalance. + closeSummary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + RemotePub: state.IdentityPub, + SettledBalance: btcutil.Amount(500 + i), + } + if err := state.CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + } + + // Now run though them all again and make sure we are able to retrieve + // summaries from the DB. + for i := uint32(0); i < numChans; i++ { + state.FundingOutpoint.Index = i + + // We calculate the ChannelID and use it to fetch the summary. + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + fetchedSummary, err := cdb.FetchClosedChannelForID(cid) + if err != nil { + t.Fatalf("unable to fetch close summary: %v", err) + } + + // Make sure we retrieved the correct one by checking the + // SettledBalance. + if fetchedSummary.SettledBalance != btcutil.Amount(500+i) { + t.Fatalf("summaries don't match: expected %v got %v", + btcutil.Amount(500+i), + fetchedSummary.SettledBalance) + } + } + + // As a final test we make sure that we get ErrClosedChannelNotFound + // for a ChannelID we didn't add to the DB. + state.FundingOutpoint.Index++ + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + _, err = cdb.FetchClosedChannelForID(cid) + if err != ErrClosedChannelNotFound { + t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) + } +} + +// TestAddrsForNode tests the we're able to properly obtain all the addresses +// for a target node. +func TestAddrsForNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + graph := cdb.ChannelGraph() + + // We'll make a test vertex to insert into the database, as the source + // node, but this node will only have half the number of addresses it + // usually does. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + testNode.Addresses = []net.Addr{testAddr} + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // Next, we'll make a link node with the same pubkey, but with an + // additional address. + nodePub, err := testNode.PubKey() + if err != nil { + t.Fatalf("unable to recv node pub: %v", err) + } + linkNode := cdb.NewLinkNode( + wire.MainNet, nodePub, anotherAddr, + ) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to sync link node: %v", err) + } + + // Now that we've created a link node, as well as a vertex for the + // node, we'll query for all its addresses. + nodeAddrs, err := cdb.AddrsForNode(nodePub) + if err != nil { + t.Fatalf("unable to obtain node addrs: %v", err) + } + + expectedAddrs := make(map[string]struct{}) + expectedAddrs[testAddr.String()] = struct{}{} + expectedAddrs[anotherAddr.String()] = struct{}{} + + // Finally, ensure that all the expected addresses are found. + if len(nodeAddrs) != len(expectedAddrs) { + t.Fatalf("expected %v addrs, got %v", + len(expectedAddrs), len(nodeAddrs)) + } + for _, addr := range nodeAddrs { + if _, ok := expectedAddrs[addr.String()]; !ok { + t.Fatalf("unexpected addr: %v", addr) + } + } +} + +// TestFetchChannel tests that we're able to fetch an arbitrary channel from +// disk. +func TestFetchChannel(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state that we'll sync to the database + // shortly. + channelState, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Mark the channel as pending, then immediately mark it as open to it + // can be fully visible. + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + if err := channelState.SyncPending(addr, 9); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99)) + if err != nil { + t.Fatalf("unable to mark channel open: %v", err) + } + + // Next, attempt to fetch the channel by its chan point. + dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + + // The decoded channel state should be identical to what we stored + // above. + if !reflect.DeepEqual(channelState, dbChannel) { + t.Fatalf("channel state doesn't match:: %v vs %v", + spew.Sdump(channelState), spew.Sdump(dbChannel)) + } + + // If we attempt to query for a non-exist ante channel, then we should + // get an error. + channelState2, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + channelState2.FundingOutpoint.Index ^= 1 + + _, err = cdb.FetchChannel(channelState2.FundingOutpoint) + if err == nil { + t.Fatalf("expected query to fail") + } +} + +func genRandomChannelShell() (*ChannelShell, error) { + var testPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) + + var chanPoint wire.OutPoint + if _, err := rand.Read(chanPoint.Hash[:]); err != nil { + return nil, err + } + + pub.Curve = nil + + chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) + + chanStatus := ChanStatusDefault | ChanStatusRestored + + var shaChainPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + revRoot, err := chainhash.NewHash(shaChainPriv[:]) + if err != nil { + return nil, err + } + shaChainProducer := shachain.NewRevocationProducer(*revRoot) + + return &ChannelShell{ + NodeAddrs: []net.Addr{&net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }}, + Chan: &OpenChannel{ + chanStatus: chanStatus, + ChainHash: rev, + FundingOutpoint: chanPoint, + ShortChannelID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + IdentityPub: pub, + LocalChanCfg: ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + CsvDelay: uint16(rand.Int63()), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily(rand.Int63()), + Index: uint32(rand.Int63()), + }, + }, + }, + RemoteCurrentRevocation: pub, + IsPending: false, + RevocationStore: shachain.NewRevocationStore(), + RevocationProducer: shaChainProducer, + }, + }, nil +} + +// TestRestoreChannelShells tests that we're able to insert a partially channel +// populated to disk. This is useful for channel recovery purposes. We should +// find the new channel shell on disk, and also the db should be populated with +// an edge for that channel. +func TestRestoreChannelShells(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First, we'll make our channel shell, it will only have the minimal + // amount of information required for us to initiate the data loss + // protection feature. + channelShell, err := genRandomChannelShell() + if err != nil { + t.Fatalf("unable to gen channel shell: %v", err) + } + + graph := cdb.ChannelGraph() + + // Before we can restore the channel, we'll need to make a source node + // in the graph as the channel edge we create will need to have a + // origin. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // With the channel shell constructed, we'll now insert it into the + // database with the restoration method. + if err := cdb.RestoreChannelShells(channelShell); err != nil { + t.Fatalf("unable to restore channel shell: %v", err) + } + + // Now that the channel has been inserted, we'll attempt to query for + // it to ensure we can properly locate it via various means. + // + // First, we'll attempt to query for all channels that we have with the + // node public key that was restored. + nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable find channel: %v", err) + } + + // We should now find a single channel from the database. + if len(nodeChans) != 1 { + t.Fatalf("unable to find restored channel by node "+ + "pubkey: %v", err) + } + + // Ensure that it isn't possible to modify the commitment state machine + // of this restored channel. + channel := nodeChans[0] + err = channel.UpdateCommitment(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + err = channel.AppendRemoteCommitChain(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + err = channel.AdvanceCommitChainTail(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + + // That single channel should have the proper channel point, and also + // the expected set of flags to indicate that it was a restored + // channel. + if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint { + t.Fatalf("wrong funding outpoint: expected %v, got %v", + nodeChans[0].FundingOutpoint, + channelShell.Chan.FundingOutpoint) + } + if !nodeChans[0].HasChanStatus(ChanStatusRestored) { + t.Fatalf("node has wrong status flags: %v", + nodeChans[0].chanStatus) + } + + // We should also be able to find the channel if we query for it + // directly. + _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + + // We should also be able to find the link node that was inserted by + // its public key. + linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch link node: %v", err) + } + + // The node should have the same address, as specified in the channel + // shell. + if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) { + t.Fatalf("addr mismach: expected %v, got %v", + linkNode.Addresses, channelShell.NodeAddrs) + } + + // Finally, we'll ensure that the edge for the channel was properly + // inserted. + chanInfos, err := graph.FetchChanInfos( + []uint64{channelShell.Chan.ShortChannelID.ToUint64()}, + ) + if err != nil { + t.Fatalf("unable to find edges: %v", err) + } + + if len(chanInfos) != 1 { + t.Fatalf("wrong amount of chan infos: expected %v got %v", + len(chanInfos), 1) + } + + // We should only find a single edge. + if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil { + t.Fatalf("only a single edge should be inserted: %v", err) + } +} diff --git a/channeldb/migration_01_to_11/doc.go b/channeldb/migration_01_to_11/doc.go new file mode 100644 index 00000000..c90412f2 --- /dev/null +++ b/channeldb/migration_01_to_11/doc.go @@ -0,0 +1 @@ +package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/error.go b/channeldb/migration_01_to_11/error.go new file mode 100644 index 00000000..f264fb70 --- /dev/null +++ b/channeldb/migration_01_to_11/error.go @@ -0,0 +1,128 @@ +package migration_01_to_11 + +import ( + "errors" + "fmt" +) + +var ( + // ErrNoChanDBExists is returned when a channel bucket hasn't been + // created. + ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") + + // ErrDBReversion is returned when detecting an attempt to revert to a + // prior database version. + ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version") + + // ErrLinkNodesNotFound is returned when node info bucket hasn't been + // created. + ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist") + + // ErrNoActiveChannels is returned when there is no active (open) + // channels within the database. + ErrNoActiveChannels = fmt.Errorf("no active channels exist") + + // ErrNoPastDeltas is returned when the channel delta bucket hasn't been + // created. + ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") + + // ErrInvoiceNotFound is returned when a targeted invoice can't be + // found. + ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice") + + // ErrNoInvoicesCreated is returned when we don't have invoices in + // our database to return. + ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") + + // ErrDuplicateInvoice is returned when an invoice with the target + // payment hash already exists. + ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") + + // ErrNoPaymentsCreated is returned when bucket of payments hasn't been + // created. + ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") + + // ErrNodeNotFound is returned when node bucket exists, but node with + // specific identity can't be found. + ErrNodeNotFound = fmt.Errorf("link node with target identity not found") + + // ErrChannelNotFound is returned when we attempt to locate a channel + // for a specific chain, but it is not found. + ErrChannelNotFound = fmt.Errorf("channel not found") + + // ErrMetaNotFound is returned when meta bucket hasn't been + // created. + ErrMetaNotFound = fmt.Errorf("unable to locate meta information") + + // ErrGraphNotFound is returned when at least one of the components of + // graph doesn't exist. + ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") + + // ErrGraphNeverPruned is returned when graph was never pruned. + ErrGraphNeverPruned = fmt.Errorf("graph never pruned") + + // ErrSourceNodeNotSet is returned if the source node of the graph + // hasn't been added The source node is the center node within a + // star-graph. + ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") + + // ErrGraphNodesNotFound is returned in case none of the nodes has + // been added in graph node bucket. + ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") + + // ErrGraphNoEdgesFound is returned in case of none of the channel/edges + // has been added in graph edge bucket. + ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") + + // ErrGraphNodeNotFound is returned when we're unable to find the target + // node. + ErrGraphNodeNotFound = fmt.Errorf("unable to find node") + + // ErrEdgeNotFound is returned when an edge for the target chanID + // can't be found. + ErrEdgeNotFound = fmt.Errorf("edge not found") + + // ErrZombieEdge is an error returned when we attempt to look up an edge + // but it is marked as a zombie within the zombie index. + ErrZombieEdge = errors.New("edge marked as zombie") + + // ErrEdgeAlreadyExist is returned when edge with specific + // channel id can't be added because it already exist. + ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") + + // ErrNodeAliasNotFound is returned when alias for node can't be found. + ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") + + // ErrUnknownAddressType is returned when a node's addressType is not + // an expected value. + ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") + + // ErrNoClosedChannels is returned when a node is queries for all the + // channels it has closed, but it hasn't yet closed any channels. + ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") + + // ErrNoForwardingEvents is returned in the case that a query fails due + // to the log not having any recorded events. + ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events") + + // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel + // policy field is not found in the db even though its message flags + // indicate it should be. + ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + + "present") + + // ErrChanAlreadyExists is return when the caller attempts to create a + // channel with a channel point that is already present in the + // database. + ErrChanAlreadyExists = fmt.Errorf("channel already exists") +) + +// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the +// caller attempts to write an announcement message which bares too many extra +// opaque bytes. We limit this value in order to ensure that we don't waste +// disk space due to nodes unnecessarily padding out their announcements with +// garbage data. +func ErrTooManyExtraOpaqueBytes(numBytes int) error { + return fmt.Errorf("max allowed number of opaque bytes is %v, received "+ + "%v bytes", MaxAllowedExtraOpaqueBytes, numBytes) +} diff --git a/channeldb/migration_01_to_11/fees.go b/channeldb/migration_01_to_11/fees.go new file mode 100644 index 00000000..c90412f2 --- /dev/null +++ b/channeldb/migration_01_to_11/fees.go @@ -0,0 +1 @@ +package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/forwarding_log.go b/channeldb/migration_01_to_11/forwarding_log.go new file mode 100644 index 00000000..6b9f8f5d --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_log.go @@ -0,0 +1,274 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + "sort" + "time" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // forwardingLogBucket is the bucket that we'll use to store the + // forwarding log. The forwarding log contains a time series database + // of the forwarding history of a lightning daemon. Each key within the + // bucket is a timestamp (in nano seconds since the unix epoch), and + // the value a slice of a forwarding event for that timestamp. + forwardingLogBucket = []byte("circuit-fwd-log") +) + +const ( + // forwardingEventSize is the size of a forwarding event. The breakdown + // is as follows: + // + // * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in + // || 8 byte value out + // + // From the value in and value out, callers can easily compute the + // total fee extract from a forwarding event. + forwardingEventSize = 32 + + // MaxResponseEvents is the max number of forwarding events that will + // be returned by a single query response. This size was selected to + // safely remain under gRPC's 4MiB message size response limit. As each + // full forwarding event (including the timestamp) is 40 bytes, we can + // safely return 50k entries in a single response. + MaxResponseEvents = 50000 +) + +// ForwardingLog returns an instance of the ForwardingLog object backed by the +// target database instance. +func (d *DB) ForwardingLog() *ForwardingLog { + return &ForwardingLog{ + db: d, + } +} + +// ForwardingLog is a time series database that logs the fulfilment of payment +// circuits by a lightning network daemon. The log contains a series of +// forwarding events which map a timestamp to a forwarding event. A forwarding +// event describes which channels were used to create+settle a circuit, and the +// amount involved. Subtracting the outgoing amount from the incoming amount +// reveals the fee charged for the forwarding service. +type ForwardingLog struct { + db *DB +} + +// ForwardingEvent is an event in the forwarding log's time series. Each +// forwarding event logs the creation and tear-down of a payment circuit. A +// circuit is created once an incoming HTLC has been fully forwarded, and +// destroyed once the payment has been settled. +type ForwardingEvent struct { + // Timestamp is the settlement time of this payment circuit. + Timestamp time.Time + + // IncomingChanID is the incoming channel ID of the payment circuit. + IncomingChanID lnwire.ShortChannelID + + // OutgoingChanID is the outgoing channel ID of the payment circuit. + OutgoingChanID lnwire.ShortChannelID + + // AmtIn is the amount of the incoming HTLC. Subtracting this from the + // outgoing amount gives the total fees of this payment circuit. + AmtIn lnwire.MilliSatoshi + + // AmtOut is the amount of the outgoing HTLC. Subtracting the incoming + // amount from this gives the total fees for this payment circuit. + AmtOut lnwire.MilliSatoshi +} + +// encodeForwardingEvent writes out the target forwarding event to the passed +// io.Writer, using the expected DB format. Note that the timestamp isn't +// serialized as this will be the key value within the bucket. +func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { + return WriteElements( + w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, + ) +} + +// decodeForwardingEvent attempts to decode the raw bytes of a serialized +// forwarding event into the target ForwardingEvent. Note that the timestamp +// won't be decoded, as the caller is expected to set this due to the bucket +// structure of the forwarding log. +func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { + return ReadElements( + r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, + ) +} + +// AddForwardingEvents adds a series of forwarding events to the database. +// Before inserting, the set of events will be sorted according to their +// timestamp. This ensures that all writes to disk are sequential. +func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { + // Before we create the database transaction, we'll ensure that the set + // of forwarding events are properly sorted according to their + // timestamp. + sort.Slice(events, func(i, j int) bool { + return events[i].Timestamp.Before(events[j].Timestamp) + }) + + var timestamp [8]byte + + return f.db.Batch(func(tx *bbolt.Tx) error { + // First, we'll fetch the bucket that stores our time series + // log. + logBucket, err := tx.CreateBucketIfNotExists( + forwardingLogBucket, + ) + if err != nil { + return err + } + + // With the bucket obtained, we can now begin to write out the + // series of events. + for _, event := range events { + var eventBytes [forwardingEventSize]byte + eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) + + // First, we'll serialize this timestamp into our + // timestamp buffer. + byteOrder.PutUint64( + timestamp[:], uint64(event.Timestamp.UnixNano()), + ) + + // With the key encoded, we'll then encode the event + // into our buffer, then write it out to disk. + err := encodeForwardingEvent(eventBuf, &event) + if err != nil { + return err + } + err = logBucket.Put(timestamp[:], eventBuf.Bytes()) + if err != nil { + return err + } + } + + return nil + }) +} + +// ForwardingEventQuery represents a query to the forwarding log payment +// circuit time series database. The query allows a caller to retrieve all +// records for a particular time slice, offset in that time slice, limiting the +// total number of responses returned. +type ForwardingEventQuery struct { + // StartTime is the start time of the time slice. + StartTime time.Time + + // EndTime is the end time of the time slice. + EndTime time.Time + + // IndexOffset is the offset within the time slice to start at. This + // can be used to start the response at a particular record. + IndexOffset uint32 + + // NumMaxEvents is the max number of events to return. + NumMaxEvents uint32 +} + +// ForwardingLogTimeSlice is the response to a forwarding query. It includes +// the original query, the set events that match the query, and an integer +// which represents the offset index of the last item in the set of retuned +// events. This integer allows callers to resume their query using this offset +// in the event that the query's response exceeds the max number of returnable +// events. +type ForwardingLogTimeSlice struct { + ForwardingEventQuery + + // ForwardingEvents is the set of events in our time series that answer + // the query embedded above. + ForwardingEvents []ForwardingEvent + + // LastIndexOffset is the index of the last element in the set of + // returned ForwardingEvents above. Callers can use this to resume + // their query in the event that the time slice has too many events to + // fit into a single response. + LastIndexOffset uint32 +} + +// Query allows a caller to query the forwarding event time series for a +// particular time slice. The caller can control the precise time as well as +// the number of events to be returned. +// +// TODO(roasbeef): rename? +func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { + resp := ForwardingLogTimeSlice{ + ForwardingEventQuery: q, + } + + // If the user provided an index offset, then we'll not know how many + // records we need to skip. We'll also keep track of the record offset + // as that's part of the final return value. + recordsToSkip := q.IndexOffset + recordOffset := q.IndexOffset + + err := f.db.View(func(tx *bbolt.Tx) error { + // If the bucket wasn't found, then there aren't any events to + // be returned. + logBucket := tx.Bucket(forwardingLogBucket) + if logBucket == nil { + return ErrNoForwardingEvents + } + + // We'll be using a cursor to seek into the database, so we'll + // populate byte slices that represent the start of the key + // space we're interested in, and the end. + var startTime, endTime [8]byte + byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano())) + byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano())) + + // If we know that a set of log events exists, then we'll begin + // our seek through the log in order to satisfy the query. + // We'll continue until either we reach the end of the range, + // or reach our max number of events. + logCursor := logBucket.Cursor() + timestamp, events := logCursor.Seek(startTime[:]) + for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() { + // If our current return payload exceeds the max number + // of events, then we'll exit now. + if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents { + return nil + } + + // If we're not yet past the user defined offset, then + // we'll continue to seek forward. + if recordsToSkip > 0 { + recordsToSkip-- + continue + } + + currentTime := time.Unix( + 0, int64(byteOrder.Uint64(timestamp)), + ) + + // At this point, we've skipped enough records to start + // to collate our query. For each record, we'll + // increment the final record offset so the querier can + // utilize pagination to seek further. + readBuf := bytes.NewReader(events) + for readBuf.Len() != 0 { + var event ForwardingEvent + err := decodeForwardingEvent(readBuf, &event) + if err != nil { + return err + } + + event.Timestamp = currentTime + resp.ForwardingEvents = append(resp.ForwardingEvents, event) + + recordOffset++ + } + } + + return nil + }) + if err != nil && err != ErrNoForwardingEvents { + return ForwardingLogTimeSlice{}, err + } + + resp.LastIndexOffset = recordOffset + + return resp, nil +} diff --git a/channeldb/migration_01_to_11/forwarding_log_test.go b/channeldb/migration_01_to_11/forwarding_log_test.go new file mode 100644 index 00000000..9e0de7c4 --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_log_test.go @@ -0,0 +1,265 @@ +package migration_01_to_11 + +import ( + "math/rand" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" + + "time" +) + +// TestForwardingLogBasicStorageAndQuery tests that we're able to store and +// then query for items that have previously been added to the event log. +func TestForwardingLogBasicStorageAndQuery(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + timestamp := time.Unix(1234, 0) + + // We'll create 100 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 100 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: timestamp, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + timestamp = timestamp.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // With our events added we'll now construct a basic query to retrieve + // all of the events. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: timestamp, + IndexOffset: 0, + NumMaxEvents: 1000, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // The set of returned events should match identically, as they should + // be returned in sorted order. + if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) { + t.Fatalf("event mismatch: expected %v vs %v", + spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The offset index of the final entry should be numEvents, so the + // number of total events we've written. + if timeSlice.LastIndexOffset != uint32(numEvents) { + t.Fatalf("wrong final offset: expected %v, got %v", + timeSlice.LastIndexOffset, numEvents) + } +} + +// TestForwardingLogQueryOptions tests that the query offset works properly. So +// if we add a series of events, then we should be able to seek within the +// timeslice accordingly. This exercises the index offset and num max event +// field in the query, and also the last index offset field int he response. +func TestForwardingLogQueryOptions(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + endTime := time.Unix(1234, 0) + + // We'll create 20 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 20 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: endTime, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + endTime = endTime.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // With all of our events added, we should be able to query for the + // first 10 events using the max event query field. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: endTime, + IndexOffset: 0, + NumMaxEvents: 10, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 10 events back. + if len(timeSlice.ForwardingEvents) != 10 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The set of events returned should be the first 10 events that we + // added. + if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[:10]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The final offset should be the exact number of events returned. + if timeSlice.LastIndexOffset != 10 { + t.Fatalf("wrong index offset: expected %v, got %v", 10, + timeSlice.LastIndexOffset) + } + + // If we use the final offset to query again, then we should get 10 + // more events, that are the last 10 events we wrote. + eventQuery.IndexOffset = 10 + timeSlice, err = log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 10 events back once again. + if len(timeSlice.ForwardingEvents) != 10 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The events that we got back should be the last 10 events that we + // wrote out. + if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[10:]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // Finally, the last index offset should be 20, or the number of + // records we've written out. + if timeSlice.LastIndexOffset != 20 { + t.Fatalf("wrong index offset: expected %v, got %v", 20, + timeSlice.LastIndexOffset) + } +} + +// TestForwardingLogQueryLimit tests that we're able to properly limit the +// number of events that are returned as part of a query. +func TestForwardingLogQueryLimit(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + endTime := time.Unix(1234, 0) + + // We'll create 200 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 200 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: endTime, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + endTime = endTime.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // Once the events have been written out, we'll issue a query over the + // entire range, but restrict the number of events to the first 100. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: endTime, + IndexOffset: 0, + NumMaxEvents: 100, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 100 events back. + if len(timeSlice.ForwardingEvents) != 100 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The set of events returned should be the first 100 events that we + // added. + if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[:100]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The final offset should be the exact number of events returned. + if timeSlice.LastIndexOffset != 100 { + t.Fatalf("wrong index offset: expected %v, got %v", 100, + timeSlice.LastIndexOffset) + } +} diff --git a/channeldb/migration_01_to_11/forwarding_package.go b/channeldb/migration_01_to_11/forwarding_package.go new file mode 100644 index 00000000..cbbf90cf --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_package.go @@ -0,0 +1,928 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding +// package has potentially been mangled. +var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + +// FwdState is an enum used to describe the lifecycle of a FwdPkg. +type FwdState byte + +const ( + // FwdStateLockedIn is the starting state for all forwarding packages. + // Packages in this state have not yet committed to the exact set of + // Adds to forward to the switch. + FwdStateLockedIn FwdState = iota + + // FwdStateProcessed marks the state in which all Adds have been + // locally processed and the forwarding decision to the switch has been + // persisted. + FwdStateProcessed + + // FwdStateCompleted signals that all Adds have been acked, and that all + // settles and fails have been delivered to their sources. Packages in + // this state can be removed permanently. + FwdStateCompleted +) + +var ( + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +// PkgFilter is used to compactly represent a particular subset of the Adds in a +// forwarding package. Each filter is represented as a simple, statically-sized +// bitvector, where the elements are intended to be the indices of the Adds as +// they are written in the FwdPkg. +type PkgFilter struct { + count uint16 + filter []byte +} + +// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. +func NewPkgFilter(count uint16) *PkgFilter { + // We add 7 to ensure that the integer division yields properly rounded + // values. + filterLen := (count + 7) / 8 + + return &PkgFilter{ + count: count, + filter: make([]byte, filterLen), + } +} + +// Count returns the number of elements represented by this PkgFilter. +func (f *PkgFilter) Count() uint16 { + return f.count +} + +// Set marks the `i`-th element as included by this filter. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Set(i uint16) { + byt := i / 8 + bit := i % 8 + + // Set the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + f.filter[byt] |= byte(1 << (7 - bit)) +} + +// Contains queries the filter for membership of index `i`. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Contains(i uint16) bool { + byt := i / 8 + bit := i % 8 + + // Read the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + return f.filter[byt]&(1<<(7-bit)) != 0 +} + +// Equal checks two PkgFilters for equality. +func (f *PkgFilter) Equal(f2 *PkgFilter) bool { + if f == f2 { + return true + } + if f.count != f2.count { + return false + } + + return bytes.Equal(f.filter, f2.filter) +} + +// IsFull returns true if every element in the filter has been Set, and false +// otherwise. +func (f *PkgFilter) IsFull() bool { + // Batch validate bytes that are fully used. + for i := uint16(0); i < f.count/8; i++ { + if f.filter[i] != 0xFF { + return false + } + } + + // If the count is not a multiple of 8, check that the filter contains + // all remaining bits. + rem := f.count % 8 + for idx := f.count - rem; idx < f.count; idx++ { + if !f.Contains(idx) { + return false + } + } + + return true +} + +// Size returns number of bytes produced when the PkgFilter is serialized. +func (f *PkgFilter) Size() uint16 { + // 2 bytes for uint16 `count`, then round up number of bytes required to + // represent `count` bits. + return 2 + (f.count+7)/8 +} + +// Encode writes the filter to the provided io.Writer. +func (f *PkgFilter) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, f.count); err != nil { + return err + } + + _, err := w.Write(f.filter) + + return err +} + +// Decode reads the filter from the provided io.Reader. +func (f *PkgFilter) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { + return err + } + + f.filter = make([]byte, f.Size()-2) + _, err := io.ReadFull(r, f.filter) + + return err +} + +// FwdPkg records all adds, settles, and fails that were locked in as a result +// of the remote peer sending us a revocation. Each package is identified by +// the short chanid and remote commitment height corresponding to the revocation +// that locked in the HTLCs. For everything except a locally initiated payment, +// settles and fails in a forwarding package must have a corresponding Add in +// another package, and can be removed individually once the source link has +// received the fail/settle. +// +// Adds cannot be removed, as we need to present the same batch of Adds to +// properly handle replay protection. Instead, we use a PkgFilter to mark that +// we have finished processing a particular Add. A FwdPkg should only be deleted +// after the AckFilter is full and all settles and fails have been persistently +// removed. +type FwdPkg struct { + // Source identifies the channel that wrote this forwarding package. + Source lnwire.ShortChannelID + + // Height is the height of the remote commitment chain that locked in + // this forwarding package. + Height uint64 + + // State signals the persistent condition of the package and directs how + // to reprocess the package in the event of failures. + State FwdState + + // Adds contains all add messages which need to be processed and + // forwarded to the switch. Adds does not change over the life of a + // forwarding package. + Adds []LogUpdate + + // FwdFilter is a filter containing the indices of all Adds that were + // forwarded to the switch. + FwdFilter *PkgFilter + + // AckFilter is a filter containing the indices of all Adds for which + // the source has received a settle or fail and is reflected in the next + // commitment txn. A package should not be removed until IsFull() + // returns true. + AckFilter *PkgFilter + + // SettleFails contains all settle and fail messages that should be + // forwarded to the switch. + SettleFails []LogUpdate + + // SettleFailFilter is a filter containing the indices of all Settle or + // Fails originating in this package that have been received and locked + // into the incoming link's commitment state. + SettleFailFilter *PkgFilter +} + +// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This +// should be used to create a package at the time we receive a revocation. +func NewFwdPkg(source lnwire.ShortChannelID, height uint64, + addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { + + nAddUpdates := uint16(len(addUpdates)) + nSettleFailUpdates := uint16(len(settleFailUpdates)) + + return &FwdPkg{ + Source: source, + Height: height, + State: FwdStateLockedIn, + Adds: addUpdates, + FwdFilter: NewPkgFilter(nAddUpdates), + AckFilter: NewPkgFilter(nAddUpdates), + SettleFails: settleFailUpdates, + SettleFailFilter: NewPkgFilter(nSettleFailUpdates), + } +} + +// ID returns an unique identifier for this package, used to ensure that sphinx +// replay processing of this batch is idempotent. +func (f *FwdPkg) ID() []byte { + var id = make([]byte, 16) + byteOrder.PutUint64(id[:8], f.Source.ToUint64()) + byteOrder.PutUint64(id[8:], f.Height) + return id +} + +// String returns a human-readable description of the forwarding package. +func (f *FwdPkg) String() string { + return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", + f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) +} + +// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID +// is assumed to be that of the packager. +type AddRef struct { + // Height is the remote commitment height that locked in the Add. + Height uint64 + + // Index is the index of the Add within the fwd pkg's Adds. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// Encode serializes the AddRef to the given io.Writer. +func (a *AddRef) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { + return err + } + + return binary.Write(w, binary.BigEndian, a.Index) +} + +// Decode deserializes the AddRef from the given io.Reader. +func (a *AddRef) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { + return err + } + + return binary.Read(r, binary.BigEndian, &a.Index) +} + +// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A +// channel does not remove its own Settle/Fail htlcs, so the source is provided +// to locate a db bucket belonging to another channel. +type SettleFailRef struct { + // Source identifies the outgoing link that locked in the settle or + // fail. This is then used by the *incoming* link to find the settle + // fail in another link's forwarding packages. + Source lnwire.ShortChannelID + + // Height is the remote commitment height that locked in this + // Settle/Fail. + Height uint64 + + // Index is the index of the Add with the fwd pkg's SettleFails. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// SettleFailAcker is a generic interface providing the ability to acknowledge +// settle/fail HTLCs stored in forwarding packages. +type SettleFailAcker interface { + // AckSettleFails atomically updates the settle-fail filters in *other* + // channels' forwarding packages. + AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error +} + +// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages +// of any active channel. +type GlobalFwdPkgReader interface { + // LoadChannelFwdPkgs loads all known forwarding packages for the given + // channel. + LoadChannelFwdPkgs(tx *bbolt.Tx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) +} + +// FwdOperator defines the interfaces for managing forwarding packages that are +// external to a particular channel. This interface is used by the switch to +// read forwarding packages from arbitrary channels, and acknowledge settles and +// fails for locally-sourced payments. +type FwdOperator interface { + // GlobalFwdPkgReader provides read access to all known forwarding + // packages + GlobalFwdPkgReader + + // SettleFailAcker grants the ability to acknowledge settles or fails + // residing in arbitrary forwarding packages. + SettleFailAcker +} + +// SwitchPackager is a concrete implementation of the FwdOperator interface. +// A SwitchPackager offers the ability to read any forwarding package, and ack +// arbitrary settle and fail HTLCs. +type SwitchPackager struct{} + +// NewSwitchPackager instantiates a new SwitchPackager. +func NewSwitchPackager() *SwitchPackager { + return &SwitchPackager{} +} + +// AckSettleFails atomically updates the settle-fail filters in *other* +// channels' forwarding packages, to mark that the switch has received a settle +// or fail residing in the forwarding package of a link. +func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx, + settleFailRefs ...SettleFailRef) error { + + return ackSettleFails(tx, settleFailRefs) +} + +// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. +func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) { + + return loadChannelFwdPkgs(tx, source) +} + +// FwdPackager supports all operations required to modify fwd packages, such as +// creation, updates, reading, and removal. The interfaces are broken down in +// this way to support future delegation of the subinterfaces. +type FwdPackager interface { + // AddFwdPkg serializes and writes a FwdPkg for this channel at the + // remote commitment height included in the forwarding package. + AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error + + // SetFwdFilter looks up the forwarding package at the remote `height` + // and sets the `fwdFilter`, marking the Adds for which: + // 1) We are not the exit node + // 2) Passed all validation + // 3) Should be forwarded to the switch immediately after a failure + SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error + + // AckAddHtlcs atomically updates the add filters in this channel's + // forwarding packages to mark the resolution of an Add that was + // received from the remote party. + AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error + + // SettleFailAcker allows a link to acknowledge settle/fail HTLCs + // belonging to other channels. + SettleFailAcker + + // LoadFwdPkgs loads all known forwarding packages owned by this + // channel. + LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) + + // RemovePkg deletes a forwarding package owned by this channel at + // the provided remote `height`. + RemovePkg(tx *bbolt.Tx, height uint64) error +} + +// ChannelPackager is used by a channel to manage the lifecycle of its forwarding +// packages. The packager is tied to a particular source channel ID, allowing it +// to create and edit its own packages. Each packager also has the ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error { + fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey) + if err != nil { + return err + } + + source := makeLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := makeLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error { + var b bytes.Buffer + if err := htlc.Encode(&b); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) { + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID, + height uint64) (*FwdPkg, error) { + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.Bucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.Bucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &FwdPkg{ + Source: source, + State: FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check to see if we have written the set exported filter adds to + // disk. If we haven't, processing of this package was never started, or + // failed during the last attempt. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + if fwdFilterBytes == nil { + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = NewPkgFilter(nAdds) + return fwdPkg, nil + } + + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Otherwise, a complete round of processing was completed, and we + // advance the package to FwdStateProcessed. + fwdPkg.State = FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) { + var htlcs []LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + var htlc LogUpdate + if err := htlc.Decode(bytes.NewReader(v)); err != nil { + return err + } + + htlcs = append(htlcs, htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// SetFwdFilter writes the set of indexes corresponding to Adds at the +// `height` that are to be forwarded to the switch. Calling this method causes +// the forwarding package at `height` to be in FwdStateProcessed. We write this +// forwarding decision so that we always arrive at the same behavior for HTLCs +// leaving this channel. After a restart, we skip validation of these Adds, +// since they are assumed to have already been validated, and make the switch or +// outgoing link responsible for handling replays. +func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64, + fwdFilter *PkgFilter) error { + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + source := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(source[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + return ErrCorruptedFwdPkg + } + + // If the fwd filter has already been written, we return early to avoid + // modifying the persistent state. + forwardedAddsBytes := heightBkt.Get(fwdFilterKey) + if forwardedAddsBytes != nil { + return nil + } + + // Otherwise we serialize and write the provided fwd filter. + var b bytes.Buffer + if err := fwdFilter.Encode(&b); err != nil { + return err + } + + return heightBkt.Put(fwdFilterKey, b.Bytes()) +} + +// AckAddHtlcs accepts a list of references to add htlcs, and updates the +// AckAddFilter of those forwarding packages to indicate that a settle or fail +// has been received in response to the add. +func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error { + if len(addRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + sourceKey := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique height. + heightDiffs := make(map[uint64][]uint16) + for _, addRef := range addRefs { + heightDiffs[addRef.Height] = append( + heightDiffs[addRef.Height], + addRef.Index, + ) + } + + // Load each height bucket once and remove all acked htlcs at that + // height. + for height, indexes := range heightDiffs { + err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) + if err != nil { + return err + } + } + + return nil +} + +// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package +// with a list of indexes, writing the resulting filter back in its place. +func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64, + indexes []uint16) error { + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is successful, as there is nothing + // to ack. + return nil + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + ackFilter := &PkgFilter{} + ackFilterReader := bytes.NewReader(ackFilterBytes) + if err := ackFilter.Decode(ackFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + ackFilter.Set(index) + } + + // Write the resulting filter to disk. + var ackFilterBuf bytes.Buffer + if err := ackFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) +} + +// AckSettleFails persistently acknowledges settles or fails from a remote forwarding +// package. This should only be called after the source of the Add has locked in +// the settle/fail, or it becomes otherwise safe to forgo retransmitting the +// settle/fail after a restart. +func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error { + return ackSettleFails(tx, settleFailRefs) +} + +// ackSettleFails persistently acknowledges a batch of settle fail references. +func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error { + if len(settleFailRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique destination-height pair. + destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) + for _, settleFailRef := range settleFailRefs { + destHeights, ok := destHeightDiffs[settleFailRef.Source] + if !ok { + destHeights = make(map[uint64][]uint16) + destHeightDiffs[settleFailRef.Source] = destHeights + } + + destHeights[settleFailRef.Height] = append( + destHeights[settleFailRef.Height], + settleFailRef.Index, + ) + } + + // With the references organized by destination and height, we now load + // each remote bucket, and update the settle fail filter for any + // settle/fail htlcs. + for dest, destHeights := range destHeightDiffs { + destKey := makeLogKey(dest.ToUint64()) + destBkt := fwdPkgBkt.Bucket(destKey[:]) + if destBkt == nil { + // If the destination bucket is not found, this is + // likely the result of the destination channel being + // closed and having it's forwarding packages wiped. We + // won't treat this as an error, because the response + // will no longer be retransmitted internally. + continue + } + + for height, indexes := range destHeights { + err := ackSettleFailsAtHeight(destBkt, height, indexes) + if err != nil { + return err + } + } + } + + return nil +} + +// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes +// at particular a height by updating the settle fail filter. +func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64, + indexes []uint16) error { + + heightKey := makeLogKey(height) + heightBkt := destBkt.Bucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is as there is nothing to ack. + return nil + } + + // Load ack filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + settleFailFilter := &PkgFilter{} + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + settleFailFilter.Set(index) + } + + // Write the resulting filter to disk. + var settleFailFilterBuf bytes.Buffer + if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// RemovePkg deletes the forwarding package at the given height from the +// packager's source bucket. +func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error { + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil + } + + sourceBytes := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + + return sourceBkt.DeleteBucket(heightKey[:]) +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// Compile-time constraint to ensure that ChannelPackager implements the public +// FwdPackager interface. +var _ FwdPackager = (*ChannelPackager)(nil) + +// Compile-time constraint to ensure that SwitchPackager implements the public +// FwdOperator interface. +var _ FwdOperator = (*SwitchPackager)(nil) diff --git a/channeldb/migration_01_to_11/forwarding_package_test.go b/channeldb/migration_01_to_11/forwarding_package_test.go new file mode 100644 index 00000000..1128aad3 --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_package_test.go @@ -0,0 +1,815 @@ +package migration_01_to_11_test + +import ( + "bytes" + "io/ioutil" + "path/filepath" + "runtime" + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000, +// which is greater than the number of HTLCs we permit on a commitment txn. +// This should encapsulate every potential filter used in practice. +func TestPkgFilterBruteForce(t *testing.T) { + t.Parallel() + + checkPkgFilterRange(t, 1000) +} + +// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear +// insertion of `high` elements. This is primarily to test that IsFull functions +// properly for all relevant sizes of `high`. +func checkPkgFilterRange(t *testing.T, high int) { + for i := uint16(0); i < uint16(high); i++ { + f := channeldb.NewPkgFilter(i) + + if f.Count() != i { + t.Fatalf("pkg filter count=%d is actually %d", + i, f.Count()) + } + checkPkgFilterEncodeDecode(t, i, f) + + for j := uint16(0); j < i; j++ { + if f.Contains(j) { + t.Fatalf("pkg filter count=%d contains %d "+ + "before being added", i, j) + } + + f.Set(j) + checkPkgFilterEncodeDecode(t, i, f) + + if !f.Contains(j) { + t.Fatalf("pkg filter count=%d missing %d "+ + "after being added", i, j) + } + + if j < i-1 && f.IsFull() { + t.Fatalf("pkg filter count=%d already full", i) + } + } + + if !f.IsFull() { + t.Fatalf("pkg filter count=%d not full", i) + } + checkPkgFilterEncodeDecode(t, i, f) + } +} + +// TestPkgFilterRand uses a random permutation to verify the proper behavior of +// the pkg filter if the entries are not inserted in-order. +func TestPkgFilterRand(t *testing.T) { + t.Parallel() + + checkPkgFilterRand(t, 3, 17) +} + +// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting +// indices and asserting the invariants. The order in which indices are inserted +// is parameterized by a base `b` coprime to `p`, and using modular +// exponentiation to generate all elements in [1,p). +func checkPkgFilterRand(t *testing.T, b, p uint16) { + f := channeldb.NewPkgFilter(p) + var j = b + for i := uint16(1); i < p; i++ { + if f.Contains(j) { + t.Fatalf("pkg filter contains %d-%d "+ + "before being added", i, j) + } + + f.Set(j) + checkPkgFilterEncodeDecode(t, i, f) + + if !f.Contains(j) { + t.Fatalf("pkg filter missing %d-%d "+ + "after being added", i, j) + } + + if i < p-1 && f.IsFull() { + t.Fatalf("pkg filter %d already full", i) + } + checkPkgFilterEncodeDecode(t, i, f) + + j = (b * j) % p + } + + // Set 0 independently, since it will never be emitted by the generator. + f.Set(0) + checkPkgFilterEncodeDecode(t, p, f) + + if !f.IsFull() { + t.Fatalf("pkg filter count=%d not full", p) + } + checkPkgFilterEncodeDecode(t, p, f) +} + +// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by: +// 1) writing it to a buffer +// 2) verifying the number of bytes written matches the filter's Size() +// 3) reconstructing the filter decoding the bytes +// 4) checking that the two filters are the same according to Equal +func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) { + var b bytes.Buffer + if err := f.Encode(&b); err != nil { + t.Fatalf("unable to serialize pkg filter: %v", err) + } + + // +2 for uint16 length + size := uint16(len(b.Bytes())) + if size != f.Size() { + t.Fatalf("pkg filter count=%d serialized size differs, "+ + "Size(): %d, len(bytes): %v", i, f.Size(), size) + } + + reader := bytes.NewReader(b.Bytes()) + + f2 := &channeldb.PkgFilter{} + if err := f2.Decode(reader); err != nil { + t.Fatalf("unable to deserialize pkg filter: %v", err) + } + + if !f.Equal(f2) { + t.Fatalf("pkg filter count=%v does is not equal "+ + "after deserialization, want: %v, got %v", + i, f, f2) + } +} + +var ( + chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{}) + + adds = []channeldb.LogUpdate{ + { + LogIndex: 0, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 100, + Expiry: 1000, + PaymentHash: [32]byte{0}, + }, + }, + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 101, + Expiry: 1001, + PaymentHash: [32]byte{1}, + }, + }, + } + + settleFails = []channeldb.LogUpdate{ + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateFulfillHTLC{ + ChanID: chanID, + ID: 0, + PaymentPreimage: [32]byte{0}, + }, + }, + { + LogIndex: 3, + UpdateMsg: &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: 1, + Reason: []byte{}, + }, + }, + } +) + +// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a +// forwarding package that contains no adds, fails or settles. We expect that +// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding +// decision via SetFwdFilter. +func TestPackagerEmptyFwdPkg(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package with no htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. With no HTLCs, + // the ack filter will have no elements, and should always return true. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Now, write the forwarding decision. In this case, its just an empty + // fwd filter. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However with no htlcs, it should leap frog to FwdStateCompleted. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted +// as soon as all the adds in the package have been acked using AckAddHtlcs. +func TestPackagerOnlyAdds(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil) + + nAdds := len(adds) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range adds { + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However not allf of the HTLCs have been acked, so should not + // have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all adds have been + // acked, the ack filter should return true and the package should be + // FwdStateCompleted since there are no other settle/fail packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerOnlySettleFails asserts that the fwdpkg remains in +// FwdStateProcessed after writing the forwarding decision when there are no +// adds in the fwdpkg. We expect this because an empty FwdFilter will always +// return true, but we are still waiting for the remaining fails and settles to +// be deleted. +func TestPackagerOnlySettleFails(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails) + + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However, not all of the HTLCs have been + // acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and +// settle/fails, then checks the behavior when the adds are acked before any of +// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the +// remainder of the fail/settles are being deleted. +func TestPackagerAddsThenSettleFails(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + + nAdds := len(adds) + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range adds { + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However not allf of the HTLCs have been acked, so should not + // have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However not allf of the HTLCs have been + // acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to remove settle/fail htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and +// settle/fails, then checks the behavior when the settle/fails are removed +// before any of the adds have been acked. This should cause the fwdpkg to +// remain in FwdStateProcessed until the final ack is recorded, at which point +// it should be promoted directly to FwdStateCompleted.since all adds have been +// removed. +func TestPackagerSettleFailsThenAdds(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that has both add + // and settle/fail htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + + nAdds := len(adds) + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + // Simulate another channel deleting the settle/fails it received from + // the original fwd pkg. + // TODO(conner): use different packager/s? + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However none all of the add HTLCs have + // been acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to remove settle/fail htlc: %v", err) + } + } + + // Now simulate this channel receiving a fail/settle for the adds in the + // fwdpkg. + for i := range adds { + // Again, we should still have one package on disk and be in + // FwdStateProcessed. This should not change until all of the + // add htlcs have been acked. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// assertFwdPkgState checks the current state of a fwdpkg meets our +// expectations. +func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, + state channeldb.FwdState) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.State != state { + t.Fatalf("line %d: expected fwdpkg in state %v, found %v", + line, state, fwdPkg.State) + } +} + +// assertFwdPkgNumAddsSettleFails checks that the number of adds and +// settle/fail log updates are correct. +func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, + expectedNumAdds, expectedNumSettleFails int) { + _, _, line, _ := runtime.Caller(1) + if len(fwdPkg.Adds) != expectedNumAdds { + t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d", + line, expectedNumAdds, len(fwdPkg.Adds)) + } + + if len(fwdPkg.SettleFails) != expectedNumSettleFails { + t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d", + line, expectedNumSettleFails, len(fwdPkg.SettleFails)) + } +} + +// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our +// expected full-ness. +func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.AckFilter.IsFull() != expected { + t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+ + "found %v", line, expected, fwdPkg.AckFilter.IsFull()) + } +} + +// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail +// filter matches our expected full-ness. +func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.SettleFailFilter.IsFull() != expected { + t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+ + "found %v", line, expected, fwdPkg.SettleFailFilter.IsFull()) + } +} + +// loadFwdPkgs is a helper method that reads all forwarding packages for a +// particular packager. +func loadFwdPkgs(t *testing.T, db *bbolt.DB, + packager channeldb.FwdPackager) []*channeldb.FwdPkg { + + var fwdPkgs []*channeldb.FwdPkg + if err := db.View(func(tx *bbolt.Tx) error { + var err error + fwdPkgs, err = packager.LoadFwdPkgs(tx) + return err + }); err != nil { + t.Fatalf("unable to load fwd pkgs: %v", err) + } + + return fwdPkgs +} + +// makeFwdPkgDB initializes a test database for forwarding packages. If the +// provided path is an empty, it will create a temp dir/file to use. +func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB { + if path == "" { + var err error + path, err = ioutil.TempDir("", "fwdpkgdb") + if err != nil { + t.Fatalf("unable to create temp path: %v", err) + } + + path = filepath.Join(path, "fwdpkg.db") + } + + db, err := bbolt.Open(path, 0600, nil) + if err != nil { + t.Fatalf("unable to open boltdb: %v", err) + } + + return db +} diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go new file mode 100644 index 00000000..d90863c6 --- /dev/null +++ b/channeldb/migration_01_to_11/graph.go @@ -0,0 +1,4060 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "image/color" + "io" + "math" + "net" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // nodeBucket is a bucket which houses all the vertices or nodes within + // the channel graph. This bucket has a single-sub bucket which adds an + // additional index from pubkey -> alias. Within the top-level of this + // bucket, the key space maps a node's compressed public key to the + // serialized information for that node. Additionally, there's a + // special key "source" which stores the pubkey of the source node. The + // source node is used as the starting point for all graph/queries and + // traversals. The graph is formed as a star-graph with the source node + // at the center. + // + // maps: pubKey -> nodeInfo + // maps: source -> selfPubKey + nodeBucket = []byte("graph-node") + + // nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket + // will be used to quickly look up the "freshness" of a node's last + // update to the network. The bucket only contains keys, and no values, + // it's mapping: + // + // maps: updateTime || nodeID -> nil + nodeUpdateIndexBucket = []byte("graph-node-update-index") + + // sourceKey is a special key that resides within the nodeBucket. The + // sourceKey maps a key to the public key of the "self node". + sourceKey = []byte("source") + + // aliasIndexBucket is a sub-bucket that's nested within the main + // nodeBucket. This bucket maps the public key of a node to its + // current alias. This bucket is provided as it can be used within a + // future UI layer to add an additional degree of confirmation. + aliasIndexBucket = []byte("alias") + + // edgeBucket is a bucket which houses all of the edge or channel + // information within the channel graph. This bucket essentially acts + // as an adjacency list, which in conjunction with a range scan, can be + // used to iterate over all the incoming and outgoing edges for a + // particular node. Key in the bucket use a prefix scheme which leads + // with the node's public key and sends with the compact edge ID. + // For each chanID, there will be two entries within the bucket, as the + // graph is directed: nodes may have different policies w.r.t to fees + // for their respective directions. + // + // maps: pubKey || chanID -> channel edge policy for node + edgeBucket = []byte("graph-edge") + + // unknownPolicy is represented as an empty slice. It is + // used as the value in edgeBucket for unknown channel edge policies. + // Unknown policies are still stored in the database to enable efficient + // lookup of incoming channel edges. + unknownPolicy = []byte{} + + // chanStart is an array of all zero bytes which is used to perform + // range scans within the edgeBucket to obtain all of the outgoing + // edges for a particular node. + chanStart [8]byte + + // edgeIndexBucket is an index which can be used to iterate all edges + // in the bucket, grouping them according to their in/out nodes. + // Additionally, the items in this bucket also contain the complete + // edge information for a channel. The edge information includes the + // capacity of the channel, the nodes that made the channel, etc. This + // bucket resides within the edgeBucket above. Creation of an edge + // proceeds in two phases: first the edge is added to the edge index, + // afterwards the edgeBucket can be updated with the latest details of + // the edge as they are announced on the network. + // + // maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo + edgeIndexBucket = []byte("edge-index") + + // edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This + // bucket contains an index which allows us to gauge the "freshness" of + // a channel's last updates. + // + // maps: updateTime || chanID -> nil + edgeUpdateIndexBucket = []byte("edge-update-index") + + // channelPointBucket maps a channel's full outpoint (txid:index) to + // its short 8-byte channel ID. This bucket resides within the + // edgeBucket above, and can be used to quickly remove an edge due to + // the outpoint being spent, or to query for existence of a channel. + // + // maps: outPoint -> chanID + channelPointBucket = []byte("chan-index") + + // zombieBucket is a sub-bucket of the main edgeBucket bucket + // responsible for maintaining an index of zombie channels. Each entry + // exists within the bucket as follows: + // + // maps: chanID -> pubKey1 || pubKey2 + // + // The chanID represents the channel ID of the edge that is marked as a + // zombie and is used as the key, which maps to the public keys of the + // edge's participants. + zombieBucket = []byte("zombie-index") + + // disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket bucket + // responsible for maintaining an index of disabled edge policies. Each + // entry exists within the bucket as follows: + // + // maps: -> []byte{} + // + // The chanID represents the channel ID of the edge and the direction is + // one byte representing the direction of the edge. The main purpose of + // this index is to allow pruning disabled channels in a fast way without + // the need to iterate all over the graph. + disabledEdgePolicyBucket = []byte("disabled-edge-policy-index") + + // graphMetaBucket is a top-level bucket which stores various meta-deta + // related to the on-disk channel graph. Data stored in this bucket + // includes the block to which the graph has been synced to, the total + // number of channels, etc. + graphMetaBucket = []byte("graph-meta") + + // pruneLogBucket is a bucket within the graphMetaBucket that stores + // a mapping from the block height to the hash for the blocks used to + // prune the graph. + // Once a new block is discovered, any channels that have been closed + // (by spending the outpoint) can safely be removed from the graph, and + // the block is added to the prune log. We need to keep such a log for + // the case where a reorg happens, and we must "rewind" the state of the + // graph by removing channels that were previously confirmed. In such a + // case we'll remove all entries from the prune log with a block height + // that no longer exists. + pruneLogBucket = []byte("prune-log") +) + +const ( + // MaxAllowedExtraOpaqueBytes is the largest amount of opaque bytes that + // we'll permit to be written to disk. We limit this as otherwise, it + // would be possible for a node to create a ton of updates and slowly + // fill our disk, and also waste bandwidth due to relaying. + MaxAllowedExtraOpaqueBytes = 10000 + + // feeRateParts is the total number of parts used to express fee rates. + feeRateParts = 1e6 +) + +// ChannelGraph is a persistent, on-disk graph representation of the Lightning +// Network. This struct can be used to implement path finding algorithms on top +// of, and also to update a node's view based on information received from the +// p2p network. Internally, the graph is stored using a modified adjacency list +// representation with some added object interaction possible with each +// serialized edge/node. The graph is stored is directed, meaning that are two +// edges stored for each channel: an inbound/outbound edge for each node pair. +// Nodes, edges, and edge information can all be added to the graph +// independently. Edge removal results in the deletion of all edge information +// for that edge. +type ChannelGraph struct { + db *DB + + cacheMu sync.RWMutex + rejectCache *rejectCache + chanCache *channelCache +} + +// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The +// returned instance has its own unique reject cache and channel cache. +func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph { + return &ChannelGraph{ + db: db, + rejectCache: newRejectCache(rejectCacheSize), + chanCache: newChannelCache(chanCacheSize), + } +} + +// Database returns a pointer to the underlying database. +func (c *ChannelGraph) Database() *DB { + return c.db +} + +// ForEachChannel iterates through all the channel edges stored within the +// graph and invokes the passed callback for each edge. The callback takes two +// edges as since this is a directed graph, both the in/out edges are visited. +// If the callback returns an error, then the transaction is aborted and the +// iteration stops early. +// +// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer +// for that particular channel edge routing policy will be passed into the +// callback. +func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates + + return c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // For each edge pair within the edge index, we fetch each edge + // itself and also the node information in order to fully + // populated the object. + return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + infoReader := bytes.NewReader(edgeInfoBytes) + edgeInfo, err := deserializeChanEdgeInfo(infoReader) + if err != nil { + return err + } + edgeInfo.db = c.db + + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + // With both edges read, execute the call back. IF this + // function returns an error then the transaction will + // be aborted. + return cb(&edgeInfo, edge1, edge2) + }) + }) +} + +// ForEachNodeChannel iterates through all channels of a given node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + db := c.db + + return nodeTraversal(tx, nodePub, db, cb) +} + +// DisabledChannelIDs returns the channel ids of disabled channels. +// A channel is disabled when two of the associated ChanelEdgePolicies +// have their disabled bit on. +func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { + var disabledChanIDs []uint64 + chanEdgeFound := make(map[uint64]struct{}) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + + disabledEdgePolicyIndex := edges.Bucket(disabledEdgePolicyBucket) + if disabledEdgePolicyIndex == nil { + return nil + } + + // We iterate over all disabled policies and we add each channel that + // has more than one disabled policy to disabledChanIDs array. + return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error { + chanID := byteOrder.Uint64(k[:8]) + _, edgeFound := chanEdgeFound[chanID] + if edgeFound { + delete(chanEdgeFound, chanID) + disabledChanIDs = append(disabledChanIDs, chanID) + return nil + } + + chanEdgeFound[chanID] = struct{}{} + return nil + }) + }) + if err != nil { + return nil, err + } + + return disabledChanIDs, nil +} + +// ForEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops +// early. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal +// +// TODO(roasbeef): add iterator interface to allow for memory efficient graph +// traversal when graph gets mega +func (c *ChannelGraph) ForEachNode(tx *bbolt.Tx, cb func(*bbolt.Tx, *LightningNode) error) error { + traversal := func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + return nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + nodeReader := bytes.NewReader(nodeBytes) + node, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + node.db = c.db + + // Execute the callback, the transaction will abort if + // this returns an error. + return cb(tx, &node) + }) + } + + // If no transaction was provided, then we'll create a new transaction + // to execute the transaction within. + if tx == nil { + return c.db.View(traversal) + } + + // Otherwise, we re-use the existing transaction to execute the graph + // traversal. + return traversal(tx) +} + +// SourceNode returns the source node of the graph. The source node is treated +// as the center node within a star-graph. This method may be used to kick off +// a path finding algorithm in order to explore the reachability of another +// node based off the source node. +func (c *ChannelGraph) SourceNode() (*LightningNode, error) { + var source *LightningNode + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := c.sourceNode(nodes) + if err != nil { + return err + } + source = node + + return nil + }) + if err != nil { + return nil, err + } + + return source, nil +} + +// sourceNode uses an existing database transaction and returns the source node +// of the graph. The source node is treated as the center node within a +// star-graph. This method may be used to kick off a path finding algorithm in +// order to explore the reachability of another node based off the source node. +func (c *ChannelGraph) sourceNode(nodes *bbolt.Bucket) (*LightningNode, error) { + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return nil, ErrSourceNodeNotSet + } + + // With the pubKey of the source node retrieved, we're able to + // fetch the full node information. + node, err := fetchLightningNode(nodes, selfPub) + if err != nil { + return nil, err + } + node.db = c.db + + return &node, nil +} + +// SetSourceNode sets the source node within the graph database. The source +// node is to be used as the center of a star-graph within path finding +// algorithms. +func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { + nodePubBytes := node.PubKeyBytes[:] + + return c.db.Update(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + // Next we create the mapping from source to the targeted + // public key. + if err := nodes.Put(sourceKey, nodePubBytes); err != nil { + return err + } + + // Finally, we commit the information of the lightning node + // itself. + return addLightningNode(tx, node) + }) +} + +// AddLightningNode adds a vertex/node to the graph database. If the node is not +// in the database from before, this will add a new, unconnected one to the +// graph. If it is present from before, this will update that node's +// information. Note that this method is expected to only be called to update +// an already present node from a node announcement, or to insert a node found +// in a channel update. +// +// TODO(roasbeef): also need sig of announcement +func (c *ChannelGraph) AddLightningNode(node *LightningNode) error { + return c.db.Update(func(tx *bbolt.Tx) error { + return addLightningNode(tx, node) + }) +} + +func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + aliases, err := nodes.CreateBucketIfNotExists(aliasIndexBucket) + if err != nil { + return err + } + + updateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return err + } + + return putLightningNode(nodes, aliases, updateIndex, node) +} + +// LookupAlias attempts to return the alias as advertised by the target node. +// TODO(roasbeef): currently assumes that aliases are unique... +func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { + var alias string + + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + nodePub := pub.SerializeCompressed() + a := aliases.Get(nodePub) + if a == nil { + return ErrNodeAliasNotFound + } + + // TODO(roasbeef): should actually be using the utf-8 + // package... + alias = string(a) + return nil + }) + if err != nil { + return "", err + } + + return alias, nil +} + +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. +func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { + // TODO(roasbeef): ensure dangling edges are removed... + return c.db.Update(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + + return c.deleteLightningNode( + nodes, nodePub.SerializeCompressed(), + ) + }) +} + +// deleteLightningNode uses an existing database transaction to remove a +// vertex/node from the database according to the node's public key. +func (c *ChannelGraph) deleteLightningNode(nodes *bbolt.Bucket, + compressedPubKey []byte) error { + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + if err := aliases.Delete(compressedPubKey); err != nil { + return err + } + + // Before we delete the node, we'll fetch its current state so we can + // determine when its last update was to clear out the node update + // index. + node, err := fetchLightningNode(nodes, compressedPubKey) + if err != nil { + return err + } + + if err := nodes.Delete(compressedPubKey); err != nil { + + return err + } + + // Finally, we'll delete the index entry for the node within the + // nodeUpdateIndexBucket as this node is no longer active, so we don't + // need to track its last update. + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // In order to delete the entry, we'll need to reconstruct the key for + // its last update. + updateUnix := uint64(node.LastUpdate.Unix()) + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], compressedPubKey) + + return nodeUpdateIndex.Delete(indexKey[:]) +} + +// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An +// undirected edge from the two target nodes are created. The information +// stored denotes the static attributes of the channel, such as the channelID, +// the keys involved in creation of the channel, and the set of features that +// the channel supports. The chanPoint and chanID are used to uniquely identify +// the edge globally within the database. +func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + return c.addChannelEdge(tx, edge) + }) + if err != nil { + return err + } + + c.rejectCache.remove(edge.ChannelID) + c.chanCache.remove(edge.ChannelID) + + return nil +} + +// addChannelEdge is the private form of AddChannelEdge that allows callers to +// utilize an existing db transaction. +func (c *ChannelGraph) addChannelEdge(tx *bbolt.Tx, edge *ChannelEdgeInfo) error { + // Construct the channel's primary key which is the 8-byte channel ID. + var chanKey [8]byte + binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + + // First, attempt to check if this edge has already been created. If + // so, then we can exit early as this method is meant to be idempotent. + if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil { + return ErrEdgeAlreadyExist + } + + // Before we insert the channel into the database, we'll ensure that + // both nodes already exist in the channel graph. If either node + // doesn't, then we'll insert a "shell" node that just includes its + // public key, so subsequent validation and queries can work properly. + _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) + switch { + case node1Err == ErrGraphNodeNotFound: + node1Shell := LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + HaveNodeAnnouncement: false, + } + err := addLightningNode(tx, &node1Shell) + if err != nil { + return fmt.Errorf("unable to create shell node "+ + "for: %x", edge.NodeKey1Bytes) + + } + case node1Err != nil: + return err + } + + _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) + switch { + case node2Err == ErrGraphNodeNotFound: + node2Shell := LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + HaveNodeAnnouncement: false, + } + err := addLightningNode(tx, &node2Shell) + if err != nil { + return fmt.Errorf("unable to create shell node "+ + "for: %x", edge.NodeKey2Bytes) + + } + case node2Err != nil: + return err + } + + // If the edge hasn't been created yet, then we'll first add it to the + // edge index in order to associate the edge between two nodes and also + // store the static components of the channel. + if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { + return err + } + + // Mark edge policies for both sides as unknown. This is to enable + // efficient incoming channel lookup for a node. + for _, key := range []*[33]byte{&edge.NodeKey1Bytes, + &edge.NodeKey2Bytes} { + + err := putChanEdgePolicyUnknown(edges, edge.ChannelID, + key[:]) + if err != nil { + return err + } + } + + // Finally we add it to the channel index which maps channel points + // (outpoints) to the shorter channel ID's. + var b bytes.Buffer + if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { + return err + } + return chanIndex.Put(b.Bytes(), chanKey[:]) +} + +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID, and false otherwise. If an edge with that ID is found +// within the graph, then two time stamps representing the last time the edge +// was updated for both directed edges are returned along with the boolean. If +// it is not found, then the zombie index is checked and its result is returned +// as the second boolean. +func (c *ChannelGraph) HasChannelEdge( + chanID uint64) (time.Time, time.Time, bool, bool, error) { + + var ( + upd1Time time.Time + upd2Time time.Time + exists bool + isZombie bool + ) + + // We'll query the cache with the shared lock held to allow multiple + // readers to access values in the cache concurrently if they exist. + c.cacheMu.RLock() + if entry, ok := c.rejectCache.get(chanID); ok { + c.cacheMu.RUnlock() + upd1Time = time.Unix(entry.upd1Time, 0) + upd2Time = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + return upd1Time, upd2Time, exists, isZombie, nil + } + c.cacheMu.RUnlock() + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + // The item was not found with the shared lock, so we'll acquire the + // exclusive lock and check the cache again in case another method added + // the entry to the cache while no lock was held. + if entry, ok := c.rejectCache.get(chanID); ok { + upd1Time = time.Unix(entry.upd1Time, 0) + upd2Time = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + return upd1Time, upd2Time, exists, isZombie, nil + } + + if err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + var channelID [8]byte + byteOrder.PutUint64(channelID[:], chanID) + + // If the edge doesn't exist, then we'll also check our zombie + // index. + if edgeIndex.Get(channelID[:]) == nil { + exists = false + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex != nil { + isZombie, _, _ = isZombieEdge( + zombieIndex, chanID, + ) + } + + return nil + } + + exists = true + isZombie = false + + // If the channel has been found in the graph, then retrieve + // the edges itself so we can return the last updated + // timestamps. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + + e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes, + channelID[:], c.db) + if err != nil { + return err + } + + // As we may have only one of the edges populated, only set the + // update time if the edge was found in the database. + if e1 != nil { + upd1Time = e1.LastUpdate + } + if e2 != nil { + upd2Time = e2.LastUpdate + } + + return nil + }); err != nil { + return time.Time{}, time.Time{}, exists, isZombie, err + } + + c.rejectCache.insert(chanID, rejectCacheEntry{ + upd1Time: upd1Time.Unix(), + upd2Time: upd2Time.Unix(), + flags: packRejectFlags(exists, isZombie), + }) + + return upd1Time, upd2Time, exists, isZombie, nil +} + +// UpdateChannelEdge retrieves and update edge of the graph database. Method +// only reserved for updating an edge info after its already been created. +// In order to maintain this constraints, we return an error in the scenario +// that an edge info hasn't yet been created yet, but someone attempts to update +// it. +func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { + // Construct the channel's primary key which is the 8-byte channel ID. + var chanKey [8]byte + binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + + return c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edge == nil { + return ErrEdgeNotFound + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + + if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo == nil { + return ErrEdgeNotFound + } + + return putChanEdgeInfo(edgeIndex, edge, chanKey) + }) +} + +const ( + // pruneTipBytes is the total size of the value which stores a prune + // entry of the graph in the prune log. The "prune tip" is the last + // entry in the prune log, and indicates if the channel graph is in + // sync with the current UTXO state. The structure of the value + // is: blockHash, taking 32 bytes total. + pruneTipBytes = 32 +) + +// PruneGraph prunes newly closed channels from the channel graph in response +// to a new block being solved on the network. Any transactions which spend the +// funding output of any known channels within he graph will be deleted. +// Additionally, the "prune tip", or the last block which has been used to +// prune the graph is stored so callers can ensure the graph is fully in sync +// with the current UTXO state. A slice of channels that have been closed by +// the target block are returned if the function succeeds without error. +func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, + blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) { + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var chansClosed []*ChannelEdgeInfo + + err := c.db.Update(func(tx *bbolt.Tx) error { + // First grab the edges bucket which houses the information + // we'd like to delete + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + + // Next grab the two edge indexes which will also need to be updated. + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrSourceNodeNotSet + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + // For each of the outpoints that have been spent within the + // block, we attempt to delete them from the graph as if that + // outpoint was a channel, then it has now been closed. + for _, chanPoint := range spentOutputs { + // TODO(roasbeef): load channel bloom filter, continue + // if NOT if filter + + var opBytes bytes.Buffer + if err := writeOutpoint(&opBytes, chanPoint); err != nil { + return err + } + + // First attempt to see if the channel exists within + // the database, if not, then we can exit early. + chanID := chanIndex.Get(opBytes.Bytes()) + if chanID == nil { + continue + } + + // However, if it does, then we'll read out the full + // version so we can add it to the set of deleted + // channels. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // Attempt to delete the channel, an ErrEdgeNotFound + // will be returned if that outpoint isn't known to be + // a channel. If no error is returned, then a channel + // was successfully pruned. + err = delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + chanID, false, + ) + if err != nil && err != ErrEdgeNotFound { + return err + } + + chansClosed = append(chansClosed, &edgeInfo) + } + + metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) + if err != nil { + return err + } + + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + // With the graph pruned, add a new entry to the prune log, + // which can be used to check if the graph is fully synced with + // the current UTXO state. + var blockHeightBytes [4]byte + byteOrder.PutUint32(blockHeightBytes[:], blockHeight) + + var newTip [pruneTipBytes]byte + copy(newTip[:], blockHash[:]) + + err = pruneBucket.Put(blockHeightBytes[:], newTip[:]) + if err != nil { + return err + } + + // Now that the graph has been pruned, we'll also attempt to + // prune any nodes that have had a channel closed within the + // latest block. + return c.pruneGraphNodes(nodes, edgeIndex) + }) + if err != nil { + return nil, err + } + + for _, channel := range chansClosed { + c.rejectCache.remove(channel.ChannelID) + c.chanCache.remove(channel.ChannelID) + } + + return chansClosed, nil +} + +// PruneGraphNodes is a garbage collection method which attempts to prune out +// any nodes from the channel graph that are currently unconnected. This ensure +// that we only maintain a graph of reachable nodes. In the event that a pruned +// node gains more channels, it will be re-added back to the graph. +func (c *ChannelGraph) PruneGraphNodes() error { + return c.db.Update(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + return c.pruneGraphNodes(nodes, edgeIndex) + }) +} + +// pruneGraphNodes attempts to remove any nodes from the graph who have had a +// channel closed within the current block. If the node still has existing +// channels in the graph, this will act as a no-op. +func (c *ChannelGraph) pruneGraphNodes(nodes *bbolt.Bucket, + edgeIndex *bbolt.Bucket) error { + + log.Trace("Pruning nodes from graph with no open channels") + + // We'll retrieve the graph's source node to ensure we don't remove it + // even if it no longer has any open channels. + sourceNode, err := c.sourceNode(nodes) + if err != nil { + return err + } + + // We'll use this map to keep count the number of references to a node + // in the graph. A node should only be removed once it has no more + // references in the graph. + nodeRefCounts := make(map[[33]byte]int) + err = nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + var nodePub [33]byte + copy(nodePub[:], pubKey) + nodeRefCounts[nodePub] = 0 + + return nil + }) + if err != nil { + return err + } + + // To ensure we never delete the source node, we'll start off by + // bumping its ref count to 1. + nodeRefCounts[sourceNode.PubKeyBytes] = 1 + + // Next, we'll run through the edgeIndex which maps a channel ID to the + // edge info. We'll use this scan to populate our reference count map + // above. + err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + // The first 66 bytes of the edge info contain the pubkeys of + // the nodes that this edge attaches. We'll extract them, and + // add them to the ref count map. + var node1, node2 [33]byte + copy(node1[:], edgeInfoBytes[:33]) + copy(node2[:], edgeInfoBytes[33:]) + + // With the nodes extracted, we'll increase the ref count of + // each of the nodes. + nodeRefCounts[node1]++ + nodeRefCounts[node2]++ + + return nil + }) + if err != nil { + return err + } + + // Finally, we'll make a second pass over the set of nodes, and delete + // any nodes that have a ref count of zero. + var numNodesPruned int + for nodePubKey, refCount := range nodeRefCounts { + // If the ref count of the node isn't zero, then we can safely + // skip it as it still has edges to or from it within the + // graph. + if refCount != 0 { + continue + } + + // If we reach this point, then there are no longer any edges + // that connect this node, so we can delete it. + if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { + log.Warnf("Unable to prune node %x from the "+ + "graph: %v", nodePubKey, err) + continue + } + + log.Infof("Pruned unconnected node %x from channel graph", + nodePubKey[:]) + + numNodesPruned++ + } + + if numNodesPruned > 0 { + log.Infof("Pruned %v unconnected nodes from the channel graph", + numNodesPruned) + } + + return nil +} + +// DisconnectBlockAtHeight is used to indicate that the block specified +// by the passed height has been disconnected from the main chain. This +// will "rewind" the graph back to the height below, deleting channels +// that are no longer confirmed from the graph. The prune log will be +// set to the last prune height valid for the remaining chain. +// Channels that were removed from the graph resulting from the +// disconnected block are returned. +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, + error) { + + // Every channel having a ShortChannelID starting at 'height' + // will no longer be confirmed. + startShortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + } + + // Delete everything after this height from the db. + endShortChanID := lnwire.ShortChannelID{ + BlockHeight: math.MaxUint32 & 0x00ffffff, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + // The block height will be the 3 first bytes of the channel IDs. + var chanIDStart [8]byte + byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) + var chanIDEnd [8]byte + byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + // Keep track of the channels that are removed from the graph. + var removedChans []*ChannelEdgeInfo + + if err := c.db.Update(func(tx *bbolt.Tx) error { + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + // Scan from chanIDStart to chanIDEnd, deleting every + // found edge. + // NOTE: we must delete the edges after the cursor loop, since + // modifying the bucket while traversing is not safe. + var keys [][]byte + cursor := edgeIndex.Cursor() + for k, v := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { + + edgeInfoReader := bytes.NewReader(v) + edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader) + if err != nil { + return err + } + + keys = append(keys, k) + removedChans = append(removedChans, &edgeInfo) + } + + for _, k := range keys { + err = delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + k, false, + ) + if err != nil && err != ErrEdgeNotFound { + return err + } + } + + // Delete all the entries in the prune log having a height + // greater or equal to the block disconnected. + metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) + if err != nil { + return err + } + + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + var pruneKeyStart [4]byte + byteOrder.PutUint32(pruneKeyStart[:], height) + + var pruneKeyEnd [4]byte + byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32) + + // To avoid modifying the bucket while traversing, we delete + // the keys in a second loop. + var pruneKeys [][]byte + pruneCursor := pruneBucket.Cursor() + for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && + bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { + + pruneKeys = append(pruneKeys, k) + } + + for _, k := range pruneKeys { + if err := pruneBucket.Delete(k); err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + for _, channel := range removedChans { + c.rejectCache.remove(channel.ChannelID) + c.chanCache.remove(channel.ChannelID) + } + + return removedChans, nil +} + +// PruneTip returns the block height and hash of the latest block that has been +// used to prune channels in the graph. Knowing the "prune tip" allows callers +// to tell if the graph is currently in sync with the current best known UTXO +// state. +func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { + var ( + tipHash chainhash.Hash + tipHeight uint32 + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + graphMeta := tx.Bucket(graphMetaBucket) + if graphMeta == nil { + return ErrGraphNotFound + } + pruneBucket := graphMeta.Bucket(pruneLogBucket) + if pruneBucket == nil { + return ErrGraphNeverPruned + } + + pruneCursor := pruneBucket.Cursor() + + // The prune key with the largest block height will be our + // prune tip. + k, v := pruneCursor.Last() + if k == nil { + return ErrGraphNeverPruned + } + + // Once we have the prune tip, the value will be the block hash, + // and the key the block height. + copy(tipHash[:], v[:]) + tipHeight = byteOrder.Uint32(k[:]) + + return nil + }) + if err != nil { + return nil, 0, err + } + + return &tipHash, tipHeight, nil +} + +// DeleteChannelEdges removes edges with the given channel IDs from the database +// and marks them as zombies. This ensures that we're unable to re-add it to our +// database once again. If an edge does not exist within the database, then +// ErrEdgeNotFound will be returned. +func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error { + // TODO(roasbeef): possibly delete from node bucket if node has no more + // channels + // TODO(roasbeef): don't delete both edges? + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrEdgeNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrEdgeNotFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + var rawChanID [8]byte + for _, chanID := range chanIDs { + byteOrder.PutUint64(rawChanID[:], chanID) + err := delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + rawChanID[:], true, + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + for _, chanID := range chanIDs { + c.rejectCache.remove(chanID) + c.chanCache.remove(chanID) + } + + return nil +} + +// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the +// passed channel point (outpoint). If the passed channel doesn't exist within +// the database, then ErrEdgeNotFound is returned. +func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { + var chanID uint64 + if err := c.db.View(func(tx *bbolt.Tx) error { + var err error + chanID, err = getChanID(tx, chanPoint) + return err + }); err != nil { + return 0, err + } + + return chanID, nil +} + +// getChanID returns the assigned channel ID for a given channel point. +func getChanID(tx *bbolt.Tx, chanPoint *wire.OutPoint) (uint64, error) { + var b bytes.Buffer + if err := writeOutpoint(&b, chanPoint); err != nil { + return 0, err + } + + edges := tx.Bucket(edgeBucket) + if edges == nil { + return 0, ErrGraphNoEdgesFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return 0, ErrGraphNoEdgesFound + } + + chanIDBytes := chanIndex.Get(b.Bytes()) + if chanIDBytes == nil { + return 0, ErrEdgeNotFound + } + + chanID := byteOrder.Uint64(chanIDBytes) + + return chanID, nil +} + +// TODO(roasbeef): allow updates to use Batch? + +// HighestChanID returns the "highest" known channel ID in the channel graph. +// This represents the "newest" channel from the PoV of the chain. This method +// can be used by peers to quickly determine if they're graphs are in sync. +func (c *ChannelGraph) HighestChanID() (uint64, error) { + var cid uint64 + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to find the highest chan ID, we'll fetch a cursor + // and use that to seek to the "end" of our known rage. + cidCursor := edgeIndex.Cursor() + + lastChanID, _ := cidCursor.Last() + + // If there's no key, then this means that we don't actually + // know of any channels, so we'll return a predicable error. + if lastChanID == nil { + return ErrGraphNoEdgesFound + } + + // Otherwise, we'll de serialize the channel ID and return it + // to the caller. + cid = byteOrder.Uint64(lastChanID) + return nil + }) + if err != nil && err != ErrGraphNoEdgesFound { + return 0, err + } + + return cid, nil +} + +// ChannelEdge represents the complete set of information for a channel edge in +// the known channel graph. This struct couples the core information of the +// edge as well as each of the known advertised edge policies. +type ChannelEdge struct { + // Info contains all the static information describing the channel. + Info *ChannelEdgeInfo + + // Policy1 points to the "first" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy1 *ChannelEdgePolicy + + // Policy2 points to the "second" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy2 *ChannelEdgePolicy +} + +// ChanUpdatesInHorizon returns all the known channel edges which have at least +// one edge that has an update timestamp within the specified horizon. +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { + // To ensure we don't return duplicate ChannelEdges, we'll use an + // additional map to keep track of the edges already seen to prevent + // re-adding it. + edgesSeen := make(map[uint64]struct{}) + edgesToCache := make(map[uint64]ChannelEdge) + var edgesInHorizon []ChannelEdge + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var hits int + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all channels within the horizon. + updateCursor := edgeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 8]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting the info and policy of each update of + // each channel that has a last update within the time range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + // We have a new eligible entry, so we'll slice of the + // chan ID so we can query it in the DB. + chanID := indexKey[8:] + + // If we've already retrieved the info and policies for + // this edge, then we can skip it as we don't need to do + // so again. + chanIDInt := byteOrder.Uint64(chanID) + if _, ok := edgesSeen[chanIDInt]; ok { + continue + } + + if channel, ok := c.chanCache.get(chanIDInt); ok { + hits++ + edgesSeen[chanIDInt] = struct{}{} + edgesInHorizon = append(edgesInHorizon, channel) + continue + } + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + chanID := byteOrder.Uint64(chanID) + return fmt.Errorf("unable to fetch info for "+ + "edge with chan_id=%v: %v", chanID, err) + } + edgeInfo.db = c.db + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + chanID := byteOrder.Uint64(chanID) + return fmt.Errorf("unable to fetch policies "+ + "for edge with chan_id=%v: %v", chanID, + err) + } + + // Finally, we'll collate this edge with the rest of + // edges to be returned. + edgesSeen[chanIDInt] = struct{}{} + channel := ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + } + edgesInHorizon = append(edgesInHorizon, channel) + edgesToCache[chanIDInt] = channel + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + // Insert any edges loaded from disk into the cache. + for chanid, channel := range edgesToCache { + c.chanCache.insert(chanid, channel) + } + + log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)", + float64(hits)/float64(len(edgesInHorizon)), hits, + len(edgesInHorizon)) + + return edgesInHorizon, nil +} + +// NodeUpdatesInHorizon returns all the known lightning node which have an +// update timestamp within the passed range. This method can be used by two +// nodes to quickly determine if they have the same set of up to date node +// announcements. +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode + + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all node announcements within the horizon. + updateCursor := nodeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 33]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting info for each node within the time + // range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + nodePub := indexKey[8:] + node, err := fetchLightningNode(nodes, nodePub) + if err != nil { + return err + } + node.db = c.db + + nodesInHorizon = append(nodesInHorizon, node) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return nodesInHorizon, nil +} + +// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan +// ID's that we don't know and are not known zombies of the passed set. In other +// words, we perform a set difference of our set of chan ID's and the ones +// passed in. This method can be used by callers to determine the set of +// channels another peer knows of that we don't. +func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { + var newChanIDs []uint64 + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // Fetch the zombie index, it may not exist if no edges have + // ever been marked as zombies. If the index has been + // initialized, we will use it later to skip known zombie edges. + zombieIndex := edges.Bucket(zombieBucket) + + // We'll run through the set of chanIDs and collate only the + // set of channel that are unable to be found within our db. + var cidBytes [8]byte + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // If the edge is already known, skip it. + if v := edgeIndex.Get(cidBytes[:]); v != nil { + continue + } + + // If the edge is a known zombie, skip it. + if zombieIndex != nil { + isZombie, _, _ := isZombieEdge(zombieIndex, cid) + if isZombie { + continue + } + } + + newChanIDs = append(newChanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any edges yet, then we'll return the entire set + // of chan IDs specified. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return newChanIDs, nil +} + +// FilterChannelRange returns the channel ID's of all known channels which were +// mined in a block height within the passed range. This method can be used to +// quickly share with a peer the set of channels we know of within a particular +// range to catch them up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { + var chanIDs []uint64 + + startChanID := &lnwire.ShortChannelID{ + BlockHeight: startHeight, + } + + endChanID := lnwire.ShortChannelID{ + BlockHeight: endHeight, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + + // As we need to perform a range scan, we'll convert the starting and + // ending height to their corresponding values when encoded using short + // channel ID's. + var chanIDStart, chanIDEnd [8]byte + byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) + byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + cursor := edgeIndex.Cursor() + + // We'll now iterate through the database, and find each + // channel ID that resides within the specified range. + var cid uint64 + for k, _ := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { + + // This channel ID rests within the target range, so + // we'll convert it into an integer and add it to our + // returned set. + cid = byteOrder.Uint64(k) + chanIDs = append(chanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any channels yet, then there's nothing to + // filter, so we'll return an empty slice. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return chanIDs, nil +} + +// FetchChanInfos returns the set of channel edges that correspond to the passed +// channel ID's. If an edge is the query is unknown to the database, it will +// skipped and the result will contain only those edges that exist at the time +// of the query. This can be used to respond to peer queries that are seeking to +// fill in gaps in their view of the channel graph. +func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { + // TODO(roasbeef): sort cids? + + var ( + chanEdges []ChannelEdge + cidBytes [8]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // First, we'll fetch the static edge information. If + // the edge is unknown, we will skip the edge and + // continue gathering all known edges. + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, cidBytes[:], + ) + switch { + case err == ErrEdgeNotFound: + continue + case err != nil: + return err + } + edgeInfo.db = c.db + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, cidBytes[:], c.db, + ) + if err != nil { + return err + } + + chanEdges = append(chanEdges, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + return nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil +} + +func delEdgeUpdateIndexEntry(edgesBucket *bbolt.Bucket, chanID uint64, + edge1, edge2 *ChannelEdgePolicy) error { + + // First, we'll fetch the edge update index bucket which currently + // stores an entry for the channel we're about to delete. + updateIndex := edgesBucket.Bucket(edgeUpdateIndexBucket) + if updateIndex == nil { + // No edges in bucket, return early. + return nil + } + + // Now that we have the bucket, we'll attempt to construct a template + // for the index key: updateTime || chanid. + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[8:], chanID) + + // With the template constructed, we'll attempt to delete an entry that + // would have been created by both edges: we'll alternate the update + // times, as one may had overridden the other. + if edge1 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + // We'll also attempt to delete the entry that may have been created by + // the second edge. + if edge2 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + return nil +} + +func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, + nodes *bbolt.Bucket, chanID []byte, isZombie bool) error { + + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // We'll also remove the entry in the edge update index bucket before + // we delete the edges themselves so we can access their last update + // times. + cid := byteOrder.Uint64(chanID) + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, nil, + ) + if err != nil { + return err + } + err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2) + if err != nil { + return err + } + + // The edge key is of the format pubKey || chanID. First we construct + // the latter half, populating the channel ID. + var edgeKey [33 + 8]byte + copy(edgeKey[33:], chanID) + + // With the latter half constructed, copy over the first public key to + // delete the edge in this direction, then the second to delete the + // edge in the opposite direction. + copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:]) + if edges.Get(edgeKey[:]) != nil { + if err := edges.Delete(edgeKey[:]); err != nil { + return err + } + } + copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:]) + if edges.Get(edgeKey[:]) != nil { + if err := edges.Delete(edgeKey[:]); err != nil { + return err + } + } + + // As part of deleting the edge we also remove all disabled entries + // from the edgePolicyDisabledIndex bucket. We do that for both directions. + updateEdgePolicyDisabledIndex(edges, cid, false, false) + updateEdgePolicyDisabledIndex(edges, cid, true, false) + + // With the edge data deleted, we can purge the information from the two + // edge indexes. + if err := edgeIndex.Delete(chanID); err != nil { + return err + } + var b bytes.Buffer + if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + return err + } + if err := chanIndex.Delete(b.Bytes()); err != nil { + return err + } + + // Finally, we'll mark the edge as a zombie within our index if it's + // being removed due to the channel becoming a zombie. We do this to + // ensure we don't store unnecessary data for spent channels. + if !isZombie { + return nil + } + + return markEdgeZombie( + zombieIndex, byteOrder.Uint64(chanID), edgeInfo.NodeKey1Bytes, + edgeInfo.NodeKey2Bytes, + ) +} + +// UpdateEdgePolicy updates the edge routing policy for a single directed edge +// within the database for the referenced channel. The `flags` attribute within +// the ChannelEdgePolicy determines which of the directed edges are being +// updated. If the flag is 1, then the first node's information is being +// updated, otherwise it's the second node's information. The node ordering is +// determined by the lexicographical ordering of the identity public keys of +// the nodes on either side of the channel. +func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var isUpdate1 bool + err := c.db.Update(func(tx *bbolt.Tx) error { + var err error + isUpdate1, err = updateEdgePolicy(tx, edge) + return err + }) + if err != nil { + return err + } + + // If an entry for this channel is found in reject cache, we'll modify + // the entry with the updated timestamp for the direction that was just + // written. If the edge doesn't exist, we'll load the cache entry lazily + // during the next query for this edge. + if entry, ok := c.rejectCache.get(edge.ChannelID); ok { + if isUpdate1 { + entry.upd1Time = edge.LastUpdate.Unix() + } else { + entry.upd2Time = edge.LastUpdate.Unix() + } + c.rejectCache.insert(edge.ChannelID, entry) + } + + // If an entry for this channel is found in channel cache, we'll modify + // the entry with the updated policy for the direction that was just + // written. If the edge doesn't exist, we'll defer loading the info and + // policies and lazily read from disk during the next query. + if channel, ok := c.chanCache.get(edge.ChannelID); ok { + if isUpdate1 { + channel.Policy1 = edge + } else { + channel.Policy2 = edge + } + c.chanCache.insert(edge.ChannelID, channel) + } + + return nil +} + +// updateEdgePolicy attempts to update an edge's policy within the relevant +// buckets using an existing database transaction. The returned boolean will be +// true if the updated policy belongs to node1, and false if the policy belonged +// to node2. +func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) (bool, error) { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return false, ErrEdgeNotFound + + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return false, ErrEdgeNotFound + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return false, err + } + + // Create the channelID key be converting the channel ID + // integer into a byte slice. + var chanID [8]byte + byteOrder.PutUint64(chanID[:], edge.ChannelID) + + // With the channel ID, we then fetch the value storing the two + // nodes which connect this channel edge. + nodeInfo := edgeIndex.Get(chanID[:]) + if nodeInfo == nil { + return false, ErrEdgeNotFound + } + + // Depending on the flags value passed above, either the first + // or second edge policy is being updated. + var fromNode, toNode []byte + var isUpdate1 bool + if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + fromNode = nodeInfo[:33] + toNode = nodeInfo[33:66] + isUpdate1 = true + } else { + fromNode = nodeInfo[33:66] + toNode = nodeInfo[:33] + isUpdate1 = false + } + + // Finally, with the direction of the edge being updated + // identified, we update the on-disk edge representation. + err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode) + if err != nil { + return false, err + } + + return isUpdate1, nil +} + +// LightningNode represents an individual vertex/node within the channel graph. +// A node is connected to other nodes by one or more channel edges emanating +// from it. As the graph is directed, a node will also have an incoming edge +// attached to it for each outgoing edge. +type LightningNode struct { + // PubKeyBytes is the raw bytes of the public key of the target node. + PubKeyBytes [33]byte + pubKey *btcec.PublicKey + + // HaveNodeAnnouncement indicates whether we received a node + // announcement for this particular node. If true, the remaining fields + // will be set, if false only the PubKey is known for this node. + HaveNodeAnnouncement bool + + // LastUpdate is the last time the vertex information for this node has + // been updated. + LastUpdate time.Time + + // Address is the TCP address this node is reachable over. + Addresses []net.Addr + + // Color is the selected color for the node. + Color color.RGBA + + // Alias is a nick-name for the node. The alias can be used to confirm + // a node's identity or to serve as a short ID for an address book. + Alias string + + // AuthSigBytes is the raw signature under the advertised public key + // which serves to authenticate the attributes announced by this node. + AuthSigBytes []byte + + // Features is the list of protocol features supported by this node. + Features *lnwire.FeatureVector + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB + + // TODO(roasbeef): discovery will need storage to keep it's last IP + // address and re-announce if interface changes? + + // TODO(roasbeef): add update method and fetch? +} + +// PubKey is the node's long-term identity public key. This key will be used to +// authenticated any advertisements/updates sent by the node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { + if l.pubKey != nil { + return l.pubKey, nil + } + + key, err := btcec.ParsePubKey(l.PubKeyBytes[:], btcec.S256()) + if err != nil { + return nil, err + } + l.pubKey = key + + return key, nil +} + +// AuthSig is a signature under the advertised public key which serves to +// authenticate the attributes announced by this node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (l *LightningNode) AuthSig() (*btcec.Signature, error) { + return btcec.ParseSignature(l.AuthSigBytes, btcec.S256()) +} + +// AddPubKey is a setter-link method that can be used to swap out the public +// key for a node. +func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { + l.pubKey = key + copy(l.PubKeyBytes[:], key.SerializeCompressed()) +} + +// NodeAnnouncement retrieves the latest node announcement of the node. +func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, + error) { + + if !l.HaveNodeAnnouncement { + return nil, fmt.Errorf("node does not have node announcement") + } + + alias, err := lnwire.NewNodeAlias(l.Alias) + if err != nil { + return nil, err + } + + nodeAnn := &lnwire.NodeAnnouncement{ + Features: l.Features.RawFeatureVector, + NodeID: l.PubKeyBytes, + RGBColor: l.Color, + Alias: alias, + Addresses: l.Addresses, + Timestamp: uint32(l.LastUpdate.Unix()), + ExtraOpaqueData: l.ExtraOpaqueData, + } + + if !signed { + return nodeAnn, nil + } + + sig, err := lnwire.NewSigFromRawSignature(l.AuthSigBytes) + if err != nil { + return nil, err + } + + nodeAnn.Signature = sig + + return nodeAnn, nil +} + +// isPublic determines whether the node is seen as public within the graph from +// the source node's point of view. An existing database transaction can also be +// specified. +func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error) { + // In order to determine whether this node is publicly advertised within + // the graph, we'll need to look at all of its edges and check whether + // they extend to any other node than the source node. errDone will be + // used to terminate the check early. + nodeIsPublic := false + errDone := errors.New("done") + err := l.ForEachChannel(tx, func(_ *bbolt.Tx, info *ChannelEdgeInfo, + _, _ *ChannelEdgePolicy) error { + + // If this edge doesn't extend to the source node, we'll + // terminate our search as we can now conclude that the node is + // publicly advertised within the graph due to the local node + // knowing of the current edge. + if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) && + !bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) { + + nodeIsPublic = true + return errDone + } + + // Since the edge _does_ extend to the source node, we'll also + // need to ensure that this is a public edge. + if info.AuthProof != nil { + nodeIsPublic = true + return errDone + } + + // Otherwise, we'll continue our search. + return nil + }) + if err != nil && err != errDone { + return false, err + } + + return nodeIsPublic, nil +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then ErrGraphNodeNotFound is +// returned. +func (c *ChannelGraph) FetchLightningNode(pub *btcec.PublicKey) (*LightningNode, error) { + var node *LightningNode + nodePub := pub.SerializeCompressed() + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // If a key for this serialized public key isn't found, then + // the target node doesn't exist within the database. + nodeBytes := nodes.Get(nodePub) + if nodeBytes == nil { + return ErrGraphNodeNotFound + } + + // If the node is found, then we can de deserialize the node + // information to return to the user. + nodeReader := bytes.NewReader(nodeBytes) + n, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + n.db = c.db + + node = &n + + return nil + }) + if err != nil { + return nil, err + } + + return node, nil +} + +// HasLightningNode determines if the graph has a vertex identified by the +// target node identity public key. If the node exists in the database, a +// timestamp of when the data for the node was lasted updated is returned along +// with a true boolean. Otherwise, an empty time.Time is returned with a false +// boolean. +func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) { + var ( + updateTime time.Time + exists bool + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // If a key for this serialized public key isn't found, we can + // exit early. + nodeBytes := nodes.Get(nodePub[:]) + if nodeBytes == nil { + exists = false + return nil + } + + // Otherwise we continue on to obtain the time stamp + // representing the last time the data for this node was + // updated. + nodeReader := bytes.NewReader(nodeBytes) + node, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + + exists = true + updateTime = node.LastUpdate + return nil + }) + if err != nil { + return time.Time{}, exists, err + } + + return updateTime, exists, nil +} + +// nodeTraversal is used to traverse all channels of a node given by its +// public key and passes channel information into the specified callback. +func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + + traversal := func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to reach all the edges for this node, we take + // advantage of the construction of the key-space within the + // edge bucket. The keys are stored in the form: pubKey || + // chanID. Therefore, starting from a chanID of zero, we can + // scan forward in the bucket, grabbing all the edges for the + // node. Once the prefix no longer matches, then we know we're + // done. + var nodeStart [33 + 8]byte + copy(nodeStart[:], nodePub) + copy(nodeStart[33:], chanStart[:]) + + // Starting from the key pubKey || 0, we seek forward in the + // bucket until the retrieved key no longer has the public key + // as its prefix. This indicates that we've stepped over into + // another node's edges, so we can terminate our scan. + edgeCursor := edges.Cursor() + for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { + // If the prefix still matches, the channel id is + // returned in nodeEdge. Channel id is used to lookup + // the node at the other end of the channel and both + // edge policies. + chanID := nodeEdge[33:] + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + edgeInfo.db = db + + outgoingPolicy, err := fetchChanEdgePolicy( + edges, chanID, nodePub, nodes, + ) + if err != nil { + return err + } + + otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub) + if err != nil { + return err + } + + incomingPolicy, err := fetchChanEdgePolicy( + edges, chanID, otherNode[:], nodes, + ) + if err != nil { + return err + } + + // Finally, we execute the callback. + err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) + if err != nil { + return err + } + } + + return nil + } + + // If no transaction was provided, then we'll create a new transaction + // to execute the transaction within. + if tx == nil { + return db.View(traversal) + } + + // Otherwise, we re-use the existing transaction to execute the graph + // traversal. + return traversal(tx) +} + +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + + nodePub := l.PubKeyBytes[:] + db := l.db + + return nodeTraversal(tx, nodePub, db, cb) +} + +// ChannelEdgeInfo represents a fully authenticated channel along with all its +// unique attributes. Once an authenticated channel announcement has been +// processed on the network, then an instance of ChannelEdgeInfo encapsulating +// the channels attributes is stored. The other portions relevant to routing +// policy of a channel are stored within a ChannelEdgePolicy for each direction +// of the channel. +type ChannelEdgeInfo struct { + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // ChainHash is the hash that uniquely identifies the chain that this + // channel was opened within. + // + // TODO(roasbeef): need to modify db keying for multi-chain + // * must add chain hash to prefix as well + ChainHash chainhash.Hash + + // NodeKey1Bytes is the raw public key of the first node. + NodeKey1Bytes [33]byte + nodeKey1 *btcec.PublicKey + + // NodeKey2Bytes is the raw public key of the first node. + NodeKey2Bytes [33]byte + nodeKey2 *btcec.PublicKey + + // BitcoinKey1Bytes is the raw public key of the first node. + BitcoinKey1Bytes [33]byte + bitcoinKey1 *btcec.PublicKey + + // BitcoinKey2Bytes is the raw public key of the first node. + BitcoinKey2Bytes [33]byte + bitcoinKey2 *btcec.PublicKey + + // Features is an opaque byte slice that encodes the set of channel + // specific features that this channel edge supports. + Features []byte + + // AuthProof is the authentication proof for this channel. This proof + // contains a set of signatures binding four identities, which attests + // to the legitimacy of the advertised channel. + AuthProof *ChannelAuthProof + + // ChannelPoint is the funding outpoint of the channel. This can be + // used to uniquely identify the channel within the channel graph. + ChannelPoint wire.OutPoint + + // Capacity is the total capacity of the channel, this is determined by + // the value output in the outpoint that created this channel. + Capacity btcutil.Amount + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB +} + +// AddNodeKeys is a setter-like method that can be used to replace the set of +// keys for the target ChannelEdgeInfo. +func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, + bitcoinKey2 *btcec.PublicKey) { + + c.nodeKey1 = nodeKey1 + copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed()) + + c.nodeKey2 = nodeKey2 + copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed()) + + c.bitcoinKey1 = bitcoinKey1 + copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed()) + + c.bitcoinKey2 = bitcoinKey2 + copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) +} + +// NodeKey1 is the identity public key of the "first" node that was involved in +// the creation of this channel. A node is considered "first" if the +// lexicographical ordering the its serialized public key is "smaller" than +// that of the other node involved in channel creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { + if c.nodeKey1 != nil { + return c.nodeKey1, nil + } + + key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.nodeKey1 = key + + return key, nil +} + +// NodeKey2 is the identity public key of the "second" node that was +// involved in the creation of this channel. A node is considered +// "second" if the lexicographical ordering the its serialized public +// key is "larger" than that of the other node involved in channel +// creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { + if c.nodeKey2 != nil { + return c.nodeKey2, nil + } + + key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.nodeKey2 = key + + return key, nil +} + +// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first +// node, that was involved in the funding transaction that originally +// created the channel that this struct represents. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { + if c.bitcoinKey1 != nil { + return c.bitcoinKey1, nil + } + + key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.bitcoinKey1 = key + + return key, nil +} + +// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second +// node, that was involved in the funding transaction that originally +// created the channel that this struct represents. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { + if c.bitcoinKey2 != nil { + return c.bitcoinKey2, nil + } + + key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.bitcoinKey2 = key + + return key, nil +} + +// OtherNodeKeyBytes returns the node key bytes of the other end of +// the channel. +func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( + [33]byte, error) { + + switch { + case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): + return c.NodeKey2Bytes, nil + case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): + return c.NodeKey1Bytes, nil + default: + return [33]byte{}, fmt.Errorf("node not participating in this channel") + } +} + +// FetchOtherNode attempts to fetch the full LightningNode that's opposite of +// the target node in the channel. This is useful when one knows the pubkey of +// one of the nodes, and wishes to obtain the full LightningNode for the other +// end of the channel. +func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*LightningNode, error) { + + // Ensure that the node passed in is actually a member of the channel. + var targetNodeBytes [33]byte + switch { + case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): + targetNodeBytes = c.NodeKey2Bytes + case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): + targetNodeBytes = c.NodeKey1Bytes + default: + return nil, fmt.Errorf("node not participating in this channel") + } + + var targetNode *LightningNode + fetchNodeFunc := func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := fetchLightningNode(nodes, targetNodeBytes[:]) + if err != nil { + return err + } + node.db = c.db + + targetNode = &node + + return nil + } + + // If the transaction is nil, then we'll need to create a new one, + // otherwise we can use the existing db transaction. + var err error + if tx == nil { + err = c.db.View(fetchNodeFunc) + } else { + err = fetchNodeFunc(tx) + } + + return targetNode, err +} + +// ChannelAuthProof is the authentication proof (the signature portion) for a +// channel. Using the four signatures contained in the struct, and some +// auxiliary knowledge (the funding script, node identities, and outpoint) nodes +// on the network are able to validate the authenticity and existence of a +// channel. Each of these signatures signs the following digest: chanID || +// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || +// features. +type ChannelAuthProof struct { + // nodeSig1 is a cached instance of the first node signature. + nodeSig1 *btcec.Signature + + // NodeSig1Bytes are the raw bytes of the first node signature encoded + // in DER format. + NodeSig1Bytes []byte + + // nodeSig2 is a cached instance of the second node signature. + nodeSig2 *btcec.Signature + + // NodeSig2Bytes are the raw bytes of the second node signature + // encoded in DER format. + NodeSig2Bytes []byte + + // bitcoinSig1 is a cached instance of the first bitcoin signature. + bitcoinSig1 *btcec.Signature + + // BitcoinSig1Bytes are the raw bytes of the first bitcoin signature + // encoded in DER format. + BitcoinSig1Bytes []byte + + // bitcoinSig2 is a cached instance of the second bitcoin signature. + bitcoinSig2 *btcec.Signature + + // BitcoinSig2Bytes are the raw bytes of the second bitcoin signature + // encoded in DER format. + BitcoinSig2Bytes []byte +} + +// Node1Sig is the signature using the identity key of the node that is first +// in a lexicographical ordering of the serialized public keys of the two nodes +// that created the channel. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) Node1Sig() (*btcec.Signature, error) { + if c.nodeSig1 != nil { + return c.nodeSig1, nil + } + + sig, err := btcec.ParseSignature(c.NodeSig1Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.nodeSig1 = sig + + return sig, nil +} + +// Node2Sig is the signature using the identity key of the node that is second +// in a lexicographical ordering of the serialized public keys of the two nodes +// that created the channel. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) Node2Sig() (*btcec.Signature, error) { + if c.nodeSig2 != nil { + return c.nodeSig2, nil + } + + sig, err := btcec.ParseSignature(c.NodeSig2Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.nodeSig2 = sig + + return sig, nil +} + +// BitcoinSig1 is the signature using the public key of the first node that was +// used in the channel's multi-sig output. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) BitcoinSig1() (*btcec.Signature, error) { + if c.bitcoinSig1 != nil { + return c.bitcoinSig1, nil + } + + sig, err := btcec.ParseSignature(c.BitcoinSig1Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.bitcoinSig1 = sig + + return sig, nil +} + +// BitcoinSig2 is the signature using the public key of the second node that +// was used in the channel's multi-sig output. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) BitcoinSig2() (*btcec.Signature, error) { + if c.bitcoinSig2 != nil { + return c.bitcoinSig2, nil + } + + sig, err := btcec.ParseSignature(c.BitcoinSig2Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.bitcoinSig2 = sig + + return sig, nil +} + +// IsEmpty check is the authentication proof is empty Proof is empty if at +// least one of the signatures are equal to nil. +func (c *ChannelAuthProof) IsEmpty() bool { + return len(c.NodeSig1Bytes) == 0 || + len(c.NodeSig2Bytes) == 0 || + len(c.BitcoinSig1Bytes) == 0 || + len(c.BitcoinSig2Bytes) == 0 +} + +// ChannelEdgePolicy represents a *directed* edge within the channel graph. For +// each channel in the database, there are two distinct edges: one for each +// possible direction of travel along the channel. The edges themselves hold +// information concerning fees, and minimum time-lock information which is +// utilized during path finding. +type ChannelEdgePolicy struct { + // SigBytes is the raw bytes of the signature of the channel edge + // policy. We'll only parse these if the caller needs to access the + // signature for validation purposes. Do not set SigBytes directly, but + // use SetSigBytes instead to make sure that the cache is invalidated. + SigBytes []byte + + // sig is a cached fully parsed signature. + sig *btcec.Signature + + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // LastUpdate is the last time an authenticated edge for this channel + // was received. + LastUpdate time.Time + + // MessageFlags is a bitfield which indicates the presence of optional + // fields (like max_htlc) in the policy. + MessageFlags lnwire.ChanUpdateMsgFlags + + // ChannelFlags is a bitfield which signals the capabilities of the + // channel as well as the directed edge this update applies to. + ChannelFlags lnwire.ChanUpdateChanFlags + + // TimeLockDelta is the number of blocks this node will subtract from + // the expiry of an incoming HTLC. This value expresses the time buffer + // the node would like to HTLC exchanges. + TimeLockDelta uint16 + + // MinHTLC is the smallest value HTLC this node will accept, expressed + // in millisatoshi. + MinHTLC lnwire.MilliSatoshi + + // MaxHTLC is the largest value HTLC this node will accept, expressed + // in millisatoshi. + MaxHTLC lnwire.MilliSatoshi + + // FeeBaseMSat is the base HTLC fee that will be charged for forwarding + // ANY HTLC, expressed in mSAT's. + FeeBaseMSat lnwire.MilliSatoshi + + // FeeProportionalMillionths is the rate that the node will charge for + // HTLCs for each millionth of a satoshi forwarded. + FeeProportionalMillionths lnwire.MilliSatoshi + + // Node is the LightningNode that this directed edge leads to. Using + // this pointer the channel graph can further be traversed. + Node *LightningNode + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB +} + +// Signature is a channel announcement signature, which is needed for proper +// edge policy announcement. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelEdgePolicy) Signature() (*btcec.Signature, error) { + if c.sig != nil { + return c.sig, nil + } + + sig, err := btcec.ParseSignature(c.SigBytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.sig = sig + + return sig, nil +} + +// SetSigBytes updates the signature and invalidates the cached parsed +// signature. +func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { + c.SigBytes = sig + c.sig = nil +} + +// IsDisabled determines whether the edge has the disabled bit set. +func (c *ChannelEdgePolicy) IsDisabled() bool { + return c.ChannelFlags&lnwire.ChanUpdateDisabled == + lnwire.ChanUpdateDisabled +} + +// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over +// the passed active payment channel. This value is currently computed as +// specified in BOLT07, but will likely change in the near future. +func (c *ChannelEdgePolicy) ComputeFee( + amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts +} + +// divideCeil divides dividend by factor and rounds the result up. +func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { + return (dividend + factor - 1) / factor +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + +// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for +// the channel identified by the funding outpoint. If the channel can't be +// found, then ErrEdgeNotFound is returned. A struct which houses the general +// information for the channel itself is returned as well as two structs that +// contain the routing policies for the channel in either direction. +func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, +) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + var ( + edgeInfo *ChannelEdgeInfo + policy1 *ChannelEdgePolicy + policy2 *ChannelEdgePolicy + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // If the channel's outpoint doesn't exist within the outpoint + // index, then the edge does not exist. + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrGraphNoEdgesFound + } + var b bytes.Buffer + if err := writeOutpoint(&b, op); err != nil { + return err + } + chanID := chanIndex.Get(b.Bytes()) + if chanID == nil { + return ErrEdgeNotFound + } + + // If the channel is found to exists, then we'll first retrieve + // the general information for the channel. + edge, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + edgeInfo = &edge + edgeInfo.db = c.db + + // Once we have the information about the channels' parameters, + // we'll fetch the routing policies for each for the directed + // edges. + e1, e2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + policy1 = e1 + policy2 = e2 + return nil + }) + if err != nil { + return nil, nil, nil, err + } + + return edgeInfo, policy1, policy2, nil +} + +// FetchChannelEdgesByID attempts to lookup the two directed edges for the +// channel identified by the channel ID. If the channel can't be found, then +// ErrEdgeNotFound is returned. A struct which houses the general information +// for the channel itself is returned as well as two structs that contain the +// routing policies for the channel in either direction. +// +// ErrZombieEdge an be returned if the edge is currently marked as a zombie +// within the database. In this case, the ChannelEdgePolicy's will be nil, and +// the ChannelEdgeInfo will only include the public keys of each node. +func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, +) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + var ( + edgeInfo *ChannelEdgeInfo + policy1 *ChannelEdgePolicy + policy2 *ChannelEdgePolicy + channelID [8]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + byteOrder.PutUint64(channelID[:], chanID) + + // Now, attempt to fetch edge. + edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:]) + + // If it doesn't exist, we'll quickly check our zombie index to + // see if we've previously marked it as so. + if err == ErrEdgeNotFound { + // If the zombie index doesn't exist, or the edge is not + // marked as a zombie within it, then we'll return the + // original ErrEdgeNotFound error. + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return ErrEdgeNotFound + } + + isZombie, pubKey1, pubKey2 := isZombieEdge( + zombieIndex, chanID, + ) + if !isZombie { + return ErrEdgeNotFound + } + + // Otherwise, the edge is marked as a zombie, so we'll + // populate the edge info with the public keys of each + // party as this is the only information we have about + // it and return an error signaling so. + edgeInfo = &ChannelEdgeInfo{ + NodeKey1Bytes: pubKey1, + NodeKey2Bytes: pubKey2, + } + return ErrZombieEdge + } + + // Otherwise, we'll just return the error if any. + if err != nil { + return err + } + + edgeInfo = &edge + edgeInfo.db = c.db + + // Then we'll attempt to fetch the accompanying policies of this + // edge. + e1, e2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, channelID[:], c.db, + ) + if err != nil { + return err + } + + policy1 = e1 + policy2 = e2 + return nil + }) + if err == ErrZombieEdge { + return edgeInfo, nil, nil, err + } + if err != nil { + return nil, nil, nil, err + } + + return edgeInfo, policy1, policy2, nil +} + +// IsPublicNode is a helper method that determines whether the node with the +// given public key is seen as a public node in the graph from the graph's +// source node's point of view. +func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { + var nodeIsPublic bool + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + ourPubKey := nodes.Get(sourceKey) + if ourPubKey == nil { + return ErrSourceNodeNotSet + } + node, err := fetchLightningNode(nodes, pubKey[:]) + if err != nil { + return err + } + + nodeIsPublic, err = node.isPublic(tx, ourPubKey) + return err + }) + if err != nil { + return false, err + } + + return nodeIsPublic, nil +} + +// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. +func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { + if len(aPub) != 33 || len(bPub) != 33 { + return nil, fmt.Errorf("Pubkey size error. Compressed " + + "pubkeys only") + } + + // Swap to sort pubkeys if needed. Keys are sorted in lexicographical + // order. The signatures within the scriptSig must also adhere to the + // order, ensuring that the signatures for each public key appears in + // the proper order on the stack. + if bytes.Compare(aPub, bPub) == 1 { + aPub, bPub = bPub, aPub + } + + // First, we'll generate the witness script for the multi-sig. + bldr := txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_2) + bldr.AddData(aPub) // Add both pubkeys (sorted). + bldr.AddData(bPub) + bldr.AddOp(txscript.OP_2) + bldr.AddOp(txscript.OP_CHECKMULTISIG) + witnessScript, err := bldr.Script() + if err != nil { + return nil, err + } + + // With the witness script generated, we'll now turn it into a p2sh + // script: + // * OP_0 + bldr = txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_0) + scriptHash := sha256.Sum256(witnessScript) + bldr.AddData(scriptHash[:]) + + return bldr.Script() +} + +// EdgePoint couples the outpoint of a channel with the funding script that it +// creates. The FilteredChainView will use this to watch for spends of this +// edge point on chain. We require both of these values as depending on the +// concrete implementation, either the pkScript, or the out point will be used. +type EdgePoint struct { + // FundingPkScript is the p2wsh multi-sig script of the target channel. + FundingPkScript []byte + + // OutPoint is the outpoint of the target channel. + OutPoint wire.OutPoint +} + +// String returns a human readable version of the target EdgePoint. We return +// the outpoint directly as it is enough to uniquely identify the edge point. +func (e *EdgePoint) String() string { + return e.OutPoint.String() +} + +// ChannelView returns the verifiable edge information for each active channel +// within the known channel graph. The set of UTXO's (along with their scripts) +// returned are the ones that need to be watched on chain to detect channel +// closes on the resident blockchain. +func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { + var edgePoints []EdgePoint + if err := c.db.View(func(tx *bbolt.Tx) error { + // We're going to iterate over the entire channel index, so + // we'll need to fetch the edgeBucket to get to the index as + // it's a sub-bucket. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // Once we have the proper bucket, we'll range over each key + // (which is the channel point for the channel) and decode it, + // accumulating each entry. + return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { + chanPointReader := bytes.NewReader(chanPointBytes) + + var chanPoint wire.OutPoint + err := readOutpoint(chanPointReader, &chanPoint) + if err != nil { + return err + } + + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, chanID, + ) + if err != nil { + return err + } + + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], + edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + return err + } + + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: chanPoint, + }) + + return nil + }) + }); err != nil { + return nil, err + } + + return edgePoints, nil +} + +// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. +func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { + return &ChannelEdgePolicy{db: c.db} +} + +// markEdgeZombie marks an edge as a zombie within our zombie index. The public +// keys should represent the node public keys of the two parties involved in the +// edge. +func markEdgeZombie(zombieIndex *bbolt.Bucket, chanID uint64, pubKey1, + pubKey2 [33]byte) error { + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + + var v [66]byte + copy(v[:33], pubKey1[:]) + copy(v[33:], pubKey2[:]) + + return zombieIndex.Put(k[:], v[:]) +} + +// MarkEdgeLive clears an edge from our zombie index, deeming it as live. +func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + return zombieIndex.Delete(k[:]) + }) + if err != nil { + return err + } + + c.rejectCache.remove(chanID) + c.chanCache.remove(chanID) + + return nil +} + +// IsZombieEdge returns whether the edge is considered zombie. If it is a +// zombie, then the two node public keys corresponding to this edge are also +// returned. +func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { + var ( + isZombie bool + pubKey1, pubKey2 [33]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) + return nil + }) + if err != nil { + return false, [33]byte{}, [33]byte{} + } + + return isZombie, pubKey1, pubKey2 +} + +// isZombieEdge returns whether an entry exists for the given channel in the +// zombie index. If an entry exists, then the two node public keys corresponding +// to this edge are also returned. +func isZombieEdge(zombieIndex *bbolt.Bucket, + chanID uint64) (bool, [33]byte, [33]byte) { + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + + v := zombieIndex.Get(k[:]) + if v == nil { + return false, [33]byte{}, [33]byte{} + } + + var pubKey1, pubKey2 [33]byte + copy(pubKey1[:], v[:33]) + copy(pubKey2[:], v[33:]) + + return true, pubKey1, pubKey2 +} + +// NumZombies returns the current number of zombie channels in the graph. +func (c *ChannelGraph) NumZombies() (uint64, error) { + var numZombies uint64 + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + return zombieIndex.ForEach(func(_, _ []byte) error { + numZombies++ + return nil + }) + }) + if err != nil { + return 0, err + } + + return numZombies, nil +} + +func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket, + updateIndex *bbolt.Bucket, node *LightningNode) error { + + var ( + scratch [16]byte + b bytes.Buffer + ) + + pub, err := node.PubKey() + if err != nil { + return err + } + nodePub := pub.SerializeCompressed() + + // If the node has the update time set, write it, else write 0. + updateUnix := uint64(0) + if node.LastUpdate.Unix() > 0 { + updateUnix = uint64(node.LastUpdate.Unix()) + } + + byteOrder.PutUint64(scratch[:8], updateUnix) + if _, err := b.Write(scratch[:8]); err != nil { + return err + } + + if _, err := b.Write(nodePub); err != nil { + return err + } + + // If we got a node announcement for this node, we will have the rest + // of the data available. If not we don't have more data to write. + if !node.HaveNodeAnnouncement { + // Write HaveNodeAnnouncement=0. + byteOrder.PutUint16(scratch[:2], 0) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) + } + + // Write HaveNodeAnnouncement=1. + byteOrder.PutUint16(scratch[:2], 1) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + if err := binary.Write(&b, byteOrder, node.Color.R); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, node.Color.G); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, node.Color.B); err != nil { + return err + } + + if err := wire.WriteVarString(&b, 0, node.Alias); err != nil { + return err + } + + if err := node.Features.Encode(&b); err != nil { + return err + } + + numAddresses := uint16(len(node.Addresses)) + byteOrder.PutUint16(scratch[:2], numAddresses) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + for _, address := range node.Addresses { + if err := serializeAddr(&b, address); err != nil { + return err + } + } + + sigLen := len(node.AuthSigBytes) + if sigLen > 80 { + return fmt.Errorf("max sig len allowed is 80, had %v", + sigLen) + } + + err = wire.WriteVarBytes(&b, 0, node.AuthSigBytes) + if err != nil { + return err + } + + if len(node.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(node.ExtraOpaqueData)) + } + err = wire.WriteVarBytes(&b, 0, node.ExtraOpaqueData) + if err != nil { + return err + } + + if err := aliasBucket.Put(nodePub, []byte(node.Alias)); err != nil { + return err + } + + // With the alias bucket updated, we'll now update the index that + // tracks the time series of node updates. + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], nodePub) + + // If there was already an old index entry for this node, then we'll + // delete the old one before we write the new entry. + if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil { + // Extract out the old update time to we can reconstruct the + // prior index key to delete it from the index. + oldUpdateTime := nodeBytes[:8] + + var oldIndexKey [8 + 33]byte + copy(oldIndexKey[:8], oldUpdateTime) + copy(oldIndexKey[8:], nodePub) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) +} + +func fetchLightningNode(nodeBucket *bbolt.Bucket, + nodePub []byte) (LightningNode, error) { + + nodeBytes := nodeBucket.Get(nodePub) + if nodeBytes == nil { + return LightningNode{}, ErrGraphNodeNotFound + } + + nodeReader := bytes.NewReader(nodeBytes) + return deserializeLightningNode(nodeReader) +} + +func deserializeLightningNode(r io.Reader) (LightningNode, error) { + var ( + node LightningNode + scratch [8]byte + err error + ) + + if _, err := r.Read(scratch[:]); err != nil { + return LightningNode{}, err + } + + unix := int64(byteOrder.Uint64(scratch[:])) + node.LastUpdate = time.Unix(unix, 0) + + if _, err := io.ReadFull(r, node.PubKeyBytes[:]); err != nil { + return LightningNode{}, err + } + + if _, err := r.Read(scratch[:2]); err != nil { + return LightningNode{}, err + } + + hasNodeAnn := byteOrder.Uint16(scratch[:2]) + if hasNodeAnn == 1 { + node.HaveNodeAnnouncement = true + } else { + node.HaveNodeAnnouncement = false + } + + // The rest of the data is optional, and will only be there if we got a node + // announcement for this node. + if !node.HaveNodeAnnouncement { + return node, nil + } + + // We did get a node announcement for this node, so we'll have the rest + // of the data available. + if err := binary.Read(r, byteOrder, &node.Color.R); err != nil { + return LightningNode{}, err + } + if err := binary.Read(r, byteOrder, &node.Color.G); err != nil { + return LightningNode{}, err + } + if err := binary.Read(r, byteOrder, &node.Color.B); err != nil { + return LightningNode{}, err + } + + node.Alias, err = wire.ReadVarString(r, 0) + if err != nil { + return LightningNode{}, err + } + + fv := lnwire.NewFeatureVector(nil, lnwire.GlobalFeatures) + err = fv.Decode(r) + if err != nil { + return LightningNode{}, err + } + node.Features = fv + + if _, err := r.Read(scratch[:2]); err != nil { + return LightningNode{}, err + } + numAddresses := int(byteOrder.Uint16(scratch[:2])) + + var addresses []net.Addr + for i := 0; i < numAddresses; i++ { + address, err := deserializeAddr(r) + if err != nil { + return LightningNode{}, err + } + addresses = append(addresses, address) + } + node.Addresses = addresses + + node.AuthSigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") + if err != nil { + return LightningNode{}, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the node as is. + node.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return LightningNode{}, err + } + + return node, nil +} + +func putChanEdgeInfo(edgeIndex *bbolt.Bucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error { + var b bytes.Buffer + + if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return err + } + + if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { + return err + } + + authProof := edgeInfo.AuthProof + var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte + if authProof != nil { + nodeSig1 = authProof.NodeSig1Bytes + nodeSig2 = authProof.NodeSig2Bytes + bitcoinSig1 = authProof.BitcoinSig1Bytes + bitcoinSig2 = authProof.BitcoinSig2Bytes + } + + if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { + return err + } + + if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { + return err + } + if _, err := b.Write(chanID[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { + return err + } + + if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) + } + err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) + if err != nil { + return err + } + + return edgeIndex.Put(chanID[:], b.Bytes()) +} + +func fetchChanEdgeInfo(edgeIndex *bbolt.Bucket, + chanID []byte) (ChannelEdgeInfo, error) { + + edgeInfoBytes := edgeIndex.Get(chanID) + if edgeInfoBytes == nil { + return ChannelEdgeInfo{}, ErrEdgeNotFound + } + + edgeInfoReader := bytes.NewReader(edgeInfoBytes) + return deserializeChanEdgeInfo(edgeInfoReader) +} + +func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { + var ( + err error + edgeInfo ChannelEdgeInfo + ) + + if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + + edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") + if err != nil { + return ChannelEdgeInfo{}, err + } + + proof := &ChannelAuthProof{} + + proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + + if !proof.IsEmpty() { + edgeInfo.AuthProof = proof + } + + edgeInfo.ChannelPoint = wire.OutPoint{} + if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { + return ChannelEdgeInfo{}, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { + return ChannelEdgeInfo{}, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { + return ChannelEdgeInfo{}, err + } + + if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { + return ChannelEdgeInfo{}, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return ChannelEdgeInfo{}, err + } + + return edgeInfo, nil +} + +func putChanEdgePolicy(edges, nodes *bbolt.Bucket, edge *ChannelEdgePolicy, + from, to []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) + + var b bytes.Buffer + if err := serializeChanEdgePolicy(&b, edge, to); err != nil { + return err + } + + // Before we write out the new edge, we'll create a new entry in the + // update index in order to keep it fresh. + updateUnix := uint64(edge.LastUpdate.Unix()) + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + // If there was already an entry for this edge, then we'll need to + // delete the old one to ensure we don't leave around any after-images. + // An unknown policy value does not have a update time recorded, so + // it also does not need to be removed. + if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil && + !bytes.Equal(edgeBytes[:], unknownPolicy) { + + // In order to delete the old entry, we'll need to obtain the + // *prior* update time in order to delete it. To do this, we'll + // need to deserialize the existing policy within the database + // (now outdated by the new one), and delete its corresponding + // entry within the update index. We'll ignore any + // ErrEdgePolicyOptionalFieldNotFound error, as we only need + // the channel ID and update time to delete the entry. + // TODO(halseth): get rid of these invalid policies in a + // migration. + oldEdgePolicy, err := deserializeChanEdgePolicy( + bytes.NewReader(edgeBytes), nodes, + ) + if err != nil && err != ErrEdgePolicyOptionalFieldNotFound { + return err + } + + oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) + + var oldIndexKey [8 + 8]byte + byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) + byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + updateEdgePolicyDisabledIndex( + edges, edge.ChannelID, + edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, + edge.IsDisabled(), + ) + + return edges.Put(edgeKey[:], b.Bytes()[:]) +} + +// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex +// bucket by either add a new disabled ChannelEdgePolicy or remove an existing +// one. +// The direction represents the direction of the edge and disabled is used for +// deciding whether to remove or add an entry to the bucket. +// In general a channel is disabled if two entries for the same chanID exist +// in this bucket. +// Maintaining the bucket this way allows a fast retrieval of disabled +// channels, for example when prune is needed. +func updateEdgePolicyDisabledIndex(edges *bbolt.Bucket, chanID uint64, + direction bool, disabled bool) error { + + var disabledEdgeKey [8 + 1]byte + byteOrder.PutUint64(disabledEdgeKey[0:], chanID) + if direction { + disabledEdgeKey[8] = 1 + } + + disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists( + disabledEdgePolicyBucket, + ) + if err != nil { + return err + } + + if disabled { + return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{}) + } + + return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:]) +} + +// putChanEdgePolicyUnknown marks the edge policy as unknown +// in the edges bucket. +func putChanEdgePolicyUnknown(edges *bbolt.Bucket, channelID uint64, + from []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], channelID) + + if edges.Get(edgeKey[:]) != nil { + return fmt.Errorf("Cannot write unknown policy for channel %v "+ + " when there is already a policy present", channelID) + } + + return edges.Put(edgeKey[:], unknownPolicy) +} + +func fetchChanEdgePolicy(edges *bbolt.Bucket, chanID []byte, + nodePub []byte, nodes *bbolt.Bucket) (*ChannelEdgePolicy, error) { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], nodePub) + copy(edgeKey[33:], chanID[:]) + + edgeBytes := edges.Get(edgeKey[:]) + if edgeBytes == nil { + return nil, ErrEdgeNotFound + } + + // No need to deserialize unknown policy. + if bytes.Equal(edgeBytes[:], unknownPolicy) { + return nil, nil + } + + edgeReader := bytes.NewReader(edgeBytes) + + ep, err := deserializeChanEdgePolicy(edgeReader, nodes) + switch { + // If the db policy was missing an expected optional field, we return + // nil as if the policy was unknown. + case err == ErrEdgePolicyOptionalFieldNotFound: + return nil, nil + + case err != nil: + return nil, err + } + + return ep, nil +} + +func fetchChanEdgePolicies(edgeIndex *bbolt.Bucket, edges *bbolt.Bucket, + nodes *bbolt.Bucket, chanID []byte, + db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + edgeInfo := edgeIndex.Get(chanID) + if edgeInfo == nil { + return nil, nil, ErrEdgeNotFound + } + + // The first node is contained within the first half of the edge + // information. We only propagate the error here and below if it's + // something other than edge non-existence. + node1Pub := edgeInfo[:33] + edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) + if err != nil { + return nil, nil, err + } + + // As we may have a single direction of the edge but not the other, + // only fill in the database pointers if the edge is found. + if edge1 != nil { + edge1.db = db + edge1.Node.db = db + } + + // Similarly, the second node is contained within the latter + // half of the edge information. + node2Pub := edgeInfo[33:66] + edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) + if err != nil { + return nil, nil, err + } + + if edge2 != nil { + edge2.db = db + edge2.Node.db = db + } + + return edge1, edge2, nil +} + +func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, + to []byte) error { + + err := wire.WriteVarBytes(w, 0, edge.SigBytes) + if err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil { + return err + } + + var scratch [8]byte + updateUnix := uint64(edge.LastUpdate.Unix()) + byteOrder.PutUint64(scratch[:], updateUnix) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)); err != nil { + return err + } + + if _, err := w.Write(to); err != nil { + return err + } + + // If the max_htlc field is present, we write it. To be compatible with + // older versions that wasn't aware of this field, we write it as part + // of the opaque data. + // TODO(halseth): clean up when moving to TLV. + var opaqueBuf bytes.Buffer + if edge.MessageFlags.HasMaxHtlc() { + err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC)) + if err != nil { + return err + } + } + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil { + return err + } + return nil +} + +func deserializeChanEdgePolicy(r io.Reader, + nodes *bbolt.Bucket) (*ChannelEdgePolicy, error) { + + edge := &ChannelEdgePolicy{} + + var err error + edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") + if err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil { + return nil, err + } + + var scratch [8]byte + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + unix := int64(byteOrder.Uint64(scratch[:])) + edge.LastUpdate = time.Unix(unix, 0) + + if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil { + return nil, err + } + + var n uint64 + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.MinHTLC = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeBaseMSat = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n) + + var pub [33]byte + if _, err := r.Read(pub[:]); err != nil { + return nil, err + } + + node, err := fetchLightningNode(nodes, pub[:]) + if err != nil { + return nil, fmt.Errorf("unable to fetch node: %x, %v", + pub[:], err) + } + edge.Node = &node + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edge.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return nil, err + } + + // See if optional fields are present. + if edge.MessageFlags.HasMaxHtlc() { + // The max_htlc field should be at the beginning of the opaque + // bytes. + opq := edge.ExtraOpaqueData + + // If the max_htlc field is not present, it might be old data + // stored before this field was validated. We'll return the + // edge along with an error. + if len(opq) < 8 { + return edge, ErrEdgePolicyOptionalFieldNotFound + } + + maxHtlc := byteOrder.Uint64(opq[:8]) + edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc) + + // Exclude the parsed field from the rest of the opaque data. + edge.ExtraOpaqueData = opq[8:] + } + + return edge, nil +} diff --git a/channeldb/migration_01_to_11/graph_test.go b/channeldb/migration_01_to_11/graph_test.go new file mode 100644 index 00000000..00a8a000 --- /dev/null +++ b/channeldb/migration_01_to_11/graph_test.go @@ -0,0 +1,3197 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "fmt" + "image/color" + "math" + "math/big" + prand "math/rand" + "net" + "reflect" + "runtime" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + testAddr = &net.TCPAddr{IP: (net.IP)([]byte{0xA, 0x0, 0x0, 0x1}), + Port: 9000} + anotherAddr, _ = net.ResolveTCPAddr("tcp", + "[2001:db8:85a3:0:0:8a2e:370:7334]:80") + testAddrs = []net.Addr{testAddr, anotherAddr} + + testSig = &btcec.Signature{ + R: new(big.Int), + S: new(big.Int), + } + _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) + _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) + + testFeatures = lnwire.NewFeatureVector(nil, lnwire.GlobalFeatures) +) + +func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) { + updateTime := prand.Int63() + + pub := priv.PubKey().SerializeCompressed() + n := &LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: time.Unix(updateTime, 0), + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek" + string(pub[:]), + Features: testFeatures, + Addresses: testAddrs, + db: db, + } + copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) + + return n, nil +} + +func createTestVertex(db *DB) (*LightningNode, error) { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, err + } + + return createLightningNode(db, priv) +} + +func TestNodeInsertionAndDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test basic insertion/deletion for vertexes from the + // graph, so we'll create a test vertex to start with. + _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + node := &LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: time.Unix(1232342, 0), + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek", + Features: testFeatures, + Addresses: testAddrs, + ExtraOpaqueData: []byte("extra new data"), + db: db, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + // First, insert the node into the graph DB. This should succeed + // without any errors. + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, fetch the node from the database to ensure everything was + // serialized properly. + dbNode, err := graph.FetchLightningNode(testPub) + if err != nil { + t.Fatalf("unable to locate node: %v", err) + } + + if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + t.Fatalf("unable to query for node: %v", err) + } else if !exists { + t.Fatalf("node should be found but wasn't") + } + + // The two nodes should match exactly! + if err := compareNodes(node, dbNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // Next, delete the node from the graph, this should purge all data + // related to the node. + if err := graph.DeleteLightningNode(testPub); err != nil { + t.Fatalf("unable to delete node; %v", err) + } + + // Finally, attempt to fetch the node again. This should fail as the + // node should have been deleted from the database. + _, err = graph.FetchLightningNode(testPub) + if err != ErrGraphNodeNotFound { + t.Fatalf("fetch after delete should fail!") + } +} + +// TestPartialNode checks that we can add and retrieve a LightningNode where +// where only the pubkey is known to the database. +func TestPartialNode(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We want to be able to insert nodes into the graph that only has the + // PubKey set. + _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + node := &LightningNode{ + HaveNodeAnnouncement: false, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, fetch the node from the database to ensure everything was + // serialized properly. + dbNode, err := graph.FetchLightningNode(testPub) + if err != nil { + t.Fatalf("unable to locate node: %v", err) + } + + if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + t.Fatalf("unable to query for node: %v", err) + } else if !exists { + t.Fatalf("node should be found but wasn't") + } + + // The two nodes should match exactly! (with default values for + // LastUpdate and db set to satisfy compareNodes()) + node = &LightningNode{ + HaveNodeAnnouncement: false, + LastUpdate: time.Unix(0, 0), + db: db, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + if err := compareNodes(node, dbNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // Next, delete the node from the graph, this should purge all data + // related to the node. + if err := graph.DeleteLightningNode(testPub); err != nil { + t.Fatalf("unable to delete node: %v", err) + } + + // Finally, attempt to fetch the node again. This should fail as the + // node should have been deleted from the database. + _, err = graph.FetchLightningNode(testPub) + if err != ErrGraphNodeNotFound { + t.Fatalf("fetch after delete should fail!") + } +} + +func TestAliasLookup(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the alias index within the database, so first + // create a new test node. + testNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // Add the node to the graph's database, this should also insert an + // entry into the alias index for this node. + if err := graph.AddLightningNode(testNode); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, attempt to lookup the alias. The alias should exactly match + // the one which the test node was assigned. + nodePub, err := testNode.PubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + } + dbAlias, err := graph.LookupAlias(nodePub) + if err != nil { + t.Fatalf("unable to find alias: %v", err) + } + if dbAlias != testNode.Alias { + t.Fatalf("aliases don't match, expected %v got %v", + testNode.Alias, dbAlias) + } + + // Ensure that looking up a non-existent alias results in an error. + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + nodePub, err = node.PubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + } + _, err = graph.LookupAlias(nodePub) + if err != ErrNodeAliasNotFound { + t.Fatalf("alias lookup should fail for non-existent pubkey") + } +} + +func TestSourceNode(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the setting/getting of the source node, so we + // first create a fake node to use within the test. + testNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // Attempt to fetch the source node, this should return an error as the + // source node hasn't yet been set. + if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet { + t.Fatalf("source node shouldn't be set in new graph") + } + + // Set the source the source node, this should insert the node into the + // database in a special way indicating it's the source node. + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // Retrieve the source node from the database, it should exactly match + // the one we set above. + sourceNode, err := graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + if err := compareNodes(testNode, sourceNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } +} + +func TestEdgeInsertionDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the insertion/deletion of edges, so we create two + // vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + chanID := uint64(prand.Int63()) + outpoint := wire.OutPoint{ + Hash: rev, + Index: 9, + } + + // Add the new edge to the database, this should proceed without any + // errors. + node1Pub, err := node1.PubKey() + if err != nil { + t.Fatalf("unable to generate node key: %v", err) + } + node2Pub, err := node2.PubKey() + if err != nil { + t.Fatalf("unable to generate node key: %v", err) + } + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure that both policies are returned as unknown (nil). + _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel edge") + } + if e1 != nil || e2 != nil { + t.Fatalf("channel edges not unknown") + } + + // Next, attempt to delete the edge from the database, again this + // should proceed without any issues. + if err := graph.DeleteChannelEdges(chanID); err != nil { + t.Fatalf("unable to delete edge: %v", err) + } + + // Ensure that any query attempts to lookup the delete channel edge are + // properly deleted. + if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil { + t.Fatalf("channel edge not deleted") + } + if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil { + t.Fatalf("channel edge not deleted") + } + isZombie, _, _ := graph.IsZombieEdge(chanID) + if !isZombie { + t.Fatal("channel edge not marked as zombie") + } + + // Finally, attempt to delete a (now) non-existent edge within the + // database, this should result in an error. + err = graph.DeleteChannelEdges(chanID) + if err != ErrEdgeNotFound { + t.Fatalf("deleting a non-existent edge should fail!") + } +} + +func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, + node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { + + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + node1Pub, _ := node1.PubKey() + node2Pub, _ := node2.PubKey() + edgeInfo := ChannelEdgeInfo{ + ChannelID: shortChanID.ToUint64(), + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + return edgeInfo, shortChanID +} + +// TestDisconnectBlockAtHeight checks that the pruned state of the channel +// database is what we expect after calling DisconnectBlockAtHeight. +func TestDisconnectBlockAtHeight(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // We'd like to test the insertion/deletion of edges, so we create two + // vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + var spendOutputs []*wire.OutPoint + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) + + // Prune the graph a few times to make sure we have entries in the + // prune log. + _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + var blockHash2 chainhash.Hash + copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) + + _, err = graph.PruneGraph(spendOutputs, &blockHash2, 156) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // We'll create 3 almost identical edges, so first create a helper + // method containing all logic for doing so. + + // Create an edge which has its block height at 156. + height := uint32(156) + edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) + + // Create an edge with block height 157. We give it + // maximum values for tx index and position, to make + // sure our database range scan get edges from the + // entire range. + edgeInfo2, _ := createEdge( + height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, + node1, node2, + ) + + // Create a third edge, this with a block height of 155. + edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) + + // Now add all these new edges to the database. + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Call DisconnectBlockAtHeight, which should prune every channel + // that has a funding height of 'height' or greater. + removed, err := graph.DisconnectBlockAtHeight(uint32(height)) + if err != nil { + t.Fatalf("unable to prune %v", err) + } + + // The two edges should have been removed. + if len(removed) != 2 { + t.Fatalf("expected two edges to be removed from graph, "+ + "only %d were", len(removed)) + } + if removed[0].ChannelID != edgeInfo.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + if removed[1].ChannelID != edgeInfo2.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + + // The two first edges should be removed from the db. + _, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge1 was not pruned from the graph") + } + if isZombie { + t.Fatal("reorged edge1 should not be marked as zombie") + } + _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge2 was not pruned from the graph") + } + if isZombie { + t.Fatal("reorged edge2 should not be marked as zombie") + } + + // Edge 3 should not be removed. + _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if !has { + t.Fatalf("edge3 was pruned from the graph") + } + if isZombie { + t.Fatal("edge3 was marked as zombie") + } + + // PruneTip should be set to the blockHash we specified for the block + // at height 155. + hash, h, err := graph.PruneTip() + if err != nil { + t.Fatalf("unable to get prune tip: %v", err) + } + if !blockHash.IsEqual(hash) { + t.Fatalf("expected best block to be %x, was %x", blockHash, hash) + } + if h != height-1 { + t.Fatalf("expected best block height to be %d, was %d", height-1, h) + } +} + +func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, + e2 *ChannelEdgeInfo) { + + if e1.ChannelID != e2.ChannelID { + t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, + e2.ChannelID) + } + + if e1.ChainHash != e2.ChainHash { + t.Fatalf("chain hashes don't match: %v vs %v", e1.ChainHash, + e2.ChainHash) + } + + if !bytes.Equal(e1.NodeKey1Bytes[:], e2.NodeKey1Bytes[:]) { + t.Fatalf("nodekey1 doesn't match") + } + if !bytes.Equal(e1.NodeKey2Bytes[:], e2.NodeKey2Bytes[:]) { + t.Fatalf("nodekey2 doesn't match") + } + if !bytes.Equal(e1.BitcoinKey1Bytes[:], e2.BitcoinKey1Bytes[:]) { + t.Fatalf("bitcoinkey1 doesn't match") + } + if !bytes.Equal(e1.BitcoinKey2Bytes[:], e2.BitcoinKey2Bytes[:]) { + t.Fatalf("bitcoinkey2 doesn't match") + } + + if !bytes.Equal(e1.Features, e2.Features) { + t.Fatalf("features doesn't match: %x vs %x", e1.Features, + e2.Features) + } + + if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) { + t.Fatalf("nodesig1 doesn't match: %v vs %v", + spew.Sdump(e1.AuthProof.NodeSig1Bytes), + spew.Sdump(e2.AuthProof.NodeSig1Bytes)) + } + if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) { + t.Fatalf("nodesig2 doesn't match") + } + if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) { + t.Fatalf("bitcoinsig1 doesn't match") + } + if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) { + t.Fatalf("bitcoinsig2 doesn't match") + } + + if e1.ChannelPoint != e2.ChannelPoint { + t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint, + e2.ChannelPoint) + } + + if e1.Capacity != e2.Capacity { + t.Fatalf("capacity doesn't match: %v vs %v", e1.Capacity, + e2.Capacity) + } + + if !bytes.Equal(e1.ExtraOpaqueData, e2.ExtraOpaqueData) { + t.Fatalf("extra data doesn't match: %v vs %v", + e2.ExtraOpaqueData, e2.ExtraOpaqueData) + } +} + +func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) { + + var ( + firstNode *LightningNode + secondNode *LightningNode + ) + if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + firstNode = node1 + secondNode = node2 + } else { + firstNode = node2 + secondNode = node1 + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + chanID := uint64(prand.Int63()) + outpoint := wire.OutPoint{ + Hash: rev, + Index: 9, + } + + // Add the new edge to the database, this should proceed without any + // errors. + edgeInfo := &ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 1000, + ExtraOpaqueData: []byte("new unknown feature"), + } + copy(edgeInfo.NodeKey1Bytes[:], firstNode.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], secondNode.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) + + edge1 := &ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(433453, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 3452352, + Node: secondNode, + ExtraOpaqueData: []byte("new unknown feature2"), + db: db, + } + edge2 := &ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(124234, 0), + MessageFlags: 1, + ChannelFlags: 1, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 90392423, + Node: firstNode, + ExtraOpaqueData: []byte("new unknown feature1"), + db: db, + } + + return edgeInfo, edge1, edge2 +} + +func TestEdgeInfoUpdates(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the update of edges inserted into the database, so + // we create two vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create an edge and add it to the db. + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + + // Make sure inserting the policy at this point, before the edge info + // is added, will fail. + if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { + t.Fatalf("expected ErrEdgeNotFound, got: %v", err) + } + + // Add the edge info. + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanID := edgeInfo.ChannelID + outpoint := edgeInfo.ChannelPoint + + // Next, insert both edge policies into the database, they should both + // be inserted without any issues. + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Check for existence of the edge within the database, it should be + // found. + _, _, found, isZombie, err := graph.HasChannelEdge(chanID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if !found { + t.Fatalf("graph should have of inserted edge") + } + if isZombie { + t.Fatal("live edge should not be marked as zombie") + } + + // We should also be able to retrieve the channelID only knowing the + // channel point of the channel. + dbChanID, err := graph.ChannelID(&outpoint) + if err != nil { + t.Fatalf("unable to retrieve channel ID: %v", err) + } + if dbChanID != chanID { + t.Fatalf("chan ID's mismatch, expected %v got %v", dbChanID, + chanID) + } + + // With the edges inserted, perform some queries to ensure that they've + // been inserted properly. + dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) + + // Next, attempt to query the channel edges according to the outpoint + // of the channel. + dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) +} + +func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { + update := prand.Int63() + + return newEdgePolicy(chanID, op, db, update) +} + +func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, + updateTime int64) *ChannelEdgePolicy { + + return &ChannelEdgePolicy{ + ChannelID: chanID, + LastUpdate: time.Unix(updateTime, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: uint16(prand.Int63()), + MinHTLC: lnwire.MilliSatoshi(prand.Int63()), + MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), + FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), + FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), + db: db, + } +} + +func TestGraphTraversal(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. + const numNodes = 20 + nodes := make([]*LightningNode, numNodes) + nodeIndex := map[string]struct{}{} + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + nodes[i] = node + nodeIndex[node.Alias] = struct{}{} + } + + // Add each of the nodes into the graph, they should be inserted + // without error. + for _, node := range nodes { + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + } + + // Iterate over each node as returned by the graph, if all nodes are + // reached, then the map created above should be empty. + err = graph.ForEachNode(nil, func(_ *bbolt.Tx, node *LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if len(nodeIndex) != 0 { + t.Fatalf("all nodes not reached within ForEach") + } + + // Determine which node is "smaller", we'll need this in order to + // properly create the edges for the graph. + var firstNode, secondNode *LightningNode + if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { + firstNode = nodes[0] + secondNode = nodes[1] + } else { + firstNode = nodes[0] + secondNode = nodes[1] + } + + // Create 5 channels between the first two nodes we generated above. + const numChannels = 5 + chanIndex := map[uint64]struct{}{} + for i := 0; i < numChannels; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64(i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) + err := graph.AddChannelEdge(&edgeInfo) + if err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create and add an edge with random data that points from + // node1 -> node2. + edge := randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 0 + edge.Node = secondNode + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Create another random edge that points from node2 -> node1 + // this time. + edge = randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 1 + edge.Node = firstNode + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + chanIndex[chanID] = struct{}{} + } + + // Iterate through all the known channels within the graph DB, once + // again if the map is empty that indicates that all edges have + // properly been reached. + err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy, + _ *ChannelEdgePolicy) error { + + delete(chanIndex, ei.ChannelID) + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if len(chanIndex) != 0 { + t.Fatalf("all edges not reached within ForEach") + } + + // Finally, we want to test the ability to iterate over all the + // outgoing channels for a particular node. + numNodeChans := 0 + err = firstNode.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, + outEdge, inEdge *ChannelEdgePolicy) error { + + // All channels between first and second node should have fully + // (both sides) specified policies. + if inEdge == nil || outEdge == nil { + return fmt.Errorf("channel policy not present") + } + + // Each should indicate that it's outgoing (pointed + // towards the second node). + if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } + + // The incoming edge should also indicate that it's pointing to + // the origin node. + if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } + + numNodeChans++ + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if numNodeChans != numChannels { + t.Fatalf("all edges for node not reached within ForEach: "+ + "expected %v, got %v", numChannels, numNodeChans) + } +} + +func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, + blockHeight uint32) { + + pruneHash, pruneHeight, err := graph.PruneTip() + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to fetch prune tip: %v", line, err) + } + if !bytes.Equal(blockHash[:], pruneHash[:]) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line: %v, prune tips don't match, expected %x got %x", + line, blockHash, pruneHash) + } + if pruneHeight != blockHeight { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: prune heights don't match, expected %v "+ + "got %v", line, blockHeight, pruneHeight) + } +} + +func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { + numChans := 0 + if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error { + + numChans++ + return nil + }); err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to scan channels: %v", line, err) + } + if numChans != n { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: expected %v chans instead have %v", line, + n, numChans) + } +} + +func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { + numNodes := 0 + err := graph.ForEachNode(nil, func(_ *bbolt.Tx, _ *LightningNode) error { + numNodes++ + return nil + }) + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to scan nodes: %v", line, err) + } + + if numNodes != n { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) + } +} + +func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views don't match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op.OutPoint] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[op.OutPoint]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) + } + } +} + +func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views don't match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op.OutPoint] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[*op]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) + } + } +} + +func TestGraphPruning(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // As initial set up for the test, we'll create a graph with 5 vertexes + // and enough edges to create a fully connected graph. The graph will + // be rather simple, representing a straight line. + const numNodes = 5 + graphNodes := make([]*LightningNode, numNodes) + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + graphNodes[i] = node + } + + // With the vertexes created, we'll next create a series of channels + // between them. + channelPoints := make([]*wire.OutPoint, 0, numNodes-1) + edgePoints := make([]EdgePoint, 0, numNodes-1) + for i := 0; i < numNodes-1; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64(i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channelPoints = append(channelPoints, &op) + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + t.Fatalf("unable to gen multi-sig p2wsh: %v", err) + } + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: op, + }) + + // Create and add an edge with random data that points from + // node_i -> node_i+1 + edge := randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 0 + edge.Node = graphNodes[i] + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Create another random edge that points from node_i+1 -> + // node_i this time. + edge = randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 1 + edge.Node = graphNodes[i] + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + } + + // With all the channel points added, we'll consult the graph to ensure + // it has the same channel view as the one we just constructed. + channelView, err := graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqual(t, channelView, edgePoints) + + // Now with our test graph created, we can test the pruning + // capabilities of the channel graph. + + // First we create a mock block that ends up closing the first two + // channels. + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) + blockHeight := uint32(1) + block := channelPoints[:2] + prunedChans, err := graph.PruneGraph(block, &blockHash, blockHeight) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + if len(prunedChans) != 2 { + t.Fatalf("incorrect number of channels pruned: "+ + "expected %v, got %v", 2, prunedChans) + } + + // Now ensure that the prune tip has been updated. + assertPruneTip(t, graph, &blockHash, blockHeight) + + // Count up the number of channels known within the graph, only 2 + // should be remaining. + assertNumChans(t, graph, 2) + + // Those channels should also be missing from the channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqualChanPoints(t, channelView, channelPoints[2:]) + + // Next we'll create a block that doesn't close any channels within the + // graph to test the negative error case. + fakeHash := sha256.Sum256([]byte("test prune")) + nonChannel := &wire.OutPoint{ + Hash: fakeHash, + Index: 9, + } + blockHash = sha256.Sum256(blockHash[:]) + blockHeight = 2 + prunedChans, err = graph.PruneGraph( + []*wire.OutPoint{nonChannel}, &blockHash, blockHeight, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // No channels should have been detected as pruned. + if len(prunedChans) != 0 { + t.Fatalf("channels were pruned but shouldn't have been") + } + + // Once again, the prune tip should have been updated. We should still + // see both channels and their participants, along with the source node. + assertPruneTip(t, graph, &blockHash, blockHeight) + assertNumChans(t, graph, 2) + assertNumNodes(t, graph, 4) + + // Finally, create a block that prunes the remainder of the channels + // from the graph. + blockHash = sha256.Sum256(blockHash[:]) + blockHeight = 3 + prunedChans, err = graph.PruneGraph( + channelPoints[2:], &blockHash, blockHeight, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // The remainder of the channels should have been pruned from the + // graph. + if len(prunedChans) != 2 { + t.Fatalf("incorrect number of channels pruned: "+ + "expected %v, got %v", 2, len(prunedChans)) + } + + // The prune tip should be updated, no channels should be found, and + // only the source node should remain within the current graph. + assertPruneTip(t, graph, &blockHash, blockHeight) + assertNumChans(t, graph, 0) + assertNumNodes(t, graph, 1) + + // Finally, the channel view at this point in the graph should now be + // completely empty. Those channels should also be missing from the + // channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + if len(channelView) != 0 { + t.Fatalf("channel view should be empty, instead have: %v", + channelView) + } +} + +// TestHighestChanID tests that we're able to properly retrieve the highest +// known channel ID in the database. +func TestHighestChanID(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we don't yet have any channels in the database, then we should + // get a channel ID of zero if we ask for the highest channel ID. + bestID, err := graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + if bestID != 0 { + t.Fatalf("best ID w/ no chan should be zero, is instead: %v", + bestID) + } + + // Next, we'll insert two channels into the database, with each channel + // connecting the same two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // The first channel with be at height 10, while the other will be at + // height 100. + edge1, _ := createEdge(10, 0, 0, 0, node1, node2) + edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) + + if err := graph.AddChannelEdge(&edge1); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + if err := graph.AddChannelEdge(&edge2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Now that the edges has been inserted, we'll query for the highest + // known channel ID in the database. + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID2.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID2.ToUint64(), bestID) + } + + // If we add another edge, then the current best chan ID should be + // updated as well. + edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edge3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID3.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID3.ToUint64(), bestID) + } +} + +// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known +// channel updates within a specific time horizon. It also tests that upon +// insertion of a new edge, the edge update index is updated properly. +func TestChanUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we issue an arbitrary query before any channel updates are + // inserted in the database, we should get zero results. + chanUpdates, err := graph.ChanUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to updates for updates: %v", err) + } + if len(chanUpdates) != 0 { + t.Fatalf("expected 0 chan updates, instead got %v", + len(chanUpdates)) + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now create 10 channels between the two nodes, with update + // times 10 seconds after each other. + const numChans = 10 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + edge1UpdateTime := endTime + edge2UpdateTime := edge1UpdateTime.Add(time.Second) + endTime = endTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), + ) + edge1.ChannelFlags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), + ) + edge2.ChannelFlags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + } + + // With our channels loaded, we'll now start our series of queries. + queryCases := []struct { + start time.Time + end time.Time + + resp []ChannelEdge + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we query for the start time, and 10 seconds directly + // after it, we should only get a single update, that first + // one. + { + start: time.Unix(1234, 0), + end: startTime.Add(time.Second * 10), + + resp: []ChannelEdge{edges[0]}, + }, + + // If we add 10 seconds past the first update, and then + // subtract 10 from the last update, then we should only get + // the 8 edges in the middle. + { + start: startTime.Add(time.Second * 10), + end: endTime.Add(-time.Second * 10), + + resp: edges[1:9], + }, + + // If we use the start and end time as is, we should get the + // entire range. + { + start: startTime, + end: endTime, + + resp: edges, + }, + } + for _, queryCase := range queryCases { + resp, err := graph.ChanUpdatesInHorizon( + queryCase.start, queryCase.end, + ) + if err != nil { + t.Fatalf("unable to query for updates: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v chans, got %v chans", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + chanExp := queryCase.resp[i] + chanRet := resp[i] + + assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) + + err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) + if err != nil { + t.Fatal(err) + } + compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve +// the most recent node updates within a particular time horizon. +func TestNodeUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + startTime := time.Unix(1234, 0) + endTime := startTime + + // If we issue an arbitrary query before we insert any nodes into the + // database, then we shouldn't get any results back. + nodeUpdates, err := graph.NodeUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to query for node updates: %v", err) + } + if len(nodeUpdates) != 0 { + t.Fatalf("expected 0 node updates, instead got %v", + len(nodeUpdates)) + } + + // We'll create 10 node announcements, each with an update timestamp 10 + // seconds after the other. + const numNodes = 10 + nodeAnns := make([]LightningNode, 0, numNodes) + for i := 0; i < numNodes; i++ { + nodeAnn, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // The node ann will use the current end time as its last + // update them, then we'll add 10 seconds in order to create + // the proper update time for the next node announcement. + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + nodeAnn.LastUpdate = updateTime + + nodeAnns = append(nodeAnns, *nodeAnn) + + if err := graph.AddLightningNode(nodeAnn); err != nil { + t.Fatalf("unable to add lightning node: %v", err) + } + } + + queryCases := []struct { + start time.Time + end time.Time + + resp []LightningNode + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we skip he first time epoch with out start time, then we + // should get back every now but the first. + { + start: startTime.Add(time.Second * 10), + end: endTime, + + resp: nodeAnns[1:], + }, + + // If we query for the range as is, we should get all 10 + // announcements back. + { + start: startTime, + end: endTime, + + resp: nodeAnns, + }, + + // If we reduce the ending time by 10 seconds, then we should + // get all but the last node we inserted. + { + start: startTime, + end: endTime.Add(-time.Second * 10), + + resp: nodeAnns[:9], + }, + } + for _, queryCase := range queryCases { + resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) + if err != nil { + t.Fatalf("unable to query for nodes: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v nodes, got %v nodes", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + err := compareNodes(&queryCase.resp[i], &resp[i]) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestFilterKnownChanIDs tests that we're able to properly perform the set +// differences of an incoming set of channel ID's, and those that we already +// know of on disk. +func TestFilterKnownChanIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we try to filter out a set of channel ID's before we even know of + // any channels, then we should get the entire set back. + preChanIDs := []uint64{1, 2, 3, 4} + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + if !reflect.DeepEqual(preChanIDs, filteredIDs) { + t.Fatalf("chan IDs shouldn't have been filtered!") + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, we'll add 5 channel ID's to the graph, each of them having a + // block height 10 blocks after the previous. + const numChans = 5 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + const numZombies = 5 + zombieIDs := make([]uint64, 0, numZombies) + for i := 0; i < numZombies; i++ { + channel, chanID := createEdge( + uint32(i*10+1), 0, 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + err := graph.DeleteChannelEdges(channel.ChannelID) + if err != nil { + t.Fatalf("unable to mark edge zombie: %v", err) + } + + zombieIDs = append(zombieIDs, chanID.ToUint64()) + } + + queryCases := []struct { + queryIDs []uint64 + + resp []uint64 + }{ + // If we attempt to filter out all chanIDs we know of, the + // response should be the empty set. + { + queryIDs: chanIDs, + }, + // If we attempt to filter out all zombies that we know of, the + // response should be the empty set. + { + queryIDs: zombieIDs, + }, + + // If we query for a set of ID's that we didn't insert, we + // should get the same set back. + { + queryIDs: []uint64{99, 100}, + resp: []uint64{99, 100}, + }, + + // If we query for a super-set of our the chan ID's inserted, + // we should only get those new chanIDs back. + { + queryIDs: append(chanIDs, []uint64{99, 101}...), + resp: []uint64{99, 101}, + }, + } + + for _, queryCase := range queryCases { + resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), + spew.Sdump(resp)) + } + } +} + +// TestFilterChannelRange tests that we're able to properly retrieve the full +// set of short channel ID's for a given block range. +func TestFilterChannelRange(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // If we try to filter a channel range before we have any channels + // inserted, we should get an empty slice of results. + resp, err := graph.FilterChannelRange(10, 100) + if err != nil { + t.Fatalf("unable to filter channels: %v", err) + } + if len(resp) != 0 { + t.Fatalf("expected zero chans, instead got %v", len(resp)) + } + + // To start, we'll create a set of channels, each mined in a block 10 + // blocks after the prior one. + startHeight := uint32(100) + endHeight := startHeight + const numChans = 10 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + chanHeight := endHeight + channel, chanID := createEdge( + uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + + endHeight += 10 + } + + // With our channels inserted, we'll construct a series of queries that + // we'll execute below in order to exercise the features of the + // FilterKnownChanIDs method. + queryCases := []struct { + startHeight uint32 + endHeight uint32 + + resp []uint64 + }{ + // If we query for the entire range, then we should get the same + // set of short channel IDs back. + { + startHeight: startHeight, + endHeight: endHeight, + + resp: chanIDs, + }, + + // If we query for a range of channels right before our range, we + // shouldn't get any results back. + { + startHeight: 0, + endHeight: 10, + }, + + // If we only query for the last height (range wise), we should + // only get that last channel. + { + startHeight: endHeight - 10, + endHeight: endHeight - 10, + + resp: chanIDs[9:], + }, + + // If we query for just the first height, we should only get a + // single channel back (the first one). + { + startHeight: startHeight, + endHeight: startHeight, + + resp: chanIDs[:1], + }, + } + for i, queryCase := range queryCases { + resp, err := graph.FilterChannelRange( + queryCase.startHeight, queryCase.endHeight, + ) + if err != nil { + t.Fatalf("unable to issue range query: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("case #%v: expected %v, got %v", i, + queryCase.resp, resp) + } + } +} + +// TestFetchChanInfos tests that we're able to properly retrieve the full set +// of ChannelEdge structs for a given set of short channel ID's. +func TestFetchChanInfos(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll make 5 test channels, ensuring we keep track of which channel + // ID corresponds to a particular ChannelEdge. + const numChans = 5 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + edgeQuery := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.ChannelFlags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.ChannelFlags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + + edgeQuery = append(edgeQuery, chanID.ToUint64()) + } + + // Add an additional edge that does not exist. The query should skip + // this channel and return only infos for the edges that exist. + edgeQuery = append(edgeQuery, 500) + + // Add an another edge to the query that has been marked as a zombie + // edge. The query should also skip this channel. + zombieChan, zombieChanID := createEdge( + 666, 0, 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&zombieChan); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + err = graph.DeleteChannelEdges(zombieChan.ChannelID) + if err != nil { + t.Fatalf("unable to delete and mark edge zombie: %v", err) + } + edgeQuery = append(edgeQuery, zombieChanID.ToUint64()) + + // We'll now attempt to query for the range of channel ID's we just + // inserted into the database. We should get the exact same set of + // edges back. + resp, err := graph.FetchChanInfos(edgeQuery) + if err != nil { + t.Fatalf("unable to fetch chan edges: %v", err) + } + if len(resp) != len(edges) { + t.Fatalf("expected %v edges, instead got %v", len(edges), + len(resp)) + } + + for i := 0; i < len(resp); i++ { + err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) + } +} + +// TestIncompleteChannelPolicies tests that a channel that only has a policy +// specified on one end is properly returned in ForEachChannel calls from +// both sides. +func TestIncompleteChannelPolicies(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // Create two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create channel between nodes. + txHash := sha256.Sum256([]byte{0}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(0), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure that channel is reported with unknown policies. + + checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { + calls := 0 + node.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, + outEdge, inEdge *ChannelEdgePolicy) error { + + if !expectedOut && outEdge != nil { + t.Fatalf("Expected no outgoing policy") + } + + if expectedOut && outEdge == nil { + t.Fatalf("Expected an outgoing policy") + } + + if !expectedIn && inEdge != nil { + t.Fatalf("Expected no incoming policy") + } + + if expectedIn && inEdge == nil { + t.Fatalf("Expected an incoming policy") + } + + calls++ + + return nil + }) + + if calls != 1 { + t.Fatalf("Expected only one callback call") + } + } + + checkPolicies(node2, false, false) + + // Only create an edge policy for node1 and leave the policy for node2 + // unknown. + updateTime := time.Unix(1234, 0) + + edgePolicy := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edgePolicy.ChannelFlags = 0 + edgePolicy.Node = node2 + edgePolicy.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + checkPolicies(node1, false, true) + checkPolicies(node2, true, false) + + // Create second policy and assert that both policies are reported + // as present. + edgePolicy = newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edgePolicy.ChannelFlags = 1 + edgePolicy.Node = node1 + edgePolicy.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + checkPolicies(node1, true, true) + checkPolicies(node2, true, true) +} + +// TestChannelEdgePruningUpdateIndexDeletion tests that once edges are deleted +// from the graph, then their entries within the update index are also cleaned +// up. +func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // With the two nodes created, we'll now create a random channel, as + // well as two edges in the database with distinct update times. + edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1.ChannelFlags = 0 + edge1.Node = node1 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge2.ChannelFlags = 1 + edge2.Node = node2 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // checkIndexTimestamps is a helper function that checks the edge update + // index only includes the given timestamps. + checkIndexTimestamps := func(timestamps ...uint64) { + timestampSet := make(map[uint64]struct{}) + for _, t := range timestamps { + timestampSet[t] = struct{}{} + } + + err := db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + numEntries := edgeUpdateIndex.Stats().KeyN + expectedEntries := len(timestampSet) + if numEntries != expectedEntries { + return fmt.Errorf("expected %v entries in the "+ + "update index, got %v", expectedEntries, + numEntries) + } + + return edgeUpdateIndex.ForEach(func(k, _ []byte) error { + t := byteOrder.Uint64(k[:8]) + if _, ok := timestampSet[t]; !ok { + return fmt.Errorf("found unexpected "+ + "timestamp "+"%d", t) + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } + } + + // With both edges policies added, we'll make sure to check they exist + // within the edge update index. + checkIndexTimestamps( + uint64(edge1.LastUpdate.Unix()), + uint64(edge2.LastUpdate.Unix()), + ) + + // Now, we'll update the edge policies to ensure the old timestamps are + // removed from the update index. + edge1.ChannelFlags = 2 + edge1.LastUpdate = time.Now() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + edge2.ChannelFlags = 3 + edge2.LastUpdate = edge1.LastUpdate.Add(time.Hour) + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // With the policies updated, we should now be able to find their + // updated entries within the update index. + checkIndexTimestamps( + uint64(edge1.LastUpdate.Unix()), + uint64(edge2.LastUpdate.Unix()), + ) + + // Now we'll prune the graph, removing the edges, and also the update + // index entries from the database all together. + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) + _, err = graph.PruneGraph( + []*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // Finally, we'll check the database state one last time to conclude + // that we should no longer be able to locate _any_ entries within the + // edge update index. + checkIndexTimestamps() +} + +// TestPruneGraphNodes tests that unconnected vertexes are pruned via the +// PruneSyncState method. +func TestPruneGraphNodes(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + // We'll start off by inserting our source node, to ensure that it's + // the only node left after we prune the graph. + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // With the source node inserted, we'll now add three nodes to the + // channel graph, at the end of the scenario, only two of these nodes + // should still be in the graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node3, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node3); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now add a new edge to the graph, but only actually advertise + // the edge of *one* of the nodes. + edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // We'll now insert an advertised edge, but it'll only be the edge that + // points from the first to the second node. + edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1.ChannelFlags = 0 + edge1.Node = node1 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // We'll now initiate a around of graph pruning. + if err := graph.PruneGraphNodes(); err != nil { + t.Fatalf("unable to prune graph nodes: %v", err) + } + + // At this point, there should be 3 nodes left in the graph still: the + // source node (which can't be pruned), and node 1+2. Nodes 1 and two + // should still be left in the graph as there's half of an advertised + // edge between them. + assertNumNodes(t, graph, 3) + + // Finally, we'll ensure that node3, the only fully unconnected node as + // properly deleted from the graph and not another node in its place. + node3Pub, err := node3.PubKey() + if err != nil { + t.Fatalf("unable to fetch the pubkey of node3: %v", err) + } + if _, err := graph.FetchLightningNode(node3Pub); err == nil { + t.Fatalf("node 3 should have been deleted!") + } +} + +// TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge +// to the graph, one or both of the nodes the edge involves aren't found in the +// database, then shell edges are created for each node if needed. +func TestAddChannelEdgeShellNodes(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // To start, we'll create two nodes, and only add one of them to the + // channel graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // We'll now create an edge between the two nodes, as a result, node2 + // should be inserted into the database as a shell node. + edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + node1Pub, err := node1.PubKey() + if err != nil { + t.Fatalf("unable to parse node 1 pub: %v", err) + } + node2Pub, err := node2.PubKey() + if err != nil { + t.Fatalf("unable to parse node 2 pub: %v", err) + } + + // Ensure that node1 was inserted as a full node, while node2 only has + // a shell node present. + node1, err = graph.FetchLightningNode(node1Pub) + if err != nil { + t.Fatalf("unable to fetch node1: %v", err) + } + if !node1.HaveNodeAnnouncement { + t.Fatalf("have shell announcement for node1, shouldn't") + } + + node2, err = graph.FetchLightningNode(node2Pub) + if err != nil { + t.Fatalf("unable to fetch node2: %v", err) + } + if node2.HaveNodeAnnouncement { + t.Fatalf("should have shell announcement for node2, but is full") + } +} + +// TestNodePruningUpdateIndexDeletion tests that once a node has been removed +// from the channel graph, we also remove the entry from the update index as +// well. +func TestNodePruningUpdateIndexDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with a single node that will be + // removed shortly. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll confirm that we can retrieve the node using + // NodeUpdatesInHorizon, using a time that's slightly beyond the last + // update time of our test node. + startTime := time.Unix(9, 0) + endTime := node1.LastUpdate.Add(time.Minute) + nodesInHorizon, err := graph.NodeUpdatesInHorizon(startTime, endTime) + if err != nil { + t.Fatalf("unable to fetch nodes in horizon: %v", err) + } + + // We should only have a single node, and that node should exactly + // match the node we just inserted. + if len(nodesInHorizon) != 1 { + t.Fatalf("should have 1 nodes instead have: %v", + len(nodesInHorizon)) + } + if err := compareNodes(node1, &nodesInHorizon[0]); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // We'll now delete the node from the graph, this should result in it + // being removed from the update index as well. + nodePub, _ := node1.PubKey() + if err := graph.DeleteLightningNode(nodePub); err != nil { + t.Fatalf("unable to delete node: %v", err) + } + + // Now that the node has been deleted, we'll again query the nodes in + // the horizon. This time we should have no nodes at all. + nodesInHorizon, err = graph.NodeUpdatesInHorizon(startTime, endTime) + if err != nil { + t.Fatalf("unable to fetch nodes in horizon: %v", err) + } + + if len(nodesInHorizon) != 0 { + t.Fatalf("should have zero nodes instead have: %v", + len(nodesInHorizon)) + } +} + +// TestNodeIsPublic ensures that we properly detect nodes that are seen as +// public within the network graph. +func TestNodeIsPublic(t *testing.T) { + t.Parallel() + + // We'll start off the test by creating a small network of 3 + // participants with the following graph: + // + // Alice <-> Bob <-> Carol + // + // We'll need to create a separate database and channel graph for each + // participant to replicate real-world scenarios (private edges being in + // some graphs but not others, etc.). + aliceDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + aliceNode, err := createTestVertex(aliceDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + aliceGraph := aliceDB.ChannelGraph() + if err := aliceGraph.SetSourceNode(aliceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + bobDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + bobNode, err := createTestVertex(bobDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + bobGraph := bobDB.ChannelGraph() + if err := bobGraph.SetSourceNode(bobNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + carolDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + carolNode, err := createTestVertex(carolDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + carolGraph := carolDB.ChannelGraph() + if err := carolGraph.SetSourceNode(carolNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + aliceBobEdge, _ := createEdge(10, 0, 0, 0, aliceNode, bobNode) + bobCarolEdge, _ := createEdge(10, 1, 0, 1, bobNode, carolNode) + + // After creating all of our nodes and edges, we'll add them to each + // participant's graph. + nodes := []*LightningNode{aliceNode, bobNode, carolNode} + edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} + dbs := []*DB{aliceDB, bobDB, carolDB} + graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} + for i, graph := range graphs { + for _, node := range nodes { + node.db = dbs[i] + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + } + for _, edge := range edges { + edge.db = dbs[i] + if err := graph.AddChannelEdge(edge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + } + } + + // checkNodes is a helper closure that will be used to assert that the + // given nodes are seen as public/private within the given graphs. + checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph, + public bool) { + + t.Helper() + + for _, node := range nodes { + for _, graph := range graphs { + isPublic, err := graph.IsPublicNode(node.PubKeyBytes) + if err != nil { + t.Fatalf("unable to determine if pivot "+ + "is public: %v", err) + } + + switch { + case isPublic && !public: + t.Fatalf("expected %x to be private", + node.PubKeyBytes) + case !isPublic && public: + t.Fatalf("expected %x to be public", + node.PubKeyBytes) + } + } + } + } + + // Due to the way the edges were set up above, we'll make sure each node + // can correctly determine that every other node is public. + checkNodes(nodes, graphs, true) + + // Now, we'll remove the edge between Alice and Bob from everyone's + // graph. This will make Alice be seen as a private node as it no longer + // has any advertised edges. + for _, graph := range graphs { + err := graph.DeleteChannelEdges(aliceBobEdge.ChannelID) + if err != nil { + t.Fatalf("unable to remove edge: %v", err) + } + } + checkNodes( + []*LightningNode{aliceNode}, + []*ChannelGraph{bobGraph, carolGraph}, + false, + ) + + // We'll also make the edge between Bob and Carol private. Within Bob's + // and Carol's graph, the edge will exist, but it will not have a proof + // that allows it to be advertised. Within Alice's graph, we'll + // completely remove the edge as it is not possible for her to know of + // it without it being advertised. + for i, graph := range graphs { + err := graph.DeleteChannelEdges(bobCarolEdge.ChannelID) + if err != nil { + t.Fatalf("unable to remove edge: %v", err) + } + + if graph == aliceGraph { + continue + } + + bobCarolEdge.AuthProof = nil + bobCarolEdge.db = dbs[i] + if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + } + + // With the modifications above, Bob should now be seen as a private + // node from both Alice's and Carol's perspective. + checkNodes( + []*LightningNode{bobNode}, + []*ChannelGraph{aliceGraph, carolGraph}, + false, + ) +} + +// TestDisabledChannelIDs ensures that the disabled channels within the +// disabledEdgePolicyBucket are managed properly and the list returned from +// DisabledChannelIDs is correct. +func TestDisabledChannelIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + graph := db.ChannelGraph() + + // Create first node and add it to the graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create second node and add it to the graph. + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Adding a new channel edge to the graph. + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure no disabled channels exist in the bucket on start. + disabledChanIds, err := graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } + + // Add one disabled policy and ensure the channel is still not in the + // disabled list. + edge1.ChannelFlags |= lnwire.ChanUpdateDisabled + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } + + // Add second disabled policy and ensure the channel is now in the + // disabled list. + edge2.ChannelFlags |= lnwire.ChanUpdateDisabled + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID { + t.Fatalf("expected disabled channel with id %v, "+ + "got %v", edgeInfo.ChannelID, disabledChanIds) + } + + // Delete the channel edge and ensure it is removed from the disabled list. + if err = graph.DeleteChannelEdges(edgeInfo.ChannelID); err != nil { + t.Fatalf("unable to delete channel edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } +} + +// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in +// the DB that indicates that it should support the htlc_maximum_value_msat +// field, but it is not part of the opaque data, then we'll handle it as it is +// unknown. It also checks that we are correctly able to overwrite it when we +// receive the proper update. +func TestEdgePolicyMissingMaxHtcl(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the update of edges inserted into the database, so + // we create two vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanID := edgeInfo.ChannelID + from := edge2.Node.PubKeyBytes[:] + to := edge1.Node.PubKeyBytes[:] + + // We'll remove the no max_htlc field from the first edge policy, and + // all other opaque data, and serialize it. + edge1.MessageFlags = 0 + edge1.ExtraOpaqueData = nil + + var b bytes.Buffer + err = serializeChanEdgePolicy(&b, edge1, to) + if err != nil { + t.Fatalf("unable to serialize policy") + } + + // Set the max_htlc field. The extra bytes added to the serialization + // will be the opaque data containing the serialized field. + edge1.MessageFlags = lnwire.ChanUpdateOptionMaxHtlc + edge1.MaxHTLC = 13928598 + var b2 bytes.Buffer + err = serializeChanEdgePolicy(&b2, edge1, to) + if err != nil { + t.Fatalf("unable to serialize policy") + } + + withMaxHtlc := b2.Bytes() + + // Remove the opaque data from the serialization. + stripped := withMaxHtlc[:len(b.Bytes())] + + // Attempting to deserialize these bytes should return an error. + r := bytes.NewReader(stripped) + err = db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + _, err = deserializeChanEdgePolicy(r, nodes) + if err != ErrEdgePolicyOptionalFieldNotFound { + t.Fatalf("expected "+ + "ErrEdgePolicyOptionalFieldNotFound, got %v", + err) + } + + return nil + }) + if err != nil { + t.Fatalf("error reading db: %v", err) + } + + // Put the stripped bytes in the DB. + err = db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrEdgeNotFound + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID) + + var scratch [8]byte + var indexKey [8 + 8]byte + copy(indexKey[:], scratch[:]) + byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return edges.Put(edgeKey[:], stripped) + }) + if err != nil { + t.Fatalf("error writing db: %v", err) + } + + // And add the second, unmodified edge. + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Attempt to fetch the edge and policies from the DB. Since the policy + // we added is invalid according to the new format, it should be as we + // are not aware of the policy (indicated by the policy returned being + // nil) + dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + + // The first edge should have a nil-policy returned + if dbEdge1 != nil { + t.Fatalf("expected db edge to be nil") + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) + + // Now add the original, unmodified edge policy, and make sure the edge + // policies then become fully populated. + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) +} + +// assertNumZombies queries the provided ChannelGraph for NumZombies, and +// asserts that the returned number is equal to expZombies. +func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { + t.Helper() + + numZombies, err := graph.NumZombies() + if err != nil { + t.Fatalf("unable to query number of zombies: %v", err) + } + + if numZombies != expZombies { + t.Fatalf("expected %d zombies, found %d", + expZombies, numZombies) + } +} + +// TestGraphZombieIndex ensures that we can mark edges correctly as zombie/live. +func TestGraphZombieIndex(t *testing.T) { + t.Parallel() + + // We'll start by creating our test graph along with a test edge. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to create test database: %v", err) + } + graph := db.ChannelGraph() + + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // Swap the nodes if the second's pubkey is smaller than the first. + // Without this, the comparisons at the end will fail probabilistically. + if bytes.Compare(node2.PubKeyBytes[:], node1.PubKeyBytes[:]) < 0 { + node1, node2 = node2, node1 + } + + edge, _, _ := createChannelEdge(db, node1, node2) + if err := graph.AddChannelEdge(edge); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Since the edge is known the graph and it isn't a zombie, IsZombieEdge + // should not report the channel as a zombie. + isZombie, _, _ := graph.IsZombieEdge(edge.ChannelID) + if isZombie { + t.Fatal("expected edge to not be marked as zombie") + } + assertNumZombies(t, graph, 0) + + // If we delete the edge and mark it as a zombie, then we should expect + // to see it within the index. + err = graph.DeleteChannelEdges(edge.ChannelID) + if err != nil { + t.Fatalf("unable to mark edge as zombie: %v", err) + } + isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID) + if !isZombie { + t.Fatal("expected edge to be marked as zombie") + } + if pubKey1 != node1.PubKeyBytes { + t.Fatalf("expected pubKey1 %x, got %x", node1.PubKeyBytes, + pubKey1) + } + if pubKey2 != node2.PubKeyBytes { + t.Fatalf("expected pubKey2 %x, got %x", node2.PubKeyBytes, + pubKey2) + } + assertNumZombies(t, graph, 1) + + // Similarly, if we mark the same edge as live, we should no longer see + // it within the index. + if err := graph.MarkEdgeLive(edge.ChannelID); err != nil { + t.Fatalf("unable to mark edge as live: %v", err) + } + isZombie, _, _ = graph.IsZombieEdge(edge.ChannelID) + if isZombie { + t.Fatal("expected edge to not be marked as zombie") + } + assertNumZombies(t, graph, 0) +} + +// compareNodes is used to compare two LightningNodes while excluding the +// Features struct, which cannot be compared as the semantics for reserializing +// the featuresMap have not been defined. +func compareNodes(a, b *LightningNode) error { + if a.LastUpdate != b.LastUpdate { + return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ + "got %v", a.LastUpdate, b.LastUpdate) + } + if !reflect.DeepEqual(a.Addresses, b.Addresses) { + return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ + "got %#v", a.Addresses, b.Addresses) + } + if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) { + return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+ + "got %#v", a.PubKeyBytes, b.PubKeyBytes) + } + if !reflect.DeepEqual(a.Color, b.Color) { + return fmt.Errorf("Color doesn't match: expected %#v, \n "+ + "got %#v", a.Color, b.Color) + } + if !reflect.DeepEqual(a.Alias, b.Alias) { + return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ + "got %#v", a.Alias, b.Alias) + } + if !reflect.DeepEqual(a.db, b.db) { + return fmt.Errorf("db doesn't match: expected %#v, \n "+ + "got %#v", a.db, b.db) + } + if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { + return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ + "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) + } + if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { + return fmt.Errorf("extra data doesn't match: %v vs %v", + a.ExtraOpaqueData, b.ExtraOpaqueData) + } + + return nil +} + +// compareEdgePolicies is used to compare two ChannelEdgePolices using +// compareNodes, so as to exclude comparisons of the Nodes' Features struct. +func compareEdgePolicies(a, b *ChannelEdgePolicy) error { + if a.ChannelID != b.ChannelID { + return fmt.Errorf("ChannelID doesn't match: expected %v, "+ + "got %v", a.ChannelID, b.ChannelID) + } + if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { + return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ + "got %#v", a.LastUpdate, b.LastUpdate) + } + if a.MessageFlags != b.MessageFlags { + return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ + "got %v", a.MessageFlags, b.MessageFlags) + } + if a.ChannelFlags != b.ChannelFlags { + return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ + "got %v", a.ChannelFlags, b.ChannelFlags) + } + if a.TimeLockDelta != b.TimeLockDelta { + return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+ + "got %v", a.TimeLockDelta, b.TimeLockDelta) + } + if a.MinHTLC != b.MinHTLC { + return fmt.Errorf("MinHTLC doesn't match: expected %v, "+ + "got %v", a.MinHTLC, b.MinHTLC) + } + if a.MaxHTLC != b.MaxHTLC { + return fmt.Errorf("MaxHTLC doesn't match: expected %v, "+ + "got %v", a.MaxHTLC, b.MaxHTLC) + } + if a.FeeBaseMSat != b.FeeBaseMSat { + return fmt.Errorf("FeeBaseMSat doesn't match: expected %v, "+ + "got %v", a.FeeBaseMSat, b.FeeBaseMSat) + } + if a.FeeProportionalMillionths != b.FeeProportionalMillionths { + return fmt.Errorf("FeeProportionalMillionths doesn't match: "+ + "expected %v, got %v", a.FeeProportionalMillionths, + b.FeeProportionalMillionths) + } + if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { + return fmt.Errorf("extra data doesn't match: %v vs %v", + a.ExtraOpaqueData, b.ExtraOpaqueData) + } + if err := compareNodes(a.Node, b.Node); err != nil { + return err + } + if !reflect.DeepEqual(a.db, b.db) { + return fmt.Errorf("db doesn't match: expected %#v, \n "+ + "got %#v", a.db, b.db) + } + return nil +} + +// TestLightningNodeSigVerifcation checks that we can use the LightningNode's +// pubkey to verify signatures. +func TestLightningNodeSigVerification(t *testing.T) { + t.Parallel() + + // Create some dummy data to sign. + var data [32]byte + if _, err := prand.Read(data[:]); err != nil { + t.Fatalf("unable to read prand: %v", err) + } + + // Create private key and sign the data with it. + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to crete priv key: %v", err) + } + + sign, err := priv.Sign(data[:]) + if err != nil { + t.Fatalf("unable to sign: %v", err) + } + + // Sanity check that the signature checks out. + if !sign.Verify(data[:], priv.PubKey()) { + t.Fatalf("signature doesn't check out") + } + + // Create a LightningNode from the same private key. + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + node, err := createLightningNode(db, priv) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + // And finally check that we can verify the same signature from the + // pubkey returned from the lightning node. + nodePub, err := node.PubKey() + if err != nil { + t.Fatalf("unable to get pubkey: %v", err) + } + + if !sign.Verify(data[:], nodePub) { + t.Fatalf("unable to verify sig") + } +} + +// TestComputeFee tests fee calculation based on both in- and outgoing amt. +func TestComputeFee(t *testing.T) { + var ( + policy = ChannelEdgePolicy{ + FeeBaseMSat: 10000, + FeeProportionalMillionths: 30000, + } + outgoingAmt = lnwire.MilliSatoshi(1000000) + expectedFee = lnwire.MilliSatoshi(40000) + ) + + fee := policy.ComputeFee(outgoingAmt) + if fee != expectedFee { + t.Fatalf("expected fee %v, got %v", expectedFee, fee) + } + + fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee) + if fwdFee != expectedFee { + t.Fatalf("expected fee %v, but got %v", fee, fwdFee) + } +} diff --git a/channeldb/migration_01_to_11/invoice_test.go b/channeldb/migration_01_to_11/invoice_test.go new file mode 100644 index 00000000..795fe493 --- /dev/null +++ b/channeldb/migration_01_to_11/invoice_test.go @@ -0,0 +1,694 @@ +package migration_01_to_11 + +import ( + "crypto/rand" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" +) + +func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { + var pre [32]byte + if _, err := rand.Read(pre[:]); err != nil { + return nil, err + } + + i := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Terms: ContractTerm{ + PaymentPreimage: pre, + Value: value, + }, + Htlcs: map[CircuitKey]*InvoiceHTLC{}, + Expiry: 4000, + } + i.Memo = []byte("memo") + i.Receipt = []byte("receipt") + + // Create a random byte slice of MaxPaymentRequestSize bytes to be used + // as a dummy paymentrequest, and determine if it should be set based + // on one of the random bytes. + var r [MaxPaymentRequestSize]byte + if _, err := rand.Read(r[:]); err != nil { + return nil, err + } + if r[0]&1 == 0 { + i.PaymentRequest = r[:] + } else { + i.PaymentRequest = []byte("") + } + + return i, nil +} + +func TestInvoiceWorkflow(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // Create a fake invoice which we'll use several times in the tests + // below. + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Htlcs: map[CircuitKey]*InvoiceHTLC{}, + } + fakeInvoice.Memo = []byte("memo") + fakeInvoice.Receipt = []byte("receipt") + fakeInvoice.PaymentRequest = []byte("") + copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) + fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) + + paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() + + // Add the invoice to the database, this should succeed as there aren't + // any existing invoices within the database with the same payment + // hash. + if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { + t.Fatalf("unable to find invoice: %v", err) + } + + // Attempt to retrieve the invoice which was just added to the + // database. It should be found, and the invoice returned should be + // identical to the one created above. + dbInvoice, err := db.LookupInvoice(paymentHash) + if err != nil { + t.Fatalf("unable to find invoice: %v", err) + } + if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { + t.Fatalf("invoice fetched from db doesn't match original %v vs %v", + spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice)) + } + + // The add index of the invoice retrieved from the database should now + // be fully populated. As this is the first index written to the DB, + // the addIndex should be 1. + if dbInvoice.AddIndex != 1 { + t.Fatalf("wrong add index: expected %v, got %v", 1, + dbInvoice.AddIndex) + } + + // Settle the invoice, the version retrieved from the database should + // now have the settled bit toggle to true and a non-default + // SettledDate + payAmt := fakeInvoice.Terms.Value * 2 + _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + dbInvoice2, err := db.LookupInvoice(paymentHash) + if err != nil { + t.Fatalf("unable to fetch invoice: %v", err) + } + if dbInvoice2.Terms.State != ContractSettled { + t.Fatalf("invoice should now be settled but isn't") + } + if dbInvoice2.SettleDate.IsZero() { + t.Fatalf("invoice should have non-zero SettledDate but isn't") + } + + // Our 2x payment should be reflected, and also the settle index of 1 + // should also have been committed for this index. + if dbInvoice2.AmtPaid != payAmt { + t.Fatalf("wrong amt paid: expected %v, got %v", payAmt, + dbInvoice2.AmtPaid) + } + if dbInvoice2.SettleIndex != 1 { + t.Fatalf("wrong settle index: expected %v, got %v", 1, + dbInvoice2.SettleIndex) + } + + // Attempt to insert generated above again, this should fail as + // duplicates are rejected by the processing logic. + if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { + t.Fatalf("invoice insertion should fail due to duplication, "+ + "instead %v", err) + } + + // Attempt to look up a non-existent invoice, this should also fail but + // with a "not found" error. + var fakeHash [32]byte + if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { + t.Fatalf("lookup should have failed, instead %v", err) + } + + // Add 10 random invoices. + const numInvoices = 10 + amt := lnwire.NewMSatFromSatoshis(1000) + invoices := make([]*Invoice, numInvoices+1) + invoices[0] = &dbInvoice2 + for i := 1; i < len(invoices)-1; i++ { + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + hash := invoice.Terms.PaymentPreimage.Hash() + if _, err := db.AddInvoice(invoice, hash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + invoices[i] = invoice + } + + // Perform a scan to collect all the active invoices. + dbInvoices, err := db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch all invoices: %v", err) + } + + // The retrieve list of invoices should be identical as since we're + // using big endian, the invoices should be retrieved in ascending + // order (and the primary key should be incremented with each + // insertion). + for i := 0; i < len(invoices)-1; i++ { + if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { + t.Fatalf("retrieved invoices don't match %v vs %v", + spew.Sdump(invoices[i]), + spew.Sdump(dbInvoices[i])) + } + } +} + +// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as +// settled invoices are added to the database are properly placed in the add +// add or settle index which serves as an event time series. +func TestInvoiceAddTimeSeries(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // We'll start off by creating 20 random invoices, and inserting them + // into the database. + const numInvoices = 20 + amt := lnwire.NewMSatFromSatoshis(1000) + invoices := make([]Invoice, numInvoices) + for i := 0; i < len(invoices); i++ { + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + invoices[i] = *invoice + } + + // With the invoices constructed, we'll now create a series of queries + // that we'll use to assert expected return values of + // InvoicesAddedSince. + addQueries := []struct { + sinceAddIndex uint64 + + resp []Invoice + }{ + // If we specify a value of zero, we shouldn't get any invoices + // back. + { + sinceAddIndex: 0, + }, + + // If we specify a value well beyond the number of inserted + // invoices, we shouldn't get any invoices back. + { + sinceAddIndex: 99999999, + }, + + // Using an index of 1 should result in all values, but the + // first one being returned. + { + sinceAddIndex: 1, + resp: invoices[1:], + }, + + // If we use an index of 10, then we should retrieve the + // reaming 10 invoices. + { + sinceAddIndex: 10, + resp: invoices[10:], + }, + } + + for i, query := range addQueries { + resp, err := db.InvoicesAddedSince(query.sinceAddIndex) + if err != nil { + t.Fatalf("unable to query: %v", err) + } + + if !reflect.DeepEqual(query.resp, resp) { + t.Fatalf("test #%v: expected %v, got %v", i, + spew.Sdump(query.resp), spew.Sdump(resp)) + } + } + + // We'll now only settle the latter half of each of those invoices. + for i := 10; i < len(invoices); i++ { + invoice := &invoices[i] + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(0), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + } + + invoices, err = db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch invoices: %v", err) + } + + // We'll slice off the first 10 invoices, as we only settled the last + // 10. + invoices = invoices[10:] + + // We'll now prepare an additional set of queries to ensure the settle + // time series has properly been maintained in the database. + settleQueries := []struct { + sinceSettleIndex uint64 + + resp []Invoice + }{ + // If we specify a value of zero, we shouldn't get any settled + // invoices back. + { + sinceSettleIndex: 0, + }, + + // If we specify a value well beyond the number of settled + // invoices, we shouldn't get any invoices back. + { + sinceSettleIndex: 99999999, + }, + + // Using an index of 1 should result in the final 10 invoices + // being returned, as we only settled those. + { + sinceSettleIndex: 1, + resp: invoices[1:], + }, + } + + for i, query := range settleQueries { + resp, err := db.InvoicesSettledSince(query.sinceSettleIndex) + if err != nil { + t.Fatalf("unable to query: %v", err) + } + + if !reflect.DeepEqual(query.resp, resp) { + t.Fatalf("test #%v: expected %v, got %v", i, + spew.Sdump(query.resp), spew.Sdump(resp)) + } + } +} + +// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it +// twice, then the second time we also receive the invoice that we settled as a +// return argument. +func TestDuplicateSettleInvoice(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + db.now = func() time.Time { return time.Unix(1, 0) } + + // We'll start out by creating an invoice and writing it to the DB. + amt := lnwire.NewMSatFromSatoshis(1000) + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + payHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, payHash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + // With the invoice in the DB, we'll now attempt to settle the invoice. + dbInvoice, err := db.UpdateInvoice( + payHash, getUpdateInvoice(amt), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + + // We'll update what we expect the settle invoice to be so that our + // comparison below has the correct assumption. + invoice.SettleIndex = 1 + invoice.Terms.State = ContractSettled + invoice.AmtPaid = amt + invoice.SettleDate = dbInvoice.SettleDate + invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ + {}: { + Amt: amt, + AcceptTime: time.Unix(1, 0), + ResolveTime: time.Unix(1, 0), + State: HtlcStateSettled, + }, + } + + // We should get back the exact same invoice that we just inserted. + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } + + // If we try to settle the invoice again, then we should get the very + // same invoice back, but with an error this time. + dbInvoice, err = db.UpdateInvoice( + payHash, getUpdateInvoice(amt), + ) + if err != ErrInvoiceAlreadySettled { + t.Fatalf("expected ErrInvoiceAlreadySettled") + } + + if dbInvoice == nil { + t.Fatalf("invoice from db is nil after settle!") + } + + invoice.SettleDate = dbInvoice.SettleDate + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after second settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } +} + +// TestQueryInvoices ensures that we can properly query the invoice database for +// invoices using different types of queries. +func TestQueryInvoices(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // To begin the test, we'll add 50 invoices to the database. We'll + // assume that the index of the invoice within the database is the same + // as the amount of the invoice itself. + const numInvoices = 50 + for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { + invoice, err := randInvoice(i) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice: %v", err) + } + + // We'll only settle half of all invoices created. + if i%2 == 0 { + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(i), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + } + } + + // We'll then retrieve the set of all invoices and pending invoices. + // This will serve useful when comparing the expected responses of the + // query with the actual ones. + invoices, err := db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to retrieve invoices: %v", err) + } + pendingInvoices, err := db.FetchAllInvoices(true) + if err != nil { + t.Fatalf("unable to retrieve pending invoices: %v", err) + } + + // The test will consist of several queries along with their respective + // expected response. Each query response should match its expected one. + testCases := []struct { + query InvoiceQuery + expected []Invoice + }{ + // Fetch all invoices with a single query. + { + query: InvoiceQuery{ + NumMaxInvoices: numInvoices, + }, + expected: invoices, + }, + // Fetch all invoices with a single query, reversed. + { + query: InvoiceQuery{ + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: invoices, + }, + // Fetch the first 25 invoices. + { + query: InvoiceQuery{ + NumMaxInvoices: numInvoices / 2, + }, + expected: invoices[:numInvoices/2], + }, + // Fetch the first 10 invoices, but this time iterating + // backwards. + { + query: InvoiceQuery{ + IndexOffset: 11, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: invoices[:10], + }, + // Fetch the last 40 invoices. + { + query: InvoiceQuery{ + IndexOffset: 10, + NumMaxInvoices: numInvoices, + }, + expected: invoices[10:], + }, + // Fetch all but the first invoice. + { + query: InvoiceQuery{ + IndexOffset: 1, + NumMaxInvoices: numInvoices, + }, + expected: invoices[1:], + }, + // Fetch one invoice, reversed, with index offset 3. This + // should give us the second invoice in the array. + { + query: InvoiceQuery{ + IndexOffset: 3, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[1:2], + }, + // Same as above, at index 2. + { + query: InvoiceQuery{ + IndexOffset: 2, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[0:1], + }, + // Fetch one invoice, at index 1, reversed. Since invoice#1 is + // the very first, there won't be any left in a reverse search, + // so we expect no invoices to be returned. + { + query: InvoiceQuery{ + IndexOffset: 1, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: nil, + }, + // Same as above, but don't restrict the number of invoices to + // 1. + { + query: InvoiceQuery{ + IndexOffset: 1, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: nil, + }, + // Fetch one invoice, reversed, with no offset set. We expect + // the last invoice in the response. + { + query: InvoiceQuery{ + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Fetch one invoice, reversed, the offset set at numInvoices+1. + // We expect this to return the last invoice. + { + query: InvoiceQuery{ + IndexOffset: numInvoices + 1, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Same as above, at offset numInvoices. + { + query: InvoiceQuery{ + IndexOffset: numInvoices, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-2 : numInvoices-1], + }, + // Fetch one invoice, at no offset (same as offset 0). We + // expect the first invoice only in the response. + { + query: InvoiceQuery{ + NumMaxInvoices: 1, + }, + expected: invoices[:1], + }, + // Same as above, at offset 1. + { + query: InvoiceQuery{ + IndexOffset: 1, + NumMaxInvoices: 1, + }, + expected: invoices[1:2], + }, + // Same as above, at offset 2. + { + query: InvoiceQuery{ + IndexOffset: 2, + NumMaxInvoices: 1, + }, + expected: invoices[2:3], + }, + // Same as above, at offset numInvoices-1. Expect the last + // invoice to be returned. + { + query: InvoiceQuery{ + IndexOffset: numInvoices - 1, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Same as above, at offset numInvoices. No invoices should be + // returned, as there are no invoices after this offset. + { + query: InvoiceQuery{ + IndexOffset: numInvoices, + NumMaxInvoices: 1, + }, + expected: nil, + }, + // Fetch all pending invoices with a single query. + { + query: InvoiceQuery{ + PendingOnly: true, + NumMaxInvoices: numInvoices, + }, + expected: pendingInvoices, + }, + // Fetch the first 12 pending invoices. + { + query: InvoiceQuery{ + PendingOnly: true, + NumMaxInvoices: numInvoices / 4, + }, + expected: pendingInvoices[:len(pendingInvoices)/2], + }, + // Fetch the first 5 pending invoices, but this time iterating + // backwards. + { + query: InvoiceQuery{ + IndexOffset: 10, + PendingOnly: true, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + // Since we seek to the invoice with index 10 and + // iterate backwards, there should only be 5 pending + // invoices before it as every other invoice within the + // index is settled. + expected: pendingInvoices[:5], + }, + // Fetch the last 15 invoices. + { + query: InvoiceQuery{ + IndexOffset: 20, + PendingOnly: true, + NumMaxInvoices: numInvoices, + }, + // Since we seek to the invoice with index 20, there are + // 30 invoices left. From these 30, only 15 of them are + // still pending. + expected: pendingInvoices[len(pendingInvoices)-15:], + }, + } + + for i, testCase := range testCases { + response, err := db.QueryInvoices(testCase.query) + if err != nil { + t.Fatalf("unable to query invoice database: %v", err) + } + + if !reflect.DeepEqual(response.Invoices, testCase.expected) { + t.Fatalf("test #%d: query returned incorrect set of "+ + "invoices: expcted %v, got %v", i, + spew.Sdump(response.Invoices), + spew.Sdump(testCase.expected)) + } + } +} + +// getUpdateInvoice returns an invoice update callback that, when called, +// settles the invoice with the given amount. +func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { + return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + if invoice.Terms.State == ContractSettled { + return nil, ErrInvoiceAlreadySettled + } + + update := &InvoiceUpdateDesc{ + Preimage: invoice.Terms.PaymentPreimage, + State: ContractSettled, + Htlcs: map[CircuitKey]*HtlcAcceptDesc{ + {}: { + Amt: amt, + }, + }, + } + + return update, nil + } +} diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go new file mode 100644 index 00000000..5f40454a --- /dev/null +++ b/channeldb/migration_01_to_11/invoices.go @@ -0,0 +1,1320 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // UnknownPreimage is an all-zeroes preimage that indicates that the + // preimage for this invoice is not yet known. + UnknownPreimage lntypes.Preimage + + // invoiceBucket is the name of the bucket within the database that + // stores all data related to invoices no matter their final state. + // Within the invoice bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint32. + invoiceBucket = []byte("invoices") + + // paymentHashIndexBucket is the name of the sub-bucket within the + // invoiceBucket which indexes all invoices by their payment hash. The + // payment hash is the sha256 of the invoice's payment preimage. This + // index is used to detect duplicates, and also to provide a fast path + // for looking up incoming HTLCs to determine if we're able to settle + // them fully. + // + // maps: payHash => invoiceKey + invoiceIndexBucket = []byte("paymenthashes") + + // numInvoicesKey is the name of key which houses the auto-incrementing + // invoice ID which is essentially used as a primary key. With each + // invoice inserted, the primary key is incremented by one. This key is + // stored within the invoiceIndexBucket. Within the invoiceBucket + // invoices are uniquely identified by the invoice ID. + numInvoicesKey = []byte("nik") + + // addIndexBucket is an index bucket that we'll use to create a + // monotonically increasing set of add indexes. Each time we add a new + // invoice, this sequence number will be incremented and then populated + // within the new invoice. + // + // In addition to this sequence number, we map: + // + // addIndexNo => invoiceKey + addIndexBucket = []byte("invoice-add-index") + + // settleIndexBucket is an index bucket that we'll use to create a + // monotonically increasing integer for tracking a "settle index". Each + // time an invoice is settled, this sequence number will be incremented + // as populate within the newly settled invoice. + // + // In addition to this sequence number, we map: + // + // settleIndexNo => invoiceKey + settleIndexBucket = []byte("invoice-settle-index") + + // ErrInvoiceAlreadySettled is returned when the invoice is already + // settled. + ErrInvoiceAlreadySettled = errors.New("invoice already settled") + + // ErrInvoiceAlreadyCanceled is returned when the invoice is already + // canceled. + ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") + + // ErrInvoiceAlreadyAccepted is returned when the invoice is already + // accepted. + ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") + + // ErrInvoiceStillOpen is returned when the invoice is still open. + ErrInvoiceStillOpen = errors.New("invoice still open") +) + +const ( + // MaxMemoSize is maximum size of the memo field within invoices stored + // in the database. + MaxMemoSize = 1024 + + // MaxReceiptSize is the maximum size of the payment receipt stored + // within the database along side incoming/outgoing invoices. + MaxReceiptSize = 1024 + + // MaxPaymentRequestSize is the max size of a payment request for + // this invoice. + // TODO(halseth): determine the max length payment request when field + // lengths are final. + MaxPaymentRequestSize = 4096 + + // A set of tlv type definitions used to serialize invoice htlcs to the + // database. + chanIDType tlv.Type = 1 + htlcIDType tlv.Type = 3 + amtType tlv.Type = 5 + acceptHeightType tlv.Type = 7 + acceptTimeType tlv.Type = 9 + resolveTimeType tlv.Type = 11 + expiryHeightType tlv.Type = 13 + stateType tlv.Type = 15 +) + +// ContractState describes the state the invoice is in. +type ContractState uint8 + +const ( + // ContractOpen means the invoice has only been created. + ContractOpen ContractState = 0 + + // ContractSettled means the htlc is settled and the invoice has been + // paid. + ContractSettled ContractState = 1 + + // ContractCanceled means the invoice has been canceled. + ContractCanceled ContractState = 2 + + // ContractAccepted means the HTLC has been accepted but not settled + // yet. + ContractAccepted ContractState = 3 +) + +// String returns a human readable identifier for the ContractState type. +func (c ContractState) String() string { + switch c { + case ContractOpen: + return "Open" + case ContractSettled: + return "Settled" + case ContractCanceled: + return "Canceled" + case ContractAccepted: + return "Accepted" + } + + return "Unknown" +} + +// ContractTerm is a companion struct to the Invoice struct. This struct houses +// the necessary conditions required before the invoice can be considered fully +// settled by the payee. +type ContractTerm struct { + // PaymentPreimage is the preimage which is to be revealed in the + // occasion that an HTLC paying to the hash of this preimage is + // extended. + PaymentPreimage lntypes.Preimage + + // Value is the expected amount of milli-satoshis to be paid to an HTLC + // which can be satisfied by the above preimage. + Value lnwire.MilliSatoshi + + // State describes the state the invoice is in. + State ContractState +} + +// Invoice is a payment invoice generated by a payee in order to request +// payment for some good or service. The inclusion of invoices within Lightning +// creates a payment work flow for merchants very similar to that of the +// existing financial system within PayPal, etc. Invoices are added to the +// database when a payment is requested, then can be settled manually once the +// payment is received at the upper layer. For record keeping purposes, +// invoices are never deleted from the database, instead a bit is toggled +// denoting the invoice has been fully settled. Within the database, all +// invoices must have a unique payment hash which is generated by taking the +// sha256 of the payment preimage. +type Invoice struct { + // Memo is an optional memo to be stored along side an invoice. The + // memo may contain further details pertaining to the invoice itself, + // or any other message which fits within the size constraints. + Memo []byte + + // Receipt is an optional field dedicated for storing a + // cryptographically binding receipt of payment. + // + // TODO(roasbeef): document scheme. + Receipt []byte + + // PaymentRequest is an optional field where a payment request created + // for this invoice can be stored. + PaymentRequest []byte + + // FinalCltvDelta is the minimum required number of blocks before htlc + // expiry when the invoice is accepted. + FinalCltvDelta int32 + + // Expiry defines how long after creation this invoice should expire. + Expiry time.Duration + + // CreationDate is the exact time the invoice was created. + CreationDate time.Time + + // SettleDate is the exact time the invoice was settled. + SettleDate time.Time + + // Terms are the contractual payment terms of the invoice. Once all the + // terms have been satisfied by the payer, then the invoice can be + // considered fully fulfilled. + // + // TODO(roasbeef): later allow for multiple terms to fulfill the final + // invoice: payment fragmentation, etc. + Terms ContractTerm + + // AddIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all invoices created. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been added before they re-connected. + // + // NOTE: This index starts at 1. + AddIndex uint64 + + // SettleIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all settled invoices. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been settled before they re-connected. + // + // NOTE: This index starts at 1. + SettleIndex uint64 + + // AmtPaid is the final amount that we ultimately accepted for pay for + // this invoice. We specify this value independently as it's possible + // that the invoice originally didn't specify an amount, or the sender + // overpaid. + AmtPaid lnwire.MilliSatoshi + + // Htlcs records all htlcs that paid to this invoice. Some of these + // htlcs may have been marked as canceled. + Htlcs map[CircuitKey]*InvoiceHTLC +} + +// HtlcState defines the states an htlc paying to an invoice can be in. +type HtlcState uint8 + +const ( + // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. + HtlcStateAccepted HtlcState = iota + + // HtlcStateCanceled indicates the htlc is canceled back to the + // sender. + HtlcStateCanceled + + // HtlcStateSettled indicates the htlc is settled. + HtlcStateSettled +) + +// InvoiceHTLC contains details about an htlc paying to this invoice. +type InvoiceHTLC struct { + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // AcceptHeight is the block height at which the invoice registry + // decided to accept this htlc as a payment to the invoice. At this + // height, the invoice cltv delay must have been met. + AcceptHeight uint32 + + // AcceptTime is the wall clock time at which the invoice registry + // decided to accept the htlc. + AcceptTime time.Time + + // ResolveTime is the wall clock time at which the invoice registry + // decided to settle the htlc. + ResolveTime time.Time + + // Expiry is the expiry height of this htlc. + Expiry uint32 + + // State indicates the state the invoice htlc is currently in. A + // canceled htlc isn't just removed from the invoice htlcs map, because + // we need AcceptHeight to properly cancel the htlc back. + State HtlcState +} + +// HtlcAcceptDesc describes the details of a newly accepted htlc. +type HtlcAcceptDesc struct { + // AcceptHeight is the block height at which this htlc was accepted. + AcceptHeight int32 + + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // Expiry is the expiry height of this htlc. + Expiry uint32 +} + +// InvoiceUpdateDesc describes the changes that should be applied to the +// invoice. +type InvoiceUpdateDesc struct { + // State is the new state that this invoice should progress to. + State ContractState + + // Htlcs describes the changes that need to be made to the invoice htlcs + // in the database. Htlc map entries with their value set should be + // added. If the map value is nil, the htlc should be canceled. + Htlcs map[CircuitKey]*HtlcAcceptDesc + + // Preimage must be set to the preimage when state is settled. + Preimage lntypes.Preimage +} + +// InvoiceUpdateCallback is a callback used in the db transaction to update the +// invoice. +type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) + +func validateInvoice(i *Invoice) error { + if len(i.Memo) > MaxMemoSize { + return fmt.Errorf("max length a memo is %v, and invoice "+ + "of length %v was provided", MaxMemoSize, len(i.Memo)) + } + if len(i.Receipt) > MaxReceiptSize { + return fmt.Errorf("max length a receipt is %v, and invoice "+ + "of length %v was provided", MaxReceiptSize, + len(i.Receipt)) + } + if len(i.PaymentRequest) > MaxPaymentRequestSize { + return fmt.Errorf("max length of payment request is %v, length "+ + "provided was %v", MaxPaymentRequestSize, + len(i.PaymentRequest)) + } + return nil +} + +// AddInvoice inserts the targeted invoice into the database. If the invoice has +// *any* payment hashes which already exists within the database, then the +// insertion will be aborted and rejected due to the strict policy banning any +// duplicate payment hashes. A side effect of this function is that it sets +// AddIndex on newInvoice. +func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( + uint64, error) { + + if err := validateInvoice(newInvoice); err != nil { + return 0, err + } + + var invoiceAddIndex uint64 + err := d.Update(func(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + + invoiceIndex, err := invoices.CreateBucketIfNotExists( + invoiceIndexBucket, + ) + if err != nil { + return err + } + addIndex, err := invoices.CreateBucketIfNotExists( + addIndexBucket, + ) + if err != nil { + return err + } + + // Ensure that an invoice an identical payment hash doesn't + // already exist within the index. + if invoiceIndex.Get(paymentHash[:]) != nil { + return ErrDuplicateInvoice + } + + // If the current running payment ID counter hasn't yet been + // created, then create it now. + var invoiceNum uint32 + invoiceCounter := invoiceIndex.Get(numInvoicesKey) + if invoiceCounter == nil { + var scratch [4]byte + byteOrder.PutUint32(scratch[:], invoiceNum) + err := invoiceIndex.Put(numInvoicesKey, scratch[:]) + if err != nil { + return err + } + } else { + invoiceNum = byteOrder.Uint32(invoiceCounter) + } + + newIndex, err := putInvoice( + invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, + paymentHash, + ) + if err != nil { + return err + } + + invoiceAddIndex = newIndex + return nil + }) + if err != nil { + return 0, err + } + + return invoiceAddIndex, err +} + +// InvoicesAddedSince can be used by callers to seek into the event time series +// of all the invoices added in the database. The specified sinceAddIndex +// should be the highest add index that the caller knows of. This method will +// return all invoices with an add index greater than the specified +// sinceAddIndex. +// +// NOTE: The index starts from 1, as a result. We enforce that specifying a +// value below the starting index value is a noop. +func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { + var newInvoices []Invoice + + // If an index of zero was specified, then in order to maintain + // backwards compat, we won't send out any new invoices. + if sinceAddIndex == 0 { + return newInvoices, nil + } + + var startIndex [8]byte + byteOrder.PutUint64(startIndex[:], sinceAddIndex) + + err := d.DB.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + addIndex := invoices.Bucket(addIndexBucket) + if addIndex == nil { + return ErrNoInvoicesCreated + } + + // We'll now run through each entry in the add index starting + // at our starting index. We'll continue until we reach the + // very end of the current key space. + invoiceCursor := addIndex.Cursor() + + // We'll seek to the starting index, then manually advance the + // cursor in order to skip the entry with the since add index. + invoiceCursor.Seek(startIndex[:]) + addSeqNo, invoiceKey := invoiceCursor.Next() + + for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() { + + // For each key found, we'll look up the actual + // invoice, then accumulate it into our return value. + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + newInvoices = append(newInvoices, invoice) + } + + return nil + }) + switch { + // If no invoices have been created, then we'll return the empty set of + // invoices. + case err == ErrNoInvoicesCreated: + + case err != nil: + return nil, err + } + + return newInvoices, nil +} + +// LookupInvoice attempts to look up an invoice according to its 32 byte +// payment hash. If an invoice which can settle the HTLC identified by the +// passed payment hash isn't found, then an error is returned. Otherwise, the +// full invoice is returned. Before setting the incoming HTLC, the values +// SHOULD be checked to ensure the payer meets the agreed upon contractual +// terms of the payment. +func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { + var invoice Invoice + err := d.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + invoiceIndex := invoices.Bucket(invoiceIndexBucket) + if invoiceIndex == nil { + return ErrNoInvoicesCreated + } + + // Check the invoice index to see if an invoice paying to this + // hash exists within the DB. + invoiceNum := invoiceIndex.Get(paymentHash[:]) + if invoiceNum == nil { + return ErrInvoiceNotFound + } + + // An invoice matching the payment hash has been found, so + // retrieve the record of the invoice itself. + i, err := fetchInvoice(invoiceNum, invoices) + if err != nil { + return err + } + invoice = i + + return nil + }) + if err != nil { + return invoice, err + } + + return invoice, nil +} + +// FetchAllInvoices returns all invoices currently stored within the database. +// If the pendingOnly param is true, then only unsettled invoices will be +// returned, skipping all invoices that are fully settled. +func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { + var invoices []Invoice + + err := d.View(func(tx *bbolt.Tx) error { + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return ErrNoInvoicesCreated + } + + // Iterate through the entire key space of the top-level + // invoice bucket. If key with a non-nil value stores the next + // invoice ID which maps to the corresponding invoice. + return invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoice(invoiceReader) + if err != nil { + return err + } + + if pendingOnly && + invoice.Terms.State == ContractSettled { + + return nil + } + + invoices = append(invoices, invoice) + + return nil + }) + }) + if err != nil { + return nil, err + } + + return invoices, nil +} + +// InvoiceQuery represents a query to the invoice database. The query allows a +// caller to retrieve all invoices starting from a particular add index and +// limit the number of results returned. +type InvoiceQuery struct { + // IndexOffset is the offset within the add indices to start at. This + // can be used to start the response at a particular invoice. + IndexOffset uint64 + + // NumMaxInvoices is the maximum number of invoices that should be + // starting from the add index. + NumMaxInvoices uint64 + + // PendingOnly, if set, returns unsettled invoices starting from the + // add index. + PendingOnly bool + + // Reversed, if set, indicates that the invoices returned should start + // from the IndexOffset and go backwards. + Reversed bool +} + +// InvoiceSlice is the response to a invoice query. It includes the original +// query, the set of invoices that match the query, and an integer which +// represents the offset index of the last item in the set of returned invoices. +// This integer allows callers to resume their query using this offset in the +// event that the query's response exceeds the maximum number of returnable +// invoices. +type InvoiceSlice struct { + InvoiceQuery + + // Invoices is the set of invoices that matched the query above. + Invoices []Invoice + + // FirstIndexOffset is the index of the first element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + FirstIndexOffset uint64 + + // LastIndexOffset is the index of the last element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + LastIndexOffset uint64 +} + +// QueryInvoices allows a caller to query the invoice database for invoices +// within the specified add index range. +func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { + resp := InvoiceSlice{ + InvoiceQuery: q, + } + + err := d.View(func(tx *bbolt.Tx) error { + // If the bucket wasn't found, then there aren't any invoices + // within the database yet, so we can simply exit. + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + invoiceAddIndex := invoices.Bucket(addIndexBucket) + if invoiceAddIndex == nil { + return ErrNoInvoicesCreated + } + + // keyForIndex is a helper closure that retrieves the invoice + // key for the given add index of an invoice. + keyForIndex := func(c *bbolt.Cursor, index uint64) []byte { + var keyIndex [8]byte + byteOrder.PutUint64(keyIndex[:], index) + _, invoiceKey := c.Seek(keyIndex[:]) + return invoiceKey + } + + // nextKey is a helper closure to determine what the next + // invoice key is when iterating over the invoice add index. + nextKey := func(c *bbolt.Cursor) ([]byte, []byte) { + if q.Reversed { + return c.Prev() + } + return c.Next() + } + + // We'll be using a cursor to seek into the database and return + // a slice of invoices. We'll need to determine where to start + // our cursor depending on the parameters set within the query. + c := invoiceAddIndex.Cursor() + invoiceKey := keyForIndex(c, q.IndexOffset+1) + + // If the query is specifying reverse iteration, then we must + // handle a few offset cases. + if q.Reversed { + switch q.IndexOffset { + + // This indicates the default case, where no offset was + // specified. In that case we just start from the last + // invoice. + case 0: + _, invoiceKey = c.Last() + + // This indicates the offset being set to the very + // first invoice. Since there are no invoices before + // this offset, and the direction is reversed, we can + // return without adding any invoices to the response. + case 1: + return nil + + // Otherwise we start iteration at the invoice prior to + // the offset. + default: + invoiceKey = keyForIndex(c, q.IndexOffset-1) + } + } + + // If we know that a set of invoices exists, then we'll begin + // our seek through the bucket in order to satisfy the query. + // We'll continue until either we reach the end of the range, or + // reach our max number of invoices. + for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { + // If our current return payload exceeds the max number + // of invoices, then we'll exit now. + if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { + break + } + + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + // Skip any settled invoices if the caller is only + // interested in unsettled. + if q.PendingOnly && + invoice.Terms.State == ContractSettled { + + continue + } + + // At this point, we've exhausted the offset, so we'll + // begin collecting invoices found within the range. + resp.Invoices = append(resp.Invoices, invoice) + } + + // If we iterated through the add index in reverse order, then + // we'll need to reverse the slice of invoices to return them in + // forward order. + if q.Reversed { + numInvoices := len(resp.Invoices) + for i := 0; i < numInvoices/2; i++ { + opposite := numInvoices - i - 1 + resp.Invoices[i], resp.Invoices[opposite] = + resp.Invoices[opposite], resp.Invoices[i] + } + } + + return nil + }) + if err != nil && err != ErrNoInvoicesCreated { + return resp, err + } + + // Finally, record the indexes of the first and last invoices returned + // so that the caller can resume from this point later on. + if len(resp.Invoices) > 0 { + resp.FirstIndexOffset = resp.Invoices[0].AddIndex + resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex + } + + return resp, nil +} + +// UpdateInvoice attempts to update an invoice corresponding to the passed +// payment hash. If an invoice matching the passed payment hash doesn't exist +// within the database, then the action will fail with a "not found" error. +// +// The update is performed inside the same database transaction that fetches the +// invoice and is therefore atomic. The fields to update are controlled by the +// supplied callback. +func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, + callback InvoiceUpdateCallback) (*Invoice, error) { + + var updatedInvoice *Invoice + err := d.Update(func(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + invoiceIndex, err := invoices.CreateBucketIfNotExists( + invoiceIndexBucket, + ) + if err != nil { + return err + } + settleIndex, err := invoices.CreateBucketIfNotExists( + settleIndexBucket, + ) + if err != nil { + return err + } + + // Check the invoice index to see if an invoice paying to this + // hash exists within the DB. + invoiceNum := invoiceIndex.Get(paymentHash[:]) + if invoiceNum == nil { + return ErrInvoiceNotFound + } + + updatedInvoice, err = d.updateInvoice( + paymentHash, invoices, settleIndex, invoiceNum, + callback, + ) + + return err + }) + + return updatedInvoice, err +} + +// InvoicesSettledSince can be used by callers to catch up any settled invoices +// they missed within the settled invoice time series. We'll return all known +// settled invoice that have a settle index higher than the passed +// sinceSettleIndex. +// +// NOTE: The index starts from 1, as a result. We enforce that specifying a +// value below the starting index value is a noop. +func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { + var settledInvoices []Invoice + + // If an index of zero was specified, then in order to maintain + // backwards compat, we won't send out any new invoices. + if sinceSettleIndex == 0 { + return settledInvoices, nil + } + + var startIndex [8]byte + byteOrder.PutUint64(startIndex[:], sinceSettleIndex) + + err := d.DB.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + settleIndex := invoices.Bucket(settleIndexBucket) + if settleIndex == nil { + return ErrNoInvoicesCreated + } + + // We'll now run through each entry in the add index starting + // at our starting index. We'll continue until we reach the + // very end of the current key space. + invoiceCursor := settleIndex.Cursor() + + // We'll seek to the starting index, then manually advance the + // cursor in order to skip the entry with the since add index. + invoiceCursor.Seek(startIndex[:]) + seqNo, invoiceKey := invoiceCursor.Next() + + for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() { + + // For each key found, we'll look up the actual + // invoice, then accumulate it into our return value. + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + settledInvoices = append(settledInvoices, invoice) + } + + return nil + }) + if err != nil { + return nil, err + } + + return settledInvoices, nil +} + +func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, + i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( + uint64, error) { + + // Create the invoice key which is just the big-endian representation + // of the invoice number. + var invoiceKey [4]byte + byteOrder.PutUint32(invoiceKey[:], invoiceNum) + + // Increment the num invoice counter index so the next invoice bares + // the proper ID. + var scratch [4]byte + invoiceCounter := invoiceNum + 1 + byteOrder.PutUint32(scratch[:], invoiceCounter) + if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { + return 0, err + } + + // Add the payment hash to the invoice index. This will let us quickly + // identify if we can settle an incoming payment, and also to possibly + // allow a single invoice to have multiple payment installations. + err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) + if err != nil { + return 0, err + } + + // Next, we'll obtain the next add invoice index (sequence + // number), so we can properly place this invoice within this + // event stream. + nextAddSeqNo, err := addIndex.NextSequence() + if err != nil { + return 0, err + } + + // With the next sequence obtained, we'll updating the event series in + // the add index bucket to map this current add counter to the index of + // this new invoice. + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) + if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { + return 0, err + } + + i.AddIndex = nextAddSeqNo + + // Finally, serialize the invoice itself to be written to the disk. + var buf bytes.Buffer + if err := serializeInvoice(&buf, i); err != nil { + return 0, err + } + + if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil { + return 0, err + } + + return nextAddSeqNo, nil +} + +// serializeInvoice serializes an invoice to a writer. +// +// Note: this function is in use for a migration. Before making changes that +// would modify the on disk format, make a copy of the original code and store +// it with the migration. +func serializeInvoice(w io.Writer, i *Invoice) error { + if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil { + return err + } + + birthBytes, err := i.CreationDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { + return err + } + + settleBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + return err + } + + if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + return err + } + + var scratch [8]byte + byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { + return err + } + + if err := serializeHtlcs(w, i.Htlcs); err != nil { + return err + } + + return nil +} + +// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to +// a writer. +func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { + for key, htlc := range htlcs { + // Encode the htlc in a tlv stream. + chanID := key.ChanID.ToUint64() + amt := uint64(htlc.Amt) + acceptTime := uint64(htlc.AcceptTime.UnixNano()) + resolveTime := uint64(htlc.ResolveTime.UnixNano()) + state := uint8(htlc.State) + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return err + } + + var b bytes.Buffer + if err := tlvStream.Encode(&b); err != nil { + return err + } + + // Write the length of the tlv stream followed by the stream + // bytes. + err = binary.Write(w, byteOrder, uint64(b.Len())) + if err != nil { + return err + } + + if _, err := w.Write(b.Bytes()); err != nil { + return err + } + } + + return nil +} + +func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) { + invoiceBytes := invoices.Get(invoiceNum) + if invoiceBytes == nil { + return Invoice{}, ErrInvoiceNotFound + } + + invoiceReader := bytes.NewReader(invoiceBytes) + + return deserializeInvoice(invoiceReader) +} + +func deserializeInvoice(r io.Reader) (Invoice, error) { + var err error + invoice := Invoice{} + + // TODO(roasbeef): use read full everywhere + invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + if err != nil { + return invoice, err + } + invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + if err != nil { + return invoice, err + } + + invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + if err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil { + return invoice, err + } + + var expiry int64 + if err := binary.Read(r, byteOrder, &expiry); err != nil { + return invoice, err + } + invoice.Expiry = time.Duration(expiry) + + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + if err != nil { + return invoice, err + } + if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { + return invoice, err + } + + settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + if err != nil { + return invoice, err + } + if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { + return invoice, err + } + + if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { + return invoice, err + } + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return invoice, err + } + invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { + return invoice, err + } + + invoice.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return Invoice{}, err + } + + return invoice, nil +} + +// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it +// as a map. +func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { + htlcs := make(map[CircuitKey]*InvoiceHTLC, 0) + + for { + // Read the length of the tlv stream for this htlc. + var streamLen uint64 + if err := binary.Read(r, byteOrder, &streamLen); err != nil { + if err == io.EOF { + break + } + + return nil, err + } + + streamBytes := make([]byte, streamLen) + if _, err := r.Read(streamBytes); err != nil { + return nil, err + } + streamReader := bytes.NewReader(streamBytes) + + // Decode the contents into the htlc fields. + var ( + htlc InvoiceHTLC + key CircuitKey + chanID uint64 + state uint8 + acceptTime, resolveTime uint64 + amt uint64 + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return nil, err + } + + if err := tlvStream.Decode(streamReader); err != nil { + return nil, err + } + + key.ChanID = lnwire.NewShortChanIDFromInt(chanID) + htlc.AcceptTime = time.Unix(0, int64(acceptTime)) + htlc.ResolveTime = time.Unix(0, int64(resolveTime)) + htlc.State = HtlcState(state) + htlc.Amt = lnwire.MilliSatoshi(amt) + + htlcs[key] = &htlc + } + + return htlcs, nil +} + +// copySlice allocates a new slice and copies the source into it. +func copySlice(src []byte) []byte { + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// copyInvoice makes a deep copy of the supplied invoice. +func copyInvoice(src *Invoice) *Invoice { + dest := Invoice{ + Memo: copySlice(src.Memo), + Receipt: copySlice(src.Receipt), + PaymentRequest: copySlice(src.PaymentRequest), + FinalCltvDelta: src.FinalCltvDelta, + CreationDate: src.CreationDate, + SettleDate: src.SettleDate, + Terms: src.Terms, + AddIndex: src.AddIndex, + SettleIndex: src.SettleIndex, + AmtPaid: src.AmtPaid, + Htlcs: make( + map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), + ), + } + + for k, v := range src.Htlcs { + dest.Htlcs[k] = v + } + + return &dest +} + +// updateInvoice fetches the invoice, obtains the update descriptor from the +// callback and applies the updates in a single db transaction. +func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket, + invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { + + invoice, err := fetchInvoice(invoiceNum, invoices) + if err != nil { + return nil, err + } + + preUpdateState := invoice.Terms.State + + // Create deep copy to prevent any accidental modification in the + // callback. + copy := copyInvoice(&invoice) + + // Call the callback and obtain the update descriptor. + update, err := callback(copy) + if err != nil { + return &invoice, err + } + + // Update invoice state. + invoice.Terms.State = update.State + + now := d.now() + + // Update htlc set. + for key, htlcUpdate := range update.Htlcs { + htlc, ok := invoice.Htlcs[key] + + // No update means the htlc needs to be canceled. + if htlcUpdate == nil { + if !ok { + return nil, fmt.Errorf("unknown htlc %v", key) + } + if htlc.State != HtlcStateAccepted { + return nil, fmt.Errorf("can only cancel " + + "accepted htlcs") + } + + htlc.State = HtlcStateCanceled + htlc.ResolveTime = now + invoice.AmtPaid -= htlc.Amt + + continue + } + + // Add new htlc paying to the invoice. + if ok { + return nil, fmt.Errorf("htlc %v already exists", key) + } + htlc = &InvoiceHTLC{ + Amt: htlcUpdate.Amt, + Expiry: htlcUpdate.Expiry, + AcceptHeight: uint32(htlcUpdate.AcceptHeight), + AcceptTime: now, + } + if preUpdateState == ContractSettled { + htlc.State = HtlcStateSettled + htlc.ResolveTime = now + } else { + htlc.State = HtlcStateAccepted + } + + invoice.Htlcs[key] = htlc + invoice.AmtPaid += htlc.Amt + } + + // If invoice moved to the settled state, update settle index and settle + // time. + if preUpdateState != invoice.Terms.State && + invoice.Terms.State == ContractSettled { + + if update.Preimage.Hash() != hash { + return nil, fmt.Errorf("preimage does not match") + } + invoice.Terms.PaymentPreimage = update.Preimage + + // Settle all accepted htlcs. + for _, htlc := range invoice.Htlcs { + if htlc.State != HtlcStateAccepted { + continue + } + + htlc.State = HtlcStateSettled + htlc.ResolveTime = now + } + + err := setSettleFields(settleIndex, invoiceNum, &invoice, now) + if err != nil { + return nil, err + } + } + + var buf bytes.Buffer + if err := serializeInvoice(&buf, &invoice); err != nil { + return nil, err + } + + if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil { + return nil, err + } + + return &invoice, nil +} + +func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte, + invoice *Invoice, now time.Time) error { + + // Now that we know the invoice hasn't already been settled, we'll + // update the settle index so we can place this settle event in the + // proper location within our time series. + nextSettleSeqNo, err := settleIndex.NextSequence() + if err != nil { + return err + } + + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) + if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { + return err + } + + invoice.Terms.State = ContractSettled + invoice.SettleDate = now + invoice.SettleIndex = nextSettleSeqNo + + return nil +} diff --git a/channeldb/migration_01_to_11/legacy_serialization.go b/channeldb/migration_01_to_11/legacy_serialization.go new file mode 100644 index 00000000..5d731bff --- /dev/null +++ b/channeldb/migration_01_to_11/legacy_serialization.go @@ -0,0 +1,55 @@ +package migration_01_to_11 + +import ( + "io" +) + +// deserializeCloseChannelSummaryV6 reads the v6 database format for +// ChannelCloseSummary. +// +// NOTE: deprecated, only for migration. +func deserializeCloseChannelSummaryV6(r io.Reader) (*ChannelCloseSummary, error) { + c := &ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + err = ReadElements(r, &c.RemoteCurrentRevocation) + switch { + case err == io.EOF: + return c, nil + + // If we got a non-eof error, then we know there's an actually issue. + // Otherwise, it may have been the case that this summary didn't have + // the set of optional fields. + case err != nil: + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message, then this can be nil. As a result, we'll use + // the same technique to read the field, only if there's still data + // left in the buffer. + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil && err != io.EOF { + // If we got a non-eof error, then we know there's an actually + // issue. Otherwise, it may have been the case that this + // summary didn't have the set of optional fields. + return nil, err + } + + return c, nil +} diff --git a/channeldb/migration_01_to_11/log.go b/channeldb/migration_01_to_11/log.go new file mode 100644 index 00000000..17958b19 --- /dev/null +++ b/channeldb/migration_01_to_11/log.go @@ -0,0 +1,28 @@ +package migration_01_to_11 + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +func init() { + UseLogger(build.NewSubLogger("CHDB", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration_01_to_11/meta.go b/channeldb/migration_01_to_11/meta.go new file mode 100644 index 00000000..fbe7a0e4 --- /dev/null +++ b/channeldb/migration_01_to_11/meta.go @@ -0,0 +1,78 @@ +package migration_01_to_11 + +import "github.com/coreos/bbolt" + +var ( + // metaBucket stores all the meta information concerning the state of + // the database. + metaBucket = []byte("metadata") + + // dbVersionKey is a boltdb key and it's used for storing/retrieving + // current database version. + dbVersionKey = []byte("dbp") +) + +// Meta structure holds the database meta information. +type Meta struct { + // DbVersionNumber is the current schema version of the database. + DbVersionNumber uint32 +} + +// FetchMeta fetches the meta data from boltdb and returns filled meta +// structure. +func (d *DB) FetchMeta(tx *bbolt.Tx) (*Meta, error) { + meta := &Meta{} + + err := d.View(func(tx *bbolt.Tx) error { + return fetchMeta(meta, tx) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +// fetchMeta is an internal helper function used in order to allow callers to +// re-use a database transaction. See the publicly exported FetchMeta method +// for more information. +func fetchMeta(meta *Meta, tx *bbolt.Tx) error { + metaBucket := tx.Bucket(metaBucket) + if metaBucket == nil { + return ErrMetaNotFound + } + + data := metaBucket.Get(dbVersionKey) + if data == nil { + meta.DbVersionNumber = getLatestDBVersion(dbVersions) + } else { + meta.DbVersionNumber = byteOrder.Uint32(data) + } + + return nil +} + +// PutMeta writes the passed instance of the database met-data struct to disk. +func (d *DB) PutMeta(meta *Meta) error { + return d.Update(func(tx *bbolt.Tx) error { + return putMeta(meta, tx) + }) +} + +// putMeta is an internal helper function used in order to allow callers to +// re-use a database transaction. See the publicly exported PutMeta method for +// more information. +func putMeta(meta *Meta, tx *bbolt.Tx) error { + metaBucket, err := tx.CreateBucketIfNotExists(metaBucket) + if err != nil { + return err + } + + return putDbVersion(metaBucket, meta) +} + +func putDbVersion(metaBucket *bbolt.Bucket, meta *Meta) error { + scratch := make([]byte, 4) + byteOrder.PutUint32(scratch, meta.DbVersionNumber) + return metaBucket.Put(dbVersionKey, scratch) +} diff --git a/channeldb/migration_01_to_11/meta_test.go b/channeldb/migration_01_to_11/meta_test.go new file mode 100644 index 00000000..27e9369c --- /dev/null +++ b/channeldb/migration_01_to_11/meta_test.go @@ -0,0 +1,442 @@ +package migration_01_to_11 + +import ( + "bytes" + "io/ioutil" + "testing" + + "github.com/coreos/bbolt" + "github.com/go-errors/errors" +) + +// applyMigration is a helper test function that encapsulates the general steps +// which are needed to properly check the result of applying migration function. +func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), + migrationFunc migration, shouldFail bool) { + + cdb, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + // Create a test node that will be our source node. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatal(err) + } + graph := cdb.ChannelGraph() + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatal(err) + } + + // beforeMigration usually used for populating the database + // with test data. + beforeMigration(cdb) + + // Create test meta info with zero database version and put it on disk. + // Than creating the version list pretending that new version was added. + meta := &Meta{DbVersionNumber: 0} + if err := cdb.PutMeta(meta); err != nil { + t.Fatalf("unable to store meta data: %v", err) + } + + versions := []version{ + { + number: 0, + migration: nil, + }, + { + number: 1, + migration: migrationFunc, + }, + } + + defer func() { + if r := recover(); r != nil { + err = errors.New(r) + } + + if err == nil && shouldFail { + t.Fatal("error wasn't received on migration stage") + } else if err != nil && !shouldFail { + t.Fatalf("error was received on migration stage: %v", err) + } + + // afterMigration usually used for checking the database state and + // throwing the error if something went wrong. + afterMigration(cdb) + }() + + // Sync with the latest version - applying migration function. + err = cdb.syncVersions(versions) + if err != nil { + log.Error(err) + } +} + +// TestVersionFetchPut checks the propernces of fetch/put methods +// and also initialization of meta data in case if don't have any in +// database. +func TestVersionFetchPut(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + meta, err := db.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != getLatestDBVersion(dbVersions) { + t.Fatal("initialization of meta information wasn't performed") + } + + newVersion := getLatestDBVersion(dbVersions) + 1 + meta.DbVersionNumber = newVersion + + if err := db.PutMeta(meta); err != nil { + t.Fatalf("update of meta failed %v", err) + } + + meta, err = db.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != newVersion { + t.Fatal("update of meta information wasn't performed") + } +} + +// TestOrderOfMigrations checks that migrations are applied in proper order. +func TestOrderOfMigrations(t *testing.T) { + t.Parallel() + + appliedMigration := -1 + versions := []version{ + {0, nil}, + {1, nil}, + {2, func(tx *bbolt.Tx) error { + appliedMigration = 2 + return nil + }}, + {3, func(tx *bbolt.Tx) error { + appliedMigration = 3 + return nil + }}, + } + + // Retrieve the migration that should be applied to db, as far as + // current version is 1, we skip zero and first versions. + migrations, _ := getMigrationsToApply(versions, 1) + + if len(migrations) != 2 { + t.Fatal("incorrect number of migrations to apply") + } + + // Apply first migration. + migrations[0](nil) + + // Check that first migration corresponds to the second version. + if appliedMigration != 2 { + t.Fatal("incorrect order of applying migrations") + } + + // Apply second migration. + migrations[1](nil) + + // Check that second migration corresponds to the third version. + if appliedMigration != 3 { + t.Fatal("incorrect order of applying migrations") + } +} + +// TestGlobalVersionList checks that there is no mistake in global version list +// in terms of version ordering. +func TestGlobalVersionList(t *testing.T) { + t.Parallel() + + if dbVersions == nil { + t.Fatal("can't find versions list") + } + + if len(dbVersions) == 0 { + t.Fatal("db versions list is empty") + } + + prev := dbVersions[0].number + for i := 1; i < len(dbVersions); i++ { + version := dbVersions[i].number + + if version == prev { + t.Fatal("duplicates db versions") + } + if version < prev { + t.Fatal("order of db versions is wrong") + } + + prev = version + } +} + +// TestMigrationWithPanic asserts that if migration logic panics, we will return +// to the original state unaltered. +func TestMigrationWithPanic(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + beforeMigrationFunc := func(d *DB) { + // Insert data in database and in order then make sure that the + // key isn't changes in case of panic or fail. + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data and + // throw the panic, in this case we pretending that something goes. + migrationWithPanic := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + panic("panic!") + } + + // Check that version of database and data wasn't changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 0 { + t.Fatal("migration panicked but version is changed") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, beforeMigration) { + return errors.New("migration failed but data is " + + "changed") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithPanic, + true) +} + +// TestMigrationWithFatal asserts that migrations which fail do not modify the +// database. +func TestMigrationWithFatal(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + beforeMigrationFunc := func(d *DB) { + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data and + // return the error, in this case we pretending that something goes + // wrong. + migrationWithFatal := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + return errors.New("some error") + } + + // Check that version of database and initial data wasn't changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 0 { + t.Fatal("migration failed but version is changed") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, beforeMigration) { + return errors.New("migration failed but data is " + + "changed") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithFatal, + true) +} + +// TestMigrationWithoutErrors asserts that a successful migration has its +// changes applied to the database. +func TestMigrationWithoutErrors(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + // Populate database with initial data. + beforeMigrationFunc := func(d *DB) { + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data. + migrationWithoutErrors := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + return nil + } + + // Check that version of database and data was properly changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("version number isn't changed after " + + "successfully applied migration") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, afterMigration) { + return errors.New("migration wasn't applied " + + "properly") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithoutErrors, + false) +} + +// TestMigrationReversion tests after performing a migration to a higher +// database version, opening the database with a lower latest db version returns +// ErrDBReversion. +func TestMigrationReversion(t *testing.T) { + t.Parallel() + + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + + cdb, err := Open(tempDirName) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + // Update the database metadata to point to one more than the highest + // known version. + err = cdb.Update(func(tx *bbolt.Tx) error { + newMeta := &Meta{ + DbVersionNumber: getLatestDBVersion(dbVersions) + 1, + } + + return putMeta(newMeta, tx) + }) + + // Close the database. Even if we succeeded, our next step is to reopen. + cdb.Close() + + if err != nil { + t.Fatalf("unable to increase db version: %v", err) + } + + _, err = Open(tempDirName) + if err != ErrDBReversion { + t.Fatalf("unexpected error when opening channeldb, "+ + "want: %v, got: %v", ErrDBReversion, err) + } +} diff --git a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go new file mode 100644 index 00000000..cc6614c9 --- /dev/null +++ b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go @@ -0,0 +1,497 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "sort" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // paymentBucket is the name of the bucket within the database that + // stores all data related to payments. + // + // Within the payments bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint64. BoltDB's sequence + // feature is used for generating monotonically increasing id. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentBucket = []byte("payments") + + // paymentStatusBucket is the name of the bucket within the database + // that stores the status of a payment indexed by the payment's + // preimage. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentStatusBucket = []byte("payment-status") +) + +// outgoingPayment represents a successful payment between the daemon and a +// remote node. Details such as the total fee paid, and the time of the payment +// are stored. +// +// NOTE: Deprecated. Kept around for migration purposes. +type outgoingPayment struct { + Invoice + + // Fee is the total fee paid for the payment in milli-satoshis. + Fee lnwire.MilliSatoshi + + // TotalTimeLock is the total cumulative time-lock in the HTLC extended + // from the second-to-last hop to the destination. + TimeLockLength uint32 + + // Path encodes the path the payment took through the network. The path + // excludes the outgoing node and consists of the hex-encoded + // compressed public key of each of the nodes involved in the payment. + Path [][33]byte + + // PaymentPreimage is the preImage of a successful payment. This is used + // to calculate the PaymentHash as well as serve as a proof of payment. + PaymentPreimage [32]byte +} + +// addPayment saves a successful payment to the database. It is assumed that +// all payment are sent using unique payment hashes. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) addPayment(payment *outgoingPayment) error { + // Validate the field of the inner voice within the outgoing payment, + // these must also adhere to the same constraints as regular invoices. + if err := validateInvoice(&payment.Invoice); err != nil { + return err + } + + // We first serialize the payment before starting the database + // transaction so we can avoid creating a DB payment in the case of a + // serialization error. + var b bytes.Buffer + if err := serializeOutgoingPayment(&b, payment); err != nil { + return err + } + paymentBytes := b.Bytes() + + return db.Batch(func(tx *bbolt.Tx) error { + payments, err := tx.CreateBucketIfNotExists(paymentBucket) + if err != nil { + return err + } + + // Obtain the new unique sequence number for this payment. + paymentID, err := payments.NextSequence() + if err != nil { + return err + } + + // We use BigEndian for keys as it orders keys in + // ascending order. This allows bucket scans to order payments + // in the order in which they were created. + paymentIDBytes := make([]byte, 8) + binary.BigEndian.PutUint64(paymentIDBytes, paymentID) + + return payments.Put(paymentIDBytes, paymentBytes) + }) +} + +// fetchAllPayments returns all outgoing payments in DB. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { + var payments []*outgoingPayment + + err := db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return ErrNoPaymentsCreated + } + + return bucket.ForEach(func(k, v []byte) error { + // If the value is nil, then we ignore it as it may be + // a sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + payments = append(payments, payment) + return nil + }) + }) + if err != nil { + return nil, err + } + + return payments, nil +} + +// fetchPaymentStatus returns the payment status for outgoing payment. +// If status of the payment isn't found, it will default to "StatusUnknown". +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + var err error + paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) + return err + }) + if err != nil { + return StatusUnknown, err + } + + return paymentStatus, nil +} + +// fetchPaymentStatusTx is a helper method that returns the payment status for +// outgoing payment. If status of the payment isn't found, it will default to +// "StatusUnknown". It accepts the boltdb transactions such that this method +// can be composed into other atomic operations. +// +// NOTE: Deprecated. Kept around for migration purposes. +func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { + // The default status for all payments that aren't recorded in database. + var paymentStatus = StatusUnknown + + bucket := tx.Bucket(paymentStatusBucket) + if bucket == nil { + return paymentStatus, nil + } + + paymentStatusBytes := bucket.Get(paymentHash[:]) + if paymentStatusBytes == nil { + return paymentStatus, nil + } + + paymentStatus.FromBytes(paymentStatusBytes) + + return paymentStatus, nil +} + +func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error { + var scratch [8]byte + + if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(p.Fee)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + // First write out the length of the bytes to prefix the value. + pathLen := uint32(len(p.Path)) + byteOrder.PutUint32(scratch[:4], pathLen) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + // Then with the path written, we write out the series of public keys + // involved in the path. + for _, hop := range p.Path { + if _, err := w.Write(hop[:]); err != nil { + return err + } + } + + byteOrder.PutUint32(scratch[:4], p.TimeLockLength) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + if _, err := w.Write(p.PaymentPreimage[:]); err != nil { + return err + } + + return nil +} + +func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) { + var scratch [8]byte + + p := &outgoingPayment{} + + inv, err := deserializeInvoiceLegacy(r) + if err != nil { + return nil, err + } + p.Invoice = inv + + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + pathLen := byteOrder.Uint32(scratch[:4]) + + path := make([][33]byte, pathLen) + for i := uint32(0); i < pathLen; i++ { + if _, err := r.Read(path[i][:]); err != nil { + return nil, err + } + } + p.Path = path + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + p.TimeLockLength = byteOrder.Uint32(scratch[:4]) + + if _, err := r.Read(p.PaymentPreimage[:]); err != nil { + return nil, err + } + + return p, nil +} + +// serializePaymentAttemptInfoMigration9 is the serializePaymentAttemptInfo +// version as existed when migration #9 was created. We keep this around, along +// with the methods below to ensure that clients that upgrade will use the +// correct version of this method. +func serializePaymentAttemptInfoMigration9(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := serializeRouteMigration9(w, a.Route); err != nil { + return err + } + + return nil +} + +func serializeHopMigration9(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + return nil +} + +func serializeRouteMigration9(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHopMigration9(w, h); err != nil { + return err + } + } + + return nil +} + +func deserializePaymentAttemptInfoMigration9(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = deserializeRouteMigration9(r) + if err != nil { + return nil, err + } + return a, nil +} + +func deserializeRouteMigration9(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHopMigration9(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} + +func deserializeHopMigration9(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + return h, nil +} + +// fetchPaymentsMigration9 returns all sent payments found in the DB using the +// payment attempt info format that was present as of migration #9. We need +// this as otherwise, the current FetchPayments version will use the latest +// decoding format. Note that we only need this for the +// TestOutgoingPaymentsMigration migration test case. +func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) { + var payments []*Payment + + err := db.View(func(tx *bbolt.Tx) error { + paymentsBucket := tx.Bucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + return paymentsBucket.ForEach(func(k, v []byte) error { + bucket := paymentsBucket.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + p, err := fetchPaymentMigration9(bucket) + if err != nil { + return err + } + + payments = append(payments, p) + + // For older versions of lnd, duplicate payments to a + // payment has was possible. These will be found in a + // sub-bucket indexed by their sequence number if + // available. + dup := bucket.Bucket(paymentDuplicateBucket) + if dup == nil { + return nil + } + + return dup.ForEach(func(k, v []byte) error { + subBucket := dup.Bucket(k) + if subBucket == nil { + // We one bucket for each duplicate to + // be found. + return fmt.Errorf("non bucket element" + + "in duplicate bucket") + } + + p, err := fetchPaymentMigration9(subBucket) + if err != nil { + return err + } + + payments = append(payments, p) + return nil + }) + }) + }) + if err != nil { + return nil, err + } + + // Before returning, sort the payments by their sequence number. + sort.Slice(payments, func(i, j int) bool { + return payments[i].sequenceNum < payments[j].sequenceNum + }) + + return payments, nil +} + +func fetchPaymentMigration9(bucket *bbolt.Bucket) (*Payment, error) { + var ( + err error + p = &Payment{} + ) + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + p.sequenceNum = binary.BigEndian.Uint64(seqBytes) + + // Get the payment status. + p.Status = fetchPaymentStatus(bucket) + + // Get the PaymentCreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return nil, fmt.Errorf("creation info not found") + } + + r := bytes.NewReader(b) + p.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return nil, err + + } + + // Get the PaymentAttemptInfo. This can be unset. + b = bucket.Get(paymentAttemptInfoKey) + if b != nil { + r = bytes.NewReader(b) + p.Attempt, err = deserializePaymentAttemptInfoMigration9(r) + if err != nil { + return nil, err + } + } + + // Get the payment preimage. This is only found for + // completed payments. + b = bucket.Get(paymentSettleInfoKey) + if b != nil { + var preimg lntypes.Preimage + copy(preimg[:], b[:]) + p.PaymentPreimage = &preimg + } + + // Get failure reason if available. + b = bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + p.Failure = &reason + } + + return p, nil +} diff --git a/channeldb/migration_01_to_11/migration_10_route_tlv_records.go b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go new file mode 100644 index 00000000..a8478cda --- /dev/null +++ b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go @@ -0,0 +1,236 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/routing/route" +) + +// migrateRouteSerialization migrates the way we serialize routes across the +// entire database. At the time of writing of this migration, this includes our +// payment attempts, as well as the payment results in mission control. +func migrateRouteSerialization(tx *bbolt.Tx) error { + // First, we'll do all the payment attempts. + rootPaymentBucket := tx.Bucket(paymentsRootBucket) + if rootPaymentBucket == nil { + return nil + } + + // As we can't mutate a bucket while we're iterating over it with + // ForEach, we'll need to collect all the known payment hashes in + // memory first. + var payHashes [][]byte + err := rootPaymentBucket.ForEach(func(k, v []byte) error { + if v != nil { + return nil + } + + payHashes = append(payHashes, k) + return nil + }) + if err != nil { + return err + } + + // Now that we have all the payment hashes, we can carry out the + // migration itself. + for _, payHash := range payHashes { + payHashBucket := rootPaymentBucket.Bucket(payHash) + + // First, we'll migrate the main (non duplicate) payment to + // this hash. + err := migrateAttemptEncoding(tx, payHashBucket) + if err != nil { + return err + } + + // Now that we've migrated the main payment, we'll also check + // for any duplicate payments to the same payment hash. + dupBucket := payHashBucket.Bucket(paymentDuplicateBucket) + + // If there's no dup bucket, then we can move on to the next + // payment. + if dupBucket == nil { + continue + } + + // Otherwise, we'll now iterate through all the duplicate pay + // hashes and migrate those. + var dupSeqNos [][]byte + err = dupBucket.ForEach(func(k, v []byte) error { + dupSeqNos = append(dupSeqNos, k) + return nil + }) + if err != nil { + return err + } + + // Now in this second pass, we'll re-serialize their duplicate + // payment attempts under the new encoding. + for _, seqNo := range dupSeqNos { + dupPayHashBucket := dupBucket.Bucket(seqNo) + err := migrateAttemptEncoding(tx, dupPayHashBucket) + if err != nil { + return err + } + } + } + + log.Infof("Migration of route/hop serialization complete!") + + log.Infof("Migrating to new mission control store by clearing " + + "existing data") + + resultsKey := []byte("missioncontrol-results") + err = tx.DeleteBucket(resultsKey) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + log.Infof("Migration to new mission control completed!") + + return nil +} + +// migrateAttemptEncoding migrates payment attempts using the legacy format to +// the new format. +func migrateAttemptEncoding(tx *bbolt.Tx, payHashBucket *bbolt.Bucket) error { + payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey) + if payAttemptBytes == nil { + return nil + } + + // For our migration, we'll first read out the existing payment attempt + // using the legacy serialization of the attempt. + payAttemptReader := bytes.NewReader(payAttemptBytes) + payAttempt, err := deserializePaymentAttemptInfoLegacy( + payAttemptReader, + ) + if err != nil { + return err + } + + // Now that we have the old attempts, we'll explicitly mark this as + // needing a legacy payload, since after this migration, the modern + // payload will be the default if signalled. + for _, hop := range payAttempt.Route.Hops { + hop.LegacyPayload = true + } + + // Finally, we'll write out the payment attempt using the new encoding. + var b bytes.Buffer + err = serializePaymentAttemptInfo(&b, payAttempt) + if err != nil { + return err + } + + return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes()) +} + +func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = deserializeRouteLegacy(r) + if err != nil { + return nil, err + } + return a, nil +} + +func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := serializeRouteLegacy(w, a.Route); err != nil { + return err + } + + return nil +} + +func deserializeHopLegacy(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + return h, nil +} + +func serializeHopLegacy(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + return nil +} + +func deserializeRouteLegacy(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHopLegacy(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} + +func serializeRouteLegacy(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHopLegacy(w, h); err != nil { + return err + } + } + + return nil +} diff --git a/channeldb/migration_01_to_11/migration_11_invoices.go b/channeldb/migration_01_to_11/migration_11_invoices.go new file mode 100644 index 00000000..1ae969be --- /dev/null +++ b/channeldb/migration_01_to_11/migration_11_invoices.go @@ -0,0 +1,230 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +// migrateInvoices adds invoice htlcs and a separate cltv delta field to the +// invoices. +func migrateInvoices(tx *bbolt.Tx) error { + log.Infof("Migrating invoices to new invoice format") + + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return nil + } + + // Iterate through the entire key space of the top-level invoice bucket. + // If key with a non-nil value stores the next invoice ID which maps to + // the corresponding invoice. Store those keys first, because it isn't + // safe to modify the bucket inside a ForEach loop. + var invoiceKeys [][]byte + err := invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceKeys = append(invoiceKeys, k) + + return nil + }) + if err != nil { + return err + } + + nets := []*bitcoinCfg.Params{ + &bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams, + &bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params, + } + + ltcNets := []*litecoinCfg.Params{ + &litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams, + &litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params, + } + for _, net := range ltcNets { + var convertedNet bitcoinCfg.Params + convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit + nets = append(nets, &convertedNet) + } + + // Iterate over all stored keys and migrate the invoices. + for _, k := range invoiceKeys { + v := invoiceB.Get(k) + + // Deserialize the invoice with the deserializing function that + // was in use for this version of the database. + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoiceLegacy(invoiceReader) + if err != nil { + return err + } + + if invoice.Terms.State == ContractAccepted { + return fmt.Errorf("cannot upgrade with invoice(s) " + + "in accepted state, see release notes") + } + + // Try to decode the payment request for every possible net to + // avoid passing a the active network to channeldb. This would + // be a layering violation, while this migration is only running + // once and will likely be removed in the future. + var payReq *zpay32.Invoice + for _, net := range nets { + payReq, err = zpay32.Decode( + string(invoice.PaymentRequest), net, + ) + if err == nil { + break + } + } + if payReq == nil { + return fmt.Errorf("cannot decode payreq") + } + invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry()) + invoice.Expiry = payReq.Expiry() + + // Serialize the invoice in the new format and use it to replace + // the old invoice in the database. + var buf bytes.Buffer + if err := serializeInvoice(&buf, &invoice); err != nil { + return err + } + + err = invoiceB.Put(k, buf.Bytes()) + if err != nil { + return err + } + } + + log.Infof("Migration of invoices completed!") + return nil +} + +func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) { + var err error + invoice := Invoice{} + + // TODO(roasbeef): use read full everywhere + invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + if err != nil { + return invoice, err + } + invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + if err != nil { + return invoice, err + } + + invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + if err != nil { + return invoice, err + } + + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + if err != nil { + return invoice, err + } + if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { + return invoice, err + } + + settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + if err != nil { + return invoice, err + } + if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { + return invoice, err + } + + if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { + return invoice, err + } + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return invoice, err + } + invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { + return invoice, err + } + + return invoice, nil +} + +// serializeInvoiceLegacy serializes an invoice in the format of the previous db +// version. +func serializeInvoiceLegacy(w io.Writer, i *Invoice) error { + if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { + return err + } + + birthBytes, err := i.CreationDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { + return err + } + + settleBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + return err + } + + if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + return err + } + + var scratch [8]byte + byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { + return err + } + + return nil +} diff --git a/channeldb/migration_01_to_11/migration_11_invoices_test.go b/channeldb/migration_01_to_11/migration_11_invoices_test.go new file mode 100644 index 00000000..9c0c877a --- /dev/null +++ b/channeldb/migration_01_to_11/migration_11_invoices_test.go @@ -0,0 +1,193 @@ +package migration_01_to_11 + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +var ( + testPrivKeyBytes = []byte{ + 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, + 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, + 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, + 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, + } + + testCltvDelta = int32(50) +) + +// beforeMigrationFuncV11 insert the test invoices in the database. +func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) { + err := d.Update(func(tx *bbolt.Tx) error { + invoicesBucket, err := tx.CreateBucketIfNotExists( + invoiceBucket, + ) + if err != nil { + return err + } + + invoiceNum := uint32(1) + for _, invoice := range invoices { + var invoiceKey [4]byte + byteOrder.PutUint32(invoiceKey[:], invoiceNum) + invoiceNum++ + + var buf bytes.Buffer + err := serializeInvoiceLegacy(&buf, &invoice) // nolint:scopelint + if err != nil { + return err + } + + err = invoicesBucket.Put( + invoiceKey[:], buf.Bytes(), + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestMigrateInvoices checks that invoices are migrated correctly. +func TestMigrateInvoices(t *testing.T) { + t.Parallel() + + payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + var ltcNetParams bitcoinCfg.Params + ltcNetParams.Bech32HRPSegwit = litecoinCfg.MainNetParams.Bech32HRPSegwit + payReqLtc, err := getPayReq(<cNetParams) + if err != nil { + t.Fatal(err) + } + + invoices := []Invoice{ + { + PaymentRequest: []byte(payReqBtc), + }, + { + PaymentRequest: []byte(payReqLtc), + }, + } + + // Verify that all invoices were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'invoices' wasn't applied") + } + + dbInvoices, err := d.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch invoices: %v", err) + } + + if len(invoices) != len(dbInvoices) { + t.Fatalf("expected %d invoices, got %d", len(invoices), + len(dbInvoices)) + } + + for _, dbInvoice := range dbInvoices { + if dbInvoice.FinalCltvDelta != testCltvDelta { + t.Fatal("incorrect final cltv delta") + } + if dbInvoice.Expiry != 3600*time.Second { + t.Fatal("incorrect expiry") + } + if len(dbInvoice.Htlcs) != 0 { + t.Fatal("expected no htlcs after migration") + } + } + } + + applyMigration(t, + func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, + afterMigrationFunc, + migrateInvoices, + false) +} + +// TestMigrateInvoicesHodl checks that a hodl invoice in the accepted state +// fails the migration. +func TestMigrateInvoicesHodl(t *testing.T) { + t.Parallel() + + payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + invoices := []Invoice{ + { + PaymentRequest: []byte(payReqBtc), + Terms: ContractTerm{ + State: ContractAccepted, + }, + }, + } + + applyMigration(t, + func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, + func(d *DB) {}, + migrateInvoices, + true) +} + +// signDigestCompact generates a test signature to be used in the generation of +// test payment requests. +func signDigestCompact(hash []byte) ([]byte, error) { + // Should the signature reference a compressed public key or not. + isCompressedKey := true + + privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes) + + // btcec.SignCompact returns a pubkey-recoverable signature + sig, err := btcec.SignCompact( + btcec.S256(), privKey, hash, isCompressedKey, + ) + if err != nil { + return nil, fmt.Errorf("can't sign the hash: %v", err) + } + + return sig, nil +} + +// getPayReq creates a payment request for the given net. +func getPayReq(net *bitcoinCfg.Params) (string, error) { + options := []func(*zpay32.Invoice){ + zpay32.CLTVExpiry(uint64(testCltvDelta)), + zpay32.Description("test"), + } + + payReq, err := zpay32.NewInvoice( + net, [32]byte{}, time.Unix(1, 0), options..., + ) + if err != nil { + return "", err + } + return payReq.Encode( + zpay32.MessageSigner{ + SignCompact: signDigestCompact, + }, + ) +} diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go new file mode 100644 index 00000000..3e296d02 --- /dev/null +++ b/channeldb/migration_01_to_11/migrations.go @@ -0,0 +1,939 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + + "github.com/btcsuite/btcd/btcec" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// migrateNodeAndEdgeUpdateIndex is a migration function that will update the +// database from version 0 to version 1. In version 1, we add two new indexes +// (one for nodes and one for edges) to keep track of the last time a node or +// edge was updated on the network. These new indexes allow us to implement the +// new graph sync protocol added. +func migrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { + // First, we'll populating the node portion of the new index. Before we + // can add new values to the index, we'll first create the new bucket + // where these items will be housed. + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return fmt.Errorf("unable to create node bucket: %v", err) + } + nodeUpdateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create node update index: %v", err) + } + + log.Infof("Populating new node update index bucket") + + // Now that we know the bucket has been created, we'll iterate over the + // entire node bucket so we can add the (updateTime || nodePub) key + // into the node update index. + err = nodes.ForEach(func(nodePub, nodeInfo []byte) error { + if len(nodePub) != 33 { + return nil + } + + log.Tracef("Adding %x to node update index", nodePub) + + // The first 8 bytes of a node's serialize data is the update + // time, so we can extract that without decoding the entire + // structure. + updateTime := nodeInfo[:8] + + // Now that we have the update time, we can construct the key + // to insert into the index. + var indexKey [8 + 33]byte + copy(indexKey[:8], updateTime) + copy(indexKey[8:], nodePub) + + return nodeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update node indexes: %v", err) + } + + log.Infof("Populating new edge update index bucket") + + // With the set of nodes updated, we'll now update all edges to have a + // corresponding entry in the edge update index. + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return fmt.Errorf("unable to create edge bucket: %v", err) + } + edgeUpdateIndex, err := edges.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create edge update index: %v", err) + } + + // We'll now run through each edge policy in the database, and update + // the index to ensure each edge has the proper record. + err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { + if len(edgeKey) != 41 { + return nil + } + + // Now that we know this is the proper record, we'll grab the + // channel ID (last 8 bytes of the key), and then decode the + // edge policy so we can access the update time. + chanID := edgeKey[33:] + edgePolicyReader := bytes.NewReader(edgePolicyBytes) + + edgePolicy, err := deserializeChanEdgePolicy( + edgePolicyReader, nodes, + ) + if err != nil { + return err + } + + log.Tracef("Adding chan_id=%v to edge update index", + edgePolicy.ChannelID) + + // We'll now construct the index key using the channel ID, and + // the last time it was updated: (updateTime || chanID). + var indexKey [8 + 8]byte + byteOrder.PutUint64( + indexKey[:], uint64(edgePolicy.LastUpdate.Unix()), + ) + copy(indexKey[8:], chanID) + + return edgeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update edge indexes: %v", err) + } + + log.Infof("Migration to node and edge update indexes complete!") + + return nil +} + +// migrateInvoiceTimeSeries is a database migration that assigns all existing +// invoices an index in the add and/or the settle index. Additionally, all +// existing invoices will have their bytes padded out in order to encode the +// add+settle index as well as the amount paid. +func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + + addIndex, err := invoices.CreateBucketIfNotExists( + addIndexBucket, + ) + if err != nil { + return err + } + settleIndex, err := invoices.CreateBucketIfNotExists( + settleIndexBucket, + ) + if err != nil { + return err + } + + log.Infof("Migrating invoice database to new time series format") + + // Now that we have all the buckets we need, we'll run through each + // invoice in the database, and update it to reflect the new format + // expected post migration. + // NOTE: we store the converted invoices and put them back into the + // database after the loop, since modifying the bucket within the + // ForEach loop is not safe. + var invoicesKeys [][]byte + var invoicesValues [][]byte + err = invoices.ForEach(func(invoiceNum, invoiceBytes []byte) error { + // If this is a sub bucket, then we'll skip it. + if invoiceBytes == nil { + return nil + } + + // First, we'll make a copy of the encoded invoice bytes. + invoiceBytesCopy := make([]byte, len(invoiceBytes)) + copy(invoiceBytesCopy, invoiceBytes) + + // With the bytes copied over, we'll append 24 additional + // bytes. We do this so we can decode the invoice under the new + // serialization format. + padding := bytes.Repeat([]byte{0}, 24) + invoiceBytesCopy = append(invoiceBytesCopy, padding...) + + invoiceReader := bytes.NewReader(invoiceBytesCopy) + invoice, err := deserializeInvoiceLegacy(invoiceReader) + if err != nil { + return fmt.Errorf("unable to decode invoice: %v", err) + } + + // Now that we have the fully decoded invoice, we can update + // the various indexes that we're added, and finally the + // invoice itself before re-inserting it. + + // First, we'll get the new sequence in the addIndex in order + // to create the proper mapping. + nextAddSeqNo, err := addIndex.NextSequence() + if err != nil { + return err + } + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) + err = addIndex.Put(seqNoBytes[:], invoiceNum[:]) + if err != nil { + return err + } + + log.Tracef("Adding invoice (preimage=%x, add_index=%v) to add "+ + "time series", invoice.Terms.PaymentPreimage[:], + nextAddSeqNo) + + // Next, we'll check if the invoice has been settled or not. If + // so, then we'll also add it to the settle index. + var nextSettleSeqNo uint64 + if invoice.Terms.State == ContractSettled { + nextSettleSeqNo, err = settleIndex.NextSequence() + if err != nil { + return err + } + + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) + err := settleIndex.Put(seqNoBytes[:], invoiceNum) + if err != nil { + return err + } + + invoice.AmtPaid = invoice.Terms.Value + + log.Tracef("Adding invoice (preimage=%x, "+ + "settle_index=%v) to add time series", + invoice.Terms.PaymentPreimage[:], + nextSettleSeqNo) + } + + // Finally, we'll update the invoice itself with the new + // indexing information as well as the amount paid if it has + // been settled or not. + invoice.AddIndex = nextAddSeqNo + invoice.SettleIndex = nextSettleSeqNo + + // We've fully migrated an invoice, so we'll now update the + // invoice in-place. + var b bytes.Buffer + if err := serializeInvoiceLegacy(&b, &invoice); err != nil { + return err + } + + // Save the key and value pending update for after the ForEach + // is done. + invoicesKeys = append(invoicesKeys, invoiceNum) + invoicesValues = append(invoicesValues, b.Bytes()) + return nil + }) + if err != nil { + return err + } + + // Now put the converted invoices into the DB. + for i := range invoicesKeys { + key := invoicesKeys[i] + value := invoicesValues[i] + if err := invoices.Put(key, value); err != nil { + return err + } + } + + log.Infof("Migration to invoice time series index complete!") + + return nil +} + +// migrateInvoiceTimeSeriesOutgoingPayments is a follow up to the +// migrateInvoiceTimeSeries migration. As at the time of writing, the +// OutgoingPayment struct embeddeds an instance of the Invoice struct. As a +// result, we also need to migrate the internal invoice to the new format. +func migrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { + payBucket := tx.Bucket(paymentBucket) + if payBucket == nil { + return nil + } + + log.Infof("Migrating invoice database to new outgoing payment format") + + // We store the keys and values we want to modify since it is not safe + // to modify them directly within the ForEach loop. + var paymentKeys [][]byte + var paymentValues [][]byte + err := payBucket.ForEach(func(payID, paymentBytes []byte) error { + log.Tracef("Migrating payment %x", payID[:]) + + // The internal invoices for each payment only contain a + // populated contract term, and creation date, as a result, + // most of the bytes will be "empty". + + // We'll calculate the end of the invoice index assuming a + // "minimal" index that's embedded within the greater + // OutgoingPayment. The breakdown is: + // 3 bytes empty var bytes, 16 bytes creation date, 16 bytes + // settled date, 32 bytes payment pre-image, 8 bytes value, 1 + // byte settled. + endOfInvoiceIndex := 1 + 1 + 1 + 16 + 16 + 32 + 8 + 1 + + // We'll now extract the prefix of the pure invoice embedded + // within. + invoiceBytes := paymentBytes[:endOfInvoiceIndex] + + // With the prefix extracted, we'll copy over the invoice, and + // also add padding for the new 24 bytes of fields, and finally + // append the remainder of the outgoing payment. + paymentCopy := make([]byte, len(invoiceBytes)) + copy(paymentCopy[:], invoiceBytes) + + padding := bytes.Repeat([]byte{0}, 24) + paymentCopy = append(paymentCopy, padding...) + paymentCopy = append( + paymentCopy, paymentBytes[endOfInvoiceIndex:]..., + ) + + // At this point, we now have the new format of the outgoing + // payments, we'll attempt to deserialize it to ensure the + // bytes are properly formatted. + paymentReader := bytes.NewReader(paymentCopy) + _, err := deserializeOutgoingPayment(paymentReader) + if err != nil { + return fmt.Errorf("unable to deserialize payment: %v", err) + } + + // Now that we know the modifications was successful, we'll + // store it to our slice of keys and values, and write it back + // to disk in the new format after the ForEach loop is over. + paymentKeys = append(paymentKeys, payID) + paymentValues = append(paymentValues, paymentCopy) + return nil + }) + if err != nil { + return err + } + + // Finally store the updated payments to the bucket. + for i := range paymentKeys { + key := paymentKeys[i] + value := paymentValues[i] + if err := payBucket.Put(key, value); err != nil { + return err + } + } + + log.Infof("Migration to outgoing payment invoices complete!") + + return nil +} + +// migrateEdgePolicies is a migration function that will update the edges +// bucket. It ensure that edges with unknown policies will also have an entry +// in the bucket. After the migration, there will be two edge entries for +// every channel, regardless of whether the policies are known. +func migrateEdgePolicies(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return nil + } + + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return nil + } + + // checkKey gets the policy from the database with a low-level call + // so that it is still possible to distinguish between unknown and + // not present. + checkKey := func(channelId uint64, keyBytes []byte) error { + var channelID [8]byte + byteOrder.PutUint64(channelID[:], channelId) + + _, err := fetchChanEdgePolicy(edges, + channelID[:], keyBytes, nodes) + + if err == ErrEdgeNotFound { + log.Tracef("Adding unknown edge policy present for node %x, channel %v", + keyBytes, channelId) + + err := putChanEdgePolicyUnknown(edges, channelId, keyBytes) + if err != nil { + return err + } + + return nil + } + + return err + } + + // Iterate over all channels and check both edge policies. + err := edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + infoReader := bytes.NewReader(edgeInfoBytes) + edgeInfo, err := deserializeChanEdgeInfo(infoReader) + if err != nil { + return err + } + + for _, key := range [][]byte{edgeInfo.NodeKey1Bytes[:], + edgeInfo.NodeKey2Bytes[:]} { + + if err := checkKey(edgeInfo.ChannelID, key); err != nil { + return err + } + } + + return nil + }) + + if err != nil { + return fmt.Errorf("unable to update edge policies: %v", err) + } + + log.Infof("Migration of edge policies complete!") + + return nil +} + +// paymentStatusesMigration is a database migration intended for adding payment +// statuses for each existing payment entity in bucket to be able control +// transitions of statuses and prevent cases such as double payment +func paymentStatusesMigration(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing statuses of payments, + // where a key is payment hash, value is payment status. + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + log.Infof("Migrating database to support payment statuses") + + circuitAddKey := []byte("circuit-adds") + circuits := tx.Bucket(circuitAddKey) + if circuits != nil { + log.Infof("Marking all known circuits with status InFlight") + + err = circuits.ForEach(func(k, v []byte) error { + // Parse the first 8 bytes as the short chan ID for the + // circuit. We'll skip all short chan IDs are not + // locally initiated, which includes all non-zero short + // chan ids. + chanID := binary.BigEndian.Uint64(k[:8]) + if chanID != 0 { + return nil + } + + // The payment hash is the third item in the serialized + // payment circuit. The first two items are an AddRef + // (10 bytes) and the incoming circuit key (16 bytes). + const payHashOffset = 10 + 16 + + paymentHash := v[payHashOffset : payHashOffset+32] + + return paymentStatuses.Put( + paymentHash[:], StatusInFlight.Bytes(), + ) + }) + if err != nil { + return err + } + } + + log.Infof("Marking all existing payments with status Completed") + + // Get the bucket dedicated to storing payments + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return nil + } + + // For each payment in the bucket, deserialize the payment and mark it + // as completed. + err = bucket.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash for current payment. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Update status for current payment to completed. If it fails, + // the migration is aborted and the payment bucket is returned + // to its previous state. + return paymentStatuses.Put(paymentHash[:], StatusSucceeded.Bytes()) + }) + if err != nil { + return err + } + + log.Infof("Migration of payment statuses complete!") + + return nil +} + +// migratePruneEdgeUpdateIndex is a database migration that attempts to resolve +// some lingering bugs with regards to edge policies and their update index. +// Stale entries within the edge update index were not being properly pruned due +// to a miscalculation on the offset of an edge's policy last update. This +// migration also fixes the case where the public keys within edge policies were +// being serialized with an extra byte, causing an even greater error when +// attempting to perform the offset calculation described earlier. +func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { + // To begin the migration, we'll retrieve the update index bucket. If it + // does not exist, we have nothing left to do so we can simply exit. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return nil + } + + // Retrieve some buckets that will be needed later on. These should + // already exist given the assumption that the buckets above do as + // well. + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return fmt.Errorf("error creating edge index bucket: %s", err) + } + if edgeIndex == nil { + return fmt.Errorf("unable to create/fetch edge index " + + "bucket") + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return fmt.Errorf("unable to make node bucket") + } + + log.Info("Migrating database to properly prune edge update index") + + // We'll need to properly prune all the outdated entries within the edge + // update index. To do so, we'll gather all of the existing policies + // within the graph to re-populate them later on. + var edgeKeys [][]byte + err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { + // All valid entries are indexed by a public key (33 bytes) + // followed by a channel ID (8 bytes), so we'll skip any entries + // with keys that do not match this. + if len(edgeKey) != 33+8 { + return nil + } + + edgeKeys = append(edgeKeys, edgeKey) + + return nil + }) + if err != nil { + return fmt.Errorf("unable to gather existing edge policies: %v", + err) + } + + log.Info("Constructing set of edge update entries to purge.") + + // Build the set of keys that we will remove from the edge update index. + // This will include all keys contained within the bucket. + var updateKeysToRemove [][]byte + err = edgeUpdateIndex.ForEach(func(updKey, _ []byte) error { + updateKeysToRemove = append(updateKeysToRemove, updKey) + return nil + }) + if err != nil { + return fmt.Errorf("unable to gather existing edge updates: %v", + err) + } + + log.Infof("Removing %d entries from edge update index.", + len(updateKeysToRemove)) + + // With the set of keys contained in the edge update index constructed, + // we'll proceed in purging all of them from the index. + for _, updKey := range updateKeysToRemove { + if err := edgeUpdateIndex.Delete(updKey); err != nil { + return err + } + } + + log.Infof("Repopulating edge update index with %d valid entries.", + len(edgeKeys)) + + // For each edge key, we'll retrieve the policy, deserialize it, and + // re-add it to the different buckets. By doing so, we'll ensure that + // all existing edge policies are serialized correctly within their + // respective buckets and that the correct entries are populated within + // the edge update index. + for _, edgeKey := range edgeKeys { + edgePolicyBytes := edges.Get(edgeKey) + + // Skip any entries with unknown policies as there will not be + // any entries for them in the edge update index. + if bytes.Equal(edgePolicyBytes[:], unknownPolicy) { + continue + } + + edgePolicy, err := deserializeChanEdgePolicy( + bytes.NewReader(edgePolicyBytes), nodes, + ) + if err != nil { + return err + } + + _, err = updateEdgePolicy(tx, edgePolicy) + if err != nil { + return err + } + } + + log.Info("Migration to properly prune edge update index complete!") + + return nil +} + +// migrateOptionalChannelCloseSummaryFields migrates the serialized format of +// ChannelCloseSummary to a format where optional fields' presence is indicated +// with boolean markers. +func migrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { + closedChanBucket := tx.Bucket(closedChannelBucket) + if closedChanBucket == nil { + return nil + } + + log.Info("Migrating to new closed channel format...") + + // We store the converted keys and values and put them back into the + // database after the loop, since modifying the bucket within the + // ForEach loop is not safe. + var closedChansKeys [][]byte + var closedChansValues [][]byte + err := closedChanBucket.ForEach(func(chanID, summary []byte) error { + r := bytes.NewReader(summary) + + // Read the old (v6) format from the database. + c, err := deserializeCloseChannelSummaryV6(r) + if err != nil { + return err + } + + // Serialize using the new format, and put back into the + // bucket. + var b bytes.Buffer + if err := serializeChannelCloseSummary(&b, c); err != nil { + return err + } + + // Now that we know the modifications was successful, we'll + // Store the key and value to our slices, and write it back to + // disk in the new format after the ForEach loop is over. + closedChansKeys = append(closedChansKeys, chanID) + closedChansValues = append(closedChansValues, b.Bytes()) + return nil + }) + if err != nil { + return fmt.Errorf("unable to update closed channels: %v", err) + } + + // Now put the new format back into the DB. + for i := range closedChansKeys { + key := closedChansKeys[i] + value := closedChansValues[i] + if err := closedChanBucket.Put(key, value); err != nil { + return err + } + } + + log.Info("Migration to new closed channel format complete!") + + return nil +} + +var messageStoreBucket = []byte("message-store") + +// migrateGossipMessageStoreKeys migrates the key format for gossip messages +// found in the message store to a new one that takes into consideration the of +// the message being stored. +func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { + // We'll start by retrieving the bucket in which these messages are + // stored within. If there isn't one, there's nothing left for us to do + // so we can avoid the migration. + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return nil + } + + log.Info("Migrating to the gossip message store new key format") + + // Otherwise we'll proceed with the migration. We'll start by coalescing + // all the current messages within the store, which are indexed by the + // public key of the peer which they should be sent to, followed by the + // short channel ID of the channel for which the message belongs to. We + // should only expect to find channel announcement signatures as that + // was the only support message type previously. + msgs := make(map[[33 + 8]byte]*lnwire.AnnounceSignatures) + err := messageStore.ForEach(func(k, v []byte) error { + var msgKey [33 + 8]byte + copy(msgKey[:], k) + + msg := &lnwire.AnnounceSignatures{} + if err := msg.Decode(bytes.NewReader(v), 0); err != nil { + return err + } + + msgs[msgKey] = msg + + return nil + + }) + if err != nil { + return err + } + + // Then, we'll go over all of our messages, remove their previous entry, + // and add another with the new key format. Once we've done this for + // every message, we can consider the migration complete. + for oldMsgKey, msg := range msgs { + if err := messageStore.Delete(oldMsgKey[:]); err != nil { + return err + } + + // Construct the new key for which we'll find this message with + // in the store. It'll be the same as the old, but we'll also + // include the message type. + var msgType [2]byte + binary.BigEndian.PutUint16(msgType[:], uint16(msg.MsgType())) + newMsgKey := append(oldMsgKey[:], msgType[:]...) + + // Serialize the message with its wire encoding. + var b bytes.Buffer + if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + return err + } + + if err := messageStore.Put(newMsgKey, b.Bytes()); err != nil { + return err + } + } + + log.Info("Migration to the gossip message store new key format complete!") + + return nil +} + +// migrateOutgoingPayments moves the OutgoingPayments into a new bucket format +// where they all reside in a top-level bucket indexed by the payment hash. In +// this sub-bucket we store information relevant to this payment, such as the +// payment status. +// +// Since the router cannot handle resumed payments that have the status +// InFlight (we have no PaymentAttemptInfo available for pre-migration +// payments) we delete those statuses, so only Completed payments remain in the +// new bucket structure. +func migrateOutgoingPayments(tx *bbolt.Tx) error { + log.Infof("Migrating outgoing payments to new bucket structure") + + oldPayments := tx.Bucket(paymentBucket) + + // Return early if there are no payments to migrate. + if oldPayments == nil { + log.Infof("No outgoing payments found, nothing to migrate.") + return nil + } + + newPayments, err := tx.CreateBucket(paymentsRootBucket) + if err != nil { + return err + } + + // Helper method to get the source pubkey. We define it such that we + // only attempt to fetch it if needed. + sourcePub := func() ([33]byte, error) { + var pub [33]byte + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return pub, ErrGraphNotFound + } + + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return pub, ErrSourceNodeNotSet + } + copy(pub[:], selfPub[:]) + return pub, nil + } + + err = oldPayments.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + // Read the old payment format. + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash from the payment preimage. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Now create and add a PaymentCreationInfo to the bucket. + c := &PaymentCreationInfo{ + PaymentHash: paymentHash, + Value: payment.Terms.Value, + CreationDate: payment.CreationDate, + PaymentRequest: payment.PaymentRequest, + } + + var infoBuf bytes.Buffer + if err := serializePaymentCreationInfo(&infoBuf, c); err != nil { + return err + } + + sourcePubKey, err := sourcePub() + if err != nil { + return err + } + + // Do the same for the PaymentAttemptInfo. + totalAmt := payment.Terms.Value + payment.Fee + rt := route.Route{ + TotalTimeLock: payment.TimeLockLength, + TotalAmount: totalAmt, + SourcePubKey: sourcePubKey, + Hops: []*route.Hop{}, + } + for _, hop := range payment.Path { + rt.Hops = append(rt.Hops, &route.Hop{ + PubKeyBytes: hop, + AmtToForward: totalAmt, + }) + } + + // Since the old format didn't store the fee for individual + // hops, we let the last hop eat the whole fee for the total to + // add up. + if len(rt.Hops) > 0 { + rt.Hops[len(rt.Hops)-1].AmtToForward = payment.Terms.Value + } + + // Since we don't have the session key for old payments, we + // create a random one to be able to serialize the attempt + // info. + priv, _ := btcec.NewPrivateKey(btcec.S256()) + s := &PaymentAttemptInfo{ + PaymentID: 0, // unknown. + SessionKey: priv, // unknown. + Route: rt, + } + + var attemptBuf bytes.Buffer + if err := serializePaymentAttemptInfoMigration9(&attemptBuf, s); err != nil { + return err + } + + // Reuse the existing payment sequence number. + var seqNum [8]byte + copy(seqNum[:], k) + + // Create a bucket indexed by the payment hash. + bucket, err := newPayments.CreateBucket(paymentHash[:]) + + // If the bucket already exists, it means that we are migrating + // from a database containing duplicate payments to a payment + // hash. To keep this information, we store such duplicate + // payments in a sub-bucket. + if err == bbolt.ErrBucketExists { + pHashBucket := newPayments.Bucket(paymentHash[:]) + + // Create a bucket for duplicate payments within this + // payment hash's bucket. + dup, err := pHashBucket.CreateBucketIfNotExists( + paymentDuplicateBucket, + ) + if err != nil { + return err + } + + // Each duplicate will get its own sub-bucket within + // this bucket, so use their sequence number to index + // them by. + bucket, err = dup.CreateBucket(seqNum[:]) + if err != nil { + return err + } + + } else if err != nil { + return err + } + + // Store the payment's information to the bucket. + err = bucket.Put(paymentSequenceKey, seqNum[:]) + if err != nil { + return err + } + + err = bucket.Put(paymentCreationInfoKey, infoBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentAttemptInfoKey, attemptBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentSettleInfoKey, payment.PaymentPreimage[:]) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + // To continue producing unique sequence numbers, we set the sequence + // of the new bucket to that of the old one. + seq := oldPayments.Sequence() + if err := newPayments.SetSequence(seq); err != nil { + return err + } + + // Now we delete the old buckets. Deleting the payment status buckets + // deletes all payment statuses other than Complete. + err = tx.DeleteBucket(paymentStatusBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + // Finally delete the old payment bucket. + err = tx.DeleteBucket(paymentBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + log.Infof("Migration of outgoing payment bucket structure completed!") + return nil +} diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go new file mode 100644 index 00000000..8a9076fb --- /dev/null +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -0,0 +1,952 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// TestPaymentStatusesMigration checks that already completed payments will have +// their payment statuses set to Completed after the migration. +func TestPaymentStatusesMigration(t *testing.T) { + t.Parallel() + + fakePayment := makeFakePayment() + paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:]) + + // Add fake payment to test database, verifying that it was created, + // that we have only one payment, and its status is not "Completed". + beforeMigrationFunc := func(d *DB) { + if err := d.addPayment(fakePayment); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + payments, err := d.fetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != 1 { + t.Fatalf("wrong qty of paymets: expected 1, got %v", + len(payments)) + } + + paymentStatus, err := d.fetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + // We should receive default status if we have any in database. + if paymentStatus != StatusUnknown { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusUnknown.String(), paymentStatus.String()) + } + + // Lastly, we'll add a locally-sourced circuit and + // non-locally-sourced circuit to the circuit map. The + // locally-sourced payment should end up with an InFlight + // status, while the other should remain unchanged, which + // defaults to Grounded. + err = d.Update(func(tx *bbolt.Tx) error { + circuits, err := tx.CreateBucketIfNotExists( + []byte("circuit-adds"), + ) + if err != nil { + return err + } + + groundedKey := make([]byte, 16) + binary.BigEndian.PutUint64(groundedKey[:8], 1) + binary.BigEndian.PutUint64(groundedKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is the case for locally-sourced + // payments. No payment status should end up being set + // for this circuit, since the short channel id of the + // key is non-zero (e.g., a forwarded circuit). This + // will default it to Grounded. + groundedCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + err = circuits.Put(groundedKey, groundedCircuit) + if err != nil { + return err + } + + inFlightKey := make([]byte, 16) + binary.BigEndian.PutUint64(inFlightKey[:8], 0) + binary.BigEndian.PutUint64(inFlightKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is not the case for forwarded + // payments, but should have no impact on the + // correctness of the test. The payment status for this + // circuit should be set to InFlight, since the short + // channel id in the key is 0 (sourceHop). + inFlightCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + return circuits.Put(inFlightKey, inFlightCircuit) + }) + if err != nil { + t.Fatalf("unable to add circuit map entry: %v", err) + } + } + + // Verify that the created payment status is "Completed" for our one + // fake payment. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + // Check that our completed payments were migrated. + paymentStatus, err := d.fetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusSucceeded { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusSucceeded.String(), paymentStatus.String()) + } + + inFlightHash := [32]byte{ + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that the locally sourced payment was transitioned to + // InFlight. + paymentStatus, err = d.fetchPaymentStatus(inFlightHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusInFlight { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusInFlight.String(), paymentStatus.String()) + } + + groundedHash := [32]byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that non-locally sourced payments remain in the default + // Grounded state. + paymentStatus, err = d.fetchPaymentStatus(groundedHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusUnknown { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusUnknown.String(), paymentStatus.String()) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + paymentStatusesMigration, + false) +} + +// TestMigrateOptionalChannelCloseSummaryFields properly converts a +// ChannelCloseSummary to the v7 format, where optional fields have their +// presence indicated with boolean markers. +func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { + t.Parallel() + + chanState, err := createTestChannelState(nil) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + var chanPointBuf bytes.Buffer + err = writeOutpoint(&chanPointBuf, &chanState.FundingOutpoint) + if err != nil { + t.Fatalf("unable to write outpoint: %v", err) + } + + chanID := chanPointBuf.Bytes() + + testCases := []struct { + closeSummary *ChannelCloseSummary + oldSerialization func(c *ChannelCloseSummary) []byte + }{ + { + // A close summary where none of the new fields are + // set. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + + // The last fields will be unset. + RemoteCurrentRevocation: nil, + LocalChanConfig: ChannelConfig{}, + RemoteNextRevocation: nil, + }, + + // In the old format the last field written is the + // IsPendingField. It should be converted by adding an + // extra boolean marker at the end to indicate that the + // remaining fields are not there. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + // For the old format, these are all the fields + // that are written. + return buf.Bytes() + }, + }, + { + // A close summary where the new fields are present, + // but the optional RemoteNextRevocation field is not + // set. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, + LocalChanConfig: chanState.LocalChanCfg, + + // RemoteNextRevocation is optional, and here + // it is not set. + RemoteNextRevocation: nil, + }, + + // In the old format the last field written is the + // LocalChanConfig. This indicates that the optional + // RemoteNextRevocation field is not present. It should + // be converted by adding boolean markers for all these + // fields. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteCurrentRevocation) + if err != nil { + t.Fatal(err) + } + + err = writeChanConfig(&buf, &cs.LocalChanConfig) + if err != nil { + t.Fatal(err) + } + + // RemoteNextRevocation is not written. + return buf.Bytes() + }, + }, + { + // A close summary where all fields are present. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, + LocalChanConfig: chanState.LocalChanCfg, + + // RemoteNextRevocation is optional, and in + // this case we set it. + RemoteNextRevocation: chanState.RemoteNextRevocation, + }, + + // In the old format all the fields are written. It + // should be converted by adding boolean markers for + // all these fields. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteCurrentRevocation) + if err != nil { + t.Fatal(err) + } + + err = writeChanConfig(&buf, &cs.LocalChanConfig) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteNextRevocation) + if err != nil { + t.Fatal(err) + } + + return buf.Bytes() + }, + }, + } + + for _, test := range testCases { + + // Before the migration we must add the old format to the DB. + beforeMigrationFunc := func(d *DB) { + + // Get the old serialization format for this test's + // close summary, and it to the closed channel bucket. + old := test.oldSerialization(test.closeSummary) + err = d.Update(func(tx *bbolt.Tx) error { + closedChanBucket, err := tx.CreateBucketIfNotExists( + closedChannelBucket, + ) + if err != nil { + return err + } + return closedChanBucket.Put(chanID, old) + }) + if err != nil { + t.Fatalf("unable to add old serialization: %v", + err) + } + } + + // After the migration it should be found in the new format. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration wasn't applied") + } + + // We generate the new serialized version, to check + // against what is found in the DB. + var b bytes.Buffer + err = serializeChannelCloseSummary(&b, test.closeSummary) + if err != nil { + t.Fatalf("unable to serialize: %v", err) + } + newSerialization := b.Bytes() + + var dbSummary []byte + err = d.View(func(tx *bbolt.Tx) error { + closedChanBucket := tx.Bucket(closedChannelBucket) + if closedChanBucket == nil { + return errors.New("unable to find bucket") + } + + // Get the serialized verision from the DB and + // make sure it matches what we expected. + dbSummary = closedChanBucket.Get(chanID) + if !bytes.Equal(dbSummary, newSerialization) { + return fmt.Errorf("unexpected new " + + "serialization") + } + return nil + }) + if err != nil { + t.Fatalf("unable to view DB: %v", err) + } + + // Finally we fetch the deserialized summary from the + // DB and check that it is equal to our original one. + dbChannels, err := d.FetchClosedChannels(false) + if err != nil { + t.Fatalf("unable to fetch closed channels: %v", + err) + } + + if len(dbChannels) != 1 { + t.Fatalf("expected 1 closed channels, found %v", + len(dbChannels)) + } + + dbChan := dbChannels[0] + if !reflect.DeepEqual(dbChan, test.closeSummary) { + dbChan.RemotePub.Curve = nil + test.closeSummary.RemotePub.Curve = nil + t.Fatalf("not equal: %v vs %v", + spew.Sdump(dbChan), + spew.Sdump(test.closeSummary)) + } + + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateOptionalChannelCloseSummaryFields, + false) + } +} + +// TestMigrateGossipMessageStoreKeys ensures that the migration to the new +// gossip message store key format is successful/unsuccessful under various +// scenarios. +func TestMigrateGossipMessageStoreKeys(t *testing.T) { + t.Parallel() + + // Construct the message which we'll use to test the migration, along + // with its old and new key formats. + shortChanID := lnwire.ShortChannelID{BlockHeight: 10} + msg := &lnwire.AnnounceSignatures{ShortChannelID: shortChanID} + + var oldMsgKey [33 + 8]byte + copy(oldMsgKey[:33], pubKey.SerializeCompressed()) + binary.BigEndian.PutUint64(oldMsgKey[33:41], shortChanID.ToUint64()) + + var newMsgKey [33 + 8 + 2]byte + copy(newMsgKey[:41], oldMsgKey[:]) + binary.BigEndian.PutUint16(newMsgKey[41:43], uint16(msg.MsgType())) + + // Before the migration, we'll create the bucket where the messages + // should live and insert them. + beforeMigration := func(db *DB) { + var b bytes.Buffer + if err := msg.Encode(&b, 0); err != nil { + t.Fatalf("unable to serialize message: %v", err) + } + + err := db.Update(func(tx *bbolt.Tx) error { + messageStore, err := tx.CreateBucketIfNotExists( + messageStoreBucket, + ) + if err != nil { + return err + } + + return messageStore.Put(oldMsgKey[:], b.Bytes()) + }) + if err != nil { + t.Fatal(err) + } + } + + // After the migration, we'll make sure that: + // 1. We cannot find the message under its old key. + // 2. We can find the message under its new key. + // 3. The message matches the original. + afterMigration := func(db *DB) { + meta, err := db.FetchMeta(nil) + if err != nil { + t.Fatalf("unable to fetch db version: %v", err) + } + if meta.DbVersionNumber != 1 { + t.Fatalf("migration should have succeeded but didn't") + } + + var rawMsg []byte + err = db.View(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return errors.New("message store bucket not " + + "found") + } + rawMsg = messageStore.Get(oldMsgKey[:]) + if rawMsg != nil { + t.Fatal("expected to not find message under " + + "old key, but did") + } + rawMsg = messageStore.Get(newMsgKey[:]) + if rawMsg == nil { + return fmt.Errorf("expected to find message " + + "under new key, but didn't") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + gotMsg, err := lnwire.ReadMessage(bytes.NewReader(rawMsg), 0) + if err != nil { + t.Fatalf("unable to deserialize raw message: %v", err) + } + if !reflect.DeepEqual(msg, gotMsg) { + t.Fatalf("expected message: %v\ngot message: %v", + spew.Sdump(msg), spew.Sdump(gotMsg)) + } + } + + applyMigration( + t, beforeMigration, afterMigration, + migrateGossipMessageStoreKeys, false, + ) +} + +// TestOutgoingPaymentsMigration checks that OutgoingPayments are migrated to a +// new bucket structure after the migration. +func TestOutgoingPaymentsMigration(t *testing.T) { + t.Parallel() + + const numPayments = 4 + var oldPayments []*outgoingPayment + + // Add fake payments to test database, verifying that it was created. + beforeMigrationFunc := func(d *DB) { + for i := 0; i < numPayments; i++ { + var p *outgoingPayment + var err error + + // We fill the database with random payments. For the + // very last one we'll use a duplicate of the first, to + // ensure we are able to handle migration from a + // database that has copies. + if i < numPayments-1 { + p, err = makeRandomFakePayment() + if err != nil { + t.Fatalf("unable to create payment: %v", + err) + } + } else { + p = oldPayments[0] + } + + if err := d.addPayment(p); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + oldPayments = append(oldPayments, p) + } + + payments, err := d.fetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != numPayments { + t.Fatalf("wrong qty of paymets: expected %d got %v", + numPayments, len(payments)) + } + } + + // Verify that all payments were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + sentPayments, err := d.fetchPaymentsMigration9() + if err != nil { + t.Fatalf("unable to fetch sent payments: %v", err) + } + + if len(sentPayments) != numPayments { + t.Fatalf("expected %d payments, got %d", numPayments, + len(sentPayments)) + } + + graph := d.ChannelGraph() + sourceNode, err := graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + + for i, p := range sentPayments { + // The payment status should be Completed. + if p.Status != StatusSucceeded { + t.Fatalf("expected Completed, got %v", p.Status) + } + + // Check that the sequence number is preserved. They + // start counting at 1. + if p.sequenceNum != uint64(i+1) { + t.Fatalf("expected seqnum %d, got %d", i, + p.sequenceNum) + } + + // Order of payments should be be preserved. + old := oldPayments[i] + + // Check the individial fields. + if p.Info.Value != old.Terms.Value { + t.Fatalf("value mismatch") + } + + if p.Info.CreationDate != old.CreationDate { + t.Fatalf("date mismatch") + } + + if !bytes.Equal(p.Info.PaymentRequest, old.PaymentRequest) { + t.Fatalf("payreq mismatch") + } + + if *p.PaymentPreimage != old.PaymentPreimage { + t.Fatalf("preimage mismatch") + } + + if p.Attempt.Route.TotalFees() != old.Fee { + t.Fatalf("Fee mismatch") + } + + if p.Attempt.Route.TotalAmount != old.Fee+old.Terms.Value { + t.Fatalf("Total amount mismatch") + } + + if p.Attempt.Route.TotalTimeLock != old.TimeLockLength { + t.Fatalf("timelock mismatch") + } + + if p.Attempt.Route.SourcePubKey != sourceNode.PubKeyBytes { + t.Fatalf("source mismatch: %x vs %x", + p.Attempt.Route.SourcePubKey[:], + sourceNode.PubKeyBytes[:]) + } + + for i, hop := range old.Path { + if hop != p.Attempt.Route.Hops[i].PubKeyBytes { + t.Fatalf("path mismatch") + } + } + } + + // Finally, check that the payment sequence number is updated + // to reflect the migrated payments. + err = d.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return fmt.Errorf("payments bucket not found") + } + + seq := payments.Sequence() + if seq != numPayments { + return fmt.Errorf("expected sequence to be "+ + "%d, got %d", numPayments, seq) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateOutgoingPayments, + false) +} + +func makeRandPaymentCreationInfo() (*PaymentCreationInfo, error) { + var payHash lntypes.Hash + if _, err := rand.Read(payHash[:]); err != nil { + return nil, err + } + + return &PaymentCreationInfo{ + PaymentHash: payHash, + Value: lnwire.MilliSatoshi(rand.Int63()), + CreationDate: time.Now(), + PaymentRequest: []byte("test"), + }, nil +} + +// TestPaymentRouteSerialization tests that we're able to properly migrate +// existing payments on disk that contain the traversed routes to the new +// routing format which supports the TLV payloads. We also test that the +// migration is able to handle duplicate payment attempts. +func TestPaymentRouteSerialization(t *testing.T) { + t.Parallel() + + legacyHop1 := &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + LegacyPayload: true, + AmtToForward: 555, + } + legacyHop2 := &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + LegacyPayload: true, + AmtToForward: 555, + } + legacyRoute := route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{legacyHop1, legacyHop2}, + } + + const numPayments = 4 + var oldPayments []*Payment + + sharedPayAttempt := PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: legacyRoute, + } + + // We'll first add a series of fake payments, using the existing legacy + // serialization format. + beforeMigrationFunc := func(d *DB) { + err := d.Update(func(tx *bbolt.Tx) error { + paymentsBucket, err := tx.CreateBucket( + paymentsRootBucket, + ) + if err != nil { + t.Fatalf("unable to create new payments "+ + "bucket: %v", err) + } + + for i := 0; i < numPayments; i++ { + var seqNum [8]byte + byteOrder.PutUint64(seqNum[:], uint64(i)) + + // All payments will be randomly generated, + // other than the final payment. We'll force + // the final payment to re-use an existing + // payment hash so we can insert it into the + // duplicate payment hash bucket. + var payInfo *PaymentCreationInfo + if i < numPayments-1 { + payInfo, err = makeRandPaymentCreationInfo() + if err != nil { + t.Fatalf("unable to create "+ + "payment: %v", err) + } + } else { + payInfo = oldPayments[0].Info + } + + // Next, legacy encoded when needed, we'll + // serialize the info and the attempt. + var payInfoBytes bytes.Buffer + err = serializePaymentCreationInfo( + &payInfoBytes, payInfo, + ) + if err != nil { + t.Fatalf("unable to encode pay "+ + "info: %v", err) + } + var payAttemptBytes bytes.Buffer + err = serializePaymentAttemptInfoLegacy( + &payAttemptBytes, &sharedPayAttempt, + ) + if err != nil { + t.Fatalf("unable to encode payment attempt: "+ + "%v", err) + } + + // Before we write to disk, we'll need to fetch + // the proper bucket. If this is the duplicate + // payment, then we'll grab the dup bucket, + // otherwise, we'll use the top level bucket. + var payHashBucket *bbolt.Bucket + if i < numPayments-1 { + payHashBucket, err = paymentsBucket.CreateBucket( + payInfo.PaymentHash[:], + ) + if err != nil { + t.Fatalf("unable to create payments bucket: %v", err) + } + } else { + payHashBucket = paymentsBucket.Bucket( + payInfo.PaymentHash[:], + ) + dupPayBucket, err := payHashBucket.CreateBucket( + paymentDuplicateBucket, + ) + if err != nil { + t.Fatalf("unable to create "+ + "dup hash bucket: %v", err) + } + + payHashBucket, err = dupPayBucket.CreateBucket( + seqNum[:], + ) + if err != nil { + t.Fatalf("unable to make dup "+ + "bucket: %v", err) + } + } + + err = payHashBucket.Put(paymentSequenceKey, seqNum[:]) + if err != nil { + t.Fatalf("unable to write seqno: %v", err) + } + + err = payHashBucket.Put( + paymentCreationInfoKey, payInfoBytes.Bytes(), + ) + if err != nil { + t.Fatalf("unable to write creation "+ + "info: %v", err) + } + + err = payHashBucket.Put( + paymentAttemptInfoKey, payAttemptBytes.Bytes(), + ) + if err != nil { + t.Fatalf("unable to write attempt "+ + "info: %v", err) + } + + oldPayments = append(oldPayments, &Payment{ + Info: payInfo, + Attempt: &sharedPayAttempt, + }) + } + + return nil + }) + if err != nil { + t.Fatalf("unable to create test payments: %v", err) + } + } + + afterMigrationFunc := func(d *DB) { + newPayments, err := d.FetchPayments() + if err != nil { + t.Fatalf("unable to fetch new payments: %v", err) + } + + if len(newPayments) != numPayments { + t.Fatalf("expected %d payments, got %d", numPayments, + len(newPayments)) + } + + for i, p := range newPayments { + // Order of payments should be be preserved. + old := oldPayments[i] + + if p.Attempt.PaymentID != old.Attempt.PaymentID { + t.Fatalf("wrong pay ID: expected %v, got %v", + p.Attempt.PaymentID, + old.Attempt.PaymentID) + } + + if p.Attempt.Route.TotalFees() != old.Attempt.Route.TotalFees() { + t.Fatalf("Fee mismatch") + } + + if p.Attempt.Route.TotalAmount != old.Attempt.Route.TotalAmount { + t.Fatalf("Total amount mismatch") + } + + if p.Attempt.Route.TotalTimeLock != old.Attempt.Route.TotalTimeLock { + t.Fatalf("timelock mismatch") + } + + if p.Attempt.Route.SourcePubKey != old.Attempt.Route.SourcePubKey { + t.Fatalf("source mismatch: %x vs %x", + p.Attempt.Route.SourcePubKey[:], + old.Attempt.Route.SourcePubKey[:]) + } + + for i, hop := range p.Attempt.Route.Hops { + if !reflect.DeepEqual(hop, legacyRoute.Hops[i]) { + t.Fatalf("hop mismatch") + } + } + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateRouteSerialization, + false) +} diff --git a/channeldb/migration_01_to_11/nodes.go b/channeldb/migration_01_to_11/nodes.go new file mode 100644 index 00000000..f40359e8 --- /dev/null +++ b/channeldb/migration_01_to_11/nodes.go @@ -0,0 +1,316 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + "net" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" +) + +var ( + // nodeInfoBucket stores metadata pertaining to nodes that we've had + // direct channel-based correspondence with. This bucket allows one to + // query for all open channels pertaining to the node by exploring each + // node's sub-bucket within the openChanBucket. + nodeInfoBucket = []byte("nib") +) + +// LinkNode stores metadata related to node's that we have/had a direct +// channel open with. Information such as the Bitcoin network the node +// advertised, and its identity public key are also stored. Additionally, this +// struct and the bucket its stored within have store data similar to that of +// Bitcoin's addrmanager. The TCP address information stored within the struct +// can be used to establish persistent connections will all channel +// counterparties on daemon startup. +// +// TODO(roasbeef): also add current OnionKey plus rotation schedule? +// TODO(roasbeef): add bitfield for supported services +// * possibly add a wire.NetAddress type, type +type LinkNode struct { + // Network indicates the Bitcoin network that the LinkNode advertises + // for incoming channel creation. + Network wire.BitcoinNet + + // IdentityPub is the node's current identity public key. Any + // channel/topology related information received by this node MUST be + // signed by this public key. + IdentityPub *btcec.PublicKey + + // LastSeen tracks the last time this node was seen within the network. + // A node should be marked as seen if the daemon either is able to + // establish an outgoing connection to the node or receives a new + // incoming connection from the node. This timestamp (stored in unix + // epoch) may be used within a heuristic which aims to determine when a + // channel should be unilaterally closed due to inactivity. + // + // TODO(roasbeef): replace with block hash/height? + // * possibly add a time-value metric into the heuristic? + LastSeen time.Time + + // Addresses is a list of IP address in which either we were able to + // reach the node over in the past, OR we received an incoming + // authenticated connection for the stored identity public key. + Addresses []net.Addr + + db *DB +} + +// NewLinkNode creates a new LinkNode from the provided parameters, which is +// backed by an instance of channeldb. +func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, + addrs ...net.Addr) *LinkNode { + + return &LinkNode{ + Network: bitNet, + IdentityPub: pub, + LastSeen: time.Now(), + Addresses: addrs, + db: db, + } +} + +// UpdateLastSeen updates the last time this node was directly encountered on +// the Lightning Network. +func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error { + l.LastSeen = lastSeen + + return l.Sync() +} + +// AddAddress appends the specified TCP address to the list of known addresses +// this node is/was known to be reachable at. +func (l *LinkNode) AddAddress(addr net.Addr) error { + for _, a := range l.Addresses { + if a.String() == addr.String() { + return nil + } + } + + l.Addresses = append(l.Addresses, addr) + + return l.Sync() +} + +// Sync performs a full database sync which writes the current up-to-date data +// within the struct to the database. +func (l *LinkNode) Sync() error { + + // Finally update the database by storing the link node and updating + // any relevant indexes. + return l.db.Update(func(tx *bbolt.Tx) error { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return ErrLinkNodesNotFound + } + + return putLinkNode(nodeMetaBucket, l) + }) +} + +// putLinkNode serializes then writes the encoded version of the passed link +// node into the nodeMetaBucket. This function is provided in order to allow +// the ability to re-use a database transaction across many operations. +func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error { + // First serialize the LinkNode into its raw-bytes encoding. + var b bytes.Buffer + if err := serializeLinkNode(&b, l); err != nil { + return err + } + + // Finally insert the link-node into the node metadata bucket keyed + // according to the its pubkey serialized in compressed form. + nodePub := l.IdentityPub.SerializeCompressed() + return nodeMetaBucket.Put(nodePub, b.Bytes()) +} + +// DeleteLinkNode removes the link node with the given identity from the +// database. +func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error { + return db.Update(func(tx *bbolt.Tx) error { + return db.deleteLinkNode(tx, identity) + }) +} + +func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return ErrLinkNodesNotFound + } + + pubKey := identity.SerializeCompressed() + return nodeMetaBucket.Delete(pubKey) +} + +// FetchLinkNode attempts to lookup the data for a LinkNode based on a target +// identity public key. If a particular LinkNode for the passed identity public +// key cannot be found, then ErrNodeNotFound if returned. +func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { + var linkNode *LinkNode + err := db.View(func(tx *bbolt.Tx) error { + node, err := fetchLinkNode(tx, identity) + if err != nil { + return err + } + + linkNode = node + return nil + }) + + return linkNode, err +} + +func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) { + // First fetch the bucket for storing node metadata, bailing out early + // if it hasn't been created yet. + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + // If a link node for that particular public key cannot be located, + // then exit early with an ErrNodeNotFound. + pubKey := targetPub.SerializeCompressed() + nodeBytes := nodeMetaBucket.Get(pubKey) + if nodeBytes == nil { + return nil, ErrNodeNotFound + } + + // Finally, decode and allocate a fresh LinkNode object to be returned + // to the caller. + nodeReader := bytes.NewReader(nodeBytes) + return deserializeLinkNode(nodeReader) +} + +// TODO(roasbeef): update link node addrs in server upon connection + +// FetchAllLinkNodes starts a new database transaction to fetch all nodes with +// whom we have active channels with. +func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { + var linkNodes []*LinkNode + err := db.View(func(tx *bbolt.Tx) error { + nodes, err := db.fetchAllLinkNodes(tx) + if err != nil { + return err + } + + linkNodes = nodes + return nil + }) + if err != nil { + return nil, err + } + + return linkNodes, nil +} + +// fetchAllLinkNodes uses an existing database transaction to fetch all nodes +// with whom we have active channels with. +func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + var linkNodes []*LinkNode + err := nodeMetaBucket.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + nodeReader := bytes.NewReader(v) + linkNode, err := deserializeLinkNode(nodeReader) + if err != nil { + return err + } + + linkNodes = append(linkNodes, linkNode) + return nil + }) + if err != nil { + return nil, err + } + + return linkNodes, nil +} + +func serializeLinkNode(w io.Writer, l *LinkNode) error { + var buf [8]byte + + byteOrder.PutUint32(buf[:4], uint32(l.Network)) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + serializedID := l.IdentityPub.SerializeCompressed() + if _, err := w.Write(serializedID); err != nil { + return err + } + + seenUnix := uint64(l.LastSeen.Unix()) + byteOrder.PutUint64(buf[:], seenUnix) + if _, err := w.Write(buf[:]); err != nil { + return err + } + + numAddrs := uint32(len(l.Addresses)) + byteOrder.PutUint32(buf[:4], numAddrs) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + for _, addr := range l.Addresses { + if err := serializeAddr(w, addr); err != nil { + return err + } + } + + return nil +} + +func deserializeLinkNode(r io.Reader) (*LinkNode, error) { + var ( + err error + buf [8]byte + ) + + node := &LinkNode{} + + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return nil, err + } + node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4])) + + var pub [33]byte + if _, err := io.ReadFull(r, pub[:]); err != nil { + return nil, err + } + node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256()) + if err != nil { + return nil, err + } + + if _, err := io.ReadFull(r, buf[:]); err != nil { + return nil, err + } + node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0) + + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return nil, err + } + numAddrs := byteOrder.Uint32(buf[:4]) + + node.Addresses = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := deserializeAddr(r) + if err != nil { + return nil, err + } + node.Addresses[i] = addr + } + + return node, nil +} diff --git a/channeldb/migration_01_to_11/nodes_test.go b/channeldb/migration_01_to_11/nodes_test.go new file mode 100644 index 00000000..481dc5bd --- /dev/null +++ b/channeldb/migration_01_to_11/nodes_test.go @@ -0,0 +1,140 @@ +package migration_01_to_11 + +import ( + "bytes" + "net" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" +) + +func TestLinkNodeEncodeDecode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First we'll create some initial data to use for populating our test + // LinkNode instances. + _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + _, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:]) + addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000") + if err != nil { + t.Fatalf("unable to create test addr: %v", err) + } + addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000") + if err != nil { + t.Fatalf("unable to create test addr: %v", err) + } + + // Create two fresh link node instances with the above dummy data, then + // fully sync both instances to disk. + node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) + node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) + if err := node1.Sync(); err != nil { + t.Fatalf("unable to sync node: %v", err) + } + if err := node2.Sync(); err != nil { + t.Fatalf("unable to sync node: %v", err) + } + + // Fetch all current link nodes from the database, they should exactly + // match the two created above. + originalNodes := []*LinkNode{node2, node1} + linkNodes, err := cdb.FetchAllLinkNodes() + if err != nil { + t.Fatalf("unable to fetch nodes: %v", err) + } + for i, node := range linkNodes { + if originalNodes[i].Network != node.Network { + t.Fatalf("node networks don't match: expected %v, got %v", + originalNodes[i].Network, node.Network) + } + + originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed() + dbPubkey := node.IdentityPub.SerializeCompressed() + if !bytes.Equal(originalPubkey, dbPubkey) { + t.Fatalf("node pubkeys don't match: expected %x, got %x", + originalPubkey, dbPubkey) + } + if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() { + t.Fatalf("last seen timestamps don't match: expected %v got %v", + originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix()) + } + if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() { + t.Fatalf("addresses don't match: expected %v, got %v", + originalNodes[i].Addresses, node.Addresses) + } + } + + // Next, we'll exercise the methods to append additional IP + // addresses, and also to update the last seen time. + if err := node1.UpdateLastSeen(time.Now()); err != nil { + t.Fatalf("unable to update last seen: %v", err) + } + if err := node1.AddAddress(addr2); err != nil { + t.Fatalf("unable to update addr: %v", err) + } + + // Fetch the same node from the database according to its public key. + node1DB, err := cdb.FetchLinkNode(pub1) + if err != nil { + t.Fatalf("unable to find node: %v", err) + } + + // Both the last seen timestamp and the list of reachable addresses for + // the node should be updated. + if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() { + t.Fatalf("last seen timestamps don't match: expected %v got %v", + node1.LastSeen.Unix(), node1DB.LastSeen.Unix()) + } + if len(node1DB.Addresses) != 2 { + t.Fatalf("wrong length for node1 addresses: expected %v, got %v", + 2, len(node1DB.Addresses)) + } + if node1DB.Addresses[0].String() != addr1.String() { + t.Fatalf("wrong address for node: expected %v, got %v", + addr1.String(), node1DB.Addresses[0].String()) + } + if node1DB.Addresses[1].String() != addr2.String() { + t.Fatalf("wrong address for node: expected %v, got %v", + addr2.String(), node1DB.Addresses[1].String()) + } +} + +func TestDeleteLinkNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1337, + } + linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to write link node to db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err != nil { + t.Fatalf("unable to find link node: %v", err) + } + + if err := cdb.DeleteLinkNode(pubKey); err != nil { + t.Fatalf("unable to delete link node from db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err == nil { + t.Fatal("should not have found link node in db, but did") + } +} diff --git a/channeldb/migration_01_to_11/options.go b/channeldb/migration_01_to_11/options.go new file mode 100644 index 00000000..c3cc2c4a --- /dev/null +++ b/channeldb/migration_01_to_11/options.go @@ -0,0 +1,62 @@ +package migration_01_to_11 + +const ( + // DefaultRejectCacheSize is the default number of rejectCacheEntries to + // cache for use in the rejection cache of incoming gossip traffic. This + // produces a cache size of around 1MB. + DefaultRejectCacheSize = 50000 + + // DefaultChannelCacheSize is the default number of ChannelEdges cached + // in order to reply to gossip queries. This produces a cache size of + // around 40MB. + DefaultChannelCacheSize = 20000 +) + +// Options holds parameters for tuning and customizing a channeldb.DB. +type Options struct { + // RejectCacheSize is the maximum number of rejectCacheEntries to hold + // in the rejection cache. + RejectCacheSize int + + // ChannelCacheSize is the maximum number of ChannelEdges to hold in the + // channel cache. + ChannelCacheSize int + + // NoFreelistSync, if true, prevents the database from syncing its + // freelist to disk, resulting in improved performance at the expense of + // increased startup time. + NoFreelistSync bool +} + +// DefaultOptions returns an Options populated with default values. +func DefaultOptions() Options { + return Options{ + RejectCacheSize: DefaultRejectCacheSize, + ChannelCacheSize: DefaultChannelCacheSize, + NoFreelistSync: true, + } +} + +// OptionModifier is a function signature for modifying the default Options. +type OptionModifier func(*Options) + +// OptionSetRejectCacheSize sets the RejectCacheSize to n. +func OptionSetRejectCacheSize(n int) OptionModifier { + return func(o *Options) { + o.RejectCacheSize = n + } +} + +// OptionSetChannelCacheSize sets the ChannelCacheSize to n. +func OptionSetChannelCacheSize(n int) OptionModifier { + return func(o *Options) { + o.ChannelCacheSize = n + } +} + +// OptionSetSyncFreelist allows the database to sync its freelist. +func OptionSetSyncFreelist(b bool) OptionModifier { + return func(o *Options) { + o.NoFreelistSync = !b + } +} diff --git a/channeldb/migration_01_to_11/payment_control.go b/channeldb/migration_01_to_11/payment_control.go new file mode 100644 index 00000000..83b1649a --- /dev/null +++ b/channeldb/migration_01_to_11/payment_control.go @@ -0,0 +1,497 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // ErrAlreadyPaid signals we have already paid this payment hash. + ErrAlreadyPaid = errors.New("invoice is already paid") + + // ErrPaymentInFlight signals that payment for this payment hash is + // already "in flight" on the network. + ErrPaymentInFlight = errors.New("payment is in transition") + + // ErrPaymentNotInitiated is returned if payment wasn't initiated in + // switch. + ErrPaymentNotInitiated = errors.New("payment isn't initiated") + + // ErrPaymentAlreadySucceeded is returned in the event we attempt to + // change the status of a payment already succeeded. + ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded") + + // ErrPaymentAlreadyFailed is returned in the event we attempt to + // re-fail a failed payment. + ErrPaymentAlreadyFailed = errors.New("payment has already failed") + + // ErrUnknownPaymentStatus is returned when we do not recognize the + // existing state of a payment. + ErrUnknownPaymentStatus = errors.New("unknown payment status") + + // errNoAttemptInfo is returned when no attempt info is stored yet. + errNoAttemptInfo = errors.New("unable to find attempt info for " + + "inflight payment") +) + +// PaymentControl implements persistence for payments and payment attempts. +type PaymentControl struct { + db *DB +} + +// NewPaymentControl creates a new instance of the PaymentControl. +func NewPaymentControl(db *DB) *PaymentControl { + return &PaymentControl{ + db: db, + } +} + +// InitPayment checks or records the given PaymentCreationInfo with the DB, +// making sure it does not already exist as an in-flight payment. Then this +// method returns successfully, the payment is guranteeed to be in the InFlight +// state. +func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, + info *PaymentCreationInfo) error { + + var b bytes.Buffer + if err := serializePaymentCreationInfo(&b, info); err != nil { + return err + } + infoBytes := b.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := createPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + // Get the existing status of this payment, if any. + paymentStatus := fetchPaymentStatus(bucket) + + switch paymentStatus { + + // We allow retrying failed payments. + case StatusFailed: + + // This is a new payment that is being initialized for the + // first time. + case StatusUnknown: + + // We already have an InFlight payment on the network. We will + // disallow any new payments. + case StatusInFlight: + updateErr = ErrPaymentInFlight + return nil + + // We've already succeeded a payment to this payment hash, + // forbid the switch from sending another. + case StatusSucceeded: + updateErr = ErrAlreadyPaid + return nil + + default: + updateErr = ErrUnknownPaymentStatus + return nil + } + + // Obtain a new sequence number for this payment. This is used + // to sort the payments in order of creation, and also acts as + // a unique identifier for each payment. + sequenceNum, err := nextPaymentSequence(tx) + if err != nil { + return err + } + + err = bucket.Put(paymentSequenceKey, sequenceNum) + if err != nil { + return err + } + + // Add the payment info to the bucket, which contains the + // static information for this payment + err = bucket.Put(paymentCreationInfoKey, infoBytes) + if err != nil { + return err + } + + // We'll delete any lingering attempt info to start with, in + // case we are initializing a payment that was attempted + // earlier, but left in a state where we could retry. + err = bucket.Delete(paymentAttemptInfoKey) + if err != nil { + return err + } + + // Also delete any lingering failure info now that we are + // re-attempting. + return bucket.Delete(paymentFailInfoKey) + }) + if err != nil { + return err + } + + return updateErr +} + +// RegisterAttempt atomically records the provided PaymentAttemptInfo to the +// DB. +func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, + attempt *PaymentAttemptInfo) error { + + // Serialize the information before opening the db transaction. + var a bytes.Buffer + if err := serializePaymentAttemptInfo(&a, attempt); err != nil { + return err + } + attemptBytes := a.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only register attempts for payments that are + // in-flight. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Add the payment attempt to the payments bucket. + return bucket.Put(paymentAttemptInfoKey, attemptBytes) + }) + if err != nil { + return err + } + + return updateErr +} + +// Success transitions a payment into the Succeeded state. After invoking this +// method, InitPayment should always return an error to prevent us from making +// duplicate payments to the same payment hash. The provided preimage is +// atomically saved to the DB for record keeping. +func (p *PaymentControl) Success(paymentHash lntypes.Hash, + preimage lntypes.Preimage) (*route.Route, error) { + + var ( + updateErr error + route *route.Route + ) + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as succeeded. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Record the successful payment info atomically to the + // payments record. + err = bucket.Put(paymentSettleInfoKey, preimage[:]) + if err != nil { + return err + } + + // Retrieve attempt info for the notification. + attempt, err := fetchPaymentAttempt(bucket) + if err != nil { + return err + } + + route = &attempt.Route + + return nil + }) + if err != nil { + return nil, err + } + + return route, updateErr +} + +// Fail transitions a payment into the Failed state, and records the reason the +// payment failed. After invoking this method, InitPayment should return nil on +// its next call for this payment hash, allowing the switch to make a +// subsequent payment. +func (p *PaymentControl) Fail(paymentHash lntypes.Hash, + reason FailureReason) (*route.Route, error) { + + var ( + updateErr error + route *route.Route + ) + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as failed. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Put the failure reason in the bucket for record keeping. + v := []byte{byte(reason)} + err = bucket.Put(paymentFailInfoKey, v) + if err != nil { + return err + } + + // Retrieve attempt info for the notification, if available. + attempt, err := fetchPaymentAttempt(bucket) + if err != nil && err != errNoAttemptInfo { + return err + } + if err != errNoAttemptInfo { + route = &attempt.Route + } + + return nil + }) + if err != nil { + return nil, err + } + + return route, updateErr +} + +// FetchPayment returns information about a payment from the database. +func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( + *Payment, error) { + + var payment *Payment + err := p.db.View(func(tx *bbolt.Tx) error { + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + payment, err = fetchPayment(bucket) + + return err + }) + if err != nil { + return nil, err + } + + return payment, nil +} + +// createPaymentBucket creates or fetches the sub-bucket assigned to this +// payment hash. +func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + return payments.CreateBucketIfNotExists(paymentHash[:]) +} + +// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If +// the bucket does not exist, it returns ErrPaymentNotInitiated. +func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil, ErrPaymentNotInitiated + } + + bucket := payments.Bucket(paymentHash[:]) + if bucket == nil { + return nil, ErrPaymentNotInitiated + } + + return bucket, nil + +} + +// nextPaymentSequence returns the next sequence number to store for a new +// payment. +func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) { + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + seq, err := payments.NextSequence() + if err != nil { + return nil, err + } + + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, seq) + return b, nil +} + +// fetchPaymentStatus fetches the payment status of the payment. If the payment +// isn't found, it will default to "StatusUnknown". +func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { + if bucket.Get(paymentSettleInfoKey) != nil { + return StatusSucceeded + } + + if bucket.Get(paymentFailInfoKey) != nil { + return StatusFailed + } + + if bucket.Get(paymentCreationInfoKey) != nil { + return StatusInFlight + } + + return StatusUnknown +} + +// ensureInFlight checks whether the payment found in the given bucket has +// status InFlight, and returns an error otherwise. This should be used to +// ensure we only mark in-flight payments as succeeded or failed. +func ensureInFlight(bucket *bbolt.Bucket) error { + paymentStatus := fetchPaymentStatus(bucket) + + switch { + + // The payment was indeed InFlight, return. + case paymentStatus == StatusInFlight: + return nil + + // Our records show the payment as unknown, meaning it never + // should have left the switch. + case paymentStatus == StatusUnknown: + return ErrPaymentNotInitiated + + // The payment succeeded previously. + case paymentStatus == StatusSucceeded: + return ErrPaymentAlreadySucceeded + + // The payment was already failed. + case paymentStatus == StatusFailed: + return ErrPaymentAlreadyFailed + + default: + return ErrUnknownPaymentStatus + } +} + +// fetchPaymentAttempt fetches the payment attempt from the bucket. +func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) { + attemptData := bucket.Get(paymentAttemptInfoKey) + if attemptData == nil { + return nil, errNoAttemptInfo + } + + r := bytes.NewReader(attemptData) + return deserializePaymentAttemptInfo(r) +} + +// InFlightPayment is a wrapper around a payment that has status InFlight. +type InFlightPayment struct { + // Info is the PaymentCreationInfo of the in-flight payment. + Info *PaymentCreationInfo + + // Attempt contains information about the last payment attempt that was + // made to this payment hash. + // + // NOTE: Might be nil. + Attempt *PaymentAttemptInfo +} + +// FetchInFlightPayments returns all payments with status InFlight. +func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { + var inFlights []*InFlightPayment + err := p.db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + return payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element") + } + + // If the status is not InFlight, we can return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus != StatusInFlight { + return nil + } + + var ( + inFlight = &InFlightPayment{} + err error + ) + + // Get the CreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return fmt.Errorf("unable to find creation " + + "info for inflight payment") + } + + r := bytes.NewReader(b) + inFlight.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return err + } + + // Now get the attempt info. It could be that there is + // no attempt info yet. + inFlight.Attempt, err = fetchPaymentAttempt(bucket) + if err != nil && err != errNoAttemptInfo { + return err + } + + inFlights = append(inFlights, inFlight) + return nil + }) + }) + if err != nil { + return nil, err + } + + return inFlights, nil +} diff --git a/channeldb/migration_01_to_11/payment_control_test.go b/channeldb/migration_01_to_11/payment_control_test.go new file mode 100644 index 00000000..9868475e --- /dev/null +++ b/channeldb/migration_01_to_11/payment_control_test.go @@ -0,0 +1,550 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "reflect" + "testing" + "time" + + "github.com/btcsuite/fastsha256" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" +) + +func initDB() (*DB, error) { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } + + db, err := Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } + return preimage, nil +} + +func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, + lntypes.Preimage, error) { + + preimage, err := genPreimage() + if err != nil { + return nil, nil, preimage, fmt.Errorf("unable to "+ + "generate preimage: %v", err) + } + + rhash := fastsha256.Sum256(preimage[:]) + return &PaymentCreationInfo{ + PaymentHash: rhash, + Value: 1, + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + }, + &PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: testRoute, + }, preimage, nil +} + +// TestPaymentControlSwitchFail checks that payment status returns to Failed +// status after failing, and that InitPayment allows another HTLC for the +// same payment hash. +func TestPaymentControlSwitchFail(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + &failReason, + ) + + // Sends the htlc again, which should succeed since the prior payment + // failed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Record a new attempt. + attempt.PaymentID = 2 + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Verifies that status was changed to StatusSucceeded. + var route *route.Route + route, err = pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + err = assertRouteEqual(route, &attempt.Route) + if err != nil { + t.Fatalf("unexpected route returned: %v vs %v: %v", + spew.Sdump(attempt.Route), spew.Sdump(*route), err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + // Attempt a final payment, which should now fail since the prior + // payment succeed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSwitchDoubleSend checks the ability of payment control to +// prevent double sending of htlc message, when message is in StatusInFlight. +func TestPaymentControlSwitchDoubleSend(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate base status and move it to + // StatusInFlight and verifies that it was changed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Try to initiate double sending of htlc message with the same + // payment hash, should result in error indicating that payment has + // already been sent. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // Record an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // After settling, the error should be ErrAlreadyPaid. + if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSuccessesWithoutInFlight checks that the payment +// control will disallow calls to Success when no payment is in flight. +func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Attempt to complete the payment should fail. + _, err = pControl.Success(info.PaymentHash, preimg) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, + nil, + ) +} + +// TestPaymentControlFailsWithoutInFlight checks that a strict payment +// control will disallow calls to Fail when no payment is in flight. +func TestPaymentControlFailsWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, _, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Calling Fail should return an error. + _, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, + ) +} + +// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// deletes payments from the database that are not in-flight. +func TestPaymentControlDeleteNonInFligt(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + payments := []struct { + failed bool + success bool + }{ + { + failed: true, + success: false, + }, + { + failed: false, + success: true, + }, + { + failed: false, + success: false, + }, + } + + for _, p := range payments { + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + if p.failed { + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, &failReason, + ) + } else if p.success { + // Verifies that status was changed to StatusSucceeded. + _, err := pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, preimg, nil, + ) + } else { + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, nil, + ) + } + } + + // Delete payments. + if err := db.DeletePayments(); err != nil { + t.Fatal(err) + } + + // This should leave the in-flight payment. + dbPayments, err := db.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != 1 { + t.Fatalf("expected one payment, got %d", len(dbPayments)) + } + + status := dbPayments[0].Status + if status != StatusInFlight { + t.Fatalf("expected in-fligth status, got %v", status) + } +} + +func assertPaymentStatus(t *testing.T, db *DB, + hash [32]byte, expStatus PaymentStatus) { + + t.Helper() + + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil { + return nil + } + + // Get the existing status of this payment, if any. + paymentStatus = fetchPaymentStatus(bucket) + return nil + }) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != expStatus { + t.Fatalf("payment status mismatch: expected %v, got %v", + expStatus, paymentStatus) + } +} + +func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { + b := bucket.Get(paymentCreationInfoKey) + switch { + case b == nil && c == nil: + return nil + case b == nil: + return fmt.Errorf("expected creation info not found") + case c == nil: + return fmt.Errorf("unexpected creation info found") + } + + r := bytes.NewReader(b) + c2, err := deserializePaymentCreationInfo(r) + if err != nil { + return err + } + if !reflect.DeepEqual(c, c2) { + return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(c), spew.Sdump(c2)) + } + + return nil +} + +func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { + b := bucket.Get(paymentAttemptInfoKey) + switch { + case b == nil && a == nil: + return nil + case b == nil: + return fmt.Errorf("expected attempt info not found") + case a == nil: + return fmt.Errorf("unexpected attempt info found") + } + + r := bytes.NewReader(b) + a2, err := deserializePaymentAttemptInfo(r) + if err != nil { + return err + } + + return assertRouteEqual(&a.Route, &a2.Route) +} + +func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { + zero := lntypes.Preimage{} + b := bucket.Get(paymentSettleInfoKey) + switch { + case b == nil && preimg == zero: + return nil + case b == nil: + return fmt.Errorf("expected preimage not found") + case preimg == zero: + return fmt.Errorf("unexpected preimage found") + } + + var pre2 lntypes.Preimage + copy(pre2[:], b[:]) + if preimg != pre2 { + return fmt.Errorf("Preimages don't match: %x vs %x", + preimg, pre2) + } + + return nil +} + +func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { + b := bucket.Get(paymentFailInfoKey) + switch { + case b == nil && failReason == nil: + return nil + case b == nil: + return fmt.Errorf("expected fail info not found") + case failReason == nil: + return fmt.Errorf("unexpected fail info found") + } + + failReason2 := FailureReason(b[0]) + if *failReason != failReason2 { + return fmt.Errorf("Failure infos don't match: %v vs %v", + *failReason, failReason2) + } + + return nil +} + +func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, + c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage, + f *FailureReason) { + + t.Helper() + + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil && c == nil { + return nil + } + if payments == nil { + return fmt.Errorf("sent payments not found") + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil && c == nil { + return nil + } + + if bucket == nil { + return fmt.Errorf("payment not found") + } + + if err := checkPaymentCreationInfo(bucket, c); err != nil { + return err + } + + if err := checkPaymentAttemptInfo(bucket, a); err != nil { + return err + } + + if err := checkSettleInfo(bucket, s); err != nil { + return err + } + + if err := checkFailInfo(bucket, f); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("assert payment info failed: %v", err) + } + +} diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go new file mode 100644 index 00000000..fd3db5a1 --- /dev/null +++ b/channeldb/migration_01_to_11/payments.go @@ -0,0 +1,669 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sort" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // paymentsRootBucket is the name of the top-level bucket within the + // database that stores all data related to payments. Within this + // bucket, each payment hash its own sub-bucket keyed by its payment + // hash. + // + // Bucket hierarchy: + // + // root-bucket + // | + // |-- + // | |--sequence-key: + // | |--creation-info-key: + // | |--attempt-info-key: + // | |--settle-info-key: + // | |--fail-info-key: + // | | + // | |--duplicate-bucket (only for old, completed payments) + // | | + // | |-- + // | | |--sequence-key: + // | | |--creation-info-key: + // | | |--attempt-info-key: + // | | |--settle-info-key: + // | | |--fail-info-key: + // | | + // | |-- + // | | | + // | ... ... + // | + // |-- + // | | + // | ... + // ... + // + paymentsRootBucket = []byte("payments-root-bucket") + + // paymentDublicateBucket is the name of a optional sub-bucket within + // the payment hash bucket, that is used to hold duplicate payments to + // a payment hash. This is needed to support information from earlier + // versions of lnd, where it was possible to pay to a payment hash more + // than once. + paymentDuplicateBucket = []byte("payment-duplicate-bucket") + + // paymentSequenceKey is a key used in the payment's sub-bucket to + // store the sequence number of the payment. + paymentSequenceKey = []byte("payment-sequence-key") + + // paymentCreationInfoKey is a key used in the payment's sub-bucket to + // store the creation info of the payment. + paymentCreationInfoKey = []byte("payment-creation-info") + + // paymentAttemptInfoKey is a key used in the payment's sub-bucket to + // store the info about the latest attempt that was done for the + // payment in question. + paymentAttemptInfoKey = []byte("payment-attempt-info") + + // paymentSettleInfoKey is a key used in the payment's sub-bucket to + // store the settle info of the payment. + paymentSettleInfoKey = []byte("payment-settle-info") + + // paymentFailInfoKey is a key used in the payment's sub-bucket to + // store information about the reason a payment failed. + paymentFailInfoKey = []byte("payment-fail-info") +) + +// FailureReason encodes the reason a payment ultimately failed. +type FailureReason byte + +const ( + // FailureReasonTimeout indicates that the payment did timeout before a + // successful payment attempt was made. + FailureReasonTimeout FailureReason = 0 + + // FailureReasonNoRoute indicates no successful route to the + // destination was found during path finding. + FailureReasonNoRoute FailureReason = 1 + + // FailureReasonError indicates that an unexpected error happened during + // payment. + FailureReasonError FailureReason = 2 + + // FailureReasonIncorrectPaymentDetails indicates that either the hash + // is unknown or the final cltv delta or amount is incorrect. + FailureReasonIncorrectPaymentDetails FailureReason = 3 + + // TODO(halseth): cancel state. + + // TODO(joostjager): Add failure reasons for: + // LocalLiquidityInsufficient, RemoteCapacityInsufficient. +) + +// String returns a human readable FailureReason +func (r FailureReason) String() string { + switch r { + case FailureReasonTimeout: + return "timeout" + case FailureReasonNoRoute: + return "no_route" + case FailureReasonError: + return "error" + case FailureReasonIncorrectPaymentDetails: + return "incorrect_payment_details" + } + + return "unknown" +} + +// PaymentStatus represent current status of payment +type PaymentStatus byte + +const ( + // StatusUnknown is the status where a payment has never been initiated + // and hence is unknown. + StatusUnknown PaymentStatus = 0 + + // StatusInFlight is the status where a payment has been initiated, but + // a response has not been received. + StatusInFlight PaymentStatus = 1 + + // StatusSucceeded is the status where a payment has been initiated and + // the payment was completed successfully. + StatusSucceeded PaymentStatus = 2 + + // StatusFailed is the status where a payment has been initiated and a + // failure result has come back. + StatusFailed PaymentStatus = 3 +) + +// Bytes returns status as slice of bytes. +func (ps PaymentStatus) Bytes() []byte { + return []byte{byte(ps)} +} + +// FromBytes sets status from slice of bytes. +func (ps *PaymentStatus) FromBytes(status []byte) error { + if len(status) != 1 { + return errors.New("payment status is empty") + } + + switch PaymentStatus(status[0]) { + case StatusUnknown, StatusInFlight, StatusSucceeded, StatusFailed: + *ps = PaymentStatus(status[0]) + default: + return errors.New("unknown payment status") + } + + return nil +} + +// String returns readable representation of payment status. +func (ps PaymentStatus) String() string { + switch ps { + case StatusUnknown: + return "Unknown" + case StatusInFlight: + return "In Flight" + case StatusSucceeded: + return "Succeeded" + case StatusFailed: + return "Failed" + default: + return "Unknown" + } +} + +// PaymentCreationInfo is the information necessary to have ready when +// initiating a payment, moving it into state InFlight. +type PaymentCreationInfo struct { + // PaymentHash is the hash this payment is paying to. + PaymentHash lntypes.Hash + + // Value is the amount we are paying. + Value lnwire.MilliSatoshi + + // CreatingDate is the time when this payment was initiated. + CreationDate time.Time + + // PaymentRequest is the full payment request, if any. + PaymentRequest []byte +} + +// PaymentAttemptInfo contains information about a specific payment attempt for +// a given payment. This information is used by the router to handle any errors +// coming back after an attempt is made, and to query the switch about the +// status of a payment. For settled payment this will be the information for +// the succeeding payment attempt. +type PaymentAttemptInfo struct { + // PaymentID is the unique ID used for this attempt. + PaymentID uint64 + + // SessionKey is the ephemeral key used for this payment attempt. + SessionKey *btcec.PrivateKey + + // Route is the route attempted to send the HTLC. + Route route.Route +} + +// Payment is a wrapper around a payment's PaymentCreationInfo, +// PaymentAttemptInfo, and preimage. All payments will have the +// PaymentCreationInfo set, the PaymentAttemptInfo will be set only if at least +// one payment attempt has been made, while only completed payments will have a +// non-zero payment preimage. +type Payment struct { + // sequenceNum is a unique identifier used to sort the payments in + // order of creation. + sequenceNum uint64 + + // Status is the current PaymentStatus of this payment. + Status PaymentStatus + + // Info holds all static information about this payment, and is + // populated when the payment is initiated. + Info *PaymentCreationInfo + + // Attempt is the information about the last payment attempt made. + // + // NOTE: Can be nil if no attempt is yet made. + Attempt *PaymentAttemptInfo + + // PaymentPreimage is the preimage of a successful payment. This serves + // as a proof of payment. It will only be non-nil for settled payments. + // + // NOTE: Can be nil if payment is not settled. + PaymentPreimage *lntypes.Preimage + + // Failure is a failure reason code indicating the reason the payment + // failed. It is only non-nil for failed payments. + // + // NOTE: Can be nil if payment is not failed. + Failure *FailureReason +} + +// FetchPayments returns all sent payments found in the DB. +func (db *DB) FetchPayments() ([]*Payment, error) { + var payments []*Payment + + err := db.View(func(tx *bbolt.Tx) error { + paymentsBucket := tx.Bucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + return paymentsBucket.ForEach(func(k, v []byte) error { + bucket := paymentsBucket.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + p, err := fetchPayment(bucket) + if err != nil { + return err + } + + payments = append(payments, p) + + // For older versions of lnd, duplicate payments to a + // payment has was possible. These will be found in a + // sub-bucket indexed by their sequence number if + // available. + dup := bucket.Bucket(paymentDuplicateBucket) + if dup == nil { + return nil + } + + return dup.ForEach(func(k, v []byte) error { + subBucket := dup.Bucket(k) + if subBucket == nil { + // We one bucket for each duplicate to + // be found. + return fmt.Errorf("non bucket element" + + "in duplicate bucket") + } + + p, err := fetchPayment(subBucket) + if err != nil { + return err + } + + payments = append(payments, p) + return nil + }) + }) + }) + if err != nil { + return nil, err + } + + // Before returning, sort the payments by their sequence number. + sort.Slice(payments, func(i, j int) bool { + return payments[i].sequenceNum < payments[j].sequenceNum + }) + + return payments, nil +} + +func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { + var ( + err error + p = &Payment{} + ) + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + p.sequenceNum = binary.BigEndian.Uint64(seqBytes) + + // Get the payment status. + p.Status = fetchPaymentStatus(bucket) + + // Get the PaymentCreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return nil, fmt.Errorf("creation info not found") + } + + r := bytes.NewReader(b) + p.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return nil, err + + } + + // Get the PaymentAttemptInfo. This can be unset. + b = bucket.Get(paymentAttemptInfoKey) + if b != nil { + r = bytes.NewReader(b) + p.Attempt, err = deserializePaymentAttemptInfo(r) + if err != nil { + return nil, err + } + } + + // Get the payment preimage. This is only found for + // completed payments. + b = bucket.Get(paymentSettleInfoKey) + if b != nil { + var preimg lntypes.Preimage + copy(preimg[:], b[:]) + p.PaymentPreimage = &preimg + } + + // Get failure reason if available. + b = bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + p.Failure = &reason + } + + return p, nil +} + +// DeletePayments deletes all completed and failed payments from the DB. +func (db *DB) DeletePayments() error { + return db.Update(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + var deleteBuckets [][]byte + err := payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + // If the status is InFlight, we cannot safely delete + // the payment information, so we return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus == StatusInFlight { + return nil + } + + deleteBuckets = append(deleteBuckets, k) + return nil + }) + if err != nil { + return err + } + + for _, k := range deleteBuckets { + if err := payments.DeleteBucket(k); err != nil { + return err + } + } + + return nil + }) +} + +func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { + var scratch [8]byte + + if _, err := w.Write(c.PaymentHash[:]); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(c.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(c.CreationDate.Unix())) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest))) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + if _, err := w.Write(c.PaymentRequest[:]); err != nil { + return err + } + + return nil +} + +func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { + var scratch [8]byte + + c := &PaymentCreationInfo{} + + if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return nil, err + } + c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return nil, err + } + c.CreationDate = time.Unix(int64(byteOrder.Uint64(scratch[:])), 0) + + if _, err := io.ReadFull(r, scratch[:4]); err != nil { + return nil, err + } + + reqLen := uint32(byteOrder.Uint32(scratch[:4])) + payReq := make([]byte, reqLen) + if reqLen > 0 { + if _, err := io.ReadFull(r, payReq[:]); err != nil { + return nil, err + } + } + c.PaymentRequest = payReq + + return c, nil +} + +func serializePaymentAttemptInfo(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := SerializeRoute(w, a.Route); err != nil { + return err + } + + return nil +} + +func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = DeserializeRoute(r) + if err != nil { + return nil, err + } + return a, nil +} + +func serializeHop(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil { + return err + } + + // For legacy payloads, we don't need to write any TLV records, so + // we'll write a zero indicating the our serialized TLV map has no + // records. + if h.LegacyPayload { + return WriteElements(w, uint32(0)) + } + + // Otherwise, we'll transform our slice of records into a map of the + // raw bytes, then serialize them in-line with a length (number of + // elements) prefix. + mapRecords, err := tlv.RecordsToMap(h.TLVRecords) + if err != nil { + return err + } + + numRecords := uint32(len(mapRecords)) + if err := WriteElements(w, numRecords); err != nil { + return err + } + + for recordType, rawBytes := range mapRecords { + if err := WriteElements(w, recordType); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil { + return err + } + } + + return nil +} + +// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need +// to read/write a TLV stream larger than this. +const maxOnionPayloadSize = 1300 + +func deserializeHop(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + // TODO(roasbeef): change field to allow LegacyPayload false to be the + // legacy default? + err := binary.Read(r, byteOrder, &h.LegacyPayload) + if err != nil { + return nil, err + } + + var numElements uint32 + if err := ReadElements(r, &numElements); err != nil { + return nil, err + } + + // If there're no elements, then we can return early. + if numElements == 0 { + return h, nil + } + + tlvMap := make(map[uint64][]byte) + for i := uint32(0); i < numElements; i++ { + var tlvType uint64 + if err := ReadElements(r, &tlvType); err != nil { + return nil, err + } + + rawRecordBytes, err := wire.ReadVarBytes( + r, 0, maxOnionPayloadSize, "tlv", + ) + if err != nil { + return nil, err + } + + tlvMap[tlvType] = rawRecordBytes + } + + tlvRecords, err := tlv.MapToRecords(tlvMap) + if err != nil { + return nil, err + } + + h.TLVRecords = tlvRecords + + return h, nil +} + +// SerializeRoute serializes a route. +func SerializeRoute(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHop(w, h); err != nil { + return err + } + } + + return nil +} + +// DeserializeRoute deserializes a route. +func DeserializeRoute(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHop(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} diff --git a/channeldb/migration_01_to_11/payments_test.go b/channeldb/migration_01_to_11/payments_test.go new file mode 100644 index 00000000..07307941 --- /dev/null +++ b/channeldb/migration_01_to_11/payments_test.go @@ -0,0 +1,324 @@ +package migration_01_to_11 + +import ( + "bytes" + "errors" + "fmt" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + priv, _ = btcec.NewPrivateKey(btcec.S256()) + pub = priv.PubKey() + + tlvBytes = []byte{1, 2, 3} + tlvEncoder = tlv.StubEncoder(tlvBytes) + testHop1 = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + TLVRecords: []tlv.Record{ + tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), + tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), + }, + } + + testHop2 = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + LegacyPayload: true, + } + + testRoute = route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{ + testHop1, + testHop2, + }, + } +) + +func makeFakePayment() *outgoingPayment { + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Memo: []byte("fake memo"), + Receipt: []byte("fake receipt"), + PaymentRequest: []byte(""), + } + + copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) + fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) + + fakePath := make([][33]byte, 3) + for i := 0; i < 3; i++ { + copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33)) + } + + fakePayment := &outgoingPayment{ + Invoice: *fakeInvoice, + Fee: 101, + Path: fakePath, + TimeLockLength: 1000, + } + copy(fakePayment.PaymentPreimage[:], rev[:]) + return fakePayment +} + +func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + c := &PaymentCreationInfo{ + PaymentHash: preimg.Hash(), + Value: 1000, + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte(""), + } + + a := &PaymentAttemptInfo{ + PaymentID: 44, + SessionKey: priv, + Route: testRoute, + } + return c, a +} + +// randomBytes creates random []byte with length in range [minLen, maxLen) +func randomBytes(minLen, maxLen int) ([]byte, error) { + randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) + + if _, err := rand.Read(randBuf); err != nil { + return nil, fmt.Errorf("Internal error. "+ + "Cannot generate random string: %v", err) + } + + return randBuf, nil +} + +func makeRandomFakePayment() (*outgoingPayment, error) { + var err error + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + } + + fakeInvoice.Memo, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + fakeInvoice.Receipt, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + fakeInvoice.PaymentRequest, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + preImg, err := randomBytes(32, 33) + if err != nil { + return nil, err + } + copy(fakeInvoice.Terms.PaymentPreimage[:], preImg) + + fakeInvoice.Terms.Value = lnwire.MilliSatoshi(rand.Intn(10000)) + + fakePathLen := 1 + rand.Intn(5) + fakePath := make([][33]byte, fakePathLen) + for i := 0; i < fakePathLen; i++ { + b, err := randomBytes(33, 34) + if err != nil { + return nil, err + } + copy(fakePath[i][:], b) + } + + fakePayment := &outgoingPayment{ + Invoice: *fakeInvoice, + Fee: lnwire.MilliSatoshi(rand.Intn(1001)), + Path: fakePath, + TimeLockLength: uint32(rand.Intn(10000)), + } + copy(fakePayment.PaymentPreimage[:], fakeInvoice.Terms.PaymentPreimage[:]) + + return fakePayment, nil +} + +func TestSentPaymentSerialization(t *testing.T) { + t.Parallel() + + c, s := makeFakeInfo() + + var b bytes.Buffer + if err := serializePaymentCreationInfo(&b, c); err != nil { + t.Fatalf("unable to serialize creation info: %v", err) + } + + newCreationInfo, err := deserializePaymentCreationInfo(&b) + if err != nil { + t.Fatalf("unable to deserialize creation info: %v", err) + } + + if !reflect.DeepEqual(c, newCreationInfo) { + t.Fatalf("Payments do not match after "+ + "serialization/deserialization %v vs %v", + spew.Sdump(c), spew.Sdump(newCreationInfo), + ) + } + + b.Reset() + if err := serializePaymentAttemptInfo(&b, s); err != nil { + t.Fatalf("unable to serialize info: %v", err) + } + + newAttemptInfo, err := deserializePaymentAttemptInfo(&b) + if err != nil { + t.Fatalf("unable to deserialize info: %v", err) + } + + // First we verify all the records match up porperly, as they aren't + // able to be properly compared using reflect.DeepEqual. + err = assertRouteEqual(&s.Route, &newAttemptInfo.Route) + if err != nil { + t.Fatalf("Routes do not match after "+ + "serialization/deserialization: %v", err) + } + + // Clear routes to allow DeepEqual to compare the remaining fields. + newAttemptInfo.Route = route.Route{} + s.Route = route.Route{} + + if !reflect.DeepEqual(s, newAttemptInfo) { + s.SessionKey.Curve = nil + newAttemptInfo.SessionKey.Curve = nil + t.Fatalf("Payments do not match after "+ + "serialization/deserialization %v vs %v", + spew.Sdump(s), spew.Sdump(newAttemptInfo), + ) + } +} + +// assertRouteEquals compares to routes for equality and returns an error if +// they are not equal. +func assertRouteEqual(a, b *route.Route) error { + err := assertRouteHopRecordsEqual(a, b) + if err != nil { + return err + } + + // TLV records have already been compared and need to be cleared to + // properly compare the remaining fields using DeepEqual. + copyRouteNoHops := func(r *route.Route) *route.Route { + copy := *r + copy.Hops = make([]*route.Hop, len(r.Hops)) + for i, hop := range r.Hops { + hopCopy := *hop + hopCopy.TLVRecords = nil + copy.Hops[i] = &hopCopy + } + return © + } + + if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { + return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(b)) + } + + return nil +} + +func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { + if len(r1.Hops) != len(r2.Hops) { + return errors.New("route hop count mismatch") + } + + for i := 0; i < len(r1.Hops); i++ { + records1 := r1.Hops[i].TLVRecords + records2 := r2.Hops[i].TLVRecords + if len(records1) != len(records2) { + return fmt.Errorf("route record count for hop %v "+ + "mismatch", i) + } + + for j := 0; j < len(records1); j++ { + expectedRecord := records1[j] + newRecord := records2[j] + + err := assertHopRecordsEqual(expectedRecord, newRecord) + if err != nil { + return fmt.Errorf("route record mismatch: %v", err) + } + } + } + + return nil +} + +func assertHopRecordsEqual(h1, h2 tlv.Record) error { + if h1.Type() != h2.Type() { + return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(), + h2.Type()) + } + + var b bytes.Buffer + if err := h2.Encode(&b); err != nil { + return fmt.Errorf("unable to encode record: %v", err) + } + + if !bytes.Equal(b.Bytes(), tlvBytes) { + return fmt.Errorf("wrong raw record: expected %x, got %x", + tlvBytes, b.Bytes()) + } + + if h1.Size() != h2.Size() { + return fmt.Errorf("wrong size: expected %v, "+ + "got %v", h1.Size(), h2.Size()) + } + + return nil +} + +func TestRouteSerialization(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + if err := SerializeRoute(&b, testRoute); err != nil { + t.Fatal(err) + } + + r := bytes.NewReader(b.Bytes()) + route2, err := DeserializeRoute(r) + if err != nil { + t.Fatal(err) + } + + // First we verify all the records match up porperly, as they aren't + // able to be properly compared using reflect.DeepEqual. + err = assertRouteEqual(&testRoute, &route2) + if err != nil { + t.Fatalf("routes not equal: \n%v vs \n%v", + spew.Sdump(testRoute), spew.Sdump(route2)) + } +} diff --git a/channeldb/migration_01_to_11/reject_cache.go b/channeldb/migration_01_to_11/reject_cache.go new file mode 100644 index 00000000..c54d78a8 --- /dev/null +++ b/channeldb/migration_01_to_11/reject_cache.go @@ -0,0 +1,95 @@ +package migration_01_to_11 + +// rejectFlags is a compact representation of various metadata stored by the +// reject cache about a particular channel. +type rejectFlags uint8 + +const ( + // rejectFlagExists is a flag indicating whether the channel exists, + // i.e. the channel is open and has a recent channel update. If this + // flag is not set, the channel is either a zombie or unknown. + rejectFlagExists rejectFlags = 1 << iota + + // rejectFlagZombie is a flag indicating whether the channel is a + // zombie, i.e. the channel is open but has no recent channel updates. + rejectFlagZombie +) + +// packRejectFlags computes the rejectFlags corresponding to the passed boolean +// values indicating whether the edge exists or is a zombie. +func packRejectFlags(exists, isZombie bool) rejectFlags { + var flags rejectFlags + if exists { + flags |= rejectFlagExists + } + if isZombie { + flags |= rejectFlagZombie + } + + return flags +} + +// unpack returns the booleans packed into the rejectFlags. The first indicates +// if the edge exists in our graph, the second indicates if the edge is a +// zombie. +func (f rejectFlags) unpack() (bool, bool) { + return f&rejectFlagExists == rejectFlagExists, + f&rejectFlagZombie == rejectFlagZombie +} + +// rejectCacheEntry caches frequently accessed information about a channel, +// including the timestamps of its latest edge policies and whether or not the +// channel exists in the graph. +type rejectCacheEntry struct { + upd1Time int64 + upd2Time int64 + flags rejectFlags +} + +// rejectCache is an in-memory cache used to improve the performance of +// HasChannelEdge. It caches information about the whether or channel exists, as +// well as the most recent timestamps for each policy (if they exists). +type rejectCache struct { + n int + edges map[uint64]rejectCacheEntry +} + +// newRejectCache creates a new rejectCache with maximum capacity of n entries. +func newRejectCache(n int) *rejectCache { + return &rejectCache{ + n: n, + edges: make(map[uint64]rejectCacheEntry, n), + } +} + +// get returns the entry from the cache for chanid, if it exists. +func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { + entry, ok := c.edges[chanid] + return entry, ok +} + +// insert adds the entry to the reject cache. If an entry for chanid already +// exists, it will be replaced with the new entry. If the entry doesn't exists, +// it will be inserted to the cache, performing a random eviction if the cache +// is at capacity. +func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { + // If entry exists, replace it. + if _, ok := c.edges[chanid]; ok { + c.edges[chanid] = entry + return + } + + // Otherwise, evict an entry at random and insert. + if len(c.edges) == c.n { + for id := range c.edges { + delete(c.edges, id) + break + } + } + c.edges[chanid] = entry +} + +// remove deletes an entry for chanid from the cache, if it exists. +func (c *rejectCache) remove(chanid uint64) { + delete(c.edges, chanid) +} diff --git a/channeldb/migration_01_to_11/reject_cache_test.go b/channeldb/migration_01_to_11/reject_cache_test.go new file mode 100644 index 00000000..e15e0a10 --- /dev/null +++ b/channeldb/migration_01_to_11/reject_cache_test.go @@ -0,0 +1,107 @@ +package migration_01_to_11 + +import ( + "reflect" + "testing" +) + +// TestRejectCache checks the behavior of the rejectCache with respect to insertion, +// eviction, and removal of cache entries. +func TestRejectCache(t *testing.T) { + const cacheSize = 100 + + // Create a new reject cache with the configured max size. + c := newRejectCache(cacheSize) + + // As a sanity check, assert that querying the empty cache does not + // return an entry. + _, ok := c.get(0) + if ok { + t.Fatalf("reject cache should be empty") + } + + // Now, fill up the cache entirely. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, entryForInt(i)) + } + + // Assert that the cache has all of the entries just inserted, since no + // eviction should occur until we try to surpass the max size. + assertHasEntries(t, c, 0, cacheSize) + + // Now, insert a new element that causes the cache to evict an element. + c.insert(cacheSize, entryForInt(cacheSize)) + + // Assert that the cache has this last entry, as the cache should evict + // some prior element and not the newly inserted one. + assertHasEntries(t, c, cacheSize, cacheSize) + + // Iterate over all inserted elements and construct a set of the evicted + // elements. + evicted := make(map[uint64]struct{}) + for i := uint64(0); i < cacheSize+1; i++ { + _, ok := c.get(i) + if !ok { + evicted[i] = struct{}{} + } + } + + // Assert that exactly one element has been evicted. + numEvicted := len(evicted) + if numEvicted != 1 { + t.Fatalf("expected one evicted entry, got: %d", numEvicted) + } + + // Remove the highest item which initially caused the eviction and + // reinsert the element that was evicted prior. + c.remove(cacheSize) + for i := range evicted { + c.insert(i, entryForInt(i)) + } + + // Since the removal created an extra slot, the last insertion should + // not have caused an eviction and the entries for all channels in the + // original set that filled the cache should be present. + assertHasEntries(t, c, 0, cacheSize) + + // Finally, reinsert the existing set back into the cache and test that + // the cache still has all the entries. If the randomized eviction were + // happening on inserts for existing cache items, we expect this to fail + // with high probability. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, entryForInt(i)) + } + assertHasEntries(t, c, 0, cacheSize) + +} + +// assertHasEntries queries the reject cache for all channels in the range [start, +// end), asserting that they exist and their value matches the entry produced by +// entryForInt. +func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) { + t.Helper() + + for i := start; i < end; i++ { + entry, ok := c.get(i) + if !ok { + t.Fatalf("reject cache should contain chan %d", i) + } + + expEntry := entryForInt(i) + if !reflect.DeepEqual(entry, expEntry) { + t.Fatalf("entry mismatch, want: %v, got: %v", + expEntry, entry) + } + } +} + +// entryForInt generates a unique rejectCacheEntry given an integer. +func entryForInt(i uint64) rejectCacheEntry { + exists := i%2 == 0 + isZombie := i%3 == 0 + return rejectCacheEntry{ + upd1Time: int64(2 * i), + upd2Time: int64(2*i + 1), + flags: packRejectFlags(exists, isZombie), + } +} diff --git a/channeldb/migration_01_to_11/waitingproof.go b/channeldb/migration_01_to_11/waitingproof.go new file mode 100644 index 00000000..64729116 --- /dev/null +++ b/channeldb/migration_01_to_11/waitingproof.go @@ -0,0 +1,251 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "sync" + + "io" + + "bytes" + + "github.com/coreos/bbolt" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // waitingProofsBucketKey byte string name of the waiting proofs store. + waitingProofsBucketKey = []byte("waitingproofs") + + // ErrWaitingProofNotFound is returned if waiting proofs haven't been + // found by db. + ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " + + "found") + + // ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been + // found by db. + ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " + + "key already exist") +) + +// WaitingProofStore is the bold db map-like storage for half announcement +// signatures. The one responsibility of this storage is to be able to +// retrieve waiting proofs after client restart. +type WaitingProofStore struct { + // cache is used in order to reduce the number of redundant get + // calls, when object isn't stored in it. + cache map[WaitingProofKey]struct{} + db *DB + mu sync.RWMutex +} + +// NewWaitingProofStore creates new instance of proofs storage. +func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { + s := &WaitingProofStore{ + db: db, + cache: make(map[WaitingProofKey]struct{}), + } + + if err := s.ForAll(func(proof *WaitingProof) error { + s.cache[proof.Key()] = struct{}{} + return nil + }); err != nil && err != ErrWaitingProofNotFound { + return nil, err + } + + return s, nil +} + +// Add adds new waiting proof in the storage. +func (s *WaitingProofStore) Add(proof *WaitingProof) error { + s.mu.Lock() + defer s.mu.Unlock() + + err := s.db.Update(func(tx *bbolt.Tx) error { + var err error + var b bytes.Buffer + + // Get or create the bucket. + bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey) + if err != nil { + return err + } + + // Encode the objects and place it in the bucket. + if err := proof.Encode(&b); err != nil { + return err + } + + key := proof.Key() + + return bucket.Put(key[:], b.Bytes()) + }) + if err != nil { + return err + } + + // Knowing that the write succeeded, we can now update the in-memory + // cache with the proof's key. + s.cache[proof.Key()] = struct{}{} + + return nil +} + +// Remove removes the proof from storage by its key. +func (s *WaitingProofStore) Remove(key WaitingProofKey) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.cache[key]; !ok { + return ErrWaitingProofNotFound + } + + err := s.db.Update(func(tx *bbolt.Tx) error { + // Get or create the top bucket. + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + return bucket.Delete(key[:]) + }) + if err != nil { + return err + } + + // Since the proof was successfully deleted from the store, we can now + // remove it from the in-memory cache. + delete(s.cache, key) + + return nil +} + +// ForAll iterates thought all waiting proofs and passing the waiting proof +// in the given callback. +func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { + return s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + // Iterate over objects buckets. + return bucket.ForEach(func(k, v []byte) error { + // Skip buckets fields. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + proof := &WaitingProof{} + if err := proof.Decode(r); err != nil { + return err + } + + return cb(proof) + }) + }) +} + +// Get returns the object which corresponds to the given index. +func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { + proof := &WaitingProof{} + + s.mu.RLock() + defer s.mu.RUnlock() + + if _, ok := s.cache[key]; !ok { + return nil, ErrWaitingProofNotFound + } + + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + // Iterate over objects buckets. + v := bucket.Get(key[:]) + if v == nil { + return ErrWaitingProofNotFound + } + + r := bytes.NewReader(v) + return proof.Decode(r) + }) + + return proof, err +} + +// WaitingProofKey is the proof key which uniquely identifies the waiting +// proof object. The goal of this key is distinguish the local and remote +// proof for the same channel id. +type WaitingProofKey [9]byte + +// WaitingProof is the storable object, which encapsulate the half proof and +// the information about from which side this proof came. This structure is +// needed to make channel proof exchange persistent, so that after client +// restart we may receive remote/local half proof and process it. +type WaitingProof struct { + *lnwire.AnnounceSignatures + isRemote bool +} + +// NewWaitingProof constructs a new waiting prof instance. +func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof { + return &WaitingProof{ + AnnounceSignatures: proof, + isRemote: isRemote, + } +} + +// OppositeKey returns the key which uniquely identifies opposite waiting proof. +func (p *WaitingProof) OppositeKey() WaitingProofKey { + var key [9]byte + binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) + + if !p.isRemote { + key[8] = 1 + } + return key +} + +// Key returns the key which uniquely identifies waiting proof. +func (p *WaitingProof) Key() WaitingProofKey { + var key [9]byte + binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) + + if p.isRemote { + key[8] = 1 + } + return key +} + +// Encode writes the internal representation of waiting proof in byte stream. +func (p *WaitingProof) Encode(w io.Writer) error { + if err := binary.Write(w, byteOrder, p.isRemote); err != nil { + return err + } + + if err := p.AnnounceSignatures.Encode(w, 0); err != nil { + return err + } + + return nil +} + +// Decode reads the data from the byte stream and initializes the +// waiting proof object with it. +func (p *WaitingProof) Decode(r io.Reader) error { + if err := binary.Read(r, byteOrder, &p.isRemote); err != nil { + return err + } + + msg := &lnwire.AnnounceSignatures{} + if err := msg.Decode(r, 0); err != nil { + return err + } + + (*p).AnnounceSignatures = msg + return nil +} diff --git a/channeldb/migration_01_to_11/waitingproof_test.go b/channeldb/migration_01_to_11/waitingproof_test.go new file mode 100644 index 00000000..968f1157 --- /dev/null +++ b/channeldb/migration_01_to_11/waitingproof_test.go @@ -0,0 +1,59 @@ +package migration_01_to_11 + +import ( + "testing" + + "reflect" + + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +// TestWaitingProofStore tests add/get/remove functions of the waiting proof +// storage. +func TestWaitingProofStore(t *testing.T) { + t.Parallel() + + db, cleanup, err := makeTestDB() + if err != nil { + t.Fatalf("failed to make test database: %s", err) + } + defer cleanup() + + proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ + NodeSignature: wireSig, + BitcoinSignature: wireSig, + }) + + store, err := NewWaitingProofStore(db) + if err != nil { + t.Fatalf("unable to create the waiting proofs storage: %v", + err) + } + + if err := store.Add(proof1); err != nil { + t.Fatalf("unable add proof to storage: %v", err) + } + + proof2, err := store.Get(proof1.Key()) + if err != nil { + t.Fatalf("unable retrieve proof from storage: %v", err) + } + if !reflect.DeepEqual(proof1, proof2) { + t.Fatal("wrong proof retrieved") + } + + if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { + t.Fatalf("proof shouldn't be found: %v", err) + } + + if err := store.Remove(proof1.Key()); err != nil { + t.Fatalf("unable remove proof from storage: %v", err) + } + + if err := store.ForAll(func(proof *WaitingProof) error { + return errors.New("storage should be empty") + }); err != nil && err != ErrWaitingProofNotFound { + t.Fatal(err) + } +} diff --git a/channeldb/migration_01_to_11/witness_cache.go b/channeldb/migration_01_to_11/witness_cache.go new file mode 100644 index 00000000..69de1054 --- /dev/null +++ b/channeldb/migration_01_to_11/witness_cache.go @@ -0,0 +1,229 @@ +package migration_01_to_11 + +import ( + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + // ErrNoWitnesses is an error that's returned when no new witnesses have + // been added to the WitnessCache. + ErrNoWitnesses = fmt.Errorf("no witnesses") + + // ErrUnknownWitnessType is returned if a caller attempts to + ErrUnknownWitnessType = fmt.Errorf("unknown witness type") +) + +// WitnessType is enum that denotes what "type" of witness is being +// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce +// any structure on added witnesses, we use this type to partition the +// witnesses on disk, and also to know how to map a witness to its look up key. +type WitnessType uint8 + +var ( + // Sha256HashWitness is a witness that is simply the pre image to a + // hash image. In order to map to its key, we'll use sha256. + Sha256HashWitness WitnessType = 1 +) + +// toDBKey is a helper method that maps a witness type to the key that we'll +// use to store it within the database. +func (w WitnessType) toDBKey() ([]byte, error) { + switch w { + + case Sha256HashWitness: + return []byte{byte(w)}, nil + + default: + return nil, ErrUnknownWitnessType + } +} + +var ( + // witnessBucketKey is the name of the bucket that we use to store all + // witnesses encountered. Within this bucket, we'll create a sub-bucket for + // each witness type. + witnessBucketKey = []byte("byte") +) + +// WitnessCache is a persistent cache of all witnesses we've encountered on the +// network. In the case of multi-hop, multi-step contracts, a cache of all +// witnesses can be useful in the case of partial contract resolution. If +// negotiations break down, we may be forced to locate the witness for a +// portion of the contract on-chain. In this case, we'll then add that witness +// to the cache so the incoming contract can fully resolve witness. +// Additionally, as one MUST always use a unique witness on the network, we may +// use this cache to detect duplicate witnesses. +// +// TODO(roasbeef): need expiry policy? +// * encrypt? +type WitnessCache struct { + db *DB +} + +// NewWitnessCache returns a new instance of the witness cache. +func (d *DB) NewWitnessCache() *WitnessCache { + return &WitnessCache{ + db: d, + } +} + +// witnessEntry is a key-value struct that holds each key -> witness pair, used +// when inserting records into the cache. +type witnessEntry struct { + key []byte + witness []byte +} + +// AddSha256Witnesses adds a batch of new sha256 preimages into the witness +// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the +// preimages' witness type. +func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error { + // Optimistically compute the preimages' hashes before attempting to + // start the db transaction. + entries := make([]witnessEntry, 0, len(preimages)) + for i := range preimages { + hash := preimages[i].Hash() + entries = append(entries, witnessEntry{ + key: hash[:], + witness: preimages[i][:], + }) + } + + return w.addWitnessEntries(Sha256HashWitness, entries) +} + +// addWitnessEntries inserts the witnessEntry key-value pairs into the cache, +// using the appropriate witness type to segment the namespace of possible +// witness types. +func (w *WitnessCache) addWitnessEntries(wType WitnessType, + entries []witnessEntry) error { + + // Exit early if there are no witnesses to add. + if len(entries) == 0 { + return nil + } + + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( + witnessTypeBucketKey, + ) + if err != nil { + return err + } + + for _, entry := range entries { + err = witnessTypeBucket.Put(entry.key, entry.witness) + if err != nil { + return err + } + } + + return nil + }) +} + +// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If +// the witness isn't found, ErrNoWitnesses will be returned. +func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) { + witness, err := w.lookupWitness(Sha256HashWitness, hash[:]) + if err != nil { + return lntypes.Preimage{}, err + } + + return lntypes.MakePreimage(witness) +} + +// lookupWitness attempts to lookup a witness according to its type and also +// its witness key. In the case that the witness isn't found, ErrNoWitnesses +// will be returned. +func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) { + var witness []byte + err := w.db.View(func(tx *bbolt.Tx) error { + witnessBucket := tx.Bucket(witnessBucketKey) + if witnessBucket == nil { + return ErrNoWitnesses + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey) + if witnessTypeBucket == nil { + return ErrNoWitnesses + } + + dbWitness := witnessTypeBucket.Get(witnessKey) + if dbWitness == nil { + return ErrNoWitnesses + } + + witness = make([]byte, len(dbWitness)) + copy(witness[:], dbWitness) + + return nil + }) + if err != nil { + return nil, err + } + + return witness, nil +} + +// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash. +func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error { + return w.deleteWitness(Sha256HashWitness, hash[:]) +} + +// deleteWitness attempts to delete a particular witness from the database. +func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error { + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( + witnessTypeBucketKey, + ) + if err != nil { + return err + } + + return witnessTypeBucket.Delete(witnessKey) + }) +} + +// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After +// this function return with a non-nil error, +func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error { + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + + return witnessBucket.DeleteBucket(witnessTypeBucketKey) + }) +} diff --git a/channeldb/migration_01_to_11/witness_cache_test.go b/channeldb/migration_01_to_11/witness_cache_test.go new file mode 100644 index 00000000..92836abe --- /dev/null +++ b/channeldb/migration_01_to_11/witness_cache_test.go @@ -0,0 +1,238 @@ +package migration_01_to_11 + +import ( + "crypto/sha256" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new +// sha256 preimages to the witness cache. +func TestWitnessCacheSha256Retrieval(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll be attempting to add then lookup two simple sha256 preimages + // within this test. + preimage1 := lntypes.Preimage(rev) + preimage2 := lntypes.Preimage(key) + + preimages := []lntypes.Preimage{preimage1, preimage2} + hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()} + + // First, we'll attempt to add the preimages to the database. + err = wCache.AddSha256Witnesses(preimages...) + if err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + // With the preimages stored, we'll now attempt to look them up. + for i, hash := range hashes { + preimage := preimages[i] + + // We should get back the *exact* same preimage as we originally + // stored. + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to look up witness: %v", err) + } + + if preimage != dbPreimage { + t.Fatalf("witnesses don't match: expected %x, got %x", + preimage[:], dbPreimage[:]) + } + } +} + +// TestWitnessCacheSha256Deletion tests that we're able to delete a single +// sha256 preimage, and also a class of witnesses from the cache. +func TestWitnessCacheSha256Deletion(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll start by adding two preimages to the cache. + preimage1 := lntypes.Preimage(key) + hash1 := preimage1.Hash() + + preimage2 := lntypes.Preimage(rev) + hash2 := preimage2.Hash() + + if err := wCache.AddSha256Witnesses(preimage1); err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + if err := wCache.AddSha256Witnesses(preimage2); err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + // We'll now delete the first preimage. If we attempt to look it up, we + // should get ErrNoWitnesses. + err = wCache.DeleteSha256Witness(hash1) + if err != nil { + t.Fatalf("unable to delete witness: %v", err) + } + _, err = wCache.LookupSha256Witness(hash1) + if err != ErrNoWitnesses { + t.Fatalf("expected ErrNoWitnesses instead got: %v", err) + } + + // Next, we'll attempt to delete the entire witness class itself. When + // we try to lookup the second preimage, we should again get + // ErrNoWitnesses. + if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil { + t.Fatalf("unable to delete witness class: %v", err) + } + _, err = wCache.LookupSha256Witness(hash2) + if err != ErrNoWitnesses { + t.Fatalf("expected ErrNoWitnesses instead got: %v", err) + } +} + +// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to +// query/add/delete an unknown witness. +func TestWitnessCacheUnknownWitness(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll attempt to add a new, undefined witness type to the database. + // We should get an error. + err = wCache.legacyAddWitnesses(234, key[:]) + if err != ErrUnknownWitnessType { + t.Fatalf("expected ErrUnknownWitnessType, got %v", err) + } +} + +// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves +// identically to the insertion via the generalized interface. +func TestAddSha256Witnesses(t *testing.T) { + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll start by adding a witnesses to the cache using the generic + // AddWitnesses method. + witness1 := rev[:] + preimage1 := lntypes.Preimage(rev) + hash1 := preimage1.Hash() + + witness2 := key[:] + preimage2 := lntypes.Preimage(key) + hash2 := preimage2.Hash() + + var ( + witnesses = [][]byte{witness1, witness2} + preimages = []lntypes.Preimage{preimage1, preimage2} + hashes = []lntypes.Hash{hash1, hash2} + ) + + err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...) + if err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + for i, hash := range hashes { + preimage := preimages[i] + + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to lookup witness: %v", err) + } + + // Assert that the retrieved witness matches the original. + if dbPreimage != preimage { + t.Fatalf("retrieved witness mismatch, want: %x, "+ + "got: %x", preimage, dbPreimage) + } + + // We'll now delete the witness, as we'll be reinserting it + // using the specialized AddSha256Witnesses method. + err = wCache.DeleteSha256Witness(hash) + if err != nil { + t.Fatalf("unable to delete witness: %v", err) + } + } + + // Now, add the same witnesses using the type-safe interface for + // lntypes.Preimages.. + err = wCache.AddSha256Witnesses(preimages...) + if err != nil { + t.Fatalf("unable to add sha256 preimage: %v", err) + } + + // Finally, iterate over the keys and assert that the returned witnesses + // match the original witnesses. This asserts that the specialized + // insertion method behaves identically to the generalized interface. + for i, hash := range hashes { + preimage := preimages[i] + + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to lookup witness: %v", err) + } + + // Assert that the retrieved witness matches the original. + if dbPreimage != preimage { + t.Fatalf("retrieved witness mismatch, want: %x, "+ + "got: %x", preimage, dbPreimage) + } + } +} + +// legacyAddWitnesses adds a batch of new witnesses of wType to the witness +// cache. The type of the witness will be used to map each witness to the key +// that will be used to look it up. All witnesses should be of the same +// WitnessType. +// +// NOTE: Previously this method exposed a generic interface for adding +// witnesses, which has since been deprecated in favor of a strongly typed +// interface for each witness class. We keep this method around to assert the +// correctness of specialized witness adding methods. +func (w *WitnessCache) legacyAddWitnesses(wType WitnessType, + witnesses ...[]byte) error { + + // Optimistically compute the witness keys before attempting to start + // the db transaction. + entries := make([]witnessEntry, 0, len(witnesses)) + for _, witness := range witnesses { + // Map each witness to its key by applying the appropriate + // transformation for the given witness type. + switch wType { + case Sha256HashWitness: + key := sha256.Sum256(witness) + entries = append(entries, witnessEntry{ + key: key[:], + witness: witness, + }) + default: + return ErrUnknownWitnessType + } + } + + return w.addWitnessEntries(wType, entries) +}