From 6e463c1634061d595953f20813860207e5d485ce Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 24 Oct 2019 12:04:26 +0200 Subject: [PATCH 1/6] channeldb: make copy for migrations This commit is a direct copy of the complete channeldb package. It only changes the package declaration at the top of every file. We make this full copy so that review can be focused on the actual changes made. Otherwise changes may drown in all the file moves. Linting for the new package is disabled, as it contains lots of pre-existing issues. --- .golangci.yml | 3 + channeldb/migration_01_to_11/README.md | 24 + channeldb/migration_01_to_11/addr.go | 221 + channeldb/migration_01_to_11/addr_test.go | 149 + channeldb/migration_01_to_11/channel.go | 2817 ++++++++++++ channeldb/migration_01_to_11/channel_cache.go | 50 + .../migration_01_to_11/channel_cache_test.go | 105 + channeldb/migration_01_to_11/channel_test.go | 1041 +++++ channeldb/migration_01_to_11/codec.go | 454 ++ channeldb/migration_01_to_11/db.go | 1185 +++++ channeldb/migration_01_to_11/db_test.go | 471 ++ channeldb/migration_01_to_11/doc.go | 1 + channeldb/migration_01_to_11/error.go | 128 + channeldb/migration_01_to_11/fees.go | 1 + .../migration_01_to_11/forwarding_log.go | 274 ++ .../migration_01_to_11/forwarding_log_test.go | 265 ++ .../migration_01_to_11/forwarding_package.go | 928 ++++ .../forwarding_package_test.go | 815 ++++ channeldb/migration_01_to_11/graph.go | 4060 +++++++++++++++++ channeldb/migration_01_to_11/graph_test.go | 3197 +++++++++++++ channeldb/migration_01_to_11/invoice_test.go | 694 +++ channeldb/migration_01_to_11/invoices.go | 1320 ++++++ .../legacy_serialization.go | 55 + channeldb/migration_01_to_11/log.go | 28 + channeldb/migration_01_to_11/meta.go | 78 + channeldb/migration_01_to_11/meta_test.go | 442 ++ .../migration_09_legacy_serialization.go | 497 ++ .../migration_10_route_tlv_records.go | 236 + .../migration_11_invoices.go | 230 + .../migration_11_invoices_test.go | 193 + channeldb/migration_01_to_11/migrations.go | 939 ++++ .../migration_01_to_11/migrations_test.go | 952 ++++ channeldb/migration_01_to_11/nodes.go | 316 ++ channeldb/migration_01_to_11/nodes_test.go | 140 + channeldb/migration_01_to_11/options.go | 62 + .../migration_01_to_11/payment_control.go | 497 ++ .../payment_control_test.go | 550 +++ channeldb/migration_01_to_11/payments.go | 669 +++ channeldb/migration_01_to_11/payments_test.go | 324 ++ channeldb/migration_01_to_11/reject_cache.go | 95 + .../migration_01_to_11/reject_cache_test.go | 107 + channeldb/migration_01_to_11/waitingproof.go | 251 + .../migration_01_to_11/waitingproof_test.go | 59 + channeldb/migration_01_to_11/witness_cache.go | 229 + .../migration_01_to_11/witness_cache_test.go | 238 + 45 files changed, 25390 insertions(+) create mode 100644 channeldb/migration_01_to_11/README.md create mode 100644 channeldb/migration_01_to_11/addr.go create mode 100644 channeldb/migration_01_to_11/addr_test.go create mode 100644 channeldb/migration_01_to_11/channel.go create mode 100644 channeldb/migration_01_to_11/channel_cache.go create mode 100644 channeldb/migration_01_to_11/channel_cache_test.go create mode 100644 channeldb/migration_01_to_11/channel_test.go create mode 100644 channeldb/migration_01_to_11/codec.go create mode 100644 channeldb/migration_01_to_11/db.go create mode 100644 channeldb/migration_01_to_11/db_test.go create mode 100644 channeldb/migration_01_to_11/doc.go create mode 100644 channeldb/migration_01_to_11/error.go create mode 100644 channeldb/migration_01_to_11/fees.go create mode 100644 channeldb/migration_01_to_11/forwarding_log.go create mode 100644 channeldb/migration_01_to_11/forwarding_log_test.go create mode 100644 channeldb/migration_01_to_11/forwarding_package.go create mode 100644 channeldb/migration_01_to_11/forwarding_package_test.go create mode 100644 channeldb/migration_01_to_11/graph.go create mode 100644 channeldb/migration_01_to_11/graph_test.go create mode 100644 channeldb/migration_01_to_11/invoice_test.go create mode 100644 channeldb/migration_01_to_11/invoices.go create mode 100644 channeldb/migration_01_to_11/legacy_serialization.go create mode 100644 channeldb/migration_01_to_11/log.go create mode 100644 channeldb/migration_01_to_11/meta.go create mode 100644 channeldb/migration_01_to_11/meta_test.go create mode 100644 channeldb/migration_01_to_11/migration_09_legacy_serialization.go create mode 100644 channeldb/migration_01_to_11/migration_10_route_tlv_records.go create mode 100644 channeldb/migration_01_to_11/migration_11_invoices.go create mode 100644 channeldb/migration_01_to_11/migration_11_invoices_test.go create mode 100644 channeldb/migration_01_to_11/migrations.go create mode 100644 channeldb/migration_01_to_11/migrations_test.go create mode 100644 channeldb/migration_01_to_11/nodes.go create mode 100644 channeldb/migration_01_to_11/nodes_test.go create mode 100644 channeldb/migration_01_to_11/options.go create mode 100644 channeldb/migration_01_to_11/payment_control.go create mode 100644 channeldb/migration_01_to_11/payment_control_test.go create mode 100644 channeldb/migration_01_to_11/payments.go create mode 100644 channeldb/migration_01_to_11/payments_test.go create mode 100644 channeldb/migration_01_to_11/reject_cache.go create mode 100644 channeldb/migration_01_to_11/reject_cache_test.go create mode 100644 channeldb/migration_01_to_11/waitingproof.go create mode 100644 channeldb/migration_01_to_11/waitingproof_test.go create mode 100644 channeldb/migration_01_to_11/witness_cache.go create mode 100644 channeldb/migration_01_to_11/witness_cache_test.go diff --git a/.golangci.yml b/.golangci.yml index ca3e6881..23b24810 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -10,6 +10,9 @@ run: skip-files: - "mobile\\/.*generated\\.go" + skip-dirs: + - channeldb/migration_01_to_11 + build-tags: - autopilotrpc - chainrpc diff --git a/channeldb/migration_01_to_11/README.md b/channeldb/migration_01_to_11/README.md new file mode 100644 index 00000000..7e3a81ef --- /dev/null +++ b/channeldb/migration_01_to_11/README.md @@ -0,0 +1,24 @@ +channeldb +========== + +[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE) +[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb) + +The channeldb implements the persistent storage engine for `lnd` and +generically a data storage layer for the required state within the Lightning +Network. The backing storage engine is +[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store +based off of LMDB. + +The package implements an object-oriented storage model with queries and +mutations flowing through a particular object instance rather than the database +itself. The storage implemented by the objects includes: open channels, past +commitment revocation states, the channel graph which includes authenticated +node and channel announcements, outgoing payments, and invoices + +## Installation and Updating + +```bash +$ go get -u github.com/lightningnetwork/lnd/channeldb +``` diff --git a/channeldb/migration_01_to_11/addr.go b/channeldb/migration_01_to_11/addr.go new file mode 100644 index 00000000..2e7def07 --- /dev/null +++ b/channeldb/migration_01_to_11/addr.go @@ -0,0 +1,221 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + + "github.com/lightningnetwork/lnd/tor" +) + +// addressType specifies the network protocol and version that should be used +// when connecting to a node at a particular address. +type addressType uint8 + +const ( + // tcp4Addr denotes an IPv4 TCP address. + tcp4Addr addressType = 0 + + // tcp6Addr denotes an IPv6 TCP address. + tcp6Addr addressType = 1 + + // v2OnionAddr denotes a version 2 Tor onion service address. + v2OnionAddr addressType = 2 + + // v3OnionAddr denotes a version 3 Tor (prop224) onion service address. + v3OnionAddr addressType = 3 +) + +// encodeTCPAddr serializes a TCP address into its compact raw bytes +// representation. +func encodeTCPAddr(w io.Writer, addr *net.TCPAddr) error { + var ( + addrType byte + ip []byte + ) + + if addr.IP.To4() != nil { + addrType = byte(tcp4Addr) + ip = addr.IP.To4() + } else { + addrType = byte(tcp6Addr) + ip = addr.IP.To16() + } + + if ip == nil { + return fmt.Errorf("unable to encode IP %v", addr.IP) + } + + if _, err := w.Write([]byte{addrType}); err != nil { + return err + } + + if _, err := w.Write(ip); err != nil { + return err + } + + var port [2]byte + byteOrder.PutUint16(port[:], uint16(addr.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + return nil +} + +// encodeOnionAddr serializes an onion address into its compact raw bytes +// representation. +func encodeOnionAddr(w io.Writer, addr *tor.OnionAddr) error { + var suffixIndex int + hostLen := len(addr.OnionService) + switch hostLen { + case tor.V2Len: + if _, err := w.Write([]byte{byte(v2OnionAddr)}); err != nil { + return err + } + suffixIndex = tor.V2Len - tor.OnionSuffixLen + case tor.V3Len: + if _, err := w.Write([]byte{byte(v3OnionAddr)}); err != nil { + return err + } + suffixIndex = tor.V3Len - tor.OnionSuffixLen + default: + return errors.New("unknown onion service length") + } + + suffix := addr.OnionService[suffixIndex:] + if suffix != tor.OnionSuffix { + return fmt.Errorf("invalid suffix \"%v\"", suffix) + } + + host, err := tor.Base32Encoding.DecodeString( + addr.OnionService[:suffixIndex], + ) + if err != nil { + return err + } + + // Sanity check the decoded length. + switch { + case hostLen == tor.V2Len && len(host) != tor.V2DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + + case hostLen == tor.V3Len && len(host) != tor.V3DecodedLen: + return fmt.Errorf("onion service %v decoded to invalid host %x", + addr.OnionService, host) + } + + if _, err := w.Write(host); err != nil { + return err + } + + var port [2]byte + byteOrder.PutUint16(port[:], uint16(addr.Port)) + if _, err := w.Write(port[:]); err != nil { + return err + } + + return nil +} + +// deserializeAddr reads the serialized raw representation of an address and +// deserializes it into the actual address. This allows us to avoid address +// resolution within the channeldb package. +func deserializeAddr(r io.Reader) (net.Addr, error) { + var addrType [1]byte + if _, err := r.Read(addrType[:]); err != nil { + return nil, err + } + + var address net.Addr + switch addressType(addrType[0]) { + case tcp4Addr: + var ip [4]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + case tcp6Addr: + var ip [16]byte + if _, err := r.Read(ip[:]); err != nil { + return nil, err + } + + var port [2]byte + if _, err := r.Read(port[:]); err != nil { + return nil, err + } + + address = &net.TCPAddr{ + IP: net.IP(ip[:]), + Port: int(binary.BigEndian.Uint16(port[:])), + } + case v2OnionAddr: + var h [tor.V2DecodedLen]byte + if _, err := r.Read(h[:]); err != nil { + return nil, err + } + + var p [2]byte + if _, err := r.Read(p[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + case v3OnionAddr: + var h [tor.V3DecodedLen]byte + if _, err := r.Read(h[:]); err != nil { + return nil, err + } + + var p [2]byte + if _, err := r.Read(p[:]); err != nil { + return nil, err + } + + onionService := tor.Base32Encoding.EncodeToString(h[:]) + onionService += tor.OnionSuffix + port := int(binary.BigEndian.Uint16(p[:])) + + address = &tor.OnionAddr{ + OnionService: onionService, + Port: port, + } + default: + return nil, ErrUnknownAddressType + } + + return address, nil +} + +// serializeAddr serializes an address into its raw bytes representation so that +// it can be deserialized without requiring address resolution. +func serializeAddr(w io.Writer, address net.Addr) error { + switch addr := address.(type) { + case *net.TCPAddr: + return encodeTCPAddr(w, addr) + case *tor.OnionAddr: + return encodeOnionAddr(w, addr) + default: + return ErrUnknownAddressType + } +} diff --git a/channeldb/migration_01_to_11/addr_test.go b/channeldb/migration_01_to_11/addr_test.go new file mode 100644 index 00000000..8cdf99c3 --- /dev/null +++ b/channeldb/migration_01_to_11/addr_test.go @@ -0,0 +1,149 @@ +package migration_01_to_11 + +import ( + "bytes" + "net" + "strings" + "testing" + + "github.com/lightningnetwork/lnd/tor" +) + +type unknownAddrType struct{} + +func (t unknownAddrType) Network() string { return "unknown" } +func (t unknownAddrType) String() string { return "unknown" } + +var testIP4 = net.ParseIP("192.168.1.1") +var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") + +var addrTests = []struct { + expAddr net.Addr + serErr string +}{ + // Valid addresses. + { + expAddr: &net.TCPAddr{ + IP: testIP4, + Port: 12345, + }, + }, + { + expAddr: &net.TCPAddr{ + IP: testIP6, + Port: 65535, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "3g2upl4pq6kufc4m.onion", + Port: 9735, + }, + }, + { + expAddr: &tor.OnionAddr{ + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", + Port: 80, + }, + }, + + // Invalid addresses. + { + expAddr: unknownAddrType{}, + serErr: ErrUnknownAddressType.Error(), + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv4 address. + IP: testIP4[:len(testIP4)-1], + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte of IPv4 address. + IP: append(testIP4, 0xff), + Port: 12345, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Remove last byte of IPv6 address. + IP: testIP6[:len(testIP6)-1], + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &net.TCPAddr{ + // Add an extra byte to the IPv6 address. + IP: append(testIP6, 0xff), + Port: 65535, + }, + serErr: "unable to encode", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid suffix. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion", + Port: 80, + }, + serErr: "invalid suffix", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid length. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion", + Port: 80, + }, + serErr: "unknown onion service length", + }, + { + expAddr: &tor.OnionAddr{ + // Invalid encoding. + OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion", + Port: 80, + }, + serErr: "illegal base32", + }, +} + +// TestAddrSerialization tests that the serialization method used by channeldb +// for net.Addr's works as intended. +func TestAddrSerialization(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + for _, test := range addrTests { + err := serializeAddr(&b, test.expAddr) + switch { + case err == nil && test.serErr != "": + t.Fatalf("expected serialization err for addr %v", + test.expAddr) + + case err != nil && test.serErr == "": + t.Fatalf("unexpected serialization err for addr %v: %v", + test.expAddr, err) + + case err != nil && !strings.Contains(err.Error(), test.serErr): + t.Fatalf("unexpected serialization err for addr %v, "+ + "want: %v, got %v", test.expAddr, test.serErr, + err) + + case err != nil: + continue + } + + addr, err := deserializeAddr(&b) + if err != nil { + t.Fatalf("unable to deserialize address: %v", err) + } + + if addr.String() != test.expAddr.String() { + t.Fatalf("expected address %v after serialization, "+ + "got %v", addr, test.expAddr) + } + } +} diff --git a/channeldb/migration_01_to_11/channel.go b/channeldb/migration_01_to_11/channel.go new file mode 100644 index 00000000..23d66852 --- /dev/null +++ b/channeldb/migration_01_to_11/channel.go @@ -0,0 +1,2817 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + // closedChannelBucket stores summarization information concerning + // previously open, but now closed channels. + closedChannelBucket = []byte("closed-chan-bucket") + + // openChanBucket stores all the currently open channels. This bucket + // has a second, nested bucket which is keyed by a node's ID. Within + // that node ID bucket, all attributes required to track, update, and + // close a channel are stored. + // + // openChan -> nodeID -> chanPoint + // + // TODO(roasbeef): flesh out comment + openChannelBucket = []byte("open-chan-bucket") + + // chanInfoKey can be accessed within the bucket for a channel + // (identified by its chanPoint). This key stores all the static + // information for a channel which is decided at the end of the + // funding flow. + chanInfoKey = []byte("chan-info-key") + + // chanCommitmentKey can be accessed within the sub-bucket for a + // particular channel. This key stores the up to date commitment state + // for a particular channel party. Appending a 0 to the end of this key + // indicates it's the commitment for the local party, and appending a 1 + // to the end of this key indicates it's the commitment for the remote + // party. + chanCommitmentKey = []byte("chan-commitment-key") + + // revocationStateKey stores their current revocation hash, our + // preimage producer and their preimage store. + revocationStateKey = []byte("revocation-state-key") + + // dataLossCommitPointKey stores the commitment point received from the + // remote peer during a channel sync in case we have lost channel state. + dataLossCommitPointKey = []byte("data-loss-commit-point-key") + + // closingTxKey points to a the closing tx that we broadcasted when + // moving the channel to state CommitBroadcasted. + closingTxKey = []byte("closing-tx-key") + + // commitDiffKey stores the current pending commitment state we've + // extended to the remote party (if any). Each time we propose a new + // state, we store the information necessary to reconstruct this state + // from the prior commitment. This allows us to resync the remote party + // to their expected state in the case of message loss. + // + // TODO(roasbeef): rename to commit chain? + commitDiffKey = []byte("commit-diff-key") + + // revocationLogBucket is dedicated for storing the necessary delta + // state between channel updates required to re-construct a past state + // in order to punish a counterparty attempting a non-cooperative + // channel closure. This key should be accessed from within the + // sub-bucket of a target channel, identified by its channel point. + revocationLogBucket = []byte("revocation-log-key") +) + +var ( + // ErrNoCommitmentsFound is returned when a channel has not set + // commitment states. + ErrNoCommitmentsFound = fmt.Errorf("no commitments found") + + // ErrNoChanInfoFound is returned when a particular channel does not + // have any channels state. + ErrNoChanInfoFound = fmt.Errorf("no chan info found") + + // ErrNoRevocationsFound is returned when revocation state for a + // particular channel cannot be found. + ErrNoRevocationsFound = fmt.Errorf("no revocations found") + + // ErrNoPendingCommit is returned when there is not a pending + // commitment for a remote party. A new commitment is written to disk + // each time we write a new state in order to be properly fault + // tolerant. + ErrNoPendingCommit = fmt.Errorf("no pending commits found") + + // ErrInvalidCircuitKeyLen signals that a circuit key could not be + // decoded because the byte slice is of an invalid length. + ErrInvalidCircuitKeyLen = fmt.Errorf( + "length of serialized circuit key must be 16 bytes") + + // ErrNoCommitPoint is returned when no data loss commit point is found + // in the database. + ErrNoCommitPoint = fmt.Errorf("no commit point found") + + // ErrNoCloseTx is returned when no closing tx is found for a channel + // in the state CommitBroadcasted. + ErrNoCloseTx = fmt.Errorf("no closing tx found") + + // ErrNoRestoredChannelMutation is returned when a caller attempts to + // mutate a channel that's been recovered. + ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + + "channel state") + + // ErrChanBorked is returned when a caller attempts to mutate a borked + // channel. + ErrChanBorked = fmt.Errorf("cannot mutate borked channel") +) + +// ChannelType is an enum-like type that describes one of several possible +// channel types. Each open channel is associated with a particular type as the +// channel type may determine how higher level operations are conducted such as +// fee negotiation, channel closing, the format of HTLCs, etc. +// TODO(roasbeef): split up per-chain? +type ChannelType uint8 + +const ( + // NOTE: iota isn't used here for this enum needs to be stable + // long-term as it will be persisted to the database. + + // SingleFunder represents a channel wherein one party solely funds the + // entire capacity of the channel. + SingleFunder ChannelType = 0 + + // DualFunder represents a channel wherein both parties contribute + // funds towards the total capacity of the channel. The channel may be + // funded symmetrically or asymmetrically. + DualFunder ChannelType = 1 + + // SingleFunderTweakless is similar to the basic SingleFunder channel + // type, but it omits the tweak for one's key in the commitment + // transaction of the remote party. + SingleFunderTweakless ChannelType = 2 +) + +// IsSingleFunder returns true if the channel type if one of the known single +// funder variants. +func (c ChannelType) IsSingleFunder() bool { + return c == SingleFunder || c == SingleFunderTweakless +} + +// IsTweakless returns true if the target channel uses a commitment that +// doesn't tweak the key for the remote party. +func (c ChannelType) IsTweakless() bool { + return c == SingleFunderTweakless +} + +// ChannelConstraints represents a set of constraints meant to allow a node to +// limit their exposure, enact flow control and ensure that all HTLCs are +// economically relevant. This struct will be mirrored for both sides of the +// channel, as each side will enforce various constraints that MUST be adhered +// to for the life time of the channel. The parameters for each of these +// constraints are static for the duration of the channel, meaning the channel +// must be torn down for them to change. +type ChannelConstraints struct { + // DustLimit is the threshold (in satoshis) below which any outputs + // should be trimmed. When an output is trimmed, it isn't materialized + // as an actual output, but is instead burned to miner's fees. + DustLimit btcutil.Amount + + // ChanReserve is an absolute reservation on the channel for the + // owner of this set of constraints. This means that the current + // settled balance for this node CANNOT dip below the reservation + // amount. This acts as a defense against costless attacks when + // either side no longer has any skin in the game. + ChanReserve btcutil.Amount + + // MaxPendingAmount is the maximum pending HTLC value that the + // owner of these constraints can offer the remote node at a + // particular time. + MaxPendingAmount lnwire.MilliSatoshi + + // MinHTLC is the minimum HTLC value that the owner of these + // constraints can offer the remote node. If any HTLCs below this + // amount are offered, then the HTLC will be rejected. This, in + // tandem with the dust limit allows a node to regulate the + // smallest HTLC that it deems economically relevant. + MinHTLC lnwire.MilliSatoshi + + // MaxAcceptedHtlcs is the maximum number of HTLCs that the owner of + // this set of constraints can offer the remote node. This allows each + // node to limit their over all exposure to HTLCs that may need to be + // acted upon in the case of a unilateral channel closure or a contract + // breach. + MaxAcceptedHtlcs uint16 + + // CsvDelay is the relative time lock delay expressed in blocks. Any + // settled outputs that pay to the owner of this channel configuration + // MUST ensure that the delay branch uses this value as the relative + // time lock. Similarly, any HTLC's offered by this node should use + // this value as well. + CsvDelay uint16 +} + +// ChannelConfig is a struct that houses the various configuration opens for +// channels. Each side maintains an instance of this configuration file as it +// governs: how the funding and commitment transaction to be created, the +// nature of HTLC's allotted, the keys to be used for delivery, and relative +// time lock parameters. +type ChannelConfig struct { + // ChannelConstraints is the set of constraints that must be upheld for + // the duration of the channel for the owner of this channel + // configuration. Constraints govern a number of flow control related + // parameters, also including the smallest HTLC that will be accepted + // by a participant. + ChannelConstraints + + // MultiSigKey is the key to be used within the 2-of-2 output script + // for the owner of this channel config. + MultiSigKey keychain.KeyDescriptor + + // RevocationBasePoint is the base public key to be used when deriving + // revocation keys for the remote node's commitment transaction. This + // will be combined along with a per commitment secret to derive a + // unique revocation key for each state. + RevocationBasePoint keychain.KeyDescriptor + + // PaymentBasePoint is the base public key to be used when deriving + // the key used within the non-delayed pay-to-self output on the + // commitment transaction for a node. This will be combined with a + // tweak derived from the per-commitment point to ensure unique keys + // for each commitment transaction. + PaymentBasePoint keychain.KeyDescriptor + + // DelayBasePoint is the base public key to be used when deriving the + // key used within the delayed pay-to-self output on the commitment + // transaction for a node. This will be combined with a tweak derived + // from the per-commitment point to ensure unique keys for each + // commitment transaction. + DelayBasePoint keychain.KeyDescriptor + + // HtlcBasePoint is the base public key to be used when deriving the + // local HTLC key. The derived key (combined with the tweak derived + // from the per-commitment point) is used within the "to self" clause + // within any HTLC output scripts. + HtlcBasePoint keychain.KeyDescriptor +} + +// ChannelCommitment is a snapshot of the commitment state at a particular +// point in the commitment chain. With each state transition, a snapshot of the +// current state along with all non-settled HTLCs are recorded. These snapshots +// detail the state of the _remote_ party's commitment at a particular state +// number. For ourselves (the local node) we ONLY store our most recent +// (unrevoked) state for safety purposes. +type ChannelCommitment struct { + // CommitHeight is the update number that this ChannelDelta represents + // the total number of commitment updates to this point. This can be + // viewed as sort of a "commitment height" as this number is + // monotonically increasing. + CommitHeight uint64 + + // LocalLogIndex is the cumulative log index index of the local node at + // this point in the commitment chain. This value will be incremented + // for each _update_ added to the local update log. + LocalLogIndex uint64 + + // LocalHtlcIndex is the current local running HTLC index. This value + // will be incremented for each outgoing HTLC the local node offers. + LocalHtlcIndex uint64 + + // RemoteLogIndex is the cumulative log index index of the remote node + // at this point in the commitment chain. This value will be + // incremented for each _update_ added to the remote update log. + RemoteLogIndex uint64 + + // RemoteHtlcIndex is the current remote running HTLC index. This value + // will be incremented for each outgoing HTLC the remote node offers. + RemoteHtlcIndex uint64 + + // LocalBalance is the current available settled balance within the + // channel directly spendable by us. + LocalBalance lnwire.MilliSatoshi + + // RemoteBalance is the current available settled balance within the + // channel directly spendable by the remote node. + RemoteBalance lnwire.MilliSatoshi + + // CommitFee is the amount calculated to be paid in fees for the + // current set of commitment transactions. The fee amount is persisted + // with the channel in order to allow the fee amount to be removed and + // recalculated with each channel state update, including updates that + // happen after a system restart. + CommitFee btcutil.Amount + + // FeePerKw is the min satoshis/kilo-weight that should be paid within + // the commitment transaction for the entire duration of the channel's + // lifetime. This field may be updated during normal operation of the + // channel as on-chain conditions change. + // + // TODO(halseth): make this SatPerKWeight. Cannot be done atm because + // this will cause the import cycle lnwallet<->channeldb. Fee + // estimation stuff should be in its own package. + FeePerKw btcutil.Amount + + // CommitTx is the latest version of the commitment state, broadcast + // able by us. + CommitTx *wire.MsgTx + + // CommitSig is one half of the signature required to fully complete + // the script for the commitment transaction above. This is the + // signature signed by the remote party for our version of the + // commitment transactions. + CommitSig []byte + + // Htlcs is the set of HTLC's that are pending at this particular + // commitment height. + Htlcs []HTLC + + // TODO(roasbeef): pending commit pointer? + // * lets just walk through +} + +// ChannelStatus is a bit vector used to indicate whether an OpenChannel is in +// the default usable state, or a state where it shouldn't be used. +type ChannelStatus uint8 + +var ( + // ChanStatusDefault is the normal state of an open channel. + ChanStatusDefault ChannelStatus + + // ChanStatusBorked indicates that the channel has entered an + // irreconcilable state, triggered by a state desynchronization or + // channel breach. Channels in this state should never be added to the + // htlc switch. + ChanStatusBorked ChannelStatus = 1 + + // ChanStatusCommitBroadcasted indicates that a commitment for this + // channel has been broadcasted. + ChanStatusCommitBroadcasted ChannelStatus = 1 << 1 + + // ChanStatusLocalDataLoss indicates that we have lost channel state + // for this channel, and broadcasting our latest commitment might be + // considered a breach. + // + // TODO(halseh): actually enforce that we are not force closing such a + // channel. + ChanStatusLocalDataLoss ChannelStatus = 1 << 2 + + // ChanStatusRestored is a status flag that signals that the channel + // has been restored, and doesn't have all the fields a typical channel + // will have. + ChanStatusRestored ChannelStatus = 1 << 3 +) + +// chanStatusStrings maps a ChannelStatus to a human friendly string that +// describes that status. +var chanStatusStrings = map[ChannelStatus]string{ + ChanStatusDefault: "ChanStatusDefault", + ChanStatusBorked: "ChanStatusBorked", + ChanStatusCommitBroadcasted: "ChanStatusCommitBroadcasted", + ChanStatusLocalDataLoss: "ChanStatusLocalDataLoss", + ChanStatusRestored: "ChanStatusRestored", +} + +// orderedChanStatusFlags is an in-order list of all that channel status flags. +var orderedChanStatusFlags = []ChannelStatus{ + ChanStatusDefault, + ChanStatusBorked, + ChanStatusCommitBroadcasted, + ChanStatusLocalDataLoss, + ChanStatusRestored, +} + +// String returns a human-readable representation of the ChannelStatus. +func (c ChannelStatus) String() string { + // If no flags are set, then this is the default case. + if c == 0 { + return chanStatusStrings[ChanStatusDefault] + } + + // Add individual bit flags. + statusStr := "" + for _, flag := range orderedChanStatusFlags { + if c&flag == flag { + statusStr += chanStatusStrings[flag] + "|" + c -= flag + } + } + + // Remove anything to the right of the final bar, including it as well. + statusStr = strings.TrimRight(statusStr, "|") + + // Add any remaining flags which aren't accounted for as hex. + if c != 0 { + statusStr += "|0x" + strconv.FormatUint(uint64(c), 16) + } + + // If this was purely an unknown flag, then remove the extra bar at the + // start of the string. + statusStr = strings.TrimLeft(statusStr, "|") + + return statusStr +} + +// OpenChannel encapsulates the persistent and dynamic state of an open channel +// with a remote node. An open channel supports several options for on-disk +// serialization depending on the exact context. Full (upon channel creation) +// state commitments, and partial (due to a commitment update) writes are +// supported. Each partial write due to a state update appends the new update +// to an on-disk log, which can then subsequently be queried in order to +// "time-travel" to a prior state. +type OpenChannel struct { + // ChanType denotes which type of channel this is. + ChanType ChannelType + + // ChainHash is a hash which represents the blockchain that this + // channel will be opened within. This value is typically the genesis + // hash. In the case that the original chain went through a contentious + // hard-fork, then this value will be tweaked using the unique fork + // point on each branch. + ChainHash chainhash.Hash + + // FundingOutpoint is the outpoint of the final funding transaction. + // This value uniquely and globally identifies the channel within the + // target blockchain as specified by the chain hash parameter. + FundingOutpoint wire.OutPoint + + // ShortChannelID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChannelID lnwire.ShortChannelID + + // IsPending indicates whether a channel's funding transaction has been + // confirmed. + IsPending bool + + // IsInitiator is a bool which indicates if we were the original + // initiator for the channel. This value may affect how higher levels + // negotiate fees, or close the channel. + IsInitiator bool + + // chanStatus is the current status of this channel. If it is not in + // the state Default, it should not be used for forwarding payments. + chanStatus ChannelStatus + + // FundingBroadcastHeight is the height in which the funding + // transaction was broadcast. This value can be used by higher level + // sub-systems to determine if a channel is stale and/or should have + // been confirmed before a certain height. + FundingBroadcastHeight uint32 + + // NumConfsRequired is the number of confirmations a channel's funding + // transaction must have received in order to be considered available + // for normal transactional use. + NumConfsRequired uint16 + + // ChannelFlags holds the flags that were sent as part of the + // open_channel message. + ChannelFlags lnwire.FundingFlag + + // IdentityPub is the identity public key of the remote node this + // channel has been established with. + IdentityPub *btcec.PublicKey + + // Capacity is the total capacity of this channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // LocalChanCfg is the channel configuration for the local node. + LocalChanCfg ChannelConfig + + // RemoteChanCfg is the channel configuration for the remote node. + RemoteChanCfg ChannelConfig + + // LocalCommitment is the current local commitment state for the local + // party. This is stored distinct from the state of the remote party + // as there are certain asymmetric parameters which affect the + // structure of each commitment. + LocalCommitment ChannelCommitment + + // RemoteCommitment is the current remote commitment state for the + // remote party. This is stored distinct from the state of the local + // party as there are certain asymmetric parameters which affect the + // structure of each commitment. + RemoteCommitment ChannelCommitment + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // RevocationProducer is used to generate the revocation in such a way + // that remote side might store it efficiently and have the ability to + // restore the revocation by index if needed. Current implementation of + // secret producer is shachain producer. + RevocationProducer shachain.Producer + + // RevocationStore is used to efficiently store the revocations for + // previous channels states sent to us by remote side. Current + // implementation of secret store is shachain store. + RevocationStore shachain.Store + + // Packager is used to create and update forwarding packages for this + // channel, which encodes all necessary information to recover from + // failures and reforward HTLCs that were not fully processed. + Packager FwdPackager + + // FundingTxn is the transaction containing this channel's funding + // outpoint. Upon restarts, this txn will be rebroadcast if the channel + // is found to be pending. + // + // NOTE: This value will only be populated for single-funder channels + // for which we are the initiator. + FundingTxn *wire.MsgTx + + // TODO(roasbeef): eww + Db *DB + + // TODO(roasbeef): just need to store local and remote HTLC's? + + sync.RWMutex +} + +// ShortChanID returns the current ShortChannelID of this channel. +func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { + c.RLock() + defer c.RUnlock() + + return c.ShortChannelID +} + +// ChanStatus returns the current ChannelStatus of this channel. +func (c *OpenChannel) ChanStatus() ChannelStatus { + c.RLock() + defer c.RUnlock() + + return c.chanStatus +} + +// ApplyChanStatus allows the caller to modify the internal channel state in a +// thead-safe manner. +func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.putChanStatus(status) +} + +// ClearChanStatus allows the caller to clear a particular channel status from +// the primary channel status bit field. After this method returns, a call to +// HasChanStatus(status) should return false. +func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { + c.Lock() + defer c.Unlock() + + return c.clearChanStatus(status) +} + +// HasChanStatus returns true if the internal bitfield channel status of the +// target channel has the specified status bit set. +func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { + c.RLock() + defer c.RUnlock() + + return c.hasChanStatus(status) +} + +func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { + return c.chanStatus&status == status +} + +// RefreshShortChanID updates the in-memory short channel ID using the latest +// value observed on disk. +func (c *OpenChannel) RefreshShortChanID() error { + c.Lock() + defer c.Unlock() + + var sid lnwire.ShortChannelID + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + sid = channel.ShortChannelID + + return nil + }) + if err != nil { + return err + } + + c.ShortChannelID = sid + c.Packager = NewChannelPackager(sid) + + return nil +} + +// fetchChanBucket is a helper function that returns the bucket where a +// channel's data resides in given: the public key for the node, the outpoint, +// and the chainhash that the channel resides on. +func fetchChanBucket(tx *bbolt.Tx, nodeKey *btcec.PublicKey, + outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bbolt.Bucket, error) { + + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, ErrNoChanDBExists + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := nodeKey.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return nil, ErrNoActiveChannels + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket := nodeChanBucket.Bucket(chainHash[:]) + if chainBucket == nil { + return nil, ErrNoActiveChannels + } + + // With the bucket for the node and chain fetched, we can now go down + // another level, for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { + return nil, err + } + chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return nil, ErrChannelNotFound + } + + return chanBucket, nil +} + +// fullSync syncs the contents of an OpenChannel while re-using an existing +// database transaction. +func (c *OpenChannel) fullSync(tx *bbolt.Tx) error { + // First fetch the top level bucket which stores all data related to + // current, active channels. + openChanBucket, err := tx.CreateBucketIfNotExists(openChannelBucket) + if err != nil { + return err + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + nodePub := c.IdentityPub.SerializeCompressed() + nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) + if err != nil { + return err + } + + // We'll then recurse down an additional layer in order to fetch the + // bucket for this particular chain. + chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:]) + if err != nil { + return err + } + + // With the bucket for the node fetched, we can now go down another + // level, creating the bucket for this channel itself. + var chanPointBuf bytes.Buffer + if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { + return err + } + chanBucket, err := chainBucket.CreateBucket( + chanPointBuf.Bytes(), + ) + switch { + case err == bbolt.ErrBucketExists: + // If this channel already exists, then in order to avoid + // overriding it, we'll return an error back up to the caller. + return ErrChanAlreadyExists + case err != nil: + return err + } + + return putOpenChannel(chanBucket, c) +} + +// MarkAsOpen marks a channel as fully open given a locator that uniquely +// describes its location within the chain. +func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { + c.Lock() + defer c.Unlock() + + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + channel.IsPending = false + channel.ShortChannelID = openLoc + + return putOpenChannel(chanBucket, channel) + }); err != nil { + return err + } + + c.IsPending = false + c.ShortChannelID = openLoc + c.Packager = NewChannelPackager(openLoc) + + return nil +} + +// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the +// passed commitPoint for use to retrieve funds in case the remote force closes +// the channel. +func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + var b bytes.Buffer + if err := WriteElement(&b, commitPoint); err != nil { + return err + } + + putCommitPoint := func(chanBucket *bbolt.Bucket) error { + return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) + } + + return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) +} + +// DataLossCommitPoint retrieves the stored commit point set during +// MarkDataLoss. If not found ErrNoCommitPoint is returned. +func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { + var commitPoint *btcec.PublicKey + + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoCommitPoint + default: + return err + } + + bs := chanBucket.Get(dataLossCommitPointKey) + if bs == nil { + return ErrNoCommitPoint + } + r := bytes.NewReader(bs) + if err := ReadElements(r, &commitPoint); err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + return commitPoint, nil +} + +// MarkBorked marks the event when the channel as reached an irreconcilable +// state, such as a channel breach or state desynchronization. Borked channels +// should never be added to the switch. +func (c *OpenChannel) MarkBorked() error { + c.Lock() + defer c.Unlock() + + return c.putChanStatus(ChanStatusBorked) +} + +// ChanSyncMsg returns the ChannelReestablish message that should be sent upon +// reconnection with the remote peer that we're maintaining this channel with. +// The information contained within this message is necessary to re-sync our +// commitment chains in the case of a last or only partially processed message. +// When the remote party receiver this message one of three things may happen: +// +// 1. We're fully synced and no messages need to be sent. +// 2. We didn't get the last CommitSig message they sent, to they'll re-send +// it. +// 3. We didn't get the last RevokeAndAck message they sent, so they'll +// re-send it. +// +// If this is a restored channel, having status ChanStatusRestored, then we'll +// modify our typical chan sync message to ensure they force close even if +// we're on the very first state. +func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { + c.Lock() + defer c.Unlock() + + // The remote commitment height that we'll send in the + // ChannelReestablish message is our current commitment height plus + // one. If the receiver thinks that our commitment height is actually + // *equal* to this value, then they'll re-send the last commitment that + // they sent but we never fully processed. + localHeight := c.LocalCommitment.CommitHeight + nextLocalCommitHeight := localHeight + 1 + + // The second value we'll send is the height of the remote commitment + // from our PoV. If the receiver thinks that their height is actually + // *one plus* this value, then they'll re-send their last revocation. + remoteChainTipHeight := c.RemoteCommitment.CommitHeight + + // If this channel has undergone a commitment update, then in order to + // prove to the remote party our knowledge of their prior commitment + // state, we'll also send over the last commitment secret that the + // remote party sent. + var lastCommitSecret [32]byte + if remoteChainTipHeight != 0 { + remoteSecret, err := c.RevocationStore.LookUp( + remoteChainTipHeight - 1, + ) + if err != nil { + return nil, err + } + lastCommitSecret = [32]byte(*remoteSecret) + } + + // Additionally, we'll send over the current unrevoked commitment on + // our local commitment transaction. + currentCommitSecret, err := c.RevocationProducer.AtIndex( + localHeight, + ) + if err != nil { + return nil, err + } + + // If we've restored this channel, then we'll purposefully give them an + // invalid LocalUnrevokedCommitPoint so they'll force close the channel + // allowing us to sweep our funds. + if c.hasChanStatus(ChanStatusRestored) { + currentCommitSecret[0] ^= 1 + + // If this is a tweakless channel, then we'll purposefully send + // a next local height taht's invalid to trigger a force close + // on their end. We do this as tweakless channels don't require + // that the commitment point is valid, only that it's present. + if c.ChanType.IsTweakless() { + nextLocalCommitHeight = 0 + } + } + + return &lnwire.ChannelReestablish{ + ChanID: lnwire.NewChanIDFromOutPoint( + &c.FundingOutpoint, + ), + NextLocalCommitHeight: nextLocalCommitHeight, + RemoteCommitTailHeight: remoteChainTipHeight, + LastRemoteCommitSecret: lastCommitSecret, + LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( + currentCommitSecret[:], + ), + }, nil +} + +// isBorked returns true if the channel has been marked as borked in the +// database. This requires an existing database transaction to already be +// active. +// +// NOTE: The primary mutex should already be held before this method is called. +func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) { + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return false, err + } + + return channel.chanStatus != ChanStatusDefault, nil +} + +// MarkCommitmentBroadcasted marks the channel as a commitment transaction has +// been broadcast, either our own or the remote, and we should watch the chain +// for it to confirm before taking any further action. It takes as argument the +// closing tx _we believe_ will appear in the chain. This is only used to +// republish this tx at startup to ensure propagation, and we should still +// handle the case where a different tx actually hits the chain. +func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error { + c.Lock() + defer c.Unlock() + + var b bytes.Buffer + if err := WriteElement(&b, closeTx); err != nil { + return err + } + + putClosingTx := func(chanBucket *bbolt.Bucket) error { + return chanBucket.Put(closingTxKey, b.Bytes()) + } + + return c.putChanStatus(ChanStatusCommitBroadcasted, putClosingTx) +} + +// BroadcastedCommitment retrieves the stored closing tx set during +// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. +func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { + var closeTx *wire.MsgTx + + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoCloseTx + default: + return err + } + + bs := chanBucket.Get(closingTxKey) + if bs == nil { + return ErrNoCloseTx + } + r := bytes.NewReader(bs) + return ReadElement(r, &closeTx) + }) + if err != nil { + return nil, err + } + + return closeTx, nil +} + +// putChanStatus appends the given status to the channel. fs is an optional +// list of closures that are given the chanBucket in order to atomically add +// extra information together with the new status. +func (c *OpenChannel) putChanStatus(status ChannelStatus, + fs ...func(*bbolt.Bucket) error) error { + + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + // Add this status to the existing bitvector found in the DB. + status = channel.chanStatus | status + channel.chanStatus = status + + if err := putOpenChannel(chanBucket, channel); err != nil { + return err + } + + for _, f := range fs { + if err := f(chanBucket); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + c.chanStatus = status + + return nil +} + +func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { + if err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) + if err != nil { + return err + } + + // Unset this bit in the bitvector on disk. + status = channel.chanStatus & ^status + channel.chanStatus = status + + return putOpenChannel(chanBucket, channel) + }); err != nil { + return err + } + + // Update the in-memory representation to keep it in sync with the DB. + c.chanStatus = status + + return nil +} + +// putChannel serializes, and stores the current state of the channel in its +// entirety. +func putOpenChannel(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + // First, we'll write out all the relatively static fields, that are + // decided upon initial channel creation. + if err := putChanInfo(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan info: %v", err) + } + + // With the static channel info written out, we'll now write out the + // current commitment state for both parties. + if err := putChanCommitments(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan commitments: %v", err) + } + + // Finally, we'll write out the revocation state for both parties + // within a distinct key space. + if err := putChanRevocationState(chanBucket, channel); err != nil { + return fmt.Errorf("unable to store chan revocations: %v", err) + } + + return nil +} + +// fetchOpenChannel retrieves, and deserializes (including decrypting +// sensitive) the complete channel currently active with the passed nodeID. +func fetchOpenChannel(chanBucket *bbolt.Bucket, + chanPoint *wire.OutPoint) (*OpenChannel, error) { + + channel := &OpenChannel{ + FundingOutpoint: *chanPoint, + } + + // First, we'll read all the static information that changes less + // frequently from disk. + if err := fetchChanInfo(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan info: %v", err) + } + + // With the static information read, we'll now read the current + // commitment state for both sides of the channel. + if err := fetchChanCommitments(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan commitments: %v", err) + } + + // Finally, we'll retrieve the current revocation state so we can + // properly + if err := fetchChanRevocationState(chanBucket, channel); err != nil { + return nil, fmt.Errorf("unable to fetch chan revocations: %v", err) + } + + channel.Packager = NewChannelPackager(channel.ShortChannelID) + + return channel, nil +} + +// SyncPending writes the contents of the channel to the database while it's in +// the pending (waiting for funding confirmation) state. The IsPending flag +// will be set to true. When the channel's funding transaction is confirmed, +// the channel should be marked as "open" and the IsPending flag set to false. +// Note that this function also creates a LinkNode relationship between this +// newly created channel and a new LinkNode instance. This allows listing all +// channels in the database globally, or according to the LinkNode they were +// created with. +// +// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type +// that includes service bits. +func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { + c.Lock() + defer c.Unlock() + + c.FundingBroadcastHeight = pendingHeight + + return c.Db.Update(func(tx *bbolt.Tx) error { + return syncNewChannel(tx, c, []net.Addr{addr}) + }) +} + +// syncNewChannel will write the passed channel to disk, and also create a +// LinkNode (if needed) for the channel peer. +func syncNewChannel(tx *bbolt.Tx, c *OpenChannel, addrs []net.Addr) error { + // First, sync all the persistent channel state to disk. + if err := c.fullSync(tx); err != nil { + return err + } + + nodeInfoBucket, err := tx.CreateBucketIfNotExists(nodeInfoBucket) + if err != nil { + return err + } + + // If a LinkNode for this identity public key already exists, + // then we can exit early. + nodePub := c.IdentityPub.SerializeCompressed() + if nodeInfoBucket.Get(nodePub) != nil { + return nil + } + + // Next, we need to establish a (possibly) new LinkNode relationship + // for this channel. The LinkNode metadata contains reachability, + // up-time, and service bits related information. + linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) + + // TODO(roasbeef): do away with link node all together? + + return putLinkNode(nodeInfoBucket, linkNode) +} + +// UpdateCommitment updates the commitment state for the specified party +// (remote or local). The commitment stat completely describes the balance +// state at this point in the commitment chain. This method its to be called on +// two occasions: when we revoke our prior commitment state, and when the +// remote party revokes their prior commitment state. +func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state as all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + if err = putChanInfo(chanBucket, c); err != nil { + return fmt.Errorf("unable to store chan info: %v", err) + } + + // With the proper bucket fetched, we'll now write the latest + // commitment state to disk for the target party. + err = putChanCommitment( + chanBucket, newCommitment, true, + ) + if err != nil { + return fmt.Errorf("unable to store chan "+ + "revocations: %v", err) + } + + return nil + }) + if err != nil { + return err + } + + c.LocalCommitment = *newCommitment + + return nil +} + +// HTLC is the on-disk representation of a hash time-locked contract. HTLCs are +// contained within ChannelDeltas which encode the current state of the +// commitment between state updates. +// +// TODO(roasbeef): save space by using smaller ints at tail end? +type HTLC struct { + // Signature is the signature for the second level covenant transaction + // for this HTLC. The second level transaction is a timeout tx in the + // case that this is an outgoing HTLC, and a success tx in the case + // that this is an incoming HTLC. + // + // TODO(roasbeef): make [64]byte instead? + Signature []byte + + // RHash is the payment hash of the HTLC. + RHash [32]byte + + // Amt is the amount of milli-satoshis this HTLC escrows. + Amt lnwire.MilliSatoshi + + // RefundTimeout is the absolute timeout on the HTLC that the sender + // must wait before reclaiming the funds in limbo. + RefundTimeout uint32 + + // OutputIndex is the output index for this particular HTLC output + // within the commitment transaction. + OutputIndex int32 + + // Incoming denotes whether we're the receiver or the sender of this + // HTLC. + Incoming bool + + // OnionBlob is an opaque blob which is used to complete multi-hop + // routing. + OnionBlob []byte + + // HtlcIndex is the HTLC counter index of this active, outstanding + // HTLC. This differs from the LogIndex, as the HtlcIndex is only + // incremented for each offered HTLC, while they LogIndex is + // incremented for each update (includes settle+fail). + HtlcIndex uint64 + + // LogIndex is the cumulative log index of this HTLC. This differs + // from the HtlcIndex as this will be incremented for each new log + // update added. + LogIndex uint64 +} + +// SerializeHtlcs writes out the passed set of HTLC's into the passed writer +// using the current default on-disk serialization format. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { + numHtlcs := uint16(len(htlcs)) + if err := WriteElement(b, numHtlcs); err != nil { + return err + } + + for _, htlc := range htlcs { + if err := WriteElements(b, + htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, + htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], + htlc.HtlcIndex, htlc.LogIndex, + ); err != nil { + return err + } + } + + return nil +} + +// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed +// io.Reader. The bytes within the passed reader MUST have been previously +// written to using the SerializeHtlcs function. +// +// NOTE: This API is NOT stable, the on-disk format will likely change in the +// future. +func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { + var numHtlcs uint16 + if err := ReadElement(r, &numHtlcs); err != nil { + return nil, err + } + + var htlcs []HTLC + if numHtlcs == 0 { + return htlcs, nil + } + + htlcs = make([]HTLC, numHtlcs) + for i := uint16(0); i < numHtlcs; i++ { + if err := ReadElements(r, + &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, + &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, + &htlcs[i].Incoming, &htlcs[i].OnionBlob, + &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, + ); err != nil { + return htlcs, err + } + } + + return htlcs, nil +} + +// Copy returns a full copy of the target HTLC. +func (h *HTLC) Copy() HTLC { + clone := HTLC{ + Incoming: h.Incoming, + Amt: h.Amt, + RefundTimeout: h.RefundTimeout, + OutputIndex: h.OutputIndex, + } + copy(clone.Signature[:], h.Signature) + copy(clone.RHash[:], h.RHash[:]) + + return clone +} + +// LogUpdate represents a pending update to the remote commitment chain. The +// log update may be an add, fail, or settle entry. We maintain this data in +// order to be able to properly retransmit our proposed +// state if necessary. +type LogUpdate struct { + // LogIndex is the log index of this proposed commitment update entry. + LogIndex uint64 + + // UpdateMsg is the update message that was included within the our + // local update log. The LogIndex value denotes the log index of this + // update which will be used when restoring our local update log if + // we're left with a dangling update on restart. + UpdateMsg lnwire.Message +} + +// Encode writes a log update to the provided io.Writer. +func (l *LogUpdate) Encode(w io.Writer) error { + return WriteElements(w, l.LogIndex, l.UpdateMsg) +} + +// Decode reads a log update from the provided io.Reader. +func (l *LogUpdate) Decode(r io.Reader) error { + return ReadElements(r, &l.LogIndex, &l.UpdateMsg) +} + +// CircuitKey is used by a channel to uniquely identify the HTLCs it receives +// from the switch, and is used to purge our in-memory state of HTLCs that have +// already been processed by a link. Two list of CircuitKeys are included in +// each CommitDiff to allow a link to determine which in-memory htlcs directed +// the opening and closing of circuits in the switch's circuit map. +type CircuitKey struct { + // ChanID is the short chanid indicating the HTLC's origin. + // + // NOTE: It is fine for this value to be blank, as this indicates a + // locally-sourced payment. + ChanID lnwire.ShortChannelID + + // HtlcID is the unique htlc index predominately assigned by links, + // though can also be assigned by switch in the case of locally-sourced + // payments. + HtlcID uint64 +} + +// SetBytes deserializes the given bytes into this CircuitKey. +func (k *CircuitKey) SetBytes(bs []byte) error { + if len(bs) != 16 { + return ErrInvalidCircuitKeyLen + } + + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(bs[:8])) + k.HtlcID = binary.BigEndian.Uint64(bs[8:]) + + return nil +} + +// Bytes returns the serialized bytes for this circuit key. +func (k CircuitKey) Bytes() []byte { + var bs = make([]byte, 16) + binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(bs[8:], k.HtlcID) + return bs +} + +// Encode writes a CircuitKey to the provided io.Writer. +func (k *CircuitKey) Encode(w io.Writer) error { + var scratch [16]byte + binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64()) + binary.BigEndian.PutUint64(scratch[8:], k.HtlcID) + + _, err := w.Write(scratch[:]) + return err +} + +// Decode reads a CircuitKey from the provided io.Reader. +func (k *CircuitKey) Decode(r io.Reader) error { + var scratch [16]byte + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return err + } + k.ChanID = lnwire.NewShortChanIDFromInt( + binary.BigEndian.Uint64(scratch[:8])) + k.HtlcID = binary.BigEndian.Uint64(scratch[8:]) + + return nil +} + +// String returns a string representation of the CircuitKey. +func (k CircuitKey) String() string { + return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID) +} + +// CommitDiff represents the delta needed to apply the state transition between +// two subsequent commitment states. Given state N and state N+1, one is able +// to apply the set of messages contained within the CommitDiff to N to arrive +// at state N+1. Each time a new commitment is extended, we'll write a new +// commitment (along with the full commitment state) to disk so we can +// re-transmit the state in the case of a connection loss or message drop. +type CommitDiff struct { + // ChannelCommitment is the full commitment state that one would arrive + // at by applying the set of messages contained in the UpdateDiff to + // the prior accepted commitment. + Commitment ChannelCommitment + + // LogUpdates is the set of messages sent prior to the commitment state + // transition in question. Upon reconnection, if we detect that they + // don't have the commitment, then we re-send this along with the + // proper signature. + LogUpdates []LogUpdate + + // CommitSig is the exact CommitSig message that should be sent after + // the set of LogUpdates above has been retransmitted. The signatures + // within this message should properly cover the new commitment state + // and also the HTLC's within the new commitment state. + CommitSig *lnwire.CommitSig + + // OpenedCircuitKeys is a set of unique identifiers for any downstream + // Add packets included in this commitment txn. After a restart, this + // set of htlcs is acked from the link's incoming mailbox to ensure + // there isn't an attempt to re-add them to this commitment txn. + OpenedCircuitKeys []CircuitKey + + // ClosedCircuitKeys records the unique identifiers for any settle/fail + // packets that were resolved by this commitment txn. After a restart, + // this is used to ensure those circuits are removed from the circuit + // map, and the downstream packets in the link's mailbox are removed. + ClosedCircuitKeys []CircuitKey + + // AddAcks specifies the locations (commit height, pkg index) of any + // Adds that were failed/settled in this commit diff. This will ack + // entries in *this* channel's forwarding packages. + // + // NOTE: This value is not serialized, it is used to atomically mark the + // resolution of adds, such that they will not be reprocessed after a + // restart. + AddAcks []AddRef + + // SettleFailAcks specifies the locations (chan id, commit height, pkg + // index) of any Settles or Fails that were locked into this commit + // diff, and originate from *another* channel, i.e. the outgoing link. + // + // NOTE: This value is not serialized, it is used to atomically acks + // settles and fails from the forwarding packages of other channels, + // such that they will not be reforwarded internally after a restart. + SettleFailAcks []SettleFailRef +} + +func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { + if err := serializeChanCommit(w, &diff.Commitment); err != nil { + return err + } + + if err := diff.CommitSig.Encode(w, 0); err != nil { + return err + } + + numUpdates := uint16(len(diff.LogUpdates)) + if err := binary.Write(w, byteOrder, numUpdates); err != nil { + return err + } + + for _, diff := range diff.LogUpdates { + err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) + if err != nil { + return err + } + } + + numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) + if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { + return err + } + + for _, openRef := range diff.OpenedCircuitKeys { + err := WriteElements(w, openRef.ChanID, openRef.HtlcID) + if err != nil { + return err + } + } + + numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) + if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { + return err + } + + for _, closedRef := range diff.ClosedCircuitKeys { + err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) + if err != nil { + return err + } + } + + return nil +} + +func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { + var ( + d CommitDiff + err error + ) + + d.Commitment, err = deserializeChanCommit(r) + if err != nil { + return nil, err + } + + d.CommitSig = &lnwire.CommitSig{} + if err := d.CommitSig.Decode(r, 0); err != nil { + return nil, err + } + + var numUpdates uint16 + if err := binary.Read(r, byteOrder, &numUpdates); err != nil { + return nil, err + } + + d.LogUpdates = make([]LogUpdate, numUpdates) + for i := 0; i < int(numUpdates); i++ { + err := ReadElements(r, + &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, + ) + if err != nil { + return nil, err + } + } + + var numOpenRefs uint16 + if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { + return nil, err + } + + d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) + for i := 0; i < int(numOpenRefs); i++ { + err := ReadElements(r, + &d.OpenedCircuitKeys[i].ChanID, + &d.OpenedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + var numClosedRefs uint16 + if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { + return nil, err + } + + d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) + for i := 0; i < int(numClosedRefs); i++ { + err := ReadElements(r, + &d.ClosedCircuitKeys[i].ChanID, + &d.ClosedCircuitKeys[i].HtlcID) + if err != nil { + return nil, err + } + } + + return &d, nil +} + +// AppendRemoteCommitChain appends a new CommitDiff to the end of the +// commitment chain for the remote party. This method is to be used once we +// have prepared a new commitment state for the remote party, but before we +// transmit it to the remote party. The contents of the argument should be +// sufficient to retransmit the updates and signature needed to reconstruct the +// state in full, in the case that we need to retransmit. +func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + return c.Db.Update(func(tx *bbolt.Tx) error { + // First, we'll grab the writable bucket where this channel's + // data resides. + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Any outgoing settles and fails necessarily have a + // corresponding adds in this channel's forwarding packages. + // Mark all of these as being fully processed in our forwarding + // package, which prevents us from reprocessing them after + // startup. + err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) + if err != nil { + return err + } + + // Additionally, we ack from any fails or settles that are + // persisted in another channel's forwarding package. This + // prevents the same fails and settles from being retransmitted + // after restarts. The actual fail or settle we need to + // propagate to the remote party is now in the commit diff. + err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) + if err != nil { + return err + } + + // TODO(roasbeef): use seqno to derive key for later LCP + + // With the bucket retrieved, we'll now serialize the commit + // diff itself, and write it to disk. + var b bytes.Buffer + if err := serializeCommitDiff(&b, diff); err != nil { + return err + } + return chanBucket.Put(commitDiffKey, b.Bytes()) + }) +} + +// RemoteCommitChainTip returns the "tip" of the current remote commitment +// chain. This value will be non-nil iff, we've created a new commitment for +// the remote party that they haven't yet ACK'd. In this case, their commitment +// chain will have a length of two: their current unrevoked commitment, and +// this new pending commitment. Once they revoked their prior state, we'll swap +// these pointers, causing the tip and the tail to point to the same entry. +func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { + var cd *CommitDiff + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + switch err { + case nil: + case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: + return ErrNoPendingCommit + default: + return err + } + + tipBytes := chanBucket.Get(commitDiffKey) + if tipBytes == nil { + return ErrNoPendingCommit + } + + tipReader := bytes.NewReader(tipBytes) + dcd, err := deserializeCommitDiff(tipReader) + if err != nil { + return err + } + + cd = dcd + return nil + }) + if err != nil { + return nil, err + } + + return cd, err +} + +// InsertNextRevocation inserts the _next_ commitment point (revocation) into +// the database, and also modifies the internal RemoteNextRevocation attribute +// to point to the passed key. This method is to be using during final channel +// set up, _after_ the channel has been fully confirmed. +// +// NOTE: If this method isn't called, then the target channel won't be able to +// propose new states for the commitment state of the remote party. +func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { + c.Lock() + defer c.Unlock() + + c.RemoteNextRevocation = revKey + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return putChanRevocationState(chanBucket, c) + }) + if err != nil { + return err + } + + return nil +} + +// AdvanceCommitChainTail records the new state transition within an on-disk +// append-only log which records all state transitions by the remote peer. In +// the case of an uncooperative broadcast of a prior state by the remote peer, +// this log can be consulted in order to reconstruct the state needed to +// rectify the situation. This method will add the current commitment for the +// remote party to the revocation log, and promote the current pending +// commitment to the current remote commitment. +func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { + c.Lock() + defer c.Unlock() + + // If this is a restored channel, then we want to avoid mutating the + // state at all, as it's impossible to do so in a protocol compliant + // manner. + if c.hasChanStatus(ChanStatusRestored) { + return ErrNoRestoredChannelMutation + } + + var newRemoteCommit *ChannelCommitment + + err := c.Db.Update(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + // If the channel is marked as borked, then for safety reasons, + // we shouldn't attempt any further updates. + isBorked, err := c.isBorked(chanBucket) + if err != nil { + return err + } + if isBorked { + return ErrChanBorked + } + + // Persist the latest preimage state to disk as the remote peer + // has just added to our local preimage store, and given us a + // new pending revocation key. + if err := putChanRevocationState(chanBucket, c); err != nil { + return err + } + + // With the current preimage producer/store state updated, + // append a new log entry recording this the delta of this + // state transition. + // + // TODO(roasbeef): could make the deltas relative, would save + // space, but then tradeoff for more disk-seeks to recover the + // full state. + logKey := revocationLogBucket + logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) + if err != nil { + return err + } + + // Before we append this revoked state to the revocation log, + // we'll swap out what's currently the tail of the commit tip, + // with the current locked-in commitment for the remote party. + tipBytes := chanBucket.Get(commitDiffKey) + tipReader := bytes.NewReader(tipBytes) + newCommit, err := deserializeCommitDiff(tipReader) + if err != nil { + return err + } + err = putChanCommitment( + chanBucket, &newCommit.Commitment, false, + ) + if err != nil { + return err + } + if err := chanBucket.Delete(commitDiffKey); err != nil { + return err + } + + // With the commitment pointer swapped, we can now add the + // revoked (prior) state to the revocation log. + // + // TODO(roasbeef): store less + err = appendChannelLogEntry(logBucket, &c.RemoteCommitment) + if err != nil { + return err + } + + // Lastly, we write the forwarding package to disk so that we + // can properly recover from failures and reforward HTLCs that + // have not received a corresponding settle/fail. + if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { + return err + } + + newRemoteCommit = &newCommit.Commitment + + return nil + }) + if err != nil { + return err + } + + // With the db transaction complete, we'll swap over the in-memory + // pointer of the new remote commitment, which was previously the tip + // of the commit chain. + c.RemoteCommitment = *newRemoteCommit + + return nil +} + +// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure +// this always returns the next index that has been not been allocated, this +// will first try to examine any pending commitments, before falling back to the +// last locked-in local commitment. +func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { + // First, load the most recent commit diff that we initiated for the + // remote party. If no pending commit is found, this is not treated as + // a critical error, since we can always fall back. + pendingRemoteCommit, err := c.RemoteCommitChainTip() + if err != nil && err != ErrNoPendingCommit { + return 0, err + } + + // If a pending commit was found, its local htlc index will be at least + // as large as the one on our local commitment. + if pendingRemoteCommit != nil { + return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil + } + + // Otherwise, fallback to using the local htlc index of our commitment. + return c.LocalCommitment.LocalHtlcIndex, nil +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in map indexed by the +// remote commitment height at which the updates were locked in. +func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { + c.RLock() + defer c.RUnlock() + + var fwdPkgs []*FwdPkg + if err := c.Db.View(func(tx *bbolt.Tx) error { + var err error + fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) + return err + }); err != nil { + return nil, err + } + + return fwdPkgs, nil +} + +// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs +// indicating that a response to this Add has been committed to the remote party. +// Doing so will prevent these Add HTLCs from being reforwarded internally. +func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.AckAddHtlcs(tx, addRefs...) + }) +} + +// AckSettleFails updates the SettleFailFilter containing any of the provided +// SettleFailRefs, indicating that the response has been delivered to the +// incoming link, corresponding to a particular AddRef. Doing so will prevent +// the responses from being retransmitted internally. +func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.AckSettleFails(tx, settleFailRefs...) + }) +} + +// SetFwdFilter atomically sets the forwarding filter for the forwarding package +// identified by `height`. +func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.SetFwdFilter(tx, height, fwdFilter) + }) +} + +// RemoveFwdPkg atomically removes a forwarding package specified by the remote +// commitment height. +// +// NOTE: This method should only be called on packages marked FwdStateCompleted. +func (c *OpenChannel) RemoveFwdPkg(height uint64) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + return c.Packager.RemovePkg(tx, height) + }) +} + +// RevocationLogTail returns the "tail", or the end of the current revocation +// log. This entry represents the last previous state for the remote node's +// commitment chain. The ChannelDelta returned by this method will always lag +// one state behind the most current (unrevoked) state of the remote node's +// commitment chain. +func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { + c.RLock() + defer c.RUnlock() + + // If we haven't created any state updates yet, then we'll exit early as + // there's nothing to be found on disk in the revocation bucket. + if c.RemoteCommitment.CommitHeight == 0 { + return nil, nil + } + + var commit ChannelCommitment + if err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket == nil { + return ErrNoPastDeltas + } + + // Once we have the bucket that stores the revocation log from + // this channel, we'll jump to the _last_ key in bucket. As we + // store the update number on disk in a big-endian format, + // this will retrieve the latest entry. + cursor := logBucket.Cursor() + _, tailLogEntry := cursor.Last() + logEntryReader := bytes.NewReader(tailLogEntry) + + // Once we have the entry, we'll decode it into the channel + // delta pointer we created above. + var dbErr error + commit, dbErr = deserializeChanCommit(logEntryReader) + if dbErr != nil { + return dbErr + } + + return nil + }); err != nil { + return nil, err + } + + return &commit, nil +} + +// CommitmentHeight returns the current commitment height. The commitment +// height represents the number of updates to the commitment state to date. +// This value is always monotonically increasing. This method is provided in +// order to allow multiple instances of a particular open channel to obtain a +// consistent view of the number of channel updates to date. +func (c *OpenChannel) CommitmentHeight() (uint64, error) { + c.RLock() + defer c.RUnlock() + + var height uint64 + err := c.Db.View(func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + commit, err := fetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + + height = commit.CommitHeight + return nil + }) + if err != nil { + return 0, err + } + + return height, nil +} + +// FindPreviousState scans through the append-only log in an attempt to recover +// the previous channel state indicated by the update number. This method is +// intended to be used for obtaining the relevant data needed to claim all +// funds rightfully spendable in the case of an on-chain broadcast of the +// commitment transaction. +func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) { + c.RLock() + defer c.RUnlock() + + var commit ChannelCommitment + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket == nil { + return ErrNoPastDeltas + } + + c, err := fetchChannelLogEntry(logBucket, updateNum) + if err != nil { + return err + } + + commit = c + return nil + }) + if err != nil { + return nil, err + } + + return &commit, nil +} + +// ClosureType is an enum like structure that details exactly _how_ a channel +// was closed. Three closure types are currently possible: none, cooperative, +// local force close, remote force close, and (remote) breach. +type ClosureType uint8 + +const ( + // CooperativeClose indicates that a channel has been closed + // cooperatively. This means that both channel peers were online and + // signed a new transaction paying out the settled balance of the + // contract. + CooperativeClose ClosureType = 0 + + // LocalForceClose indicates that we have unilaterally broadcast our + // current commitment state on-chain. + LocalForceClose ClosureType = 1 + + // RemoteForceClose indicates that the remote peer has unilaterally + // broadcast their current commitment state on-chain. + RemoteForceClose ClosureType = 4 + + // BreachClose indicates that the remote peer attempted to broadcast a + // prior _revoked_ channel state. + BreachClose ClosureType = 2 + + // FundingCanceled indicates that the channel never was fully opened + // before it was marked as closed in the database. This can happen if + // we or the remote fail at some point during the opening workflow, or + // we timeout waiting for the funding transaction to be confirmed. + FundingCanceled ClosureType = 3 + + // Abandoned indicates that the channel state was removed without + // any further actions. This is intended to clean up unusable + // channels during development. + Abandoned ClosureType = 5 +) + +// ChannelCloseSummary contains the final state of a channel at the point it +// was closed. Once a channel is closed, all the information pertaining to that +// channel within the openChannelBucket is deleted, and a compact summary is +// put in place instead. +type ChannelCloseSummary struct { + // ChanPoint is the outpoint for this channel's funding transaction, + // and is used as a unique identifier for the channel. + ChanPoint wire.OutPoint + + // ShortChanID encodes the exact location in the chain in which the + // channel was initially confirmed. This includes: the block height, + // transaction index, and the output within the target transaction. + ShortChanID lnwire.ShortChannelID + + // ChainHash is the hash of the genesis block that this channel resides + // within. + ChainHash chainhash.Hash + + // ClosingTXID is the txid of the transaction which ultimately closed + // this channel. + ClosingTXID chainhash.Hash + + // RemotePub is the public key of the remote peer that we formerly had + // a channel with. + RemotePub *btcec.PublicKey + + // Capacity was the total capacity of the channel. + Capacity btcutil.Amount + + // CloseHeight is the height at which the funding transaction was + // spent. + CloseHeight uint32 + + // SettledBalance is our total balance settled balance at the time of + // channel closure. This _does not_ include the sum of any outputs that + // have been time-locked as a result of the unilateral channel closure. + SettledBalance btcutil.Amount + + // TimeLockedBalance is the sum of all the time-locked outputs at the + // time of channel closure. If we triggered the force closure of this + // channel, then this value will be non-zero if our settled output is + // above the dust limit. If we were on the receiving side of a channel + // force closure, then this value will be non-zero if we had any + // outstanding outgoing HTLC's at the time of channel closure. + TimeLockedBalance btcutil.Amount + + // CloseType details exactly _how_ the channel was closed. Five closure + // types are possible: cooperative, local force, remote force, breach + // and funding canceled. + CloseType ClosureType + + // IsPending indicates whether this channel is in the 'pending close' + // state, which means the channel closing transaction has been + // confirmed, but not yet been fully resolved. In the case of a channel + // that has been cooperatively closed, it will go straight into the + // fully resolved state as soon as the closing transaction has been + // confirmed. However, for channels that have been force closed, they'll + // stay marked as "pending" until _all_ the pending funds have been + // swept. + IsPending bool + + // RemoteCurrentRevocation is the current revocation for their + // commitment transaction. However, since this is the derived public key, + // we don't yet have the private key so we aren't yet able to verify + // that it's actually in the hash chain. + RemoteCurrentRevocation *btcec.PublicKey + + // RemoteNextRevocation is the revocation key to be used for the *next* + // commitment transaction we create for the local node. Within the + // specification, this value is referred to as the + // per-commitment-point. + RemoteNextRevocation *btcec.PublicKey + + // LocalChanCfg is the channel configuration for the local node. + LocalChanConfig ChannelConfig + + // LastChanSyncMsg is the ChannelReestablish message for this channel + // for the state at the point where it was closed. + LastChanSyncMsg *lnwire.ChannelReestablish +} + +// CloseChannel closes a previously active Lightning channel. Closing a channel +// entails deleting all saved state within the database concerning this +// channel. This method also takes a struct that summarizes the state of the +// channel at closing, this compact representation will be the only component +// of a channel left over after a full closing. +func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary) error { + c.Lock() + defer c.Unlock() + + return c.Db.Update(func(tx *bbolt.Tx) error { + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoChanDBExists + } + + nodePub := c.IdentityPub.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return ErrNoActiveChannels + } + + chainBucket := nodeChanBucket.Bucket(c.ChainHash[:]) + if chainBucket == nil { + return ErrNoActiveChannels + } + + var chanPointBuf bytes.Buffer + err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) + if err != nil { + return err + } + chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) + if chanBucket == nil { + return ErrNoActiveChannels + } + + // Before we delete the channel state, we'll read out the full + // details, as we'll also store portions of this information + // for record keeping. + chanState, err := fetchOpenChannel( + chanBucket, &c.FundingOutpoint, + ) + if err != nil { + return err + } + + // Now that the index to this channel has been deleted, purge + // the remaining channel metadata from the database. + err = deleteOpenChannel(chanBucket, chanPointBuf.Bytes()) + if err != nil { + return err + } + + // With the base channel data deleted, attempt to delete the + // information stored within the revocation log. + logBucket := chanBucket.Bucket(revocationLogBucket) + if logBucket != nil { + err = chanBucket.DeleteBucket(revocationLogBucket) + if err != nil { + return err + } + } + + err = chainBucket.DeleteBucket(chanPointBuf.Bytes()) + if err != nil { + return err + } + + // Finally, create a summary of this channel in the closed + // channel bucket for this node. + return putChannelCloseSummary( + tx, chanPointBuf.Bytes(), summary, chanState, + ) + }) +} + +// ChannelSnapshot is a frozen snapshot of the current channel state. A +// snapshot is detached from the original channel that generated it, providing +// read-only access to the current or prior state of an active channel. +// +// TODO(roasbeef): remove all together? pretty much just commitment +type ChannelSnapshot struct { + // RemoteIdentity is the identity public key of the remote node that we + // are maintaining the open channel with. + RemoteIdentity btcec.PublicKey + + // ChanPoint is the outpoint that created the channel. This output is + // found within the funding transaction and uniquely identified the + // channel on the resident chain. + ChannelPoint wire.OutPoint + + // ChainHash is the genesis hash of the chain that the channel resides + // within. + ChainHash chainhash.Hash + + // Capacity is the total capacity of the channel. + Capacity btcutil.Amount + + // TotalMSatSent is the total number of milli-satoshis we've sent + // within this channel. + TotalMSatSent lnwire.MilliSatoshi + + // TotalMSatReceived is the total number of milli-satoshis we've + // received within this channel. + TotalMSatReceived lnwire.MilliSatoshi + + // ChannelCommitment is the current up-to-date commitment for the + // target channel. + ChannelCommitment +} + +// Snapshot returns a read-only snapshot of the current channel state. This +// snapshot includes information concerning the current settled balance within +// the channel, metadata detailing total flows, and any outstanding HTLCs. +func (c *OpenChannel) Snapshot() *ChannelSnapshot { + c.RLock() + defer c.RUnlock() + + localCommit := c.LocalCommitment + snapshot := &ChannelSnapshot{ + RemoteIdentity: *c.IdentityPub, + ChannelPoint: c.FundingOutpoint, + Capacity: c.Capacity, + TotalMSatSent: c.TotalMSatSent, + TotalMSatReceived: c.TotalMSatReceived, + ChainHash: c.ChainHash, + ChannelCommitment: ChannelCommitment{ + LocalBalance: localCommit.LocalBalance, + RemoteBalance: localCommit.RemoteBalance, + CommitHeight: localCommit.CommitHeight, + CommitFee: localCommit.CommitFee, + }, + } + + // Copy over the current set of HTLCs to ensure the caller can't mutate + // our internal state. + snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) + for i, h := range localCommit.Htlcs { + snapshot.Htlcs[i] = h.Copy() + } + + return snapshot +} + +// LatestCommitments returns the two latest commitments for both the local and +// remote party. These commitments are read from disk to ensure that only the +// latest fully committed state is returned. The first commitment returned is +// the local commitment, and the second returned is the remote commitment. +func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return fetchChanCommitments(chanBucket, c) + }) + if err != nil { + return nil, nil, err + } + + return &c.LocalCommitment, &c.RemoteCommitment, nil +} + +// RemoteRevocationStore returns the most up to date commitment version of the +// revocation storage tree for the remote party. This method can be used when +// acting on a possible contract breach to ensure, that the caller has the most +// up to date information required to deliver justice. +func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { + err := c.Db.View(func(tx *bbolt.Tx) error { + chanBucket, err := fetchChanBucket( + tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, + ) + if err != nil { + return err + } + + return fetchChanRevocationState(chanBucket, c) + }) + if err != nil { + return nil, err + } + + return c.RevocationStore, nil +} + +func putChannelCloseSummary(tx *bbolt.Tx, chanID []byte, + summary *ChannelCloseSummary, lastChanState *OpenChannel) error { + + closedChanBucket, err := tx.CreateBucketIfNotExists(closedChannelBucket) + if err != nil { + return err + } + + summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation + summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation + summary.LocalChanConfig = lastChanState.LocalChanCfg + + var b bytes.Buffer + if err := serializeChannelCloseSummary(&b, summary); err != nil { + return err + } + + return closedChanBucket.Put(chanID, b.Bytes()) +} + +func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { + err := WriteElements(w, + cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, + cs.CloseHeight, cs.RemotePub, cs.Capacity, cs.SettledBalance, + cs.TimeLockedBalance, cs.CloseType, cs.IsPending, + ) + if err != nil { + return err + } + + // If this is a close channel summary created before the addition of + // the new fields, then we can exit here. + if cs.RemoteCurrentRevocation == nil { + return WriteElements(w, false) + } + + // If fields are present, write boolean to indicate this, and continue. + if err := WriteElements(w, true); err != nil { + return err + } + + if err := WriteElements(w, cs.RemoteCurrentRevocation); err != nil { + return err + } + + if err := writeChanConfig(w, &cs.LocalChanConfig); err != nil { + return err + } + + // The RemoteNextRevocation field is optional, as it's possible for a + // channel to be closed before we learn of the next unrevoked + // revocation point for the remote party. Write a boolen indicating + // whether this field is present or not. + if err := WriteElements(w, cs.RemoteNextRevocation != nil); err != nil { + return err + } + + // Write the field, if present. + if cs.RemoteNextRevocation != nil { + if err = WriteElements(w, cs.RemoteNextRevocation); err != nil { + return err + } + } + + // Write whether the channel sync message is present. + if err := WriteElements(w, cs.LastChanSyncMsg != nil); err != nil { + return err + } + + // Write the channel sync message, if present. + if cs.LastChanSyncMsg != nil { + if err := WriteElements(w, cs.LastChanSyncMsg); err != nil { + return err + } + } + + return nil +} + +func deserializeCloseChannelSummary(r io.Reader) (*ChannelCloseSummary, error) { + c := &ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + var hasNewFields bool + err = ReadElements(r, &hasNewFields) + if err != nil { + return nil, err + } + + // If fields are not present, we can return. + if !hasNewFields { + return c, nil + } + + // Otherwise read the new fields. + if err := ReadElements(r, &c.RemoteCurrentRevocation); err != nil { + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message then this might not be present. A boolean + // indicating whether the field is present will come first. + var hasRemoteNextRevocation bool + err = ReadElements(r, &hasRemoteNextRevocation) + if err != nil { + return nil, err + } + + // If this field was written, read it. + if hasRemoteNextRevocation { + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil { + return nil, err + } + } + + // Check if we have a channel sync message to read. + var hasChanSyncMsg bool + err = ReadElements(r, &hasChanSyncMsg) + if err == io.EOF { + return c, nil + } else if err != nil { + return nil, err + } + + // If a chan sync message is present, read it. + if hasChanSyncMsg { + // We must pass in reference to a lnwire.Message for the codec + // to support it. + var msg lnwire.Message + if err := ReadElements(r, &msg); err != nil { + return nil, err + } + + chanSync, ok := msg.(*lnwire.ChannelReestablish) + if !ok { + return nil, errors.New("unable cast db Message to " + + "ChannelReestablish") + } + c.LastChanSyncMsg = chanSync + } + + return c, nil +} + +func writeChanConfig(b io.Writer, c *ChannelConfig) error { + return WriteElements(b, + c.DustLimit, c.MaxPendingAmount, c.ChanReserve, c.MinHTLC, + c.MaxAcceptedHtlcs, c.CsvDelay, c.MultiSigKey, + c.RevocationBasePoint, c.PaymentBasePoint, c.DelayBasePoint, + c.HtlcBasePoint, + ) +} + +func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + var w bytes.Buffer + if err := WriteElements(&w, + channel.ChanType, channel.ChainHash, channel.FundingOutpoint, + channel.ShortChannelID, channel.IsPending, channel.IsInitiator, + channel.chanStatus, channel.FundingBroadcastHeight, + channel.NumConfsRequired, channel.ChannelFlags, + channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, + channel.TotalMSatReceived, + ); err != nil { + return err + } + + // For single funder channels that we initiated, write the funding txn. + if channel.ChanType.IsSingleFunder() && channel.IsInitiator && + !channel.hasChanStatus(ChanStatusRestored) { + + if err := WriteElement(&w, channel.FundingTxn); err != nil { + return err + } + } + + if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil { + return err + } + if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil { + return err + } + + return chanBucket.Put(chanInfoKey, w.Bytes()) +} + +func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { + if err := WriteElements(w, + c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, + c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, + c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, + c.CommitSig, + ); err != nil { + return err + } + + return SerializeHtlcs(w, c.Htlcs...) +} + +func putChanCommitment(chanBucket *bbolt.Bucket, c *ChannelCommitment, + local bool) error { + + var commitKey []byte + if local { + commitKey = append(chanCommitmentKey, byte(0x00)) + } else { + commitKey = append(chanCommitmentKey, byte(0x01)) + } + + var b bytes.Buffer + if err := serializeChanCommit(&b, c); err != nil { + return err + } + + return chanBucket.Put(commitKey, b.Bytes()) +} + +func putChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + // If this is a restored channel, then we don't have any commitments to + // write. + if channel.hasChanStatus(ChanStatusRestored) { + return nil + } + + err := putChanCommitment( + chanBucket, &channel.LocalCommitment, true, + ) + if err != nil { + return err + } + + return putChanCommitment( + chanBucket, &channel.RemoteCommitment, false, + ) +} + +func putChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + + var b bytes.Buffer + err := WriteElements( + &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, + channel.RevocationStore, + ) + if err != nil { + return err + } + + // TODO(roasbeef): don't keep producer on disk + + // If the next revocation is present, which is only the case after the + // FundingLocked message has been sent, then we'll write it to disk. + if channel.RemoteNextRevocation != nil { + err = WriteElements(&b, channel.RemoteNextRevocation) + if err != nil { + return err + } + } + + return chanBucket.Put(revocationStateKey, b.Bytes()) +} + +func readChanConfig(b io.Reader, c *ChannelConfig) error { + return ReadElements(b, + &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, + &c.MinHTLC, &c.MaxAcceptedHtlcs, &c.CsvDelay, + &c.MultiSigKey, &c.RevocationBasePoint, + &c.PaymentBasePoint, &c.DelayBasePoint, + &c.HtlcBasePoint, + ) +} + +func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + infoBytes := chanBucket.Get(chanInfoKey) + if infoBytes == nil { + return ErrNoChanInfoFound + } + r := bytes.NewReader(infoBytes) + + if err := ReadElements(r, + &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, + &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, + &channel.chanStatus, &channel.FundingBroadcastHeight, + &channel.NumConfsRequired, &channel.ChannelFlags, + &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, + &channel.TotalMSatReceived, + ); err != nil { + return err + } + + // For single funder channels that we initiated, read the funding txn. + if channel.ChanType.IsSingleFunder() && channel.IsInitiator && + !channel.hasChanStatus(ChanStatusRestored) { + + if err := ReadElement(r, &channel.FundingTxn); err != nil { + return err + } + } + + if err := readChanConfig(r, &channel.LocalChanCfg); err != nil { + return err + } + if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil { + return err + } + + channel.Packager = NewChannelPackager(channel.ShortChannelID) + + return nil +} + +func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { + var c ChannelCommitment + + err := ReadElements(r, + &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, + &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, + &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, + ) + if err != nil { + return c, err + } + + c.Htlcs, err = DeserializeHtlcs(r) + if err != nil { + return c, err + } + + return c, nil +} + +func fetchChanCommitment(chanBucket *bbolt.Bucket, local bool) (ChannelCommitment, error) { + var commitKey []byte + if local { + commitKey = append(chanCommitmentKey, byte(0x00)) + } else { + commitKey = append(chanCommitmentKey, byte(0x01)) + } + + commitBytes := chanBucket.Get(commitKey) + if commitBytes == nil { + return ChannelCommitment{}, ErrNoCommitmentsFound + } + + r := bytes.NewReader(commitBytes) + return deserializeChanCommit(r) +} + +func fetchChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + var err error + + // If this is a restored channel, then we don't have any commitments to + // read. + if channel.hasChanStatus(ChanStatusRestored) { + return nil + } + + channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true) + if err != nil { + return err + } + channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false) + if err != nil { + return err + } + + return nil +} + +func fetchChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { + revBytes := chanBucket.Get(revocationStateKey) + if revBytes == nil { + return ErrNoRevocationsFound + } + r := bytes.NewReader(revBytes) + + err := ReadElements( + r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, + &channel.RevocationStore, + ) + if err != nil { + return err + } + + // If there aren't any bytes left in the buffer, then we don't yet have + // the next remote revocation, so we can exit early here. + if r.Len() == 0 { + return nil + } + + // Otherwise we'll read the next revocation for the remote party which + // is always the last item within the buffer. + return ReadElements(r, &channel.RemoteNextRevocation) +} + +func deleteOpenChannel(chanBucket *bbolt.Bucket, chanPointBytes []byte) error { + + if err := chanBucket.Delete(chanInfoKey); err != nil { + return err + } + + err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00))) + if err != nil { + return err + } + err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01))) + if err != nil { + return err + } + + if err := chanBucket.Delete(revocationStateKey); err != nil { + return err + } + + if diff := chanBucket.Get(commitDiffKey); diff != nil { + return chanBucket.Delete(commitDiffKey) + } + + return nil + +} + +// makeLogKey converts a uint64 into an 8 byte array. +func makeLogKey(updateNum uint64) [8]byte { + var key [8]byte + byteOrder.PutUint64(key[:], updateNum) + return key +} + +func appendChannelLogEntry(log *bbolt.Bucket, + commit *ChannelCommitment) error { + + var b bytes.Buffer + if err := serializeChanCommit(&b, commit); err != nil { + return err + } + + logEntrykey := makeLogKey(commit.CommitHeight) + return log.Put(logEntrykey[:], b.Bytes()) +} + +func fetchChannelLogEntry(log *bbolt.Bucket, + updateNum uint64) (ChannelCommitment, error) { + + logEntrykey := makeLogKey(updateNum) + commitBytes := log.Get(logEntrykey[:]) + if commitBytes == nil { + return ChannelCommitment{}, fmt.Errorf("log entry not found") + } + + commitReader := bytes.NewReader(commitBytes) + return deserializeChanCommit(commitReader) +} diff --git a/channeldb/migration_01_to_11/channel_cache.go b/channeldb/migration_01_to_11/channel_cache.go new file mode 100644 index 00000000..5d391e00 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_cache.go @@ -0,0 +1,50 @@ +package migration_01_to_11 + +// channelCache is an in-memory cache used to improve the performance of +// ChanUpdatesInHorizon. It caches the chan info and edge policies for a +// particular channel. +type channelCache struct { + n int + channels map[uint64]ChannelEdge +} + +// newChannelCache creates a new channelCache with maximum capacity of n +// channels. +func newChannelCache(n int) *channelCache { + return &channelCache{ + n: n, + channels: make(map[uint64]ChannelEdge), + } +} + +// get returns the channel from the cache, if it exists. +func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) { + channel, ok := c.channels[chanid] + return channel, ok +} + +// insert adds the entry to the channel cache. If an entry for chanid already +// exists, it will be replaced with the new entry. If the entry doesn't exist, +// it will be inserted to the cache, performing a random eviction if the cache +// is at capacity. +func (c *channelCache) insert(chanid uint64, channel ChannelEdge) { + // If entry exists, replace it. + if _, ok := c.channels[chanid]; ok { + c.channels[chanid] = channel + return + } + + // Otherwise, evict an entry at random and insert. + if len(c.channels) == c.n { + for id := range c.channels { + delete(c.channels, id) + break + } + } + c.channels[chanid] = channel +} + +// remove deletes an edge for chanid from the cache, if it exists. +func (c *channelCache) remove(chanid uint64) { + delete(c.channels, chanid) +} diff --git a/channeldb/migration_01_to_11/channel_cache_test.go b/channeldb/migration_01_to_11/channel_cache_test.go new file mode 100644 index 00000000..b2929635 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_cache_test.go @@ -0,0 +1,105 @@ +package migration_01_to_11 + +import ( + "reflect" + "testing" +) + +// TestChannelCache checks the behavior of the channelCache with respect to +// insertion, eviction, and removal of cache entries. +func TestChannelCache(t *testing.T) { + const cacheSize = 100 + + // Create a new channel cache with the configured max size. + c := newChannelCache(cacheSize) + + // As a sanity check, assert that querying the empty cache does not + // return an entry. + _, ok := c.get(0) + if ok { + t.Fatalf("channel cache should be empty") + } + + // Now, fill up the cache entirely. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, channelForInt(i)) + } + + // Assert that the cache has all of the entries just inserted, since no + // eviction should occur until we try to surpass the max size. + assertHasChanEntries(t, c, 0, cacheSize) + + // Now, insert a new element that causes the cache to evict an element. + c.insert(cacheSize, channelForInt(cacheSize)) + + // Assert that the cache has this last entry, as the cache should evict + // some prior element and not the newly inserted one. + assertHasChanEntries(t, c, cacheSize, cacheSize) + + // Iterate over all inserted elements and construct a set of the evicted + // elements. + evicted := make(map[uint64]struct{}) + for i := uint64(0); i < cacheSize+1; i++ { + _, ok := c.get(i) + if !ok { + evicted[i] = struct{}{} + } + } + + // Assert that exactly one element has been evicted. + numEvicted := len(evicted) + if numEvicted != 1 { + t.Fatalf("expected one evicted entry, got: %d", numEvicted) + } + + // Remove the highest item which initially caused the eviction and + // reinsert the element that was evicted prior. + c.remove(cacheSize) + for i := range evicted { + c.insert(i, channelForInt(i)) + } + + // Since the removal created an extra slot, the last insertion should + // not have caused an eviction and the entries for all channels in the + // original set that filled the cache should be present. + assertHasChanEntries(t, c, 0, cacheSize) + + // Finally, reinsert the existing set back into the cache and test that + // the cache still has all the entries. If the randomized eviction were + // happening on inserts for existing cache items, we expect this to fail + // with high probability. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, channelForInt(i)) + } + assertHasChanEntries(t, c, 0, cacheSize) + +} + +// assertHasEntries queries the edge cache for all channels in the range [start, +// end), asserting that they exist and their value matches the entry produced by +// entryForInt. +func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { + t.Helper() + + for i := start; i < end; i++ { + entry, ok := c.get(i) + if !ok { + t.Fatalf("channel cache should contain chan %d", i) + } + + expEntry := channelForInt(i) + if !reflect.DeepEqual(entry, expEntry) { + t.Fatalf("entry mismatch, want: %v, got: %v", + expEntry, entry) + } + } +} + +// channelForInt generates a unique ChannelEdge given an integer. +func channelForInt(i uint64) ChannelEdge { + return ChannelEdge{ + Info: &ChannelEdgeInfo{ + ChannelID: i, + }, + } +} diff --git a/channeldb/migration_01_to_11/channel_test.go b/channeldb/migration_01_to_11/channel_test.go new file mode 100644 index 00000000..53fb39d7 --- /dev/null +++ b/channeldb/migration_01_to_11/channel_test.go @@ -0,0 +1,1041 @@ +package migration_01_to_11 + +import ( + "bytes" + "io/ioutil" + "math/rand" + "net" + "os" + "reflect" + "runtime" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + _ "github.com/btcsuite/btcwallet/walletdb/bdb" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +var ( + key = [chainhash.HashSize]byte{ + 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d, + 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9, + } + rev = [chainhash.HashSize]byte{ + 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda, + 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17, + 0x2d, 0xe7, 0x93, 0xe4, + } + testTx = &wire.MsgTx{ + Version: 1, + TxIn: []*wire.TxIn{ + { + PreviousOutPoint: wire.OutPoint{ + Hash: chainhash.Hash{}, + Index: 0xffffffff, + }, + SignatureScript: []byte{0x04, 0x31, 0xdc, 0x00, 0x1b, 0x01, 0x62}, + Sequence: 0xffffffff, + }, + }, + TxOut: []*wire.TxOut{ + { + Value: 5000000000, + PkScript: []byte{ + 0x41, // OP_DATA_65 + 0x04, 0xd6, 0x4b, 0xdf, 0xd0, 0x9e, 0xb1, 0xc5, + 0xfe, 0x29, 0x5a, 0xbd, 0xeb, 0x1d, 0xca, 0x42, + 0x81, 0xbe, 0x98, 0x8e, 0x2d, 0xa0, 0xb6, 0xc1, + 0xc6, 0xa5, 0x9d, 0xc2, 0x26, 0xc2, 0x86, 0x24, + 0xe1, 0x81, 0x75, 0xe8, 0x51, 0xc9, 0x6b, 0x97, + 0x3d, 0x81, 0xb0, 0x1c, 0xc3, 0x1f, 0x04, 0x78, + 0x34, 0xbc, 0x06, 0xd6, 0xd6, 0xed, 0xf6, 0x20, + 0xd1, 0x84, 0x24, 0x1a, 0x6a, 0xed, 0x8b, 0x63, + 0xa6, // 65-byte signature + 0xac, // OP_CHECKSIG + }, + }, + }, + LockTime: 5, + } + privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + + wireSig, _ = lnwire.NewSigFromSignature(testSig) +) + +// makeTestDB creates a new instance of the ChannelDB for testing purposes. A +// callback which cleans up the created temporary directories is also returned +// and intended to be executed after the test completes. +func makeTestDB() (*DB, func(), error) { + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + return nil, nil, err + } + + // Next, create channeldb for the first time. + cdb, err := Open(tempDirName) + if err != nil { + return nil, nil, err + } + + cleanUp := func() { + cdb.Close() + os.RemoveAll(tempDirName) + } + + return cdb, cleanUp, nil +} + +func createTestChannelState(cdb *DB) (*OpenChannel, error) { + // Simulate 1000 channel updates. + producer, err := shachain.NewRevocationProducerFromBytes(key[:]) + if err != nil { + return nil, err + } + store := shachain.NewRevocationStore() + for i := 0; i < 1; i++ { + preImage, err := producer.AtIndex(uint64(i)) + if err != nil { + return nil, err + } + + if err := store.AddNextEntry(preImage); err != nil { + return nil, err + } + } + + localCfg := ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + DustLimit: btcutil.Amount(rand.Int63()), + MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), + ChanReserve: btcutil.Amount(rand.Int63()), + MinHTLC: lnwire.MilliSatoshi(rand.Int63()), + MaxAcceptedHtlcs: uint16(rand.Int31()), + CsvDelay: uint16(rand.Int31()), + }, + MultiSigKey: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + RevocationBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + DelayBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + HtlcBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + }, + } + remoteCfg := ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + DustLimit: btcutil.Amount(rand.Int63()), + MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()), + ChanReserve: btcutil.Amount(rand.Int63()), + MinHTLC: lnwire.MilliSatoshi(rand.Int63()), + MaxAcceptedHtlcs: uint16(rand.Int31()), + CsvDelay: uint16(rand.Int31()), + }, + MultiSigKey: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyMultiSig, + Index: 9, + }, + }, + RevocationBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyRevocationBase, + Index: 8, + }, + }, + PaymentBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyPaymentBase, + Index: 7, + }, + }, + DelayBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyDelayBase, + Index: 6, + }, + }, + HtlcBasePoint: keychain.KeyDescriptor{ + PubKey: privKey.PubKey(), + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyHtlcBase, + Index: 5, + }, + }, + } + + chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63())) + + return &OpenChannel{ + ChanType: SingleFunder, + ChainHash: key, + FundingOutpoint: wire.OutPoint{Hash: key, Index: rand.Uint32()}, + ShortChannelID: chanID, + IsInitiator: true, + IsPending: true, + IdentityPub: pubKey, + Capacity: btcutil.Amount(10000), + LocalChanCfg: localCfg, + RemoteChanCfg: remoteCfg, + TotalMSatSent: 8, + TotalMSatReceived: 2, + LocalCommitment: ChannelCommitment{ + CommitHeight: 0, + LocalBalance: lnwire.MilliSatoshi(9000), + RemoteBalance: lnwire.MilliSatoshi(3000), + CommitFee: btcutil.Amount(rand.Int63()), + FeePerKw: btcutil.Amount(5000), + CommitTx: testTx, + CommitSig: bytes.Repeat([]byte{1}, 71), + }, + RemoteCommitment: ChannelCommitment{ + CommitHeight: 0, + LocalBalance: lnwire.MilliSatoshi(3000), + RemoteBalance: lnwire.MilliSatoshi(9000), + CommitFee: btcutil.Amount(rand.Int63()), + FeePerKw: btcutil.Amount(5000), + CommitTx: testTx, + CommitSig: bytes.Repeat([]byte{1}, 71), + }, + NumConfsRequired: 4, + RemoteCurrentRevocation: privKey.PubKey(), + RemoteNextRevocation: privKey.PubKey(), + RevocationProducer: producer, + RevocationStore: store, + Db: cdb, + Packager: NewChannelPackager(chanID), + FundingTxn: testTx, + }, nil +} + +func TestOpenChannelPutGetDelete(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, then add an additional fake HTLC + // before syncing to disk. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + state.LocalCommitment.Htlcs = []HTLC{ + { + Signature: testSig.Serialize(), + Incoming: true, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: []byte("onionblob"), + }, + } + state.RemoteCommitment.Htlcs = []HTLC{ + { + Signature: testSig.Serialize(), + Incoming: false, + Amt: 10, + RHash: key, + RefundTimeout: 1, + OnionBlob: []byte("onionblob"), + }, + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := state.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channel: %v", err) + } + + newState := openChannels[0] + + // The decoded channel state should be identical to what we stored + // above. + if !reflect.DeepEqual(state, newState) { + t.Fatalf("channel state doesn't match:: %v vs %v", + spew.Sdump(state), spew.Sdump(newState)) + } + + // We'll also test that the channel is properly able to hot swap the + // next revocation for the state machine. This tests the initial + // post-funding revocation exchange. + nextRevKey, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to create new private key: %v", err) + } + if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil { + t.Fatalf("unable to update revocation: %v", err) + } + + openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channel: %v", err) + } + updatedChan := openChannels[0] + + // Ensure that the revocation was set properly. + if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) { + t.Fatalf("next revocation wasn't updated") + } + + // Finally to wrap up the test, delete the state of the channel within + // the database. This involves "closing" the channel which removes all + // written state, and creates a small "summary" elsewhere within the + // database. + closeSummary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + RemotePub: state.IdentityPub, + SettledBalance: btcutil.Amount(500), + TimeLockedBalance: btcutil.Amount(10000), + IsPending: false, + CloseType: CooperativeClose, + } + if err := state.CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + + // As the channel is now closed, attempting to fetch all open channels + // for our fake node ID should return an empty slice. + openChans, err := cdb.FetchOpenChannels(state.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch open channels: %v", err) + } + if len(openChans) != 0 { + t.Fatalf("all channels not deleted, found %v", len(openChans)) + } + + // Additionally, attempting to fetch all the open channels globally + // should yield no results. + openChans, err = cdb.FetchAllChannels() + if err != nil { + t.Fatal("unable to fetch all open chans") + } + if len(openChans) != 0 { + t.Fatalf("all channels not deleted, found %v", len(openChans)) + } +} + +func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { + if !reflect.DeepEqual(a, b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: commitments don't match: %v vs %v", + line, spew.Sdump(a), spew.Sdump(b)) + } +} + +func TestChannelStateTransition(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a minimal channel, then perform a full sync in order to + // persist the data. + channel, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := channel.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Add some HTLCs which were added during this new state transition. + // Half of the HTLCs are incoming, while the other half are outgoing. + var ( + htlcs []HTLC + htlcAmt lnwire.MilliSatoshi + ) + for i := uint32(0); i < 10; i++ { + var incoming bool + if i > 5 { + incoming = true + } + htlc := HTLC{ + Signature: testSig.Serialize(), + Incoming: incoming, + Amt: 10, + RHash: key, + RefundTimeout: i, + OutputIndex: int32(i * 3), + LogIndex: uint64(i * 2), + HtlcIndex: uint64(i), + } + htlc.OnionBlob = make([]byte, 10) + copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10)) + htlcs = append(htlcs, htlc) + htlcAmt += htlc.Amt + } + + // Create a new channel delta which includes the above HTLCs, some + // balance updates, and an increment of the current commitment height. + // Additionally, modify the signature and commitment transaction. + newSequence := uint32(129498) + newSig := bytes.Repeat([]byte{3}, 71) + newTx := channel.LocalCommitment.CommitTx.Copy() + newTx.TxIn[0].Sequence = newSequence + commitment := ChannelCommitment{ + CommitHeight: 1, + LocalLogIndex: 2, + LocalHtlcIndex: 1, + RemoteLogIndex: 2, + RemoteHtlcIndex: 1, + LocalBalance: lnwire.MilliSatoshi(1e8), + RemoteBalance: lnwire.MilliSatoshi(1e8), + CommitFee: 55, + FeePerKw: 99, + CommitTx: newTx, + CommitSig: newSig, + Htlcs: htlcs, + } + + // First update the local node's broadcastable state and also add a + // CommitDiff remote node's as well in order to simulate a proper state + // transition. + if err := channel.UpdateCommitment(&commitment); err != nil { + t.Fatalf("unable to update commitment: %v", err) + } + + // The balances, new update, the HTLCs and the changes to the fake + // commitment transaction along with the modified signature should all + // have been updated. + updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channel: %v", err) + } + assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) + numDiskUpdates, err := updatedChannel[0].CommitmentHeight() + if err != nil { + t.Fatalf("unable to read commitment height from disk: %v", err) + } + if numDiskUpdates != uint64(commitment.CommitHeight) { + t.Fatalf("num disk updates doesn't match: %v vs %v", + numDiskUpdates, commitment.CommitHeight) + } + + // Attempting to query for a commitment diff should return + // ErrNoPendingCommit as we haven't yet created a new state for them. + _, err = channel.RemoteCommitChainTip() + if err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) + } + + // To simulate us extending a new state to the remote party, we'll also + // create a new commit diff for them. + remoteCommit := commitment + remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) + remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) + remoteCommit.CommitHeight = 1 + commitDiff := &CommitDiff{ + Commitment: remoteCommit, + CommitSig: &lnwire.CommitSig{ + ChanID: lnwire.ChannelID(key), + CommitSig: wireSig, + HtlcSigs: []lnwire.Sig{ + wireSig, + wireSig, + }, + }, + LogUpdates: []LogUpdate{ + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 1, + Amount: lnwire.NewMSatFromSatoshis(100), + Expiry: 25, + }, + }, + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ID: 2, + Amount: lnwire.NewMSatFromSatoshis(200), + Expiry: 50, + }, + }, + }, + OpenedCircuitKeys: []CircuitKey{}, + ClosedCircuitKeys: []CircuitKey{}, + } + copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{1}, 32)) + copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], + bytes.Repeat([]byte{2}, 32)) + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatalf("unable to add to commit chain: %v", err) + } + + // The commitment tip should now match the commitment that we just + // inserted. + diskCommitDiff, err := channel.RemoteCommitChainTip() + if err != nil { + t.Fatalf("unable to fetch commit diff: %v", err) + } + if !reflect.DeepEqual(commitDiff, diskCommitDiff) { + t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), + spew.Sdump(diskCommitDiff)) + } + + // We'll save the old remote commitment as this will be added to the + // revocation log shortly. + oldRemoteCommit := channel.RemoteCommitment + + // Next, write to the log which tracks the necessary revocation state + // needed to rectify any fishy behavior by the remote party. Modify the + // current uncollapsed revocation state to simulate a state transition + // by the remote party. + channel.RemoteCurrentRevocation = channel.RemoteNextRevocation + newPriv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to generate key: %v", err) + } + channel.RemoteNextRevocation = newPriv.PubKey() + + fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, + diskCommitDiff.LogUpdates, nil) + + err = channel.AdvanceCommitChainTail(fwdPkg) + if err != nil { + t.Fatalf("unable to append to revocation log: %v", err) + } + + // At this point, the remote commit chain should be nil, and the posted + // remote commitment should match the one we added as a diff above. + if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { + t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) + } + + // We should be able to fetch the channel delta created above by its + // update number with all the state properly reconstructed. + diskPrevCommit, err := channel.FindPreviousState( + oldRemoteCommit.CommitHeight, + ) + if err != nil { + t.Fatalf("unable to fetch past delta: %v", err) + } + + // The two deltas (the original vs the on-disk version) should + // identical, and all HTLC data should properly be retained. + assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) + + // The state number recovered from the tail of the revocation log + // should be identical to this current state. + logTail, err := channel.RevocationLogTail() + if err != nil { + t.Fatalf("unable to retrieve log: %v", err) + } + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { + t.Fatal("update number doesn't match") + } + + oldRemoteCommit = channel.RemoteCommitment + + // Next modify the posted diff commitment slightly, then create a new + // commitment diff and advance the tail. + commitDiff.Commitment.CommitHeight = 2 + commitDiff.Commitment.LocalBalance -= htlcAmt + commitDiff.Commitment.RemoteBalance += htlcAmt + commitDiff.LogUpdates = []LogUpdate{} + if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { + t.Fatalf("unable to add to commit chain: %v", err) + } + + fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) + + err = channel.AdvanceCommitChainTail(fwdPkg) + if err != nil { + t.Fatalf("unable to append to revocation log: %v", err) + } + + // Once again, fetch the state and ensure it has been properly updated. + prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) + if err != nil { + t.Fatalf("unable to fetch past delta: %v", err) + } + assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) + + // Once again, state number recovered from the tail of the revocation + // log should be identical to this current state. + logTail, err = channel.RevocationLogTail() + if err != nil { + t.Fatalf("unable to retrieve log: %v", err) + } + if logTail.CommitHeight != oldRemoteCommit.CommitHeight { + t.Fatal("update number doesn't match") + } + + // The revocation state stored on-disk should now also be identical. + updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channel: %v", err) + } + if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { + t.Fatalf("revocation state was not synced") + } + if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { + t.Fatalf("revocation state was not synced") + } + + // Now attempt to delete the channel from the database. + closeSummary := &ChannelCloseSummary{ + ChanPoint: channel.FundingOutpoint, + RemotePub: channel.IdentityPub, + SettledBalance: btcutil.Amount(500), + TimeLockedBalance: btcutil.Amount(10000), + IsPending: false, + CloseType: RemoteForceClose, + } + if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to delete updated channel: %v", err) + } + + // If we attempt to fetch the target channel again, it shouldn't be + // found. + channels, err := cdb.FetchOpenChannels(channel.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch updated channels: %v", err) + } + if len(channels) != 0 { + t.Fatalf("%v channels, found, but none should be", + len(channels)) + } + + // Attempting to find previous states on the channel should fail as the + // revocation log has been deleted. + _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) + if err == nil { + t.Fatal("revocation log search should have failed") + } +} + +func TestFetchPendingChannels(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create first test channel state + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + pendingChannels, err := cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to list pending channels: %v", err) + } + + if len(pendingChannels) != 1 { + t.Fatalf("incorrect number of pending channels: expecting %v,"+ + "got %v", 1, len(pendingChannels)) + } + + // The broadcast height of the pending channel should have been set + // properly. + if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { + t.Fatalf("broadcast height mismatch: expected %v, got %v", + pendingChannels[0].FundingBroadcastHeight, + broadcastHeight) + } + + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 5, + TxIndex: 10, + TxPosition: 15, + } + err = pendingChannels[0].MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + if pendingChannels[0].IsPending { + t.Fatalf("channel marked open should no longer be pending") + } + + if pendingChannels[0].ShortChanID() != chanOpenLoc { + t.Fatalf("channel opening height not updated: expected %v, "+ + "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), + chanOpenLoc) + } + + // Next, we'll re-fetch the channel to ensure that the open height was + // properly set. + openChans, err := cdb.FetchAllChannels() + if err != nil { + t.Fatalf("unable to fetch channels: %v", err) + } + if openChans[0].ShortChanID() != chanOpenLoc { + t.Fatalf("channel opening heights don't match: expected %v, "+ + "got %v", spew.Sdump(openChans[0].ShortChanID()), + chanOpenLoc) + } + if openChans[0].FundingBroadcastHeight != broadcastHeight { + t.Fatalf("broadcast height mismatch: expected %v, got %v", + openChans[0].FundingBroadcastHeight, + broadcastHeight) + } + + pendingChannels, err = cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to list pending channels: %v", err) + } + + if len(pendingChannels) != 0 { + t.Fatalf("incorrect number of pending channels: expecting %v,"+ + "got %v", 0, len(pendingChannels)) + } +} + +func TestFetchClosedChannels(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a test channel, that we'll be closing within this pull + // request. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Next sync the channel to disk, marking it as being in a pending open + // state. + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Next, simulate the confirmation of the channel by marking it as + // pending within the database. + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 5, + TxIndex: 10, + TxPosition: 15, + } + err = state.MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + // Next, close the channel by including a close channel summary in the + // database. + summary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + ClosingTXID: rev, + RemotePub: state.IdentityPub, + Capacity: state.Capacity, + SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(), + TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000, + CloseType: RemoteForceClose, + IsPending: true, + LocalChanConfig: state.LocalChanCfg, + } + if err := state.CloseChannel(summary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + + // Query the database to ensure that the channel has now been properly + // closed. We should get the same result whether querying for pending + // channels only, or not. + pendingClosed, err := cdb.FetchClosedChannels(true) + if err != nil { + t.Fatalf("failed fetching closed channels: %v", err) + } + if len(pendingClosed) != 1 { + t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ + "got %v", 1, len(pendingClosed)) + } + if !reflect.DeepEqual(summary, pendingClosed[0]) { + t.Fatalf("database summaries don't match: expected %v got %v", + spew.Sdump(summary), spew.Sdump(pendingClosed[0])) + } + closed, err := cdb.FetchClosedChannels(false) + if err != nil { + t.Fatalf("failed fetching all closed channels: %v", err) + } + if len(closed) != 1 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 1, len(closed)) + } + if !reflect.DeepEqual(summary, closed[0]) { + t.Fatalf("database summaries don't match: expected %v got %v", + spew.Sdump(summary), spew.Sdump(closed[0])) + } + + // Mark the channel as fully closed. + err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) + if err != nil { + t.Fatalf("failed fully closing channel: %v", err) + } + + // The channel should no longer be considered pending, but should still + // be retrieved when fetching all the closed channels. + closed, err = cdb.FetchClosedChannels(false) + if err != nil { + t.Fatalf("failed fetching closed channels: %v", err) + } + if len(closed) != 1 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 1, len(closed)) + } + pendingClose, err := cdb.FetchClosedChannels(true) + if err != nil { + t.Fatalf("failed fetching channels pending close: %v", err) + } + if len(pendingClose) != 0 { + t.Fatalf("incorrect number of closed channels: expecting %v, "+ + "got %v", 0, len(closed)) + } +} + +// TestFetchWaitingCloseChannels ensures that the correct channels that are +// waiting to be closed are returned. +func TestFetchWaitingCloseChannels(t *testing.T) { + t.Parallel() + + const numChannels = 2 + const broadcastHeight = 99 + addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555} + + // We'll start by creating two channels within our test database. One of + // them will have their funding transaction confirmed on-chain, while + // the other one will remain unconfirmed. + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + channels := make([]*OpenChannel, numChannels) + for i := 0; i < numChannels; i++ { + channel, err := createTestChannelState(db) + if err != nil { + t.Fatalf("unable to create channel: %v", err) + } + err = channel.SyncPending(addr, broadcastHeight) + if err != nil { + t.Fatalf("unable to sync channel: %v", err) + } + channels[i] = channel + } + + // We'll only confirm the first one. + channelConf := lnwire.ShortChannelID{ + BlockHeight: broadcastHeight + 1, + TxIndex: 10, + TxPosition: 15, + } + if err := channels[0].MarkAsOpen(channelConf); err != nil { + t.Fatalf("unable to mark channel as open: %v", err) + } + + // Then, we'll mark the channels as if their commitments were broadcast. + // This would happen in the event of a force close and should make the + // channels enter a state of waiting close. + for _, channel := range channels { + closeTx := wire.NewMsgTx(2) + closeTx.AddTxIn( + &wire.TxIn{ + PreviousOutPoint: channel.FundingOutpoint, + }, + ) + if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil { + t.Fatalf("unable to mark commitment broadcast: %v", err) + } + } + + // Now, we'll fetch all the channels waiting to be closed from the + // database. We should expect to see both channels above, even if any of + // them haven't had their funding transaction confirm on-chain. + waitingCloseChannels, err := db.FetchWaitingCloseChannels() + if err != nil { + t.Fatalf("unable to fetch all waiting close channels: %v", err) + } + if len(waitingCloseChannels) != 2 { + t.Fatalf("expected %d channels waiting to be closed, got %d", 2, + len(waitingCloseChannels)) + } + expectedChannels := make(map[wire.OutPoint]struct{}) + for _, channel := range channels { + expectedChannels[channel.FundingOutpoint] = struct{}{} + } + for _, channel := range waitingCloseChannels { + if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { + t.Fatalf("expected channel %v to be waiting close", + channel.FundingOutpoint) + } + + // Finally, make sure we can retrieve the closing tx for the + // channel. + closeTx, err := channel.BroadcastedCommitment() + if err != nil { + t.Fatalf("Unable to retrieve commitment: %v", err) + } + + if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint { + t.Fatalf("expected outpoint %v, got %v", + channel.FundingOutpoint, + closeTx.TxIn[0].PreviousOutPoint) + } + } +} + +// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory +// short channel ID of another OpenChannel to reflect a preceding call to +// MarkOpen on a different OpenChannel. +func TestRefreshShortChanID(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First create a test channel. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + + // Mark the channel as pending within the channeldb. + const broadcastHeight = 99 + if err := state.SyncPending(addr, broadcastHeight); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + + // Next, locate the pending channel with the database. + pendingChannels, err := cdb.FetchPendingChannels() + if err != nil { + t.Fatalf("unable to load pending channels; %v", err) + } + + var pendingChannel *OpenChannel + for _, channel := range pendingChannels { + if channel.FundingOutpoint == state.FundingOutpoint { + pendingChannel = channel + break + } + } + if pendingChannel == nil { + t.Fatalf("unable to find pending channel with funding "+ + "outpoint=%v: %v", state.FundingOutpoint, err) + } + + // Next, simulate the confirmation of the channel by marking it as + // pending within the database. + chanOpenLoc := lnwire.ShortChannelID{ + BlockHeight: 105, + TxIndex: 10, + TxPosition: 15, + } + + err = state.MarkAsOpen(chanOpenLoc) + if err != nil { + t.Fatalf("unable to mark channel open: %v", err) + } + + // The short_chan_id of the receiver to MarkAsOpen should reflect the + // open location, but the other pending channel should remain unchanged. + if state.ShortChanID() == pendingChannel.ShortChanID() { + t.Fatalf("pending channel short_chan_ID should not have been " + + "updated before refreshing short_chan_id") + } + + // Now that the receiver's short channel id has been updated, check to + // ensure that the channel packager's source has been updated as well. + // This ensures that the packager will read and write to buckets + // corresponding to the new short chan id, instead of the prior. + if state.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + state.Packager.(*ChannelPackager).source) + } + + // Now, refresh the short channel ID of the pending channel. + err = pendingChannel.RefreshShortChanID() + if err != nil { + t.Fatalf("unable to refresh short_chan_id: %v", err) + } + + // This should result in both OpenChannel's now having the same + // ShortChanID. + if state.ShortChanID() != pendingChannel.ShortChanID() { + t.Fatalf("expected pending channel short_chan_id to be "+ + "refreshed: want %v, got %v", state.ShortChanID(), + pendingChannel.ShortChanID()) + } + + // Check to ensure that the _other_ OpenChannel channel packager's + // source has also been updated after the refresh. This ensures that the + // other packagers will read and write to buckets corresponding to the + // updated short chan id. + if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { + t.Fatalf("channel packager source was not updated: want %v, "+ + "got %v", chanOpenLoc, + pendingChannel.Packager.(*ChannelPackager).source) + } +} diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go new file mode 100644 index 00000000..cfef35e0 --- /dev/null +++ b/channeldb/migration_01_to_11/codec.go @@ -0,0 +1,454 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "fmt" + "io" + "net" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +// writeOutpoint writes an outpoint to the passed writer using the minimal +// amount of bytes possible. +func writeOutpoint(w io.Writer, o *wire.OutPoint) error { + if _, err := w.Write(o.Hash[:]); err != nil { + return err + } + if err := binary.Write(w, byteOrder, o.Index); err != nil { + return err + } + + return nil +} + +// readOutpoint reads an outpoint from the passed reader that was previously +// written using the writeOutpoint struct. +func readOutpoint(r io.Reader, o *wire.OutPoint) error { + if _, err := io.ReadFull(r, o.Hash[:]); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &o.Index); err != nil { + return err + } + + return nil +} + +// UnknownElementType is an error returned when the codec is unable to encode or +// decode a particular type. +type UnknownElementType struct { + method string + element interface{} +} + +// NewUnknownElementType creates a new UnknownElementType error from the passed +// method name and element. +func NewUnknownElementType(method string, el interface{}) UnknownElementType { + return UnknownElementType{method: method, element: el} +} + +// Error returns the name of the method that encountered the error, as well as +// the type that was unsupported. +func (e UnknownElementType) Error() string { + return fmt.Sprintf("Unknown type in %s: %T", e.method, e.element) +} + +// WriteElement is a one-stop shop to write the big endian representation of +// any element which is to be serialized for storage on disk. The passed +// io.Writer should be backed by an appropriately sized byte slice, or be able +// to dynamically expand to accommodate additional data. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case keychain.KeyDescriptor: + if err := binary.Write(w, byteOrder, e.Family); err != nil { + return err + } + if err := binary.Write(w, byteOrder, e.Index); err != nil { + return err + } + + if e.PubKey != nil { + if err := binary.Write(w, byteOrder, true); err != nil { + return fmt.Errorf("error writing serialized element: %s", err) + } + + return WriteElement(w, e.PubKey) + } + + return binary.Write(w, byteOrder, false) + case ChannelType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case chainhash.Hash: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wire.OutPoint: + return writeOutpoint(w, &e) + + case lnwire.ShortChannelID: + if err := binary.Write(w, byteOrder, e.ToUint64()); err != nil { + return err + } + + case lnwire.ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case int64, uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case int32: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint16: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case uint8: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case bool: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case btcutil.Amount: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case lnwire.MilliSatoshi: + if err := binary.Write(w, byteOrder, uint64(e)); err != nil { + return err + } + + case *btcec.PrivateKey: + b := e.Serialize() + if _, err := w.Write(b); err != nil { + return err + } + + case *btcec.PublicKey: + b := e.SerializeCompressed() + if _, err := w.Write(b); err != nil { + return err + } + + case shachain.Producer: + return e.Encode(w) + + case shachain.Store: + return e.Encode(w) + + case *wire.MsgTx: + return e.Serialize(w) + + case [32]byte: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + case lnwire.Message: + if _, err := lnwire.WriteMessage(w, e, 0); err != nil { + return err + } + + case ChannelStatus: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case ClosureType: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case lnwire.FundingFlag: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case net.Addr: + if err := serializeAddr(w, e); err != nil { + return err + } + + case []net.Addr: + if err := WriteElement(w, uint32(len(e))); err != nil { + return err + } + + for _, addr := range e { + if err := serializeAddr(w, addr); err != nil { + return err + } + } + + default: + return UnknownElementType{"WriteElement", e} + } + + return nil +} + +// WriteElements is writes each element in the elements slice to the passed +// io.Writer using WriteElement. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + err := WriteElement(w, element) + if err != nil { + return err + } + } + return nil +} + +// ReadElement is a one-stop utility function to deserialize any datastructure +// encoded using the serialization format of the database. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *keychain.KeyDescriptor: + if err := binary.Read(r, byteOrder, &e.Family); err != nil { + return err + } + if err := binary.Read(r, byteOrder, &e.Index); err != nil { + return err + } + + var hasPubKey bool + if err := binary.Read(r, byteOrder, &hasPubKey); err != nil { + return err + } + + if hasPubKey { + return ReadElement(r, &e.PubKey) + } + + case *ChannelType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *chainhash.Hash: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wire.OutPoint: + return readOutpoint(r, e) + + case *lnwire.ShortChannelID: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + *e = lnwire.NewShortChanIDFromInt(a) + + case *lnwire.ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *int64, *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *int32: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint16: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *uint8: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *bool: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *btcutil.Amount: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = btcutil.Amount(a) + + case *lnwire.MilliSatoshi: + var a uint64 + if err := binary.Read(r, byteOrder, &a); err != nil { + return err + } + + *e = lnwire.MilliSatoshi(a) + + case **btcec.PrivateKey: + var b [btcec.PrivKeyBytesLen]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + priv, _ := btcec.PrivKeyFromBytes(btcec.S256(), b[:]) + *e = priv + + case **btcec.PublicKey: + var b [btcec.PubKeyBytesLenCompressed]byte + if _, err := io.ReadFull(r, b[:]); err != nil { + return err + } + + pubKey, err := btcec.ParsePubKey(b[:], btcec.S256()) + if err != nil { + return err + } + *e = pubKey + + case *shachain.Producer: + var root [32]byte + if _, err := io.ReadFull(r, root[:]); err != nil { + return err + } + + // TODO(roasbeef): remove + producer, err := shachain.NewRevocationProducerFromBytes(root[:]) + if err != nil { + return err + } + + *e = producer + + case *shachain.Store: + store, err := shachain.NewRevocationStoreFromBytes(r) + if err != nil { + return err + } + + *e = store + + case **wire.MsgTx: + tx := wire.NewMsgTx(2) + if err := tx.Deserialize(r); err != nil { + return err + } + + *e = tx + + case *[32]byte: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + case *lnwire.Message: + msg, err := lnwire.ReadMessage(r, 0) + if err != nil { + return err + } + + *e = msg + + case *ChannelStatus: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *ClosureType: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *lnwire.FundingFlag: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *net.Addr: + addr, err := deserializeAddr(r) + if err != nil { + return err + } + *e = addr + + case *[]net.Addr: + var numAddrs uint32 + if err := ReadElement(r, &numAddrs); err != nil { + return err + } + + *e = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := deserializeAddr(r) + if err != nil { + return err + } + (*e)[i] = addr + } + + default: + return UnknownElementType{"ReadElement", e} + } + + return nil +} + +// ReadElements deserializes a variable number of elements into the passed +// io.Reader, with each element being deserialized according to the ReadElement +// function. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + err := ReadElement(r, element) + if err != nil { + return err + } + } + return nil +} diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go new file mode 100644 index 00000000..c4306400 --- /dev/null +++ b/channeldb/migration_01_to_11/db.go @@ -0,0 +1,1185 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "net" + "os" + "path/filepath" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + dbName = "channel.db" + dbFilePermission = 0600 +) + +// migration is a function which takes a prior outdated version of the database +// instances and mutates the key/bucket structure to arrive at a more +// up-to-date version of the database. +type migration func(tx *bbolt.Tx) error + +type version struct { + number uint32 + migration migration +} + +var ( + // dbVersions is storing all versions of database. If current version + // of database don't match with latest version this list will be used + // for retrieving all migration function that are need to apply to the + // current db. + dbVersions = []version{ + { + // The base DB version requires no migration. + number: 0, + migration: nil, + }, + { + // The version of the database where two new indexes + // for the update time of node and channel updates were + // added. + number: 1, + migration: migrateNodeAndEdgeUpdateIndex, + }, + { + // The DB version that added the invoice event time + // series. + number: 2, + migration: migrateInvoiceTimeSeries, + }, + { + // The DB version that updated the embedded invoice in + // outgoing payments to match the new format. + number: 3, + migration: migrateInvoiceTimeSeriesOutgoingPayments, + }, + { + // The version of the database where every channel + // always has two entries in the edges bucket. If + // a policy is unknown, this will be represented + // by a special byte sequence. + number: 4, + migration: migrateEdgePolicies, + }, + { + // The DB version where we persist each attempt to send + // an HTLC to a payment hash, and track whether the + // payment is in-flight, succeeded, or failed. + number: 5, + migration: paymentStatusesMigration, + }, + { + // The DB version that properly prunes stale entries + // from the edge update index. + number: 6, + migration: migratePruneEdgeUpdateIndex, + }, + { + // The DB version that migrates the ChannelCloseSummary + // to a format where optional fields are indicated with + // boolean flags. + number: 7, + migration: migrateOptionalChannelCloseSummaryFields, + }, + { + // The DB version that changes the gossiper's message + // store keys to account for the message's type and + // ShortChannelID. + number: 8, + migration: migrateGossipMessageStoreKeys, + }, + { + // The DB version where the payments and payment + // statuses are moved to being stored in a combined + // bucket. + number: 9, + migration: migrateOutgoingPayments, + }, + { + // The DB version where we started to store legacy + // payload information for all routes, as well as the + // optional TLV records. + number: 10, + migration: migrateRouteSerialization, + }, + { + // Add invoice htlc and cltv delta fields. + number: 11, + migration: migrateInvoices, + }, + } + + // Big endian is the preferred byte order, due to cursor scans over + // integer keys iterating in order. + byteOrder = binary.BigEndian +) + +// DB is the primary datastore for the lnd daemon. The database stores +// information related to nodes, routing data, open/closed channels, fee +// schedules, and reputation data. +type DB struct { + *bbolt.DB + dbPath string + graph *ChannelGraph + now func() time.Time +} + +// Open opens an existing channeldb. Any necessary schemas migrations due to +// updates will take place as necessary. +func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { + path := filepath.Join(dbPath, dbName) + + if !fileExists(path) { + if err := createChannelDB(dbPath); err != nil { + return nil, err + } + } + + opts := DefaultOptions() + for _, modifier := range modifiers { + modifier(&opts) + } + + // Specify bbolt freelist options to reduce heap pressure in case the + // freelist grows to be very large. + options := &bbolt.Options{ + NoFreelistSync: opts.NoFreelistSync, + FreelistType: bbolt.FreelistMapType, + } + + bdb, err := bbolt.Open(path, dbFilePermission, options) + if err != nil { + return nil, err + } + + chanDB := &DB{ + DB: bdb, + dbPath: dbPath, + now: time.Now, + } + chanDB.graph = newChannelGraph( + chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, + ) + + // Synchronize the version of database and apply migrations if needed. + if err := chanDB.syncVersions(dbVersions); err != nil { + bdb.Close() + return nil, err + } + + return chanDB, nil +} + +// Path returns the file path to the channel database. +func (d *DB) Path() string { + return d.dbPath +} + +// Wipe completely deletes all saved state within all used buckets within the +// database. The deletion is done in a single transaction, therefore this +// operation is fully atomic. +func (d *DB) Wipe() error { + return d.Update(func(tx *bbolt.Tx) error { + err := tx.DeleteBucket(openChannelBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(closedChannelBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(invoiceBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(nodeInfoBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + err = tx.DeleteBucket(nodeBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(edgeBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(edgeIndexBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + err = tx.DeleteBucket(graphMetaBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + return nil + }) +} + +// createChannelDB creates and initializes a fresh version of channeldb. In +// the case that the target path has not yet been created or doesn't yet exist, +// then the path is created. Additionally, all required top-level buckets used +// within the database are created. +func createChannelDB(dbPath string) error { + if !fileExists(dbPath) { + if err := os.MkdirAll(dbPath, 0700); err != nil { + return err + } + } + + path := filepath.Join(dbPath, dbName) + bdb, err := bbolt.Open(path, dbFilePermission, nil) + if err != nil { + return err + } + + err = bdb.Update(func(tx *bbolt.Tx) error { + if _, err := tx.CreateBucket(openChannelBucket); err != nil { + return err + } + if _, err := tx.CreateBucket(closedChannelBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(forwardingLogBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(fwdPackagesKey); err != nil { + return err + } + + if _, err := tx.CreateBucket(invoiceBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(paymentBucket); err != nil { + return err + } + + if _, err := tx.CreateBucket(nodeInfoBucket); err != nil { + return err + } + + nodes, err := tx.CreateBucket(nodeBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucket(aliasIndexBucket) + if err != nil { + return err + } + _, err = nodes.CreateBucket(nodeUpdateIndexBucket) + if err != nil { + return err + } + + edges, err := tx.CreateBucket(edgeBucket) + if err != nil { + return err + } + if _, err := edges.CreateBucket(edgeIndexBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(edgeUpdateIndexBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(channelPointBucket); err != nil { + return err + } + if _, err := edges.CreateBucket(zombieBucket); err != nil { + return err + } + + graphMeta, err := tx.CreateBucket(graphMetaBucket) + if err != nil { + return err + } + _, err = graphMeta.CreateBucket(pruneLogBucket) + if err != nil { + return err + } + + if _, err := tx.CreateBucket(metaBucket); err != nil { + return err + } + + meta := &Meta{ + DbVersionNumber: getLatestDBVersion(dbVersions), + } + return putMeta(meta, tx) + }) + if err != nil { + return fmt.Errorf("unable to create new channeldb") + } + + return bdb.Close() +} + +// fileExists returns true if the file exists, and false otherwise. +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + if os.IsNotExist(err) { + return false + } + } + + return true +} + +// FetchOpenChannels starts a new database transaction and returns all stored +// currently active/open channels associated with the target nodeID. In the case +// that no active channels are known to have been created with this node, then a +// zero-length slice is returned. +func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + var channels []*OpenChannel + err := d.View(func(tx *bbolt.Tx) error { + var err error + channels, err = d.fetchOpenChannels(tx, nodeID) + return err + }) + + return channels, err +} + +// fetchOpenChannels uses and existing database transaction and returns all +// stored currently active/open channels associated with the target nodeID. In +// the case that no active channels are known to have been created with this +// node, then a zero-length slice is returned. +func (d *DB) fetchOpenChannels(tx *bbolt.Tx, + nodeID *btcec.PublicKey) ([]*OpenChannel, error) { + + // Get the bucket dedicated to storing the metadata for open channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return nil, nil + } + + // Within this top level bucket, fetch the bucket dedicated to storing + // open channel data specific to the remote node. + pub := nodeID.SerializeCompressed() + nodeChanBucket := openChanBucket.Bucket(pub) + if nodeChanBucket == nil { + return nil, nil + } + + // Next, we'll need to go down an additional layer in order to retrieve + // the channels for each chain the node knows of. + var channels []*OpenChannel + err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, then we'll retrieve + // that so we can extract all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read bucket for chain=%x", + chainHash[:]) + } + + // Finally, we both of the necessary buckets retrieved, fetch + // all the active channels related to this node. + nodeChannels, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read channel for "+ + "chain_hash=%x, node_key=%x: %v", + chainHash[:], pub, err) + } + + channels = append(channels, nodeChannels...) + return nil + }) + + return channels, err +} + +// fetchNodeChannels retrieves all active channels from the target chainBucket +// which is under a node's dedicated channel bucket. This function is typically +// used to fetch all the active channels related to a particular node. +func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) { + + var channels []*OpenChannel + + // A node may have channels on several chains, so for each known chain, + // we'll extract all the channels. + err := chainBucket.ForEach(func(chanPoint, v []byte) error { + // If there's a value, it's not a bucket so ignore it. + if v != nil { + return nil + } + + // Once we've found a valid channel bucket, we'll extract it + // from the node's chain bucket. + chanBucket := chainBucket.Bucket(chanPoint) + + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) + if err != nil { + return err + } + oChannel, err := fetchOpenChannel(chanBucket, &outPoint) + if err != nil { + return fmt.Errorf("unable to read channel data for "+ + "chan_point=%v: %v", outPoint, err) + } + oChannel.Db = d + + channels = append(channels, oChannel) + + return nil + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchChannel attempts to locate a channel specified by the passed channel +// point. If the channel cannot be found, then an error will be returned. +func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { + var ( + targetChan *OpenChannel + targetChanPoint bytes.Buffer + ) + + if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { + return nil, err + } + + // chanScan will traverse the following bucket structure: + // * nodePub => chainHash => chanPoint + // + // At each level we go one further, ensuring that we're traversing the + // proper key (that's actually a bucket). By only reading the bucket + // structure and skipping fully decoding each channel, we save a good + // bit of CPU as we don't need to do things like decompress public + // keys. + chanScan := func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Within the node channel bucket, are the set of node pubkeys + // we have channels with, we don't know the entire set, so + // we'll check them all. + return openChanBucket.ForEach(func(nodePub, v []byte) error { + // Ensure that this is a key the same size as a pubkey, + // and also that it leads directly to a bucket. + if len(nodePub) != 33 || v != nil { + return nil + } + + nodeChanBucket := openChanBucket.Bucket(nodePub) + if nodeChanBucket == nil { + return nil + } + + // The next layer down is all the chains that this node + // has channels on with us. + return nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", chainHash[:]) + } + + // Finally we reach the leaf bucket that stores + // all the chanPoints for this node. + chanBucket := chainBucket.Bucket( + targetChanPoint.Bytes(), + ) + if chanBucket == nil { + return nil + } + + channel, err := fetchOpenChannel( + chanBucket, &chanPoint, + ) + if err != nil { + return err + } + + targetChan = channel + targetChan.Db = d + + return nil + }) + }) + } + + err := d.View(chanScan) + if err != nil { + return nil, err + } + + if targetChan != nil { + return targetChan, nil + } + + // If we can't find the channel, then we return with an error, as we + // have nothing to backup. + return nil, ErrChannelNotFound +} + +// FetchAllChannels attempts to retrieve all open channels currently stored +// within the database, including pending open, fully open and channels waiting +// for a closing transaction to confirm. +func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { + var channels []*OpenChannel + + // TODO(halseth): fetch all in one db tx. + openChannels, err := d.FetchAllOpenChannels() + if err != nil { + return nil, err + } + channels = append(channels, openChannels...) + + pendingChannels, err := d.FetchPendingChannels() + if err != nil { + return nil, err + } + channels = append(channels, pendingChannels...) + + waitingClose, err := d.FetchWaitingCloseChannels() + if err != nil { + return nil, err + } + channels = append(channels, waitingClose...) + + return channels, nil +} + +// FetchAllOpenChannels will return all channels that have the funding +// transaction confirmed, and is not waiting for a closing transaction to be +// confirmed. +func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { + return fetchChannels(d, false, false) +} + +// FetchPendingChannels will return channels that have completed the process of +// generating and broadcasting funding transactions, but whose funding +// transactions have yet to be confirmed on the blockchain. +func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { + return fetchChannels(d, true, false) +} + +// FetchWaitingCloseChannels will return all channels that have been opened, +// but are now waiting for a closing transaction to be confirmed. +// +// NOTE: This includes channels that are also pending to be opened. +func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { + waitingClose, err := fetchChannels(d, false, true) + if err != nil { + return nil, err + } + pendingWaitingClose, err := fetchChannels(d, true, true) + if err != nil { + return nil, err + } + + return append(waitingClose, pendingWaitingClose...), nil +} + +// fetchChannels attempts to retrieve channels currently stored in the +// database. The pending parameter determines whether only pending channels +// will be returned, or only open channels will be returned. The waitingClose +// parameter determines whether only channels waiting for a closing transaction +// to be confirmed should be returned. If no active channels exist within the +// network, then ErrNoActiveChannels is returned. +func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { + var channels []*OpenChannel + + err := d.View(func(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing the metadata for open + // channels. + openChanBucket := tx.Bucket(openChannelBucket) + if openChanBucket == nil { + return ErrNoActiveChannels + } + + // Next, fetch the bucket dedicated to storing metadata related + // to all nodes. All keys within this bucket are the serialized + // public keys of all our direct counterparties. + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return fmt.Errorf("node bucket not created") + } + + // Finally for each node public key in the bucket, fetch all + // the channels related to this particular node. + return nodeMetaBucket.ForEach(func(k, v []byte) error { + nodeChanBucket := openChanBucket.Bucket(k) + if nodeChanBucket == nil { + return nil + } + + return nodeChanBucket.ForEach(func(chainHash, v []byte) error { + // If there's a value, it's not a bucket so + // ignore it. + if v != nil { + return nil + } + + // If we've found a valid chainhash bucket, + // then we'll retrieve that so we can extract + // all the channels. + chainBucket := nodeChanBucket.Bucket(chainHash) + if chainBucket == nil { + return fmt.Errorf("unable to read "+ + "bucket for chain=%x", chainHash[:]) + } + + nodeChans, err := d.fetchNodeChannels(chainBucket) + if err != nil { + return fmt.Errorf("unable to read "+ + "channel for chain_hash=%x, "+ + "node_key=%x: %v", chainHash[:], k, err) + } + for _, channel := range nodeChans { + if channel.IsPending != pending { + continue + } + + // If the channel is in any other state + // than Default, then it means it is + // waiting to be closed. + channelWaitingClose := + channel.ChanStatus() != ChanStatusDefault + + // Only include it if we requested + // channels with the same waitingClose + // status. + if channelWaitingClose != waitingClose { + continue + } + + channels = append(channels, channel) + } + return nil + }) + + }) + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +// FetchClosedChannels attempts to fetch all closed channels from the database. +// The pendingOnly bool toggles if channels that aren't yet fully closed should +// be returned in the response or not. When a channel was cooperatively closed, +// it becomes fully closed after a single confirmation. When a channel was +// forcibly closed, it will become fully closed after _all_ the pending funds +// (if any) have been swept. +func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, error) { + var chanSummaries []*ChannelCloseSummary + + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrNoClosedChannels + } + + return closeBucket.ForEach(func(chanID []byte, summaryBytes []byte) error { + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err := deserializeCloseChannelSummary(summaryReader) + if err != nil { + return err + } + + // If the query specified to only include pending + // channels, then we'll skip any channels which aren't + // currently pending. + if !chanSummary.IsPending && pendingOnly { + return nil + } + + chanSummaries = append(chanSummaries, chanSummary) + return nil + }) + }); err != nil { + return nil, err + } + + return chanSummaries, nil +} + +// ErrClosedChannelNotFound signals that a closed channel could not be found in +// the channeldb. +var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary") + +// FetchClosedChannel queries for a channel close summary using the channel +// point of the channel in question. +func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { + var chanSummary *ChannelCloseSummary + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + var b bytes.Buffer + var err error + if err = writeOutpoint(&b, chanID); err != nil { + return err + } + + summaryBytes := closeBucket.Get(b.Bytes()) + if summaryBytes == nil { + return ErrClosedChannelNotFound + } + + summaryReader := bytes.NewReader(summaryBytes) + chanSummary, err = deserializeCloseChannelSummary(summaryReader) + + return err + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// FetchClosedChannelForID queries for a channel close summary using the +// channel ID of the channel in question. +func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( + *ChannelCloseSummary, error) { + + var chanSummary *ChannelCloseSummary + if err := d.View(func(tx *bbolt.Tx) error { + closeBucket := tx.Bucket(closedChannelBucket) + if closeBucket == nil { + return ErrClosedChannelNotFound + } + + // The first 30 bytes of the channel ID and outpoint will be + // equal. + cursor := closeBucket.Cursor() + op, c := cursor.Seek(cid[:30]) + + // We scan over all possible candidates for this channel ID. + for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { + var outPoint wire.OutPoint + err := readOutpoint(bytes.NewReader(op), &outPoint) + if err != nil { + return err + } + + // If the found outpoint does not correspond to this + // channel ID, we continue. + if !cid.IsChanPoint(&outPoint) { + continue + } + + // Deserialize the close summary and return. + r := bytes.NewReader(c) + chanSummary, err = deserializeCloseChannelSummary(r) + if err != nil { + return err + } + + return nil + } + return ErrClosedChannelNotFound + }); err != nil { + return nil, err + } + + return chanSummary, nil +} + +// MarkChanFullyClosed marks a channel as fully closed within the database. A +// channel should be marked as fully closed if the channel was initially +// cooperatively closed and it's reached a single confirmation, or after all +// the pending funds in a channel that has been forcibly closed have been +// swept. +func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { + return d.Update(func(tx *bbolt.Tx) error { + var b bytes.Buffer + if err := writeOutpoint(&b, chanPoint); err != nil { + return err + } + + chanID := b.Bytes() + + closedChanBucket, err := tx.CreateBucketIfNotExists( + closedChannelBucket, + ) + if err != nil { + return err + } + + chanSummaryBytes := closedChanBucket.Get(chanID) + if chanSummaryBytes == nil { + return fmt.Errorf("no closed channel for "+ + "chan_point=%v found", chanPoint) + } + + chanSummaryReader := bytes.NewReader(chanSummaryBytes) + chanSummary, err := deserializeCloseChannelSummary( + chanSummaryReader, + ) + if err != nil { + return err + } + + chanSummary.IsPending = false + + var newSummary bytes.Buffer + err = serializeChannelCloseSummary(&newSummary, chanSummary) + if err != nil { + return err + } + + err = closedChanBucket.Put(chanID, newSummary.Bytes()) + if err != nil { + return err + } + + // Now that the channel is closed, we'll check if we have any + // other open channels with this peer. If we don't we'll + // garbage collect it to ensure we don't establish persistent + // connections to peers without open channels. + return d.pruneLinkNode(tx, chanSummary.RemotePub) + }) +} + +// pruneLinkNode determines whether we should garbage collect a link node from +// the database due to no longer having any open channels with it. If there are +// any left, then this acts as a no-op. +func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error { + openChannels, err := d.fetchOpenChannels(tx, remotePub) + if err != nil { + return fmt.Errorf("unable to fetch open channels for peer %x: "+ + "%v", remotePub.SerializeCompressed(), err) + } + + if len(openChannels) > 0 { + return nil + } + + log.Infof("Pruning link node %x with zero open channels from database", + remotePub.SerializeCompressed()) + + return d.deleteLinkNode(tx, remotePub) +} + +// PruneLinkNodes attempts to prune all link nodes found within the databse with +// whom we no longer have any open channels with. +func (d *DB) PruneLinkNodes() error { + return d.Update(func(tx *bbolt.Tx) error { + linkNodes, err := d.fetchAllLinkNodes(tx) + if err != nil { + return err + } + + for _, linkNode := range linkNodes { + err := d.pruneLinkNode(tx, linkNode.IdentityPub) + if err != nil { + return err + } + } + + return nil + }) +} + +// ChannelShell is a shell of a channel that is meant to be used for channel +// recovery purposes. It contains a minimal OpenChannel instance along with +// addresses for that target node. +type ChannelShell struct { + // NodeAddrs the set of addresses that this node has known to be + // reachable at in the past. + NodeAddrs []net.Addr + + // Chan is a shell of an OpenChannel, it contains only the items + // required to restore the channel on disk. + Chan *OpenChannel +} + +// RestoreChannelShells is a method that allows the caller to reconstruct the +// state of an OpenChannel from the ChannelShell. We'll attempt to write the +// new channel to disk, create a LinkNode instance with the passed node +// addresses, and finally create an edge within the graph for the channel as +// well. This method is idempotent, so repeated calls with the same set of +// channel shells won't modify the database after the initial call. +func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { + chanGraph := d.ChannelGraph() + + // TODO(conner): find way to do this w/o accessing internal members? + chanGraph.cacheMu.Lock() + defer chanGraph.cacheMu.Unlock() + + var chansRestored []uint64 + err := d.Update(func(tx *bbolt.Tx) error { + for _, channelShell := range channelShells { + channel := channelShell.Chan + + // When we make a channel, we mark that the channel has + // been restored, this will signal to other sub-systems + // to not attempt to use the channel as if it was a + // regular one. + channel.chanStatus |= ChanStatusRestored + + // First, we'll attempt to create a new open channel + // and link node for this channel. If the channel + // already exists, then in order to ensure this method + // is idempotent, we'll continue to the next step. + channel.Db = d + err := syncNewChannel( + tx, channel, channelShell.NodeAddrs, + ) + if err != nil { + return err + } + + // Next, we'll create an active edge in the graph + // database for this channel in order to restore our + // partial view of the network. + // + // TODO(roasbeef): if we restore *after* the channel + // has been closed on chain, then need to inform the + // router that it should try and prune these values as + // we can detect them + edgeInfo := ChannelEdgeInfo{ + ChannelID: channel.ShortChannelID.ToUint64(), + ChainHash: channel.ChainHash, + ChannelPoint: channel.FundingOutpoint, + Capacity: channel.Capacity, + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + selfNode, err := chanGraph.sourceNode(nodes) + if err != nil { + return err + } + + // Depending on which pub key is smaller, we'll assign + // our roles as "node1" and "node2". + chanPeer := channel.IdentityPub.SerializeCompressed() + selfIsSmaller := bytes.Compare( + selfNode.PubKeyBytes[:], chanPeer, + ) == -1 + if selfIsSmaller { + copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], chanPeer) + } else { + copy(edgeInfo.NodeKey1Bytes[:], chanPeer) + copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) + } + + // With the edge info shell constructed, we'll now add + // it to the graph. + err = chanGraph.addChannelEdge(tx, &edgeInfo) + if err != nil && err != ErrEdgeAlreadyExist { + return err + } + + // Similarly, we'll construct a channel edge shell and + // add that itself to the graph. + chanEdge := ChannelEdgePolicy{ + ChannelID: edgeInfo.ChannelID, + LastUpdate: time.Now(), + } + + // If their pubkey is larger, then we'll flip the + // direction bit to indicate that us, the "second" node + // is updating their policy. + if !selfIsSmaller { + chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection + } + + _, err = updateEdgePolicy(tx, &chanEdge) + if err != nil { + return err + } + + chansRestored = append(chansRestored, edgeInfo.ChannelID) + } + + return nil + }) + if err != nil { + return err + } + + for _, chanid := range chansRestored { + chanGraph.rejectCache.remove(chanid) + chanGraph.chanCache.remove(chanid) + } + + return nil +} + +// AddrsForNode consults the graph and channel database for all addresses known +// to the passed node public key. +func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { + var ( + linkNode *LinkNode + graphNode LightningNode + ) + + dbErr := d.View(func(tx *bbolt.Tx) error { + var err error + + linkNode, err = fetchLinkNode(tx, nodePub) + if err != nil { + return err + } + + // We'll also query the graph for this peer to see if they have + // any addresses that we don't currently have stored within the + // link node database. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + compressedPubKey := nodePub.SerializeCompressed() + graphNode, err = fetchLightningNode(nodes, compressedPubKey) + if err != nil && err != ErrGraphNodeNotFound { + // If the node isn't found, then that's OK, as we still + // have the link node data. + return err + } + + return nil + }) + if dbErr != nil { + return nil, dbErr + } + + // Now that we have both sources of addrs for this node, we'll use a + // map to de-duplicate any addresses between the two sources, and + // produce a final list of the combined addrs. + addrs := make(map[string]net.Addr) + for _, addr := range linkNode.Addresses { + addrs[addr.String()] = addr + } + for _, addr := range graphNode.Addresses { + addrs[addr.String()] = addr + } + dedupedAddrs := make([]net.Addr, 0, len(addrs)) + for _, addr := range addrs { + dedupedAddrs = append(dedupedAddrs, addr) + } + + return dedupedAddrs, nil +} + +// syncVersions function is used for safe db version synchronization. It +// applies migration functions to the current database and recovers the +// previous state of db if at least one error/panic appeared during migration. +func (d *DB) syncVersions(versions []version) error { + meta, err := d.FetchMeta(nil) + if err != nil { + if err == ErrMetaNotFound { + meta = &Meta{} + } else { + return err + } + } + + latestVersion := getLatestDBVersion(versions) + log.Infof("Checking for schema update: latest_version=%v, "+ + "db_version=%v", latestVersion, meta.DbVersionNumber) + + switch { + + // If the database reports a higher version that we are aware of, the + // user is probably trying to revert to a prior version of lnd. We fail + // here to prevent reversions and unintended corruption. + case meta.DbVersionNumber > latestVersion: + log.Errorf("Refusing to revert from db_version=%d to "+ + "lower version=%d", meta.DbVersionNumber, + latestVersion) + return ErrDBReversion + + // If the current database version matches the latest version number, + // then we don't need to perform any migrations. + case meta.DbVersionNumber == latestVersion: + return nil + } + + log.Infof("Performing database schema migration") + + // Otherwise, we fetch the migrations which need to applied, and + // execute them serially within a single database transaction to ensure + // the migration is atomic. + migrations, migrationVersions := getMigrationsToApply( + versions, meta.DbVersionNumber, + ) + return d.Update(func(tx *bbolt.Tx) error { + for i, migration := range migrations { + if migration == nil { + continue + } + + log.Infof("Applying migration #%v", migrationVersions[i]) + + if err := migration(tx); err != nil { + log.Infof("Unable to apply migration #%v", + migrationVersions[i]) + return err + } + } + + meta.DbVersionNumber = latestVersion + return putMeta(meta, tx) + }) +} + +// ChannelGraph returns a new instance of the directed channel graph. +func (d *DB) ChannelGraph() *ChannelGraph { + return d.graph +} + +func getLatestDBVersion(versions []version) uint32 { + return versions[len(versions)-1].number +} + +// getMigrationsToApply retrieves the migration function that should be +// applied to the database. +func getMigrationsToApply(versions []version, version uint32) ([]migration, []uint32) { + migrations := make([]migration, 0, len(versions)) + migrationVersions := make([]uint32, 0, len(versions)) + + for _, v := range versions { + if v.number > version { + migrations = append(migrations, v.migration) + migrationVersions = append(migrationVersions, v.number) + } + } + + return migrations, migrationVersions +} diff --git a/channeldb/migration_01_to_11/db_test.go b/channeldb/migration_01_to_11/db_test.go new file mode 100644 index 00000000..721546e7 --- /dev/null +++ b/channeldb/migration_01_to_11/db_test.go @@ -0,0 +1,471 @@ +package migration_01_to_11 + +import ( + "io/ioutil" + "math" + "math/rand" + "net" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/shachain" +) + +func TestOpenWithCreate(t *testing.T) { + t.Parallel() + + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDirName) + + // Next, open thereby creating channeldb for the first time. + dbPath := filepath.Join(tempDirName, "cdb") + cdb, err := Open(dbPath) + if err != nil { + t.Fatalf("unable to create channeldb: %v", err) + } + if err := cdb.Close(); err != nil { + t.Fatalf("unable to close channeldb: %v", err) + } + + // The path should have been successfully created. + if !fileExists(dbPath) { + t.Fatalf("channeldb failed to create data directory") + } +} + +// TestWipe tests that the database wipe operation completes successfully +// and that the buckets are deleted. It also checks that attempts to fetch +// information while the buckets are not set return the correct errors. +func TestWipe(t *testing.T) { + t.Parallel() + + // First, create a temporary directory to be used for the duration of + // this test. + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + defer os.RemoveAll(tempDirName) + + // Next, open thereby creating channeldb for the first time. + dbPath := filepath.Join(tempDirName, "cdb") + cdb, err := Open(dbPath) + if err != nil { + t.Fatalf("unable to create channeldb: %v", err) + } + defer cdb.Close() + + if err := cdb.Wipe(); err != nil { + t.Fatalf("unable to wipe channeldb: %v", err) + } + // Check correct errors are returned + _, err = cdb.FetchAllOpenChannels() + if err != ErrNoActiveChannels { + t.Fatalf("fetching open channels: expected '%v' instead got '%v'", + ErrNoActiveChannels, err) + } + _, err = cdb.FetchClosedChannels(false) + if err != ErrNoClosedChannels { + t.Fatalf("fetching closed channels: expected '%v' instead got '%v'", + ErrNoClosedChannels, err) + } +} + +// TestFetchClosedChannelForID tests that we are able to properly retrieve a +// ChannelCloseSummary from the DB given a ChannelID. +func TestFetchClosedChannelForID(t *testing.T) { + t.Parallel() + + const numChans = 101 + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state, that we will mutate the index of the + // funding point. + state, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Now run through the number of channels, and modify the outpoint index + // to create new channel IDs. + for i := uint32(0); i < numChans; i++ { + // Save the open channel to disk. + state.FundingOutpoint.Index = i + + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18556, + } + if err := state.SyncPending(addr, 101); err != nil { + t.Fatalf("unable to save and serialize channel "+ + "state: %v", err) + } + + // Close the channel. To make sure we retrieve the correct + // summary later, we make them differ in the SettledBalance. + closeSummary := &ChannelCloseSummary{ + ChanPoint: state.FundingOutpoint, + RemotePub: state.IdentityPub, + SettledBalance: btcutil.Amount(500 + i), + } + if err := state.CloseChannel(closeSummary); err != nil { + t.Fatalf("unable to close channel: %v", err) + } + } + + // Now run though them all again and make sure we are able to retrieve + // summaries from the DB. + for i := uint32(0); i < numChans; i++ { + state.FundingOutpoint.Index = i + + // We calculate the ChannelID and use it to fetch the summary. + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + fetchedSummary, err := cdb.FetchClosedChannelForID(cid) + if err != nil { + t.Fatalf("unable to fetch close summary: %v", err) + } + + // Make sure we retrieved the correct one by checking the + // SettledBalance. + if fetchedSummary.SettledBalance != btcutil.Amount(500+i) { + t.Fatalf("summaries don't match: expected %v got %v", + btcutil.Amount(500+i), + fetchedSummary.SettledBalance) + } + } + + // As a final test we make sure that we get ErrClosedChannelNotFound + // for a ChannelID we didn't add to the DB. + state.FundingOutpoint.Index++ + cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) + _, err = cdb.FetchClosedChannelForID(cid) + if err != ErrClosedChannelNotFound { + t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) + } +} + +// TestAddrsForNode tests the we're able to properly obtain all the addresses +// for a target node. +func TestAddrsForNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + graph := cdb.ChannelGraph() + + // We'll make a test vertex to insert into the database, as the source + // node, but this node will only have half the number of addresses it + // usually does. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + testNode.Addresses = []net.Addr{testAddr} + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // Next, we'll make a link node with the same pubkey, but with an + // additional address. + nodePub, err := testNode.PubKey() + if err != nil { + t.Fatalf("unable to recv node pub: %v", err) + } + linkNode := cdb.NewLinkNode( + wire.MainNet, nodePub, anotherAddr, + ) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to sync link node: %v", err) + } + + // Now that we've created a link node, as well as a vertex for the + // node, we'll query for all its addresses. + nodeAddrs, err := cdb.AddrsForNode(nodePub) + if err != nil { + t.Fatalf("unable to obtain node addrs: %v", err) + } + + expectedAddrs := make(map[string]struct{}) + expectedAddrs[testAddr.String()] = struct{}{} + expectedAddrs[anotherAddr.String()] = struct{}{} + + // Finally, ensure that all the expected addresses are found. + if len(nodeAddrs) != len(expectedAddrs) { + t.Fatalf("expected %v addrs, got %v", + len(expectedAddrs), len(nodeAddrs)) + } + for _, addr := range nodeAddrs { + if _, ok := expectedAddrs[addr.String()]; !ok { + t.Fatalf("unexpected addr: %v", addr) + } + } +} + +// TestFetchChannel tests that we're able to fetch an arbitrary channel from +// disk. +func TestFetchChannel(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // Create the test channel state that we'll sync to the database + // shortly. + channelState, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + // Mark the channel as pending, then immediately mark it as open to it + // can be fully visible. + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + } + if err := channelState.SyncPending(addr, 9); err != nil { + t.Fatalf("unable to save and serialize channel state: %v", err) + } + err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99)) + if err != nil { + t.Fatalf("unable to mark channel open: %v", err) + } + + // Next, attempt to fetch the channel by its chan point. + dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + + // The decoded channel state should be identical to what we stored + // above. + if !reflect.DeepEqual(channelState, dbChannel) { + t.Fatalf("channel state doesn't match:: %v vs %v", + spew.Sdump(channelState), spew.Sdump(dbChannel)) + } + + // If we attempt to query for a non-exist ante channel, then we should + // get an error. + channelState2, err := createTestChannelState(cdb) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + channelState2.FundingOutpoint.Index ^= 1 + + _, err = cdb.FetchChannel(channelState2.FundingOutpoint) + if err == nil { + t.Fatalf("expected query to fail") + } +} + +func genRandomChannelShell() (*ChannelShell, error) { + var testPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + + _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) + + var chanPoint wire.OutPoint + if _, err := rand.Read(chanPoint.Hash[:]); err != nil { + return nil, err + } + + pub.Curve = nil + + chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) + + chanStatus := ChanStatusDefault | ChanStatusRestored + + var shaChainPriv [32]byte + if _, err := rand.Read(testPriv[:]); err != nil { + return nil, err + } + revRoot, err := chainhash.NewHash(shaChainPriv[:]) + if err != nil { + return nil, err + } + shaChainProducer := shachain.NewRevocationProducer(*revRoot) + + return &ChannelShell{ + NodeAddrs: []net.Addr{&net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 18555, + }}, + Chan: &OpenChannel{ + chanStatus: chanStatus, + ChainHash: rev, + FundingOutpoint: chanPoint, + ShortChannelID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + IdentityPub: pub, + LocalChanCfg: ChannelConfig{ + ChannelConstraints: ChannelConstraints{ + CsvDelay: uint16(rand.Int63()), + }, + PaymentBasePoint: keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily(rand.Int63()), + Index: uint32(rand.Int63()), + }, + }, + }, + RemoteCurrentRevocation: pub, + IsPending: false, + RevocationStore: shachain.NewRevocationStore(), + RevocationProducer: shaChainProducer, + }, + }, nil +} + +// TestRestoreChannelShells tests that we're able to insert a partially channel +// populated to disk. This is useful for channel recovery purposes. We should +// find the new channel shell on disk, and also the db should be populated with +// an edge for that channel. +func TestRestoreChannelShells(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First, we'll make our channel shell, it will only have the minimal + // amount of information required for us to initiate the data loss + // protection feature. + channelShell, err := genRandomChannelShell() + if err != nil { + t.Fatalf("unable to gen channel shell: %v", err) + } + + graph := cdb.ChannelGraph() + + // Before we can restore the channel, we'll need to make a source node + // in the graph as the channel edge we create will need to have a + // origin. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // With the channel shell constructed, we'll now insert it into the + // database with the restoration method. + if err := cdb.RestoreChannelShells(channelShell); err != nil { + t.Fatalf("unable to restore channel shell: %v", err) + } + + // Now that the channel has been inserted, we'll attempt to query for + // it to ensure we can properly locate it via various means. + // + // First, we'll attempt to query for all channels that we have with the + // node public key that was restored. + nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable find channel: %v", err) + } + + // We should now find a single channel from the database. + if len(nodeChans) != 1 { + t.Fatalf("unable to find restored channel by node "+ + "pubkey: %v", err) + } + + // Ensure that it isn't possible to modify the commitment state machine + // of this restored channel. + channel := nodeChans[0] + err = channel.UpdateCommitment(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + err = channel.AppendRemoteCommitChain(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + err = channel.AdvanceCommitChainTail(nil) + if err != ErrNoRestoredChannelMutation { + t.Fatalf("able to mutate restored channel") + } + + // That single channel should have the proper channel point, and also + // the expected set of flags to indicate that it was a restored + // channel. + if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint { + t.Fatalf("wrong funding outpoint: expected %v, got %v", + nodeChans[0].FundingOutpoint, + channelShell.Chan.FundingOutpoint) + } + if !nodeChans[0].HasChanStatus(ChanStatusRestored) { + t.Fatalf("node has wrong status flags: %v", + nodeChans[0].chanStatus) + } + + // We should also be able to find the channel if we query for it + // directly. + _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) + if err != nil { + t.Fatalf("unable to fetch channel: %v", err) + } + + // We should also be able to find the link node that was inserted by + // its public key. + linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) + if err != nil { + t.Fatalf("unable to fetch link node: %v", err) + } + + // The node should have the same address, as specified in the channel + // shell. + if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) { + t.Fatalf("addr mismach: expected %v, got %v", + linkNode.Addresses, channelShell.NodeAddrs) + } + + // Finally, we'll ensure that the edge for the channel was properly + // inserted. + chanInfos, err := graph.FetchChanInfos( + []uint64{channelShell.Chan.ShortChannelID.ToUint64()}, + ) + if err != nil { + t.Fatalf("unable to find edges: %v", err) + } + + if len(chanInfos) != 1 { + t.Fatalf("wrong amount of chan infos: expected %v got %v", + len(chanInfos), 1) + } + + // We should only find a single edge. + if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil { + t.Fatalf("only a single edge should be inserted: %v", err) + } +} diff --git a/channeldb/migration_01_to_11/doc.go b/channeldb/migration_01_to_11/doc.go new file mode 100644 index 00000000..c90412f2 --- /dev/null +++ b/channeldb/migration_01_to_11/doc.go @@ -0,0 +1 @@ +package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/error.go b/channeldb/migration_01_to_11/error.go new file mode 100644 index 00000000..f264fb70 --- /dev/null +++ b/channeldb/migration_01_to_11/error.go @@ -0,0 +1,128 @@ +package migration_01_to_11 + +import ( + "errors" + "fmt" +) + +var ( + // ErrNoChanDBExists is returned when a channel bucket hasn't been + // created. + ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") + + // ErrDBReversion is returned when detecting an attempt to revert to a + // prior database version. + ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version") + + // ErrLinkNodesNotFound is returned when node info bucket hasn't been + // created. + ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist") + + // ErrNoActiveChannels is returned when there is no active (open) + // channels within the database. + ErrNoActiveChannels = fmt.Errorf("no active channels exist") + + // ErrNoPastDeltas is returned when the channel delta bucket hasn't been + // created. + ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") + + // ErrInvoiceNotFound is returned when a targeted invoice can't be + // found. + ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice") + + // ErrNoInvoicesCreated is returned when we don't have invoices in + // our database to return. + ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") + + // ErrDuplicateInvoice is returned when an invoice with the target + // payment hash already exists. + ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") + + // ErrNoPaymentsCreated is returned when bucket of payments hasn't been + // created. + ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") + + // ErrNodeNotFound is returned when node bucket exists, but node with + // specific identity can't be found. + ErrNodeNotFound = fmt.Errorf("link node with target identity not found") + + // ErrChannelNotFound is returned when we attempt to locate a channel + // for a specific chain, but it is not found. + ErrChannelNotFound = fmt.Errorf("channel not found") + + // ErrMetaNotFound is returned when meta bucket hasn't been + // created. + ErrMetaNotFound = fmt.Errorf("unable to locate meta information") + + // ErrGraphNotFound is returned when at least one of the components of + // graph doesn't exist. + ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") + + // ErrGraphNeverPruned is returned when graph was never pruned. + ErrGraphNeverPruned = fmt.Errorf("graph never pruned") + + // ErrSourceNodeNotSet is returned if the source node of the graph + // hasn't been added The source node is the center node within a + // star-graph. + ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") + + // ErrGraphNodesNotFound is returned in case none of the nodes has + // been added in graph node bucket. + ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") + + // ErrGraphNoEdgesFound is returned in case of none of the channel/edges + // has been added in graph edge bucket. + ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") + + // ErrGraphNodeNotFound is returned when we're unable to find the target + // node. + ErrGraphNodeNotFound = fmt.Errorf("unable to find node") + + // ErrEdgeNotFound is returned when an edge for the target chanID + // can't be found. + ErrEdgeNotFound = fmt.Errorf("edge not found") + + // ErrZombieEdge is an error returned when we attempt to look up an edge + // but it is marked as a zombie within the zombie index. + ErrZombieEdge = errors.New("edge marked as zombie") + + // ErrEdgeAlreadyExist is returned when edge with specific + // channel id can't be added because it already exist. + ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") + + // ErrNodeAliasNotFound is returned when alias for node can't be found. + ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") + + // ErrUnknownAddressType is returned when a node's addressType is not + // an expected value. + ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") + + // ErrNoClosedChannels is returned when a node is queries for all the + // channels it has closed, but it hasn't yet closed any channels. + ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") + + // ErrNoForwardingEvents is returned in the case that a query fails due + // to the log not having any recorded events. + ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events") + + // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel + // policy field is not found in the db even though its message flags + // indicate it should be. + ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + + "present") + + // ErrChanAlreadyExists is return when the caller attempts to create a + // channel with a channel point that is already present in the + // database. + ErrChanAlreadyExists = fmt.Errorf("channel already exists") +) + +// ErrTooManyExtraOpaqueBytes creates an error which should be returned if the +// caller attempts to write an announcement message which bares too many extra +// opaque bytes. We limit this value in order to ensure that we don't waste +// disk space due to nodes unnecessarily padding out their announcements with +// garbage data. +func ErrTooManyExtraOpaqueBytes(numBytes int) error { + return fmt.Errorf("max allowed number of opaque bytes is %v, received "+ + "%v bytes", MaxAllowedExtraOpaqueBytes, numBytes) +} diff --git a/channeldb/migration_01_to_11/fees.go b/channeldb/migration_01_to_11/fees.go new file mode 100644 index 00000000..c90412f2 --- /dev/null +++ b/channeldb/migration_01_to_11/fees.go @@ -0,0 +1 @@ +package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/forwarding_log.go b/channeldb/migration_01_to_11/forwarding_log.go new file mode 100644 index 00000000..6b9f8f5d --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_log.go @@ -0,0 +1,274 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + "sort" + "time" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // forwardingLogBucket is the bucket that we'll use to store the + // forwarding log. The forwarding log contains a time series database + // of the forwarding history of a lightning daemon. Each key within the + // bucket is a timestamp (in nano seconds since the unix epoch), and + // the value a slice of a forwarding event for that timestamp. + forwardingLogBucket = []byte("circuit-fwd-log") +) + +const ( + // forwardingEventSize is the size of a forwarding event. The breakdown + // is as follows: + // + // * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in + // || 8 byte value out + // + // From the value in and value out, callers can easily compute the + // total fee extract from a forwarding event. + forwardingEventSize = 32 + + // MaxResponseEvents is the max number of forwarding events that will + // be returned by a single query response. This size was selected to + // safely remain under gRPC's 4MiB message size response limit. As each + // full forwarding event (including the timestamp) is 40 bytes, we can + // safely return 50k entries in a single response. + MaxResponseEvents = 50000 +) + +// ForwardingLog returns an instance of the ForwardingLog object backed by the +// target database instance. +func (d *DB) ForwardingLog() *ForwardingLog { + return &ForwardingLog{ + db: d, + } +} + +// ForwardingLog is a time series database that logs the fulfilment of payment +// circuits by a lightning network daemon. The log contains a series of +// forwarding events which map a timestamp to a forwarding event. A forwarding +// event describes which channels were used to create+settle a circuit, and the +// amount involved. Subtracting the outgoing amount from the incoming amount +// reveals the fee charged for the forwarding service. +type ForwardingLog struct { + db *DB +} + +// ForwardingEvent is an event in the forwarding log's time series. Each +// forwarding event logs the creation and tear-down of a payment circuit. A +// circuit is created once an incoming HTLC has been fully forwarded, and +// destroyed once the payment has been settled. +type ForwardingEvent struct { + // Timestamp is the settlement time of this payment circuit. + Timestamp time.Time + + // IncomingChanID is the incoming channel ID of the payment circuit. + IncomingChanID lnwire.ShortChannelID + + // OutgoingChanID is the outgoing channel ID of the payment circuit. + OutgoingChanID lnwire.ShortChannelID + + // AmtIn is the amount of the incoming HTLC. Subtracting this from the + // outgoing amount gives the total fees of this payment circuit. + AmtIn lnwire.MilliSatoshi + + // AmtOut is the amount of the outgoing HTLC. Subtracting the incoming + // amount from this gives the total fees for this payment circuit. + AmtOut lnwire.MilliSatoshi +} + +// encodeForwardingEvent writes out the target forwarding event to the passed +// io.Writer, using the expected DB format. Note that the timestamp isn't +// serialized as this will be the key value within the bucket. +func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { + return WriteElements( + w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, + ) +} + +// decodeForwardingEvent attempts to decode the raw bytes of a serialized +// forwarding event into the target ForwardingEvent. Note that the timestamp +// won't be decoded, as the caller is expected to set this due to the bucket +// structure of the forwarding log. +func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { + return ReadElements( + r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, + ) +} + +// AddForwardingEvents adds a series of forwarding events to the database. +// Before inserting, the set of events will be sorted according to their +// timestamp. This ensures that all writes to disk are sequential. +func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { + // Before we create the database transaction, we'll ensure that the set + // of forwarding events are properly sorted according to their + // timestamp. + sort.Slice(events, func(i, j int) bool { + return events[i].Timestamp.Before(events[j].Timestamp) + }) + + var timestamp [8]byte + + return f.db.Batch(func(tx *bbolt.Tx) error { + // First, we'll fetch the bucket that stores our time series + // log. + logBucket, err := tx.CreateBucketIfNotExists( + forwardingLogBucket, + ) + if err != nil { + return err + } + + // With the bucket obtained, we can now begin to write out the + // series of events. + for _, event := range events { + var eventBytes [forwardingEventSize]byte + eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) + + // First, we'll serialize this timestamp into our + // timestamp buffer. + byteOrder.PutUint64( + timestamp[:], uint64(event.Timestamp.UnixNano()), + ) + + // With the key encoded, we'll then encode the event + // into our buffer, then write it out to disk. + err := encodeForwardingEvent(eventBuf, &event) + if err != nil { + return err + } + err = logBucket.Put(timestamp[:], eventBuf.Bytes()) + if err != nil { + return err + } + } + + return nil + }) +} + +// ForwardingEventQuery represents a query to the forwarding log payment +// circuit time series database. The query allows a caller to retrieve all +// records for a particular time slice, offset in that time slice, limiting the +// total number of responses returned. +type ForwardingEventQuery struct { + // StartTime is the start time of the time slice. + StartTime time.Time + + // EndTime is the end time of the time slice. + EndTime time.Time + + // IndexOffset is the offset within the time slice to start at. This + // can be used to start the response at a particular record. + IndexOffset uint32 + + // NumMaxEvents is the max number of events to return. + NumMaxEvents uint32 +} + +// ForwardingLogTimeSlice is the response to a forwarding query. It includes +// the original query, the set events that match the query, and an integer +// which represents the offset index of the last item in the set of retuned +// events. This integer allows callers to resume their query using this offset +// in the event that the query's response exceeds the max number of returnable +// events. +type ForwardingLogTimeSlice struct { + ForwardingEventQuery + + // ForwardingEvents is the set of events in our time series that answer + // the query embedded above. + ForwardingEvents []ForwardingEvent + + // LastIndexOffset is the index of the last element in the set of + // returned ForwardingEvents above. Callers can use this to resume + // their query in the event that the time slice has too many events to + // fit into a single response. + LastIndexOffset uint32 +} + +// Query allows a caller to query the forwarding event time series for a +// particular time slice. The caller can control the precise time as well as +// the number of events to be returned. +// +// TODO(roasbeef): rename? +func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { + resp := ForwardingLogTimeSlice{ + ForwardingEventQuery: q, + } + + // If the user provided an index offset, then we'll not know how many + // records we need to skip. We'll also keep track of the record offset + // as that's part of the final return value. + recordsToSkip := q.IndexOffset + recordOffset := q.IndexOffset + + err := f.db.View(func(tx *bbolt.Tx) error { + // If the bucket wasn't found, then there aren't any events to + // be returned. + logBucket := tx.Bucket(forwardingLogBucket) + if logBucket == nil { + return ErrNoForwardingEvents + } + + // We'll be using a cursor to seek into the database, so we'll + // populate byte slices that represent the start of the key + // space we're interested in, and the end. + var startTime, endTime [8]byte + byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano())) + byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano())) + + // If we know that a set of log events exists, then we'll begin + // our seek through the log in order to satisfy the query. + // We'll continue until either we reach the end of the range, + // or reach our max number of events. + logCursor := logBucket.Cursor() + timestamp, events := logCursor.Seek(startTime[:]) + for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() { + // If our current return payload exceeds the max number + // of events, then we'll exit now. + if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents { + return nil + } + + // If we're not yet past the user defined offset, then + // we'll continue to seek forward. + if recordsToSkip > 0 { + recordsToSkip-- + continue + } + + currentTime := time.Unix( + 0, int64(byteOrder.Uint64(timestamp)), + ) + + // At this point, we've skipped enough records to start + // to collate our query. For each record, we'll + // increment the final record offset so the querier can + // utilize pagination to seek further. + readBuf := bytes.NewReader(events) + for readBuf.Len() != 0 { + var event ForwardingEvent + err := decodeForwardingEvent(readBuf, &event) + if err != nil { + return err + } + + event.Timestamp = currentTime + resp.ForwardingEvents = append(resp.ForwardingEvents, event) + + recordOffset++ + } + } + + return nil + }) + if err != nil && err != ErrNoForwardingEvents { + return ForwardingLogTimeSlice{}, err + } + + resp.LastIndexOffset = recordOffset + + return resp, nil +} diff --git a/channeldb/migration_01_to_11/forwarding_log_test.go b/channeldb/migration_01_to_11/forwarding_log_test.go new file mode 100644 index 00000000..9e0de7c4 --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_log_test.go @@ -0,0 +1,265 @@ +package migration_01_to_11 + +import ( + "math/rand" + "reflect" + "testing" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" + + "time" +) + +// TestForwardingLogBasicStorageAndQuery tests that we're able to store and +// then query for items that have previously been added to the event log. +func TestForwardingLogBasicStorageAndQuery(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + timestamp := time.Unix(1234, 0) + + // We'll create 100 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 100 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: timestamp, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + timestamp = timestamp.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // With our events added we'll now construct a basic query to retrieve + // all of the events. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: timestamp, + IndexOffset: 0, + NumMaxEvents: 1000, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // The set of returned events should match identically, as they should + // be returned in sorted order. + if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) { + t.Fatalf("event mismatch: expected %v vs %v", + spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The offset index of the final entry should be numEvents, so the + // number of total events we've written. + if timeSlice.LastIndexOffset != uint32(numEvents) { + t.Fatalf("wrong final offset: expected %v, got %v", + timeSlice.LastIndexOffset, numEvents) + } +} + +// TestForwardingLogQueryOptions tests that the query offset works properly. So +// if we add a series of events, then we should be able to seek within the +// timeslice accordingly. This exercises the index offset and num max event +// field in the query, and also the last index offset field int he response. +func TestForwardingLogQueryOptions(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + endTime := time.Unix(1234, 0) + + // We'll create 20 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 20 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: endTime, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + endTime = endTime.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // With all of our events added, we should be able to query for the + // first 10 events using the max event query field. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: endTime, + IndexOffset: 0, + NumMaxEvents: 10, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 10 events back. + if len(timeSlice.ForwardingEvents) != 10 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The set of events returned should be the first 10 events that we + // added. + if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[:10]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The final offset should be the exact number of events returned. + if timeSlice.LastIndexOffset != 10 { + t.Fatalf("wrong index offset: expected %v, got %v", 10, + timeSlice.LastIndexOffset) + } + + // If we use the final offset to query again, then we should get 10 + // more events, that are the last 10 events we wrote. + eventQuery.IndexOffset = 10 + timeSlice, err = log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 10 events back once again. + if len(timeSlice.ForwardingEvents) != 10 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The events that we got back should be the last 10 events that we + // wrote out. + if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[10:]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // Finally, the last index offset should be 20, or the number of + // records we've written out. + if timeSlice.LastIndexOffset != 20 { + t.Fatalf("wrong index offset: expected %v, got %v", 20, + timeSlice.LastIndexOffset) + } +} + +// TestForwardingLogQueryLimit tests that we're able to properly limit the +// number of events that are returned as part of a query. +func TestForwardingLogQueryLimit(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + endTime := time.Unix(1234, 0) + + // We'll create 200 random events, which each event being spaced 10 + // minutes after the prior event. + numEvents := 200 + events := make([]ForwardingEvent, numEvents) + for i := 0; i < numEvents; i++ { + events[i] = ForwardingEvent{ + Timestamp: endTime, + IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + endTime = endTime.Add(time.Minute * 10) + } + + // Now that all of our set of events constructed, we'll add them to the + // database in a batch manner. + if err := log.AddForwardingEvents(events); err != nil { + t.Fatalf("unable to add events: %v", err) + } + + // Once the events have been written out, we'll issue a query over the + // entire range, but restrict the number of events to the first 100. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: endTime, + IndexOffset: 0, + NumMaxEvents: 100, + } + timeSlice, err := log.Query(eventQuery) + if err != nil { + t.Fatalf("unable to query for events: %v", err) + } + + // We should get exactly 100 events back. + if len(timeSlice.ForwardingEvents) != 100 { + t.Fatalf("wrong number of events: expected %v, got %v", 10, + len(timeSlice.ForwardingEvents)) + } + + // The set of events returned should be the first 100 events that we + // added. + if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) { + t.Fatalf("wrong response: expected %v, got %v", + spew.Sdump(events[:100]), + spew.Sdump(timeSlice.ForwardingEvents)) + } + + // The final offset should be the exact number of events returned. + if timeSlice.LastIndexOffset != 100 { + t.Fatalf("wrong index offset: expected %v, got %v", 100, + timeSlice.LastIndexOffset) + } +} diff --git a/channeldb/migration_01_to_11/forwarding_package.go b/channeldb/migration_01_to_11/forwarding_package.go new file mode 100644 index 00000000..cbbf90cf --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_package.go @@ -0,0 +1,928 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding +// package has potentially been mangled. +var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") + +// FwdState is an enum used to describe the lifecycle of a FwdPkg. +type FwdState byte + +const ( + // FwdStateLockedIn is the starting state for all forwarding packages. + // Packages in this state have not yet committed to the exact set of + // Adds to forward to the switch. + FwdStateLockedIn FwdState = iota + + // FwdStateProcessed marks the state in which all Adds have been + // locally processed and the forwarding decision to the switch has been + // persisted. + FwdStateProcessed + + // FwdStateCompleted signals that all Adds have been acked, and that all + // settles and fails have been delivered to their sources. Packages in + // this state can be removed permanently. + FwdStateCompleted +) + +var ( + // fwdPackagesKey is the root-level bucket that all forwarding packages + // are written. This bucket is further subdivided based on the short + // channel ID of each channel. + fwdPackagesKey = []byte("fwd-packages") + + // addBucketKey is the bucket to which all Add log updates are written. + addBucketKey = []byte("add-updates") + + // failSettleBucketKey is the bucket to which all Settle/Fail log + // updates are written. + failSettleBucketKey = []byte("fail-settle-updates") + + // fwdFilterKey is a key used to write the set of Adds that passed + // validation and are to be forwarded to the switch. + // NOTE: The presence of this key within a forwarding package indicates + // that the package has reached FwdStateProcessed. + fwdFilterKey = []byte("fwd-filter-key") + + // ackFilterKey is a key used to access the PkgFilter indicating which + // Adds have received a Settle/Fail. This response may come from a + // number of sources, including: exitHop settle/fails, switch failures, + // chain arbiter interjections, as well as settle/fails from the + // next hop in the route. + ackFilterKey = []byte("ack-filter-key") + + // settleFailFilterKey is a key used to access the PkgFilter indicating + // which Settles/Fails in have been received and processed by the link + // that originally received the Add. + settleFailFilterKey = []byte("settle-fail-filter-key") +) + +// PkgFilter is used to compactly represent a particular subset of the Adds in a +// forwarding package. Each filter is represented as a simple, statically-sized +// bitvector, where the elements are intended to be the indices of the Adds as +// they are written in the FwdPkg. +type PkgFilter struct { + count uint16 + filter []byte +} + +// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. +func NewPkgFilter(count uint16) *PkgFilter { + // We add 7 to ensure that the integer division yields properly rounded + // values. + filterLen := (count + 7) / 8 + + return &PkgFilter{ + count: count, + filter: make([]byte, filterLen), + } +} + +// Count returns the number of elements represented by this PkgFilter. +func (f *PkgFilter) Count() uint16 { + return f.count +} + +// Set marks the `i`-th element as included by this filter. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Set(i uint16) { + byt := i / 8 + bit := i % 8 + + // Set the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + f.filter[byt] |= byte(1 << (7 - bit)) +} + +// Contains queries the filter for membership of index `i`. +// NOTE: It is assumed that i is always less than count. +func (f *PkgFilter) Contains(i uint16) bool { + byt := i / 8 + bit := i % 8 + + // Read the i-th bit in the filter. + // TODO(conner): ignore if > count to prevent panic? + return f.filter[byt]&(1<<(7-bit)) != 0 +} + +// Equal checks two PkgFilters for equality. +func (f *PkgFilter) Equal(f2 *PkgFilter) bool { + if f == f2 { + return true + } + if f.count != f2.count { + return false + } + + return bytes.Equal(f.filter, f2.filter) +} + +// IsFull returns true if every element in the filter has been Set, and false +// otherwise. +func (f *PkgFilter) IsFull() bool { + // Batch validate bytes that are fully used. + for i := uint16(0); i < f.count/8; i++ { + if f.filter[i] != 0xFF { + return false + } + } + + // If the count is not a multiple of 8, check that the filter contains + // all remaining bits. + rem := f.count % 8 + for idx := f.count - rem; idx < f.count; idx++ { + if !f.Contains(idx) { + return false + } + } + + return true +} + +// Size returns number of bytes produced when the PkgFilter is serialized. +func (f *PkgFilter) Size() uint16 { + // 2 bytes for uint16 `count`, then round up number of bytes required to + // represent `count` bits. + return 2 + (f.count+7)/8 +} + +// Encode writes the filter to the provided io.Writer. +func (f *PkgFilter) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, f.count); err != nil { + return err + } + + _, err := w.Write(f.filter) + + return err +} + +// Decode reads the filter from the provided io.Reader. +func (f *PkgFilter) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { + return err + } + + f.filter = make([]byte, f.Size()-2) + _, err := io.ReadFull(r, f.filter) + + return err +} + +// FwdPkg records all adds, settles, and fails that were locked in as a result +// of the remote peer sending us a revocation. Each package is identified by +// the short chanid and remote commitment height corresponding to the revocation +// that locked in the HTLCs. For everything except a locally initiated payment, +// settles and fails in a forwarding package must have a corresponding Add in +// another package, and can be removed individually once the source link has +// received the fail/settle. +// +// Adds cannot be removed, as we need to present the same batch of Adds to +// properly handle replay protection. Instead, we use a PkgFilter to mark that +// we have finished processing a particular Add. A FwdPkg should only be deleted +// after the AckFilter is full and all settles and fails have been persistently +// removed. +type FwdPkg struct { + // Source identifies the channel that wrote this forwarding package. + Source lnwire.ShortChannelID + + // Height is the height of the remote commitment chain that locked in + // this forwarding package. + Height uint64 + + // State signals the persistent condition of the package and directs how + // to reprocess the package in the event of failures. + State FwdState + + // Adds contains all add messages which need to be processed and + // forwarded to the switch. Adds does not change over the life of a + // forwarding package. + Adds []LogUpdate + + // FwdFilter is a filter containing the indices of all Adds that were + // forwarded to the switch. + FwdFilter *PkgFilter + + // AckFilter is a filter containing the indices of all Adds for which + // the source has received a settle or fail and is reflected in the next + // commitment txn. A package should not be removed until IsFull() + // returns true. + AckFilter *PkgFilter + + // SettleFails contains all settle and fail messages that should be + // forwarded to the switch. + SettleFails []LogUpdate + + // SettleFailFilter is a filter containing the indices of all Settle or + // Fails originating in this package that have been received and locked + // into the incoming link's commitment state. + SettleFailFilter *PkgFilter +} + +// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This +// should be used to create a package at the time we receive a revocation. +func NewFwdPkg(source lnwire.ShortChannelID, height uint64, + addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { + + nAddUpdates := uint16(len(addUpdates)) + nSettleFailUpdates := uint16(len(settleFailUpdates)) + + return &FwdPkg{ + Source: source, + Height: height, + State: FwdStateLockedIn, + Adds: addUpdates, + FwdFilter: NewPkgFilter(nAddUpdates), + AckFilter: NewPkgFilter(nAddUpdates), + SettleFails: settleFailUpdates, + SettleFailFilter: NewPkgFilter(nSettleFailUpdates), + } +} + +// ID returns an unique identifier for this package, used to ensure that sphinx +// replay processing of this batch is idempotent. +func (f *FwdPkg) ID() []byte { + var id = make([]byte, 16) + byteOrder.PutUint64(id[:8], f.Source.ToUint64()) + byteOrder.PutUint64(id[8:], f.Height) + return id +} + +// String returns a human-readable description of the forwarding package. +func (f *FwdPkg) String() string { + return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", + f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) +} + +// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID +// is assumed to be that of the packager. +type AddRef struct { + // Height is the remote commitment height that locked in the Add. + Height uint64 + + // Index is the index of the Add within the fwd pkg's Adds. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// Encode serializes the AddRef to the given io.Writer. +func (a *AddRef) Encode(w io.Writer) error { + if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { + return err + } + + return binary.Write(w, binary.BigEndian, a.Index) +} + +// Decode deserializes the AddRef from the given io.Reader. +func (a *AddRef) Decode(r io.Reader) error { + if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { + return err + } + + return binary.Read(r, binary.BigEndian, &a.Index) +} + +// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A +// channel does not remove its own Settle/Fail htlcs, so the source is provided +// to locate a db bucket belonging to another channel. +type SettleFailRef struct { + // Source identifies the outgoing link that locked in the settle or + // fail. This is then used by the *incoming* link to find the settle + // fail in another link's forwarding packages. + Source lnwire.ShortChannelID + + // Height is the remote commitment height that locked in this + // Settle/Fail. + Height uint64 + + // Index is the index of the Add with the fwd pkg's SettleFails. + // + // NOTE: This index is static over the lifetime of a forwarding package. + Index uint16 +} + +// SettleFailAcker is a generic interface providing the ability to acknowledge +// settle/fail HTLCs stored in forwarding packages. +type SettleFailAcker interface { + // AckSettleFails atomically updates the settle-fail filters in *other* + // channels' forwarding packages. + AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error +} + +// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages +// of any active channel. +type GlobalFwdPkgReader interface { + // LoadChannelFwdPkgs loads all known forwarding packages for the given + // channel. + LoadChannelFwdPkgs(tx *bbolt.Tx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) +} + +// FwdOperator defines the interfaces for managing forwarding packages that are +// external to a particular channel. This interface is used by the switch to +// read forwarding packages from arbitrary channels, and acknowledge settles and +// fails for locally-sourced payments. +type FwdOperator interface { + // GlobalFwdPkgReader provides read access to all known forwarding + // packages + GlobalFwdPkgReader + + // SettleFailAcker grants the ability to acknowledge settles or fails + // residing in arbitrary forwarding packages. + SettleFailAcker +} + +// SwitchPackager is a concrete implementation of the FwdOperator interface. +// A SwitchPackager offers the ability to read any forwarding package, and ack +// arbitrary settle and fail HTLCs. +type SwitchPackager struct{} + +// NewSwitchPackager instantiates a new SwitchPackager. +func NewSwitchPackager() *SwitchPackager { + return &SwitchPackager{} +} + +// AckSettleFails atomically updates the settle-fail filters in *other* +// channels' forwarding packages, to mark that the switch has received a settle +// or fail residing in the forwarding package of a link. +func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx, + settleFailRefs ...SettleFailRef) error { + + return ackSettleFails(tx, settleFailRefs) +} + +// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. +func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx, + source lnwire.ShortChannelID) ([]*FwdPkg, error) { + + return loadChannelFwdPkgs(tx, source) +} + +// FwdPackager supports all operations required to modify fwd packages, such as +// creation, updates, reading, and removal. The interfaces are broken down in +// this way to support future delegation of the subinterfaces. +type FwdPackager interface { + // AddFwdPkg serializes and writes a FwdPkg for this channel at the + // remote commitment height included in the forwarding package. + AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error + + // SetFwdFilter looks up the forwarding package at the remote `height` + // and sets the `fwdFilter`, marking the Adds for which: + // 1) We are not the exit node + // 2) Passed all validation + // 3) Should be forwarded to the switch immediately after a failure + SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error + + // AckAddHtlcs atomically updates the add filters in this channel's + // forwarding packages to mark the resolution of an Add that was + // received from the remote party. + AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error + + // SettleFailAcker allows a link to acknowledge settle/fail HTLCs + // belonging to other channels. + SettleFailAcker + + // LoadFwdPkgs loads all known forwarding packages owned by this + // channel. + LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) + + // RemovePkg deletes a forwarding package owned by this channel at + // the provided remote `height`. + RemovePkg(tx *bbolt.Tx, height uint64) error +} + +// ChannelPackager is used by a channel to manage the lifecycle of its forwarding +// packages. The packager is tied to a particular source channel ID, allowing it +// to create and edit its own packages. Each packager also has the ability to +// remove fail/settle htlcs that correspond to an add contained in one of +// source's packages. +type ChannelPackager struct { + source lnwire.ShortChannelID +} + +// NewChannelPackager creates a new packager for a single channel. +func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { + return &ChannelPackager{ + source: source, + } +} + +// AddFwdPkg writes a newly locked in forwarding package to disk. +func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error { + fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey) + if err != nil { + return err + } + + source := makeLogKey(fwdPkg.Source.ToUint64()) + sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) + if err != nil { + return err + } + + heightKey := makeLogKey(fwdPkg.Height) + heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) + if err != nil { + return err + } + + // Write ADD updates we received at this commit height. + addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) + if err != nil { + return err + } + + // Write SETTLE/FAIL updates we received at this commit height. + failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) + if err != nil { + return err + } + + for i := range fwdPkg.Adds { + err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) + if err != nil { + return err + } + } + + // Persist the initialized pkg filter, which will be used to determine + // when we can remove this forwarding package from disk. + var ackFilterBuf bytes.Buffer + if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { + return err + } + + for i := range fwdPkg.SettleFails { + err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) + if err != nil { + return err + } + } + + var settleFailFilterBuf bytes.Buffer + err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) + if err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. +func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error { + var b bytes.Buffer + if err := htlc.Encode(&b); err != nil { + return err + } + + return bkt.Put(uint16Key(idx), b.Bytes()) +} + +// LoadFwdPkgs scans the forwarding log for any packages that haven't been +// processed, and returns their deserialized log updates in a map indexed by the +// remote commitment height at which the updates were locked in. +func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) { + return loadChannelFwdPkgs(tx, p.source) +} + +// loadChannelFwdPkgs loads all forwarding packages owned by `source`. +func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) { + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil, nil + } + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return nil, nil + } + + var heights []uint64 + if err := sourceBkt.ForEach(func(k, _ []byte) error { + if len(k) != 8 { + return ErrCorruptedFwdPkg + } + + heights = append(heights, byteOrder.Uint64(k)) + + return nil + }); err != nil { + return nil, err + } + + // Load the forwarding package for each retrieved height. + fwdPkgs := make([]*FwdPkg, 0, len(heights)) + for _, height := range heights { + fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) + if err != nil { + return nil, err + } + + fwdPkgs = append(fwdPkgs, fwdPkg) + } + + return fwdPkgs, nil +} + +// loadFwPkg reads the packager's fwd pkg at a given height, and determines the +// appropriate FwdState. +func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID, + height uint64) (*FwdPkg, error) { + + sourceKey := makeLogKey(source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + // Load ADDs from disk. + addBkt := heightBkt.Bucket(addBucketKey) + if addBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + adds, err := loadHtlcs(addBkt) + if err != nil { + return nil, err + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + ackFilterReader := bytes.NewReader(ackFilterBytes) + + ackFilter := &PkgFilter{} + if err := ackFilter.Decode(ackFilterReader); err != nil { + return nil, err + } + + // Load SETTLE/FAILs from disk. + failSettleBkt := heightBkt.Bucket(failSettleBucketKey) + if failSettleBkt == nil { + return nil, ErrCorruptedFwdPkg + } + + failSettles, err := loadHtlcs(failSettleBkt) + if err != nil { + return nil, err + } + + // Load settle fail filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return nil, ErrCorruptedFwdPkg + } + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + + settleFailFilter := &PkgFilter{} + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return nil, err + } + + // Initialize the fwding package, which always starts in the + // FwdStateLockedIn. We can determine what state the package was left in + // by examining constraints on the information loaded from disk. + fwdPkg := &FwdPkg{ + Source: source, + State: FwdStateLockedIn, + Height: height, + Adds: adds, + AckFilter: ackFilter, + SettleFails: failSettles, + SettleFailFilter: settleFailFilter, + } + + // Check to see if we have written the set exported filter adds to + // disk. If we haven't, processing of this package was never started, or + // failed during the last attempt. + fwdFilterBytes := heightBkt.Get(fwdFilterKey) + if fwdFilterBytes == nil { + nAdds := uint16(len(adds)) + fwdPkg.FwdFilter = NewPkgFilter(nAdds) + return fwdPkg, nil + } + + fwdFilterReader := bytes.NewReader(fwdFilterBytes) + fwdPkg.FwdFilter = &PkgFilter{} + if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { + return nil, err + } + + // Otherwise, a complete round of processing was completed, and we + // advance the package to FwdStateProcessed. + fwdPkg.State = FwdStateProcessed + + // If every add, settle, and fail has been fully acknowledged, we can + // safely set the package's state to FwdStateCompleted, signalling that + // it can be garbage collected. + if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { + fwdPkg.State = FwdStateCompleted + } + + return fwdPkg, nil +} + +// loadHtlcs retrieves all serialized htlcs in a bucket, returning +// them in order of the indexes they were written under. +func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) { + var htlcs []LogUpdate + if err := bkt.ForEach(func(_, v []byte) error { + var htlc LogUpdate + if err := htlc.Decode(bytes.NewReader(v)); err != nil { + return err + } + + htlcs = append(htlcs, htlc) + + return nil + }); err != nil { + return nil, err + } + + return htlcs, nil +} + +// SetFwdFilter writes the set of indexes corresponding to Adds at the +// `height` that are to be forwarded to the switch. Calling this method causes +// the forwarding package at `height` to be in FwdStateProcessed. We write this +// forwarding decision so that we always arrive at the same behavior for HTLCs +// leaving this channel. After a restart, we skip validation of these Adds, +// since they are assumed to have already been validated, and make the switch or +// outgoing link responsible for handling replays. +func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64, + fwdFilter *PkgFilter) error { + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + source := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(source[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + return ErrCorruptedFwdPkg + } + + // If the fwd filter has already been written, we return early to avoid + // modifying the persistent state. + forwardedAddsBytes := heightBkt.Get(fwdFilterKey) + if forwardedAddsBytes != nil { + return nil + } + + // Otherwise we serialize and write the provided fwd filter. + var b bytes.Buffer + if err := fwdFilter.Encode(&b); err != nil { + return err + } + + return heightBkt.Put(fwdFilterKey, b.Bytes()) +} + +// AckAddHtlcs accepts a list of references to add htlcs, and updates the +// AckAddFilter of those forwarding packages to indicate that a settle or fail +// has been received in response to the add. +func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error { + if len(addRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + sourceKey := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique height. + heightDiffs := make(map[uint64][]uint16) + for _, addRef := range addRefs { + heightDiffs[addRef.Height] = append( + heightDiffs[addRef.Height], + addRef.Index, + ) + } + + // Load each height bucket once and remove all acked htlcs at that + // height. + for height, indexes := range heightDiffs { + err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) + if err != nil { + return err + } + } + + return nil +} + +// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package +// with a list of indexes, writing the resulting filter back in its place. +func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64, + indexes []uint16) error { + + heightKey := makeLogKey(height) + heightBkt := sourceBkt.Bucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is successful, as there is nothing + // to ack. + return nil + } + + // Load ack filter from disk. + ackFilterBytes := heightBkt.Get(ackFilterKey) + if ackFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + ackFilter := &PkgFilter{} + ackFilterReader := bytes.NewReader(ackFilterBytes) + if err := ackFilter.Decode(ackFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + ackFilter.Set(index) + } + + // Write the resulting filter to disk. + var ackFilterBuf bytes.Buffer + if err := ackFilter.Encode(&ackFilterBuf); err != nil { + return err + } + + return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) +} + +// AckSettleFails persistently acknowledges settles or fails from a remote forwarding +// package. This should only be called after the source of the Add has locked in +// the settle/fail, or it becomes otherwise safe to forgo retransmitting the +// settle/fail after a restart. +func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error { + return ackSettleFails(tx, settleFailRefs) +} + +// ackSettleFails persistently acknowledges a batch of settle fail references. +func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error { + if len(settleFailRefs) == 0 { + return nil + } + + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return ErrCorruptedFwdPkg + } + + // Organize the forward references such that we just get a single slice + // of indexes for each unique destination-height pair. + destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) + for _, settleFailRef := range settleFailRefs { + destHeights, ok := destHeightDiffs[settleFailRef.Source] + if !ok { + destHeights = make(map[uint64][]uint16) + destHeightDiffs[settleFailRef.Source] = destHeights + } + + destHeights[settleFailRef.Height] = append( + destHeights[settleFailRef.Height], + settleFailRef.Index, + ) + } + + // With the references organized by destination and height, we now load + // each remote bucket, and update the settle fail filter for any + // settle/fail htlcs. + for dest, destHeights := range destHeightDiffs { + destKey := makeLogKey(dest.ToUint64()) + destBkt := fwdPkgBkt.Bucket(destKey[:]) + if destBkt == nil { + // If the destination bucket is not found, this is + // likely the result of the destination channel being + // closed and having it's forwarding packages wiped. We + // won't treat this as an error, because the response + // will no longer be retransmitted internally. + continue + } + + for height, indexes := range destHeights { + err := ackSettleFailsAtHeight(destBkt, height, indexes) + if err != nil { + return err + } + } + } + + return nil +} + +// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes +// at particular a height by updating the settle fail filter. +func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64, + indexes []uint16) error { + + heightKey := makeLogKey(height) + heightBkt := destBkt.Bucket(heightKey[:]) + if heightBkt == nil { + // If the height bucket isn't found, this could be because the + // forwarding package was already removed. We'll return nil to + // signal that the operation is as there is nothing to ack. + return nil + } + + // Load ack filter from disk. + settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) + if settleFailFilterBytes == nil { + return ErrCorruptedFwdPkg + } + + settleFailFilter := &PkgFilter{} + settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) + if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { + return err + } + + // Update the ack filter for this height. + for _, index := range indexes { + settleFailFilter.Set(index) + } + + // Write the resulting filter to disk. + var settleFailFilterBuf bytes.Buffer + if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { + return err + } + + return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) +} + +// RemovePkg deletes the forwarding package at the given height from the +// packager's source bucket. +func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error { + fwdPkgBkt := tx.Bucket(fwdPackagesKey) + if fwdPkgBkt == nil { + return nil + } + + sourceBytes := makeLogKey(p.source.ToUint64()) + sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:]) + if sourceBkt == nil { + return ErrCorruptedFwdPkg + } + + heightKey := makeLogKey(height) + + return sourceBkt.DeleteBucket(heightKey[:]) +} + +// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. +func uint16Key(i uint16) []byte { + key := make([]byte, 2) + byteOrder.PutUint16(key, i) + return key +} + +// Compile-time constraint to ensure that ChannelPackager implements the public +// FwdPackager interface. +var _ FwdPackager = (*ChannelPackager)(nil) + +// Compile-time constraint to ensure that SwitchPackager implements the public +// FwdOperator interface. +var _ FwdOperator = (*SwitchPackager)(nil) diff --git a/channeldb/migration_01_to_11/forwarding_package_test.go b/channeldb/migration_01_to_11/forwarding_package_test.go new file mode 100644 index 00000000..1128aad3 --- /dev/null +++ b/channeldb/migration_01_to_11/forwarding_package_test.go @@ -0,0 +1,815 @@ +package migration_01_to_11_test + +import ( + "bytes" + "io/ioutil" + "path/filepath" + "runtime" + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" +) + +// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000, +// which is greater than the number of HTLCs we permit on a commitment txn. +// This should encapsulate every potential filter used in practice. +func TestPkgFilterBruteForce(t *testing.T) { + t.Parallel() + + checkPkgFilterRange(t, 1000) +} + +// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear +// insertion of `high` elements. This is primarily to test that IsFull functions +// properly for all relevant sizes of `high`. +func checkPkgFilterRange(t *testing.T, high int) { + for i := uint16(0); i < uint16(high); i++ { + f := channeldb.NewPkgFilter(i) + + if f.Count() != i { + t.Fatalf("pkg filter count=%d is actually %d", + i, f.Count()) + } + checkPkgFilterEncodeDecode(t, i, f) + + for j := uint16(0); j < i; j++ { + if f.Contains(j) { + t.Fatalf("pkg filter count=%d contains %d "+ + "before being added", i, j) + } + + f.Set(j) + checkPkgFilterEncodeDecode(t, i, f) + + if !f.Contains(j) { + t.Fatalf("pkg filter count=%d missing %d "+ + "after being added", i, j) + } + + if j < i-1 && f.IsFull() { + t.Fatalf("pkg filter count=%d already full", i) + } + } + + if !f.IsFull() { + t.Fatalf("pkg filter count=%d not full", i) + } + checkPkgFilterEncodeDecode(t, i, f) + } +} + +// TestPkgFilterRand uses a random permutation to verify the proper behavior of +// the pkg filter if the entries are not inserted in-order. +func TestPkgFilterRand(t *testing.T) { + t.Parallel() + + checkPkgFilterRand(t, 3, 17) +} + +// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting +// indices and asserting the invariants. The order in which indices are inserted +// is parameterized by a base `b` coprime to `p`, and using modular +// exponentiation to generate all elements in [1,p). +func checkPkgFilterRand(t *testing.T, b, p uint16) { + f := channeldb.NewPkgFilter(p) + var j = b + for i := uint16(1); i < p; i++ { + if f.Contains(j) { + t.Fatalf("pkg filter contains %d-%d "+ + "before being added", i, j) + } + + f.Set(j) + checkPkgFilterEncodeDecode(t, i, f) + + if !f.Contains(j) { + t.Fatalf("pkg filter missing %d-%d "+ + "after being added", i, j) + } + + if i < p-1 && f.IsFull() { + t.Fatalf("pkg filter %d already full", i) + } + checkPkgFilterEncodeDecode(t, i, f) + + j = (b * j) % p + } + + // Set 0 independently, since it will never be emitted by the generator. + f.Set(0) + checkPkgFilterEncodeDecode(t, p, f) + + if !f.IsFull() { + t.Fatalf("pkg filter count=%d not full", p) + } + checkPkgFilterEncodeDecode(t, p, f) +} + +// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by: +// 1) writing it to a buffer +// 2) verifying the number of bytes written matches the filter's Size() +// 3) reconstructing the filter decoding the bytes +// 4) checking that the two filters are the same according to Equal +func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) { + var b bytes.Buffer + if err := f.Encode(&b); err != nil { + t.Fatalf("unable to serialize pkg filter: %v", err) + } + + // +2 for uint16 length + size := uint16(len(b.Bytes())) + if size != f.Size() { + t.Fatalf("pkg filter count=%d serialized size differs, "+ + "Size(): %d, len(bytes): %v", i, f.Size(), size) + } + + reader := bytes.NewReader(b.Bytes()) + + f2 := &channeldb.PkgFilter{} + if err := f2.Decode(reader); err != nil { + t.Fatalf("unable to deserialize pkg filter: %v", err) + } + + if !f.Equal(f2) { + t.Fatalf("pkg filter count=%v does is not equal "+ + "after deserialization, want: %v, got %v", + i, f, f2) + } +} + +var ( + chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{}) + + adds = []channeldb.LogUpdate{ + { + LogIndex: 0, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 100, + Expiry: 1000, + PaymentHash: [32]byte{0}, + }, + }, + { + LogIndex: 1, + UpdateMsg: &lnwire.UpdateAddHTLC{ + ChanID: chanID, + ID: 1, + Amount: 101, + Expiry: 1001, + PaymentHash: [32]byte{1}, + }, + }, + } + + settleFails = []channeldb.LogUpdate{ + { + LogIndex: 2, + UpdateMsg: &lnwire.UpdateFulfillHTLC{ + ChanID: chanID, + ID: 0, + PaymentPreimage: [32]byte{0}, + }, + }, + { + LogIndex: 3, + UpdateMsg: &lnwire.UpdateFailHTLC{ + ChanID: chanID, + ID: 1, + Reason: []byte{}, + }, + }, + } +) + +// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a +// forwarding package that contains no adds, fails or settles. We expect that +// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding +// decision via SetFwdFilter. +func TestPackagerEmptyFwdPkg(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package with no htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. With no HTLCs, + // the ack filter will have no elements, and should always return true. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Now, write the forwarding decision. In this case, its just an empty + // fwd filter. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However with no htlcs, it should leap frog to FwdStateCompleted. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted +// as soon as all the adds in the package have been acked using AckAddHtlcs. +func TestPackagerOnlyAdds(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil) + + nAdds := len(adds) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range adds { + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However not allf of the HTLCs have been acked, so should not + // have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all adds have been + // acked, the ack filter should return true and the package should be + // FwdStateCompleted since there are no other settle/fail packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerOnlySettleFails asserts that the fwdpkg remains in +// FwdStateProcessed after writing the forwarding decision when there are no +// adds in the fwdpkg. We expect this because an empty FwdFilter will always +// return true, but we are still waiting for the remaining fails and settles to +// be deleted. +func TestPackagerOnlySettleFails(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails) + + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However, not all of the HTLCs have been + // acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and +// settle/fails, then checks the behavior when the adds are acked before any of +// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the +// remainder of the fail/settles are being deleted. +func TestPackagerAddsThenSettleFails(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that only has add + // htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + + nAdds := len(adds) + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + for i := range adds { + // We should still have one package on disk. Since the forwarding + // decision has been written, it will minimally be in FwdStateProcessed. + // However not allf of the HTLCs have been acked, so should not + // have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However not allf of the HTLCs have been + // acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to remove settle/fail htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and +// settle/fails, then checks the behavior when the settle/fails are removed +// before any of the adds have been acked. This should cause the fwdpkg to +// remain in FwdStateProcessed until the final ack is recorded, at which point +// it should be promoted directly to FwdStateCompleted.since all adds have been +// removed. +func TestPackagerSettleFailsThenAdds(t *testing.T) { + t.Parallel() + + db := makeFwdPkgDB(t, "") + + shortChanID := lnwire.NewShortChanIDFromInt(1) + packager := channeldb.NewChannelPackager(shortChanID) + + // To begin, there should be no forwarding packages on disk. + fwdPkgs := loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } + + // Next, create and write a new forwarding package that has both add + // and settle/fail htlcs. + fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) + + nAdds := len(adds) + nSettleFails := len(settleFails) + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AddFwdPkg(tx, fwdPkg) + }); err != nil { + t.Fatalf("unable to add fwd pkg: %v", err) + } + + // There should now be one fwdpkg on disk. Since no forwarding decision + // has been written, we expect it to be FwdStateLockedIn. The package + // has unacked add HTLCs, so the ack filter should not be full. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + // Now, write the forwarding decision. Since we have not explicitly + // added any adds to the fwdfilter, this would indicate that all of the + // adds were 1) settled locally by this link (exit hop), or 2) the htlc + // was failed locally. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) + }); err != nil { + t.Fatalf("unable to set fwdfiter: %v", err) + } + + // Simulate another channel deleting the settle/fails it received from + // the original fwd pkg. + // TODO(conner): use different packager/s? + for i := range settleFails { + // We should still have one package on disk. Since the + // forwarding decision has been written, it will minimally be in + // FwdStateProcessed. However none all of the add HTLCs have + // been acked, so should not have advanced further. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], false) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + failSettleRef := channeldb.SettleFailRef{ + Source: shortChanID, + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckSettleFails(tx, failSettleRef) + }); err != nil { + t.Fatalf("unable to remove settle/fail htlc: %v", err) + } + } + + // Now simulate this channel receiving a fail/settle for the adds in the + // fwdpkg. + for i := range adds { + // Again, we should still have one package on disk and be in + // FwdStateProcessed. This should not change until all of the + // add htlcs have been acked. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], false) + + addRef := channeldb.AddRef{ + Height: fwdPkg.Height, + Index: uint16(i), + } + + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.AckAddHtlcs(tx, addRef) + }); err != nil { + t.Fatalf("unable to ack add htlc: %v", err) + } + } + + // We should still have one package on disk. Now that all settles and + // fails have been removed, package should be FwdStateCompleted since + // there are no other add packets. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 1 { + t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) + } + assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) + assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) + assertSettleFailFilterIsFull(t, fwdPkgs[0], true) + assertAckFilterIsFull(t, fwdPkgs[0], true) + + // Lastly, remove the completed forwarding package from disk. + if err := db.Update(func(tx *bbolt.Tx) error { + return packager.RemovePkg(tx, fwdPkg.Height) + }); err != nil { + t.Fatalf("unable to remove fwdpkg: %v", err) + } + + // Check that the fwd package was actually removed. + fwdPkgs = loadFwdPkgs(t, db, packager) + if len(fwdPkgs) != 0 { + t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) + } +} + +// assertFwdPkgState checks the current state of a fwdpkg meets our +// expectations. +func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, + state channeldb.FwdState) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.State != state { + t.Fatalf("line %d: expected fwdpkg in state %v, found %v", + line, state, fwdPkg.State) + } +} + +// assertFwdPkgNumAddsSettleFails checks that the number of adds and +// settle/fail log updates are correct. +func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, + expectedNumAdds, expectedNumSettleFails int) { + _, _, line, _ := runtime.Caller(1) + if len(fwdPkg.Adds) != expectedNumAdds { + t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d", + line, expectedNumAdds, len(fwdPkg.Adds)) + } + + if len(fwdPkg.SettleFails) != expectedNumSettleFails { + t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d", + line, expectedNumSettleFails, len(fwdPkg.SettleFails)) + } +} + +// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our +// expected full-ness. +func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.AckFilter.IsFull() != expected { + t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+ + "found %v", line, expected, fwdPkg.AckFilter.IsFull()) + } +} + +// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail +// filter matches our expected full-ness. +func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { + _, _, line, _ := runtime.Caller(1) + if fwdPkg.SettleFailFilter.IsFull() != expected { + t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+ + "found %v", line, expected, fwdPkg.SettleFailFilter.IsFull()) + } +} + +// loadFwdPkgs is a helper method that reads all forwarding packages for a +// particular packager. +func loadFwdPkgs(t *testing.T, db *bbolt.DB, + packager channeldb.FwdPackager) []*channeldb.FwdPkg { + + var fwdPkgs []*channeldb.FwdPkg + if err := db.View(func(tx *bbolt.Tx) error { + var err error + fwdPkgs, err = packager.LoadFwdPkgs(tx) + return err + }); err != nil { + t.Fatalf("unable to load fwd pkgs: %v", err) + } + + return fwdPkgs +} + +// makeFwdPkgDB initializes a test database for forwarding packages. If the +// provided path is an empty, it will create a temp dir/file to use. +func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB { + if path == "" { + var err error + path, err = ioutil.TempDir("", "fwdpkgdb") + if err != nil { + t.Fatalf("unable to create temp path: %v", err) + } + + path = filepath.Join(path, "fwdpkg.db") + } + + db, err := bbolt.Open(path, 0600, nil) + if err != nil { + t.Fatalf("unable to open boltdb: %v", err) + } + + return db +} diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go new file mode 100644 index 00000000..d90863c6 --- /dev/null +++ b/channeldb/migration_01_to_11/graph.go @@ -0,0 +1,4060 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "image/color" + "io" + "math" + "net" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // nodeBucket is a bucket which houses all the vertices or nodes within + // the channel graph. This bucket has a single-sub bucket which adds an + // additional index from pubkey -> alias. Within the top-level of this + // bucket, the key space maps a node's compressed public key to the + // serialized information for that node. Additionally, there's a + // special key "source" which stores the pubkey of the source node. The + // source node is used as the starting point for all graph/queries and + // traversals. The graph is formed as a star-graph with the source node + // at the center. + // + // maps: pubKey -> nodeInfo + // maps: source -> selfPubKey + nodeBucket = []byte("graph-node") + + // nodeUpdateIndexBucket is a sub-bucket of the nodeBucket. This bucket + // will be used to quickly look up the "freshness" of a node's last + // update to the network. The bucket only contains keys, and no values, + // it's mapping: + // + // maps: updateTime || nodeID -> nil + nodeUpdateIndexBucket = []byte("graph-node-update-index") + + // sourceKey is a special key that resides within the nodeBucket. The + // sourceKey maps a key to the public key of the "self node". + sourceKey = []byte("source") + + // aliasIndexBucket is a sub-bucket that's nested within the main + // nodeBucket. This bucket maps the public key of a node to its + // current alias. This bucket is provided as it can be used within a + // future UI layer to add an additional degree of confirmation. + aliasIndexBucket = []byte("alias") + + // edgeBucket is a bucket which houses all of the edge or channel + // information within the channel graph. This bucket essentially acts + // as an adjacency list, which in conjunction with a range scan, can be + // used to iterate over all the incoming and outgoing edges for a + // particular node. Key in the bucket use a prefix scheme which leads + // with the node's public key and sends with the compact edge ID. + // For each chanID, there will be two entries within the bucket, as the + // graph is directed: nodes may have different policies w.r.t to fees + // for their respective directions. + // + // maps: pubKey || chanID -> channel edge policy for node + edgeBucket = []byte("graph-edge") + + // unknownPolicy is represented as an empty slice. It is + // used as the value in edgeBucket for unknown channel edge policies. + // Unknown policies are still stored in the database to enable efficient + // lookup of incoming channel edges. + unknownPolicy = []byte{} + + // chanStart is an array of all zero bytes which is used to perform + // range scans within the edgeBucket to obtain all of the outgoing + // edges for a particular node. + chanStart [8]byte + + // edgeIndexBucket is an index which can be used to iterate all edges + // in the bucket, grouping them according to their in/out nodes. + // Additionally, the items in this bucket also contain the complete + // edge information for a channel. The edge information includes the + // capacity of the channel, the nodes that made the channel, etc. This + // bucket resides within the edgeBucket above. Creation of an edge + // proceeds in two phases: first the edge is added to the edge index, + // afterwards the edgeBucket can be updated with the latest details of + // the edge as they are announced on the network. + // + // maps: chanID -> pubKey1 || pubKey2 || restofEdgeInfo + edgeIndexBucket = []byte("edge-index") + + // edgeUpdateIndexBucket is a sub-bucket of the main edgeBucket. This + // bucket contains an index which allows us to gauge the "freshness" of + // a channel's last updates. + // + // maps: updateTime || chanID -> nil + edgeUpdateIndexBucket = []byte("edge-update-index") + + // channelPointBucket maps a channel's full outpoint (txid:index) to + // its short 8-byte channel ID. This bucket resides within the + // edgeBucket above, and can be used to quickly remove an edge due to + // the outpoint being spent, or to query for existence of a channel. + // + // maps: outPoint -> chanID + channelPointBucket = []byte("chan-index") + + // zombieBucket is a sub-bucket of the main edgeBucket bucket + // responsible for maintaining an index of zombie channels. Each entry + // exists within the bucket as follows: + // + // maps: chanID -> pubKey1 || pubKey2 + // + // The chanID represents the channel ID of the edge that is marked as a + // zombie and is used as the key, which maps to the public keys of the + // edge's participants. + zombieBucket = []byte("zombie-index") + + // disabledEdgePolicyBucket is a sub-bucket of the main edgeBucket bucket + // responsible for maintaining an index of disabled edge policies. Each + // entry exists within the bucket as follows: + // + // maps: -> []byte{} + // + // The chanID represents the channel ID of the edge and the direction is + // one byte representing the direction of the edge. The main purpose of + // this index is to allow pruning disabled channels in a fast way without + // the need to iterate all over the graph. + disabledEdgePolicyBucket = []byte("disabled-edge-policy-index") + + // graphMetaBucket is a top-level bucket which stores various meta-deta + // related to the on-disk channel graph. Data stored in this bucket + // includes the block to which the graph has been synced to, the total + // number of channels, etc. + graphMetaBucket = []byte("graph-meta") + + // pruneLogBucket is a bucket within the graphMetaBucket that stores + // a mapping from the block height to the hash for the blocks used to + // prune the graph. + // Once a new block is discovered, any channels that have been closed + // (by spending the outpoint) can safely be removed from the graph, and + // the block is added to the prune log. We need to keep such a log for + // the case where a reorg happens, and we must "rewind" the state of the + // graph by removing channels that were previously confirmed. In such a + // case we'll remove all entries from the prune log with a block height + // that no longer exists. + pruneLogBucket = []byte("prune-log") +) + +const ( + // MaxAllowedExtraOpaqueBytes is the largest amount of opaque bytes that + // we'll permit to be written to disk. We limit this as otherwise, it + // would be possible for a node to create a ton of updates and slowly + // fill our disk, and also waste bandwidth due to relaying. + MaxAllowedExtraOpaqueBytes = 10000 + + // feeRateParts is the total number of parts used to express fee rates. + feeRateParts = 1e6 +) + +// ChannelGraph is a persistent, on-disk graph representation of the Lightning +// Network. This struct can be used to implement path finding algorithms on top +// of, and also to update a node's view based on information received from the +// p2p network. Internally, the graph is stored using a modified adjacency list +// representation with some added object interaction possible with each +// serialized edge/node. The graph is stored is directed, meaning that are two +// edges stored for each channel: an inbound/outbound edge for each node pair. +// Nodes, edges, and edge information can all be added to the graph +// independently. Edge removal results in the deletion of all edge information +// for that edge. +type ChannelGraph struct { + db *DB + + cacheMu sync.RWMutex + rejectCache *rejectCache + chanCache *channelCache +} + +// newChannelGraph allocates a new ChannelGraph backed by a DB instance. The +// returned instance has its own unique reject cache and channel cache. +func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph { + return &ChannelGraph{ + db: db, + rejectCache: newRejectCache(rejectCacheSize), + chanCache: newChannelCache(chanCacheSize), + } +} + +// Database returns a pointer to the underlying database. +func (c *ChannelGraph) Database() *DB { + return c.db +} + +// ForEachChannel iterates through all the channel edges stored within the +// graph and invokes the passed callback for each edge. The callback takes two +// edges as since this is a directed graph, both the in/out edges are visited. +// If the callback returns an error, then the transaction is aborted and the +// iteration stops early. +// +// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer +// for that particular channel edge routing policy will be passed into the +// callback. +func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates + + return c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // For each edge pair within the edge index, we fetch each edge + // itself and also the node information in order to fully + // populated the object. + return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + infoReader := bytes.NewReader(edgeInfoBytes) + edgeInfo, err := deserializeChanEdgeInfo(infoReader) + if err != nil { + return err + } + edgeInfo.db = c.db + + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + // With both edges read, execute the call back. IF this + // function returns an error then the transaction will + // be aborted. + return cb(&edgeInfo, edge1, edge2) + }) + }) +} + +// ForEachNodeChannel iterates through all channels of a given node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error) error { + + db := c.db + + return nodeTraversal(tx, nodePub, db, cb) +} + +// DisabledChannelIDs returns the channel ids of disabled channels. +// A channel is disabled when two of the associated ChanelEdgePolicies +// have their disabled bit on. +func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { + var disabledChanIDs []uint64 + chanEdgeFound := make(map[uint64]struct{}) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + + disabledEdgePolicyIndex := edges.Bucket(disabledEdgePolicyBucket) + if disabledEdgePolicyIndex == nil { + return nil + } + + // We iterate over all disabled policies and we add each channel that + // has more than one disabled policy to disabledChanIDs array. + return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error { + chanID := byteOrder.Uint64(k[:8]) + _, edgeFound := chanEdgeFound[chanID] + if edgeFound { + delete(chanEdgeFound, chanID) + disabledChanIDs = append(disabledChanIDs, chanID) + return nil + } + + chanEdgeFound[chanID] = struct{}{} + return nil + }) + }) + if err != nil { + return nil, err + } + + return disabledChanIDs, nil +} + +// ForEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops +// early. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal +// +// TODO(roasbeef): add iterator interface to allow for memory efficient graph +// traversal when graph gets mega +func (c *ChannelGraph) ForEachNode(tx *bbolt.Tx, cb func(*bbolt.Tx, *LightningNode) error) error { + traversal := func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + return nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + nodeReader := bytes.NewReader(nodeBytes) + node, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + node.db = c.db + + // Execute the callback, the transaction will abort if + // this returns an error. + return cb(tx, &node) + }) + } + + // If no transaction was provided, then we'll create a new transaction + // to execute the transaction within. + if tx == nil { + return c.db.View(traversal) + } + + // Otherwise, we re-use the existing transaction to execute the graph + // traversal. + return traversal(tx) +} + +// SourceNode returns the source node of the graph. The source node is treated +// as the center node within a star-graph. This method may be used to kick off +// a path finding algorithm in order to explore the reachability of another +// node based off the source node. +func (c *ChannelGraph) SourceNode() (*LightningNode, error) { + var source *LightningNode + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := c.sourceNode(nodes) + if err != nil { + return err + } + source = node + + return nil + }) + if err != nil { + return nil, err + } + + return source, nil +} + +// sourceNode uses an existing database transaction and returns the source node +// of the graph. The source node is treated as the center node within a +// star-graph. This method may be used to kick off a path finding algorithm in +// order to explore the reachability of another node based off the source node. +func (c *ChannelGraph) sourceNode(nodes *bbolt.Bucket) (*LightningNode, error) { + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return nil, ErrSourceNodeNotSet + } + + // With the pubKey of the source node retrieved, we're able to + // fetch the full node information. + node, err := fetchLightningNode(nodes, selfPub) + if err != nil { + return nil, err + } + node.db = c.db + + return &node, nil +} + +// SetSourceNode sets the source node within the graph database. The source +// node is to be used as the center of a star-graph within path finding +// algorithms. +func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { + nodePubBytes := node.PubKeyBytes[:] + + return c.db.Update(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + // Next we create the mapping from source to the targeted + // public key. + if err := nodes.Put(sourceKey, nodePubBytes); err != nil { + return err + } + + // Finally, we commit the information of the lightning node + // itself. + return addLightningNode(tx, node) + }) +} + +// AddLightningNode adds a vertex/node to the graph database. If the node is not +// in the database from before, this will add a new, unconnected one to the +// graph. If it is present from before, this will update that node's +// information. Note that this method is expected to only be called to update +// an already present node from a node announcement, or to insert a node found +// in a channel update. +// +// TODO(roasbeef): also need sig of announcement +func (c *ChannelGraph) AddLightningNode(node *LightningNode) error { + return c.db.Update(func(tx *bbolt.Tx) error { + return addLightningNode(tx, node) + }) +} + +func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + aliases, err := nodes.CreateBucketIfNotExists(aliasIndexBucket) + if err != nil { + return err + } + + updateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return err + } + + return putLightningNode(nodes, aliases, updateIndex, node) +} + +// LookupAlias attempts to return the alias as advertised by the target node. +// TODO(roasbeef): currently assumes that aliases are unique... +func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { + var alias string + + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + nodePub := pub.SerializeCompressed() + a := aliases.Get(nodePub) + if a == nil { + return ErrNodeAliasNotFound + } + + // TODO(roasbeef): should actually be using the utf-8 + // package... + alias = string(a) + return nil + }) + if err != nil { + return "", err + } + + return alias, nil +} + +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. +func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { + // TODO(roasbeef): ensure dangling edges are removed... + return c.db.Update(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + + return c.deleteLightningNode( + nodes, nodePub.SerializeCompressed(), + ) + }) +} + +// deleteLightningNode uses an existing database transaction to remove a +// vertex/node from the database according to the node's public key. +func (c *ChannelGraph) deleteLightningNode(nodes *bbolt.Bucket, + compressedPubKey []byte) error { + + aliases := nodes.Bucket(aliasIndexBucket) + if aliases == nil { + return ErrGraphNodesNotFound + } + + if err := aliases.Delete(compressedPubKey); err != nil { + return err + } + + // Before we delete the node, we'll fetch its current state so we can + // determine when its last update was to clear out the node update + // index. + node, err := fetchLightningNode(nodes, compressedPubKey) + if err != nil { + return err + } + + if err := nodes.Delete(compressedPubKey); err != nil { + + return err + } + + // Finally, we'll delete the index entry for the node within the + // nodeUpdateIndexBucket as this node is no longer active, so we don't + // need to track its last update. + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // In order to delete the entry, we'll need to reconstruct the key for + // its last update. + updateUnix := uint64(node.LastUpdate.Unix()) + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], compressedPubKey) + + return nodeUpdateIndex.Delete(indexKey[:]) +} + +// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An +// undirected edge from the two target nodes are created. The information +// stored denotes the static attributes of the channel, such as the channelID, +// the keys involved in creation of the channel, and the set of features that +// the channel supports. The chanPoint and chanID are used to uniquely identify +// the edge globally within the database. +func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + return c.addChannelEdge(tx, edge) + }) + if err != nil { + return err + } + + c.rejectCache.remove(edge.ChannelID) + c.chanCache.remove(edge.ChannelID) + + return nil +} + +// addChannelEdge is the private form of AddChannelEdge that allows callers to +// utilize an existing db transaction. +func (c *ChannelGraph) addChannelEdge(tx *bbolt.Tx, edge *ChannelEdgeInfo) error { + // Construct the channel's primary key which is the 8-byte channel ID. + var chanKey [8]byte + binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + + // First, attempt to check if this edge has already been created. If + // so, then we can exit early as this method is meant to be idempotent. + if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil { + return ErrEdgeAlreadyExist + } + + // Before we insert the channel into the database, we'll ensure that + // both nodes already exist in the channel graph. If either node + // doesn't, then we'll insert a "shell" node that just includes its + // public key, so subsequent validation and queries can work properly. + _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) + switch { + case node1Err == ErrGraphNodeNotFound: + node1Shell := LightningNode{ + PubKeyBytes: edge.NodeKey1Bytes, + HaveNodeAnnouncement: false, + } + err := addLightningNode(tx, &node1Shell) + if err != nil { + return fmt.Errorf("unable to create shell node "+ + "for: %x", edge.NodeKey1Bytes) + + } + case node1Err != nil: + return err + } + + _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) + switch { + case node2Err == ErrGraphNodeNotFound: + node2Shell := LightningNode{ + PubKeyBytes: edge.NodeKey2Bytes, + HaveNodeAnnouncement: false, + } + err := addLightningNode(tx, &node2Shell) + if err != nil { + return fmt.Errorf("unable to create shell node "+ + "for: %x", edge.NodeKey2Bytes) + + } + case node2Err != nil: + return err + } + + // If the edge hasn't been created yet, then we'll first add it to the + // edge index in order to associate the edge between two nodes and also + // store the static components of the channel. + if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { + return err + } + + // Mark edge policies for both sides as unknown. This is to enable + // efficient incoming channel lookup for a node. + for _, key := range []*[33]byte{&edge.NodeKey1Bytes, + &edge.NodeKey2Bytes} { + + err := putChanEdgePolicyUnknown(edges, edge.ChannelID, + key[:]) + if err != nil { + return err + } + } + + // Finally we add it to the channel index which maps channel points + // (outpoints) to the shorter channel ID's. + var b bytes.Buffer + if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { + return err + } + return chanIndex.Put(b.Bytes(), chanKey[:]) +} + +// HasChannelEdge returns true if the database knows of a channel edge with the +// passed channel ID, and false otherwise. If an edge with that ID is found +// within the graph, then two time stamps representing the last time the edge +// was updated for both directed edges are returned along with the boolean. If +// it is not found, then the zombie index is checked and its result is returned +// as the second boolean. +func (c *ChannelGraph) HasChannelEdge( + chanID uint64) (time.Time, time.Time, bool, bool, error) { + + var ( + upd1Time time.Time + upd2Time time.Time + exists bool + isZombie bool + ) + + // We'll query the cache with the shared lock held to allow multiple + // readers to access values in the cache concurrently if they exist. + c.cacheMu.RLock() + if entry, ok := c.rejectCache.get(chanID); ok { + c.cacheMu.RUnlock() + upd1Time = time.Unix(entry.upd1Time, 0) + upd2Time = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + return upd1Time, upd2Time, exists, isZombie, nil + } + c.cacheMu.RUnlock() + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + // The item was not found with the shared lock, so we'll acquire the + // exclusive lock and check the cache again in case another method added + // the entry to the cache while no lock was held. + if entry, ok := c.rejectCache.get(chanID); ok { + upd1Time = time.Unix(entry.upd1Time, 0) + upd2Time = time.Unix(entry.upd2Time, 0) + exists, isZombie = entry.flags.unpack() + return upd1Time, upd2Time, exists, isZombie, nil + } + + if err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + var channelID [8]byte + byteOrder.PutUint64(channelID[:], chanID) + + // If the edge doesn't exist, then we'll also check our zombie + // index. + if edgeIndex.Get(channelID[:]) == nil { + exists = false + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex != nil { + isZombie, _, _ = isZombieEdge( + zombieIndex, chanID, + ) + } + + return nil + } + + exists = true + isZombie = false + + // If the channel has been found in the graph, then retrieve + // the edges itself so we can return the last updated + // timestamps. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + + e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes, + channelID[:], c.db) + if err != nil { + return err + } + + // As we may have only one of the edges populated, only set the + // update time if the edge was found in the database. + if e1 != nil { + upd1Time = e1.LastUpdate + } + if e2 != nil { + upd2Time = e2.LastUpdate + } + + return nil + }); err != nil { + return time.Time{}, time.Time{}, exists, isZombie, err + } + + c.rejectCache.insert(chanID, rejectCacheEntry{ + upd1Time: upd1Time.Unix(), + upd2Time: upd2Time.Unix(), + flags: packRejectFlags(exists, isZombie), + }) + + return upd1Time, upd2Time, exists, isZombie, nil +} + +// UpdateChannelEdge retrieves and update edge of the graph database. Method +// only reserved for updating an edge info after its already been created. +// In order to maintain this constraints, we return an error in the scenario +// that an edge info hasn't yet been created yet, but someone attempts to update +// it. +func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { + // Construct the channel's primary key which is the 8-byte channel ID. + var chanKey [8]byte + binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) + + return c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edge == nil { + return ErrEdgeNotFound + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + + if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo == nil { + return ErrEdgeNotFound + } + + return putChanEdgeInfo(edgeIndex, edge, chanKey) + }) +} + +const ( + // pruneTipBytes is the total size of the value which stores a prune + // entry of the graph in the prune log. The "prune tip" is the last + // entry in the prune log, and indicates if the channel graph is in + // sync with the current UTXO state. The structure of the value + // is: blockHash, taking 32 bytes total. + pruneTipBytes = 32 +) + +// PruneGraph prunes newly closed channels from the channel graph in response +// to a new block being solved on the network. Any transactions which spend the +// funding output of any known channels within he graph will be deleted. +// Additionally, the "prune tip", or the last block which has been used to +// prune the graph is stored so callers can ensure the graph is fully in sync +// with the current UTXO state. A slice of channels that have been closed by +// the target block are returned if the function succeeds without error. +func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, + blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) { + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var chansClosed []*ChannelEdgeInfo + + err := c.db.Update(func(tx *bbolt.Tx) error { + // First grab the edges bucket which houses the information + // we'd like to delete + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + + // Next grab the two edge indexes which will also need to be updated. + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrSourceNodeNotSet + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + // For each of the outpoints that have been spent within the + // block, we attempt to delete them from the graph as if that + // outpoint was a channel, then it has now been closed. + for _, chanPoint := range spentOutputs { + // TODO(roasbeef): load channel bloom filter, continue + // if NOT if filter + + var opBytes bytes.Buffer + if err := writeOutpoint(&opBytes, chanPoint); err != nil { + return err + } + + // First attempt to see if the channel exists within + // the database, if not, then we can exit early. + chanID := chanIndex.Get(opBytes.Bytes()) + if chanID == nil { + continue + } + + // However, if it does, then we'll read out the full + // version so we can add it to the set of deleted + // channels. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // Attempt to delete the channel, an ErrEdgeNotFound + // will be returned if that outpoint isn't known to be + // a channel. If no error is returned, then a channel + // was successfully pruned. + err = delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + chanID, false, + ) + if err != nil && err != ErrEdgeNotFound { + return err + } + + chansClosed = append(chansClosed, &edgeInfo) + } + + metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) + if err != nil { + return err + } + + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + // With the graph pruned, add a new entry to the prune log, + // which can be used to check if the graph is fully synced with + // the current UTXO state. + var blockHeightBytes [4]byte + byteOrder.PutUint32(blockHeightBytes[:], blockHeight) + + var newTip [pruneTipBytes]byte + copy(newTip[:], blockHash[:]) + + err = pruneBucket.Put(blockHeightBytes[:], newTip[:]) + if err != nil { + return err + } + + // Now that the graph has been pruned, we'll also attempt to + // prune any nodes that have had a channel closed within the + // latest block. + return c.pruneGraphNodes(nodes, edgeIndex) + }) + if err != nil { + return nil, err + } + + for _, channel := range chansClosed { + c.rejectCache.remove(channel.ChannelID) + c.chanCache.remove(channel.ChannelID) + } + + return chansClosed, nil +} + +// PruneGraphNodes is a garbage collection method which attempts to prune out +// any nodes from the channel graph that are currently unconnected. This ensure +// that we only maintain a graph of reachable nodes. In the event that a pruned +// node gains more channels, it will be re-added back to the graph. +func (c *ChannelGraph) PruneGraphNodes() error { + return c.db.Update(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + return c.pruneGraphNodes(nodes, edgeIndex) + }) +} + +// pruneGraphNodes attempts to remove any nodes from the graph who have had a +// channel closed within the current block. If the node still has existing +// channels in the graph, this will act as a no-op. +func (c *ChannelGraph) pruneGraphNodes(nodes *bbolt.Bucket, + edgeIndex *bbolt.Bucket) error { + + log.Trace("Pruning nodes from graph with no open channels") + + // We'll retrieve the graph's source node to ensure we don't remove it + // even if it no longer has any open channels. + sourceNode, err := c.sourceNode(nodes) + if err != nil { + return err + } + + // We'll use this map to keep count the number of references to a node + // in the graph. A node should only be removed once it has no more + // references in the graph. + nodeRefCounts := make(map[[33]byte]int) + err = nodes.ForEach(func(pubKey, nodeBytes []byte) error { + // If this is the source key, then we skip this + // iteration as the value for this key is a pubKey + // rather than raw node information. + if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { + return nil + } + + var nodePub [33]byte + copy(nodePub[:], pubKey) + nodeRefCounts[nodePub] = 0 + + return nil + }) + if err != nil { + return err + } + + // To ensure we never delete the source node, we'll start off by + // bumping its ref count to 1. + nodeRefCounts[sourceNode.PubKeyBytes] = 1 + + // Next, we'll run through the edgeIndex which maps a channel ID to the + // edge info. We'll use this scan to populate our reference count map + // above. + err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + // The first 66 bytes of the edge info contain the pubkeys of + // the nodes that this edge attaches. We'll extract them, and + // add them to the ref count map. + var node1, node2 [33]byte + copy(node1[:], edgeInfoBytes[:33]) + copy(node2[:], edgeInfoBytes[33:]) + + // With the nodes extracted, we'll increase the ref count of + // each of the nodes. + nodeRefCounts[node1]++ + nodeRefCounts[node2]++ + + return nil + }) + if err != nil { + return err + } + + // Finally, we'll make a second pass over the set of nodes, and delete + // any nodes that have a ref count of zero. + var numNodesPruned int + for nodePubKey, refCount := range nodeRefCounts { + // If the ref count of the node isn't zero, then we can safely + // skip it as it still has edges to or from it within the + // graph. + if refCount != 0 { + continue + } + + // If we reach this point, then there are no longer any edges + // that connect this node, so we can delete it. + if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { + log.Warnf("Unable to prune node %x from the "+ + "graph: %v", nodePubKey, err) + continue + } + + log.Infof("Pruned unconnected node %x from channel graph", + nodePubKey[:]) + + numNodesPruned++ + } + + if numNodesPruned > 0 { + log.Infof("Pruned %v unconnected nodes from the channel graph", + numNodesPruned) + } + + return nil +} + +// DisconnectBlockAtHeight is used to indicate that the block specified +// by the passed height has been disconnected from the main chain. This +// will "rewind" the graph back to the height below, deleting channels +// that are no longer confirmed from the graph. The prune log will be +// set to the last prune height valid for the remaining chain. +// Channels that were removed from the graph resulting from the +// disconnected block are returned. +func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, + error) { + + // Every channel having a ShortChannelID starting at 'height' + // will no longer be confirmed. + startShortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + } + + // Delete everything after this height from the db. + endShortChanID := lnwire.ShortChannelID{ + BlockHeight: math.MaxUint32 & 0x00ffffff, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + // The block height will be the 3 first bytes of the channel IDs. + var chanIDStart [8]byte + byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) + var chanIDEnd [8]byte + byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + // Keep track of the channels that are removed from the graph. + var removedChans []*ChannelEdgeInfo + + if err := c.db.Update(func(tx *bbolt.Tx) error { + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return err + } + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return err + } + chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) + if err != nil { + return err + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return err + } + + // Scan from chanIDStart to chanIDEnd, deleting every + // found edge. + // NOTE: we must delete the edges after the cursor loop, since + // modifying the bucket while traversing is not safe. + var keys [][]byte + cursor := edgeIndex.Cursor() + for k, v := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { + + edgeInfoReader := bytes.NewReader(v) + edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader) + if err != nil { + return err + } + + keys = append(keys, k) + removedChans = append(removedChans, &edgeInfo) + } + + for _, k := range keys { + err = delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + k, false, + ) + if err != nil && err != ErrEdgeNotFound { + return err + } + } + + // Delete all the entries in the prune log having a height + // greater or equal to the block disconnected. + metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) + if err != nil { + return err + } + + pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) + if err != nil { + return err + } + + var pruneKeyStart [4]byte + byteOrder.PutUint32(pruneKeyStart[:], height) + + var pruneKeyEnd [4]byte + byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32) + + // To avoid modifying the bucket while traversing, we delete + // the keys in a second loop. + var pruneKeys [][]byte + pruneCursor := pruneBucket.Cursor() + for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && + bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { + + pruneKeys = append(pruneKeys, k) + } + + for _, k := range pruneKeys { + if err := pruneBucket.Delete(k); err != nil { + return err + } + } + + return nil + }); err != nil { + return nil, err + } + + for _, channel := range removedChans { + c.rejectCache.remove(channel.ChannelID) + c.chanCache.remove(channel.ChannelID) + } + + return removedChans, nil +} + +// PruneTip returns the block height and hash of the latest block that has been +// used to prune channels in the graph. Knowing the "prune tip" allows callers +// to tell if the graph is currently in sync with the current best known UTXO +// state. +func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { + var ( + tipHash chainhash.Hash + tipHeight uint32 + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + graphMeta := tx.Bucket(graphMetaBucket) + if graphMeta == nil { + return ErrGraphNotFound + } + pruneBucket := graphMeta.Bucket(pruneLogBucket) + if pruneBucket == nil { + return ErrGraphNeverPruned + } + + pruneCursor := pruneBucket.Cursor() + + // The prune key with the largest block height will be our + // prune tip. + k, v := pruneCursor.Last() + if k == nil { + return ErrGraphNeverPruned + } + + // Once we have the prune tip, the value will be the block hash, + // and the key the block height. + copy(tipHash[:], v[:]) + tipHeight = byteOrder.Uint32(k[:]) + + return nil + }) + if err != nil { + return nil, 0, err + } + + return &tipHash, tipHeight, nil +} + +// DeleteChannelEdges removes edges with the given channel IDs from the database +// and marks them as zombies. This ensures that we're unable to re-add it to our +// database once again. If an edge does not exist within the database, then +// ErrEdgeNotFound will be returned. +func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error { + // TODO(roasbeef): possibly delete from node bucket if node has no more + // channels + // TODO(roasbeef): don't delete both edges? + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrEdgeNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrEdgeNotFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodeNotFound + } + zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) + if err != nil { + return err + } + + var rawChanID [8]byte + for _, chanID := range chanIDs { + byteOrder.PutUint64(rawChanID[:], chanID) + err := delChannelEdge( + edges, edgeIndex, chanIndex, zombieIndex, nodes, + rawChanID[:], true, + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + for _, chanID := range chanIDs { + c.rejectCache.remove(chanID) + c.chanCache.remove(chanID) + } + + return nil +} + +// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the +// passed channel point (outpoint). If the passed channel doesn't exist within +// the database, then ErrEdgeNotFound is returned. +func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { + var chanID uint64 + if err := c.db.View(func(tx *bbolt.Tx) error { + var err error + chanID, err = getChanID(tx, chanPoint) + return err + }); err != nil { + return 0, err + } + + return chanID, nil +} + +// getChanID returns the assigned channel ID for a given channel point. +func getChanID(tx *bbolt.Tx, chanPoint *wire.OutPoint) (uint64, error) { + var b bytes.Buffer + if err := writeOutpoint(&b, chanPoint); err != nil { + return 0, err + } + + edges := tx.Bucket(edgeBucket) + if edges == nil { + return 0, ErrGraphNoEdgesFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return 0, ErrGraphNoEdgesFound + } + + chanIDBytes := chanIndex.Get(b.Bytes()) + if chanIDBytes == nil { + return 0, ErrEdgeNotFound + } + + chanID := byteOrder.Uint64(chanIDBytes) + + return chanID, nil +} + +// TODO(roasbeef): allow updates to use Batch? + +// HighestChanID returns the "highest" known channel ID in the channel graph. +// This represents the "newest" channel from the PoV of the chain. This method +// can be used by peers to quickly determine if they're graphs are in sync. +func (c *ChannelGraph) HighestChanID() (uint64, error) { + var cid uint64 + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to find the highest chan ID, we'll fetch a cursor + // and use that to seek to the "end" of our known rage. + cidCursor := edgeIndex.Cursor() + + lastChanID, _ := cidCursor.Last() + + // If there's no key, then this means that we don't actually + // know of any channels, so we'll return a predicable error. + if lastChanID == nil { + return ErrGraphNoEdgesFound + } + + // Otherwise, we'll de serialize the channel ID and return it + // to the caller. + cid = byteOrder.Uint64(lastChanID) + return nil + }) + if err != nil && err != ErrGraphNoEdgesFound { + return 0, err + } + + return cid, nil +} + +// ChannelEdge represents the complete set of information for a channel edge in +// the known channel graph. This struct couples the core information of the +// edge as well as each of the known advertised edge policies. +type ChannelEdge struct { + // Info contains all the static information describing the channel. + Info *ChannelEdgeInfo + + // Policy1 points to the "first" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy1 *ChannelEdgePolicy + + // Policy2 points to the "second" edge policy of the channel containing + // the dynamic information required to properly route through the edge. + Policy2 *ChannelEdgePolicy +} + +// ChanUpdatesInHorizon returns all the known channel edges which have at least +// one edge that has an update timestamp within the specified horizon. +func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { + // To ensure we don't return duplicate ChannelEdges, we'll use an + // additional map to keep track of the edges already seen to prevent + // re-adding it. + edgesSeen := make(map[uint64]struct{}) + edgesToCache := make(map[uint64]ChannelEdge) + var edgesInHorizon []ChannelEdge + + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var hits int + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all channels within the horizon. + updateCursor := edgeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 8]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting the info and policy of each update of + // each channel that has a last update within the time range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + // We have a new eligible entry, so we'll slice of the + // chan ID so we can query it in the DB. + chanID := indexKey[8:] + + // If we've already retrieved the info and policies for + // this edge, then we can skip it as we don't need to do + // so again. + chanIDInt := byteOrder.Uint64(chanID) + if _, ok := edgesSeen[chanIDInt]; ok { + continue + } + + if channel, ok := c.chanCache.get(chanIDInt); ok { + hits++ + edgesSeen[chanIDInt] = struct{}{} + edgesInHorizon = append(edgesInHorizon, channel) + continue + } + + // First, we'll fetch the static edge information. + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + chanID := byteOrder.Uint64(chanID) + return fmt.Errorf("unable to fetch info for "+ + "edge with chan_id=%v: %v", chanID, err) + } + edgeInfo.db = c.db + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + chanID := byteOrder.Uint64(chanID) + return fmt.Errorf("unable to fetch policies "+ + "for edge with chan_id=%v: %v", chanID, + err) + } + + // Finally, we'll collate this edge with the rest of + // edges to be returned. + edgesSeen[chanIDInt] = struct{}{} + channel := ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + } + edgesInHorizon = append(edgesInHorizon, channel) + edgesToCache[chanIDInt] = channel + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + // Insert any edges loaded from disk into the cache. + for chanid, channel := range edgesToCache { + c.chanCache.insert(chanid, channel) + } + + log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)", + float64(hits)/float64(len(edgesInHorizon)), hits, + len(edgesInHorizon)) + + return edgesInHorizon, nil +} + +// NodeUpdatesInHorizon returns all the known lightning node which have an +// update timestamp within the passed range. This method can be used by two +// nodes to quickly determine if they have the same set of up to date node +// announcements. +func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { + var nodesInHorizon []LightningNode + + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + + nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) + if nodeUpdateIndex == nil { + return ErrGraphNodesNotFound + } + + // We'll now obtain a cursor to perform a range query within + // the index to find all node announcements within the horizon. + updateCursor := nodeUpdateIndex.Cursor() + + var startTimeBytes, endTimeBytes [8 + 33]byte + byteOrder.PutUint64( + startTimeBytes[:8], uint64(startTime.Unix()), + ) + byteOrder.PutUint64( + endTimeBytes[:8], uint64(endTime.Unix()), + ) + + // With our start and end times constructed, we'll step through + // the index collecting info for each node within the time + // range. + for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && + bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { + + nodePub := indexKey[8:] + node, err := fetchLightningNode(nodes, nodePub) + if err != nil { + return err + } + node.db = c.db + + nodesInHorizon = append(nodesInHorizon, node) + } + + return nil + }) + switch { + case err == ErrGraphNoEdgesFound: + fallthrough + case err == ErrGraphNodesNotFound: + break + + case err != nil: + return nil, err + } + + return nodesInHorizon, nil +} + +// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan +// ID's that we don't know and are not known zombies of the passed set. In other +// words, we perform a set difference of our set of chan ID's and the ones +// passed in. This method can be used by callers to determine the set of +// channels another peer knows of that we don't. +func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { + var newChanIDs []uint64 + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // Fetch the zombie index, it may not exist if no edges have + // ever been marked as zombies. If the index has been + // initialized, we will use it later to skip known zombie edges. + zombieIndex := edges.Bucket(zombieBucket) + + // We'll run through the set of chanIDs and collate only the + // set of channel that are unable to be found within our db. + var cidBytes [8]byte + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // If the edge is already known, skip it. + if v := edgeIndex.Get(cidBytes[:]); v != nil { + continue + } + + // If the edge is a known zombie, skip it. + if zombieIndex != nil { + isZombie, _, _ := isZombieEdge(zombieIndex, cid) + if isZombie { + continue + } + } + + newChanIDs = append(newChanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any edges yet, then we'll return the entire set + // of chan IDs specified. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return newChanIDs, nil +} + +// FilterChannelRange returns the channel ID's of all known channels which were +// mined in a block height within the passed range. This method can be used to +// quickly share with a peer the set of channels we know of within a particular +// range to catch them up after a period of time offline. +func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { + var chanIDs []uint64 + + startChanID := &lnwire.ShortChannelID{ + BlockHeight: startHeight, + } + + endChanID := lnwire.ShortChannelID{ + BlockHeight: endHeight, + TxIndex: math.MaxUint32 & 0x00ffffff, + TxPosition: math.MaxUint16, + } + + // As we need to perform a range scan, we'll convert the starting and + // ending height to their corresponding values when encoded using short + // channel ID's. + var chanIDStart, chanIDEnd [8]byte + byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) + byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + cursor := edgeIndex.Cursor() + + // We'll now iterate through the database, and find each + // channel ID that resides within the specified range. + var cid uint64 + for k, _ := cursor.Seek(chanIDStart[:]); k != nil && + bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { + + // This channel ID rests within the target range, so + // we'll convert it into an integer and add it to our + // returned set. + cid = byteOrder.Uint64(k) + chanIDs = append(chanIDs, cid) + } + + return nil + }) + switch { + // If we don't know of any channels yet, then there's nothing to + // filter, so we'll return an empty slice. + case err == ErrGraphNoEdgesFound: + return chanIDs, nil + + case err != nil: + return nil, err + } + + return chanIDs, nil +} + +// FetchChanInfos returns the set of channel edges that correspond to the passed +// channel ID's. If an edge is the query is unknown to the database, it will +// skipped and the result will contain only those edges that exist at the time +// of the query. This can be used to respond to peer queries that are seeking to +// fill in gaps in their view of the channel graph. +func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { + // TODO(roasbeef): sort cids? + + var ( + chanEdges []ChannelEdge + cidBytes [8]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + for _, cid := range chanIDs { + byteOrder.PutUint64(cidBytes[:], cid) + + // First, we'll fetch the static edge information. If + // the edge is unknown, we will skip the edge and + // continue gathering all known edges. + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, cidBytes[:], + ) + switch { + case err == ErrEdgeNotFound: + continue + case err != nil: + return err + } + edgeInfo.db = c.db + + // With the static information obtained, we'll now + // fetch the dynamic policy info. + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, cidBytes[:], c.db, + ) + if err != nil { + return err + } + + chanEdges = append(chanEdges, ChannelEdge{ + Info: &edgeInfo, + Policy1: edge1, + Policy2: edge2, + }) + } + return nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil +} + +func delEdgeUpdateIndexEntry(edgesBucket *bbolt.Bucket, chanID uint64, + edge1, edge2 *ChannelEdgePolicy) error { + + // First, we'll fetch the edge update index bucket which currently + // stores an entry for the channel we're about to delete. + updateIndex := edgesBucket.Bucket(edgeUpdateIndexBucket) + if updateIndex == nil { + // No edges in bucket, return early. + return nil + } + + // Now that we have the bucket, we'll attempt to construct a template + // for the index key: updateTime || chanid. + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[8:], chanID) + + // With the template constructed, we'll attempt to delete an entry that + // would have been created by both edges: we'll alternate the update + // times, as one may had overridden the other. + if edge1 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + // We'll also attempt to delete the entry that may have been created by + // the second edge. + if edge2 != nil { + byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) + if err := updateIndex.Delete(indexKey[:]); err != nil { + return err + } + } + + return nil +} + +func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, + nodes *bbolt.Bucket, chanID []byte, isZombie bool) error { + + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + + // We'll also remove the entry in the edge update index bucket before + // we delete the edges themselves so we can access their last update + // times. + cid := byteOrder.Uint64(chanID) + edge1, edge2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, nil, + ) + if err != nil { + return err + } + err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2) + if err != nil { + return err + } + + // The edge key is of the format pubKey || chanID. First we construct + // the latter half, populating the channel ID. + var edgeKey [33 + 8]byte + copy(edgeKey[33:], chanID) + + // With the latter half constructed, copy over the first public key to + // delete the edge in this direction, then the second to delete the + // edge in the opposite direction. + copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:]) + if edges.Get(edgeKey[:]) != nil { + if err := edges.Delete(edgeKey[:]); err != nil { + return err + } + } + copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:]) + if edges.Get(edgeKey[:]) != nil { + if err := edges.Delete(edgeKey[:]); err != nil { + return err + } + } + + // As part of deleting the edge we also remove all disabled entries + // from the edgePolicyDisabledIndex bucket. We do that for both directions. + updateEdgePolicyDisabledIndex(edges, cid, false, false) + updateEdgePolicyDisabledIndex(edges, cid, true, false) + + // With the edge data deleted, we can purge the information from the two + // edge indexes. + if err := edgeIndex.Delete(chanID); err != nil { + return err + } + var b bytes.Buffer + if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + return err + } + if err := chanIndex.Delete(b.Bytes()); err != nil { + return err + } + + // Finally, we'll mark the edge as a zombie within our index if it's + // being removed due to the channel becoming a zombie. We do this to + // ensure we don't store unnecessary data for spent channels. + if !isZombie { + return nil + } + + return markEdgeZombie( + zombieIndex, byteOrder.Uint64(chanID), edgeInfo.NodeKey1Bytes, + edgeInfo.NodeKey2Bytes, + ) +} + +// UpdateEdgePolicy updates the edge routing policy for a single directed edge +// within the database for the referenced channel. The `flags` attribute within +// the ChannelEdgePolicy determines which of the directed edges are being +// updated. If the flag is 1, then the first node's information is being +// updated, otherwise it's the second node's information. The node ordering is +// determined by the lexicographical ordering of the identity public keys of +// the nodes on either side of the channel. +func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + var isUpdate1 bool + err := c.db.Update(func(tx *bbolt.Tx) error { + var err error + isUpdate1, err = updateEdgePolicy(tx, edge) + return err + }) + if err != nil { + return err + } + + // If an entry for this channel is found in reject cache, we'll modify + // the entry with the updated timestamp for the direction that was just + // written. If the edge doesn't exist, we'll load the cache entry lazily + // during the next query for this edge. + if entry, ok := c.rejectCache.get(edge.ChannelID); ok { + if isUpdate1 { + entry.upd1Time = edge.LastUpdate.Unix() + } else { + entry.upd2Time = edge.LastUpdate.Unix() + } + c.rejectCache.insert(edge.ChannelID, entry) + } + + // If an entry for this channel is found in channel cache, we'll modify + // the entry with the updated policy for the direction that was just + // written. If the edge doesn't exist, we'll defer loading the info and + // policies and lazily read from disk during the next query. + if channel, ok := c.chanCache.get(edge.ChannelID); ok { + if isUpdate1 { + channel.Policy1 = edge + } else { + channel.Policy2 = edge + } + c.chanCache.insert(edge.ChannelID, channel) + } + + return nil +} + +// updateEdgePolicy attempts to update an edge's policy within the relevant +// buckets using an existing database transaction. The returned boolean will be +// true if the updated policy belongs to node1, and false if the policy belonged +// to node2. +func updateEdgePolicy(tx *bbolt.Tx, edge *ChannelEdgePolicy) (bool, error) { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return false, ErrEdgeNotFound + + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return false, ErrEdgeNotFound + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return false, err + } + + // Create the channelID key be converting the channel ID + // integer into a byte slice. + var chanID [8]byte + byteOrder.PutUint64(chanID[:], edge.ChannelID) + + // With the channel ID, we then fetch the value storing the two + // nodes which connect this channel edge. + nodeInfo := edgeIndex.Get(chanID[:]) + if nodeInfo == nil { + return false, ErrEdgeNotFound + } + + // Depending on the flags value passed above, either the first + // or second edge policy is being updated. + var fromNode, toNode []byte + var isUpdate1 bool + if edge.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + fromNode = nodeInfo[:33] + toNode = nodeInfo[33:66] + isUpdate1 = true + } else { + fromNode = nodeInfo[33:66] + toNode = nodeInfo[:33] + isUpdate1 = false + } + + // Finally, with the direction of the edge being updated + // identified, we update the on-disk edge representation. + err = putChanEdgePolicy(edges, nodes, edge, fromNode, toNode) + if err != nil { + return false, err + } + + return isUpdate1, nil +} + +// LightningNode represents an individual vertex/node within the channel graph. +// A node is connected to other nodes by one or more channel edges emanating +// from it. As the graph is directed, a node will also have an incoming edge +// attached to it for each outgoing edge. +type LightningNode struct { + // PubKeyBytes is the raw bytes of the public key of the target node. + PubKeyBytes [33]byte + pubKey *btcec.PublicKey + + // HaveNodeAnnouncement indicates whether we received a node + // announcement for this particular node. If true, the remaining fields + // will be set, if false only the PubKey is known for this node. + HaveNodeAnnouncement bool + + // LastUpdate is the last time the vertex information for this node has + // been updated. + LastUpdate time.Time + + // Address is the TCP address this node is reachable over. + Addresses []net.Addr + + // Color is the selected color for the node. + Color color.RGBA + + // Alias is a nick-name for the node. The alias can be used to confirm + // a node's identity or to serve as a short ID for an address book. + Alias string + + // AuthSigBytes is the raw signature under the advertised public key + // which serves to authenticate the attributes announced by this node. + AuthSigBytes []byte + + // Features is the list of protocol features supported by this node. + Features *lnwire.FeatureVector + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB + + // TODO(roasbeef): discovery will need storage to keep it's last IP + // address and re-announce if interface changes? + + // TODO(roasbeef): add update method and fetch? +} + +// PubKey is the node's long-term identity public key. This key will be used to +// authenticated any advertisements/updates sent by the node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { + if l.pubKey != nil { + return l.pubKey, nil + } + + key, err := btcec.ParsePubKey(l.PubKeyBytes[:], btcec.S256()) + if err != nil { + return nil, err + } + l.pubKey = key + + return key, nil +} + +// AuthSig is a signature under the advertised public key which serves to +// authenticate the attributes announced by this node. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (l *LightningNode) AuthSig() (*btcec.Signature, error) { + return btcec.ParseSignature(l.AuthSigBytes, btcec.S256()) +} + +// AddPubKey is a setter-link method that can be used to swap out the public +// key for a node. +func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { + l.pubKey = key + copy(l.PubKeyBytes[:], key.SerializeCompressed()) +} + +// NodeAnnouncement retrieves the latest node announcement of the node. +func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, + error) { + + if !l.HaveNodeAnnouncement { + return nil, fmt.Errorf("node does not have node announcement") + } + + alias, err := lnwire.NewNodeAlias(l.Alias) + if err != nil { + return nil, err + } + + nodeAnn := &lnwire.NodeAnnouncement{ + Features: l.Features.RawFeatureVector, + NodeID: l.PubKeyBytes, + RGBColor: l.Color, + Alias: alias, + Addresses: l.Addresses, + Timestamp: uint32(l.LastUpdate.Unix()), + ExtraOpaqueData: l.ExtraOpaqueData, + } + + if !signed { + return nodeAnn, nil + } + + sig, err := lnwire.NewSigFromRawSignature(l.AuthSigBytes) + if err != nil { + return nil, err + } + + nodeAnn.Signature = sig + + return nodeAnn, nil +} + +// isPublic determines whether the node is seen as public within the graph from +// the source node's point of view. An existing database transaction can also be +// specified. +func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error) { + // In order to determine whether this node is publicly advertised within + // the graph, we'll need to look at all of its edges and check whether + // they extend to any other node than the source node. errDone will be + // used to terminate the check early. + nodeIsPublic := false + errDone := errors.New("done") + err := l.ForEachChannel(tx, func(_ *bbolt.Tx, info *ChannelEdgeInfo, + _, _ *ChannelEdgePolicy) error { + + // If this edge doesn't extend to the source node, we'll + // terminate our search as we can now conclude that the node is + // publicly advertised within the graph due to the local node + // knowing of the current edge. + if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) && + !bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) { + + nodeIsPublic = true + return errDone + } + + // Since the edge _does_ extend to the source node, we'll also + // need to ensure that this is a public edge. + if info.AuthProof != nil { + nodeIsPublic = true + return errDone + } + + // Otherwise, we'll continue our search. + return nil + }) + if err != nil && err != errDone { + return false, err + } + + return nodeIsPublic, nil +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then ErrGraphNodeNotFound is +// returned. +func (c *ChannelGraph) FetchLightningNode(pub *btcec.PublicKey) (*LightningNode, error) { + var node *LightningNode + nodePub := pub.SerializeCompressed() + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // If a key for this serialized public key isn't found, then + // the target node doesn't exist within the database. + nodeBytes := nodes.Get(nodePub) + if nodeBytes == nil { + return ErrGraphNodeNotFound + } + + // If the node is found, then we can de deserialize the node + // information to return to the user. + nodeReader := bytes.NewReader(nodeBytes) + n, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + n.db = c.db + + node = &n + + return nil + }) + if err != nil { + return nil, err + } + + return node, nil +} + +// HasLightningNode determines if the graph has a vertex identified by the +// target node identity public key. If the node exists in the database, a +// timestamp of when the data for the node was lasted updated is returned along +// with a true boolean. Otherwise, an empty time.Time is returned with a false +// boolean. +func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) { + var ( + updateTime time.Time + exists bool + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // If a key for this serialized public key isn't found, we can + // exit early. + nodeBytes := nodes.Get(nodePub[:]) + if nodeBytes == nil { + exists = false + return nil + } + + // Otherwise we continue on to obtain the time stamp + // representing the last time the data for this node was + // updated. + nodeReader := bytes.NewReader(nodeBytes) + node, err := deserializeLightningNode(nodeReader) + if err != nil { + return err + } + + exists = true + updateTime = node.LastUpdate + return nil + }) + if err != nil { + return time.Time{}, exists, err + } + + return updateTime, exists, nil +} + +// nodeTraversal is used to traverse all channels of a node given by its +// public key and passes channel information into the specified callback. +func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + + traversal := func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNotFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // In order to reach all the edges for this node, we take + // advantage of the construction of the key-space within the + // edge bucket. The keys are stored in the form: pubKey || + // chanID. Therefore, starting from a chanID of zero, we can + // scan forward in the bucket, grabbing all the edges for the + // node. Once the prefix no longer matches, then we know we're + // done. + var nodeStart [33 + 8]byte + copy(nodeStart[:], nodePub) + copy(nodeStart[33:], chanStart[:]) + + // Starting from the key pubKey || 0, we seek forward in the + // bucket until the retrieved key no longer has the public key + // as its prefix. This indicates that we've stepped over into + // another node's edges, so we can terminate our scan. + edgeCursor := edges.Cursor() + for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { + // If the prefix still matches, the channel id is + // returned in nodeEdge. Channel id is used to lookup + // the node at the other end of the channel and both + // edge policies. + chanID := nodeEdge[33:] + edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + edgeInfo.db = db + + outgoingPolicy, err := fetchChanEdgePolicy( + edges, chanID, nodePub, nodes, + ) + if err != nil { + return err + } + + otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub) + if err != nil { + return err + } + + incomingPolicy, err := fetchChanEdgePolicy( + edges, chanID, otherNode[:], nodes, + ) + if err != nil { + return err + } + + // Finally, we execute the callback. + err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) + if err != nil { + return err + } + } + + return nil + } + + // If no transaction was provided, then we'll create a new transaction + // to execute the transaction within. + if tx == nil { + return db.View(traversal) + } + + // Otherwise, we re-use the existing transaction to execute the graph + // traversal. + return traversal(tx) +} + +// ForEachChannel iterates through all channels of this node, executing the +// passed callback with an edge info structure and the policies of each end +// of the channel. The first edge policy is the outgoing edge *to* the +// the connecting node, while the second is the incoming edge *from* the +// connecting node. If the callback returns an error, then the iteration is +// halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// If the caller wishes to re-use an existing boltdb transaction, then it +// should be passed as the first argument. Otherwise the first argument should +// be nil and a fresh transaction will be created to execute the graph +// traversal. +func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, + cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { + + nodePub := l.PubKeyBytes[:] + db := l.db + + return nodeTraversal(tx, nodePub, db, cb) +} + +// ChannelEdgeInfo represents a fully authenticated channel along with all its +// unique attributes. Once an authenticated channel announcement has been +// processed on the network, then an instance of ChannelEdgeInfo encapsulating +// the channels attributes is stored. The other portions relevant to routing +// policy of a channel are stored within a ChannelEdgePolicy for each direction +// of the channel. +type ChannelEdgeInfo struct { + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // ChainHash is the hash that uniquely identifies the chain that this + // channel was opened within. + // + // TODO(roasbeef): need to modify db keying for multi-chain + // * must add chain hash to prefix as well + ChainHash chainhash.Hash + + // NodeKey1Bytes is the raw public key of the first node. + NodeKey1Bytes [33]byte + nodeKey1 *btcec.PublicKey + + // NodeKey2Bytes is the raw public key of the first node. + NodeKey2Bytes [33]byte + nodeKey2 *btcec.PublicKey + + // BitcoinKey1Bytes is the raw public key of the first node. + BitcoinKey1Bytes [33]byte + bitcoinKey1 *btcec.PublicKey + + // BitcoinKey2Bytes is the raw public key of the first node. + BitcoinKey2Bytes [33]byte + bitcoinKey2 *btcec.PublicKey + + // Features is an opaque byte slice that encodes the set of channel + // specific features that this channel edge supports. + Features []byte + + // AuthProof is the authentication proof for this channel. This proof + // contains a set of signatures binding four identities, which attests + // to the legitimacy of the advertised channel. + AuthProof *ChannelAuthProof + + // ChannelPoint is the funding outpoint of the channel. This can be + // used to uniquely identify the channel within the channel graph. + ChannelPoint wire.OutPoint + + // Capacity is the total capacity of the channel, this is determined by + // the value output in the outpoint that created this channel. + Capacity btcutil.Amount + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB +} + +// AddNodeKeys is a setter-like method that can be used to replace the set of +// keys for the target ChannelEdgeInfo. +func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, + bitcoinKey2 *btcec.PublicKey) { + + c.nodeKey1 = nodeKey1 + copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed()) + + c.nodeKey2 = nodeKey2 + copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed()) + + c.bitcoinKey1 = bitcoinKey1 + copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed()) + + c.bitcoinKey2 = bitcoinKey2 + copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) +} + +// NodeKey1 is the identity public key of the "first" node that was involved in +// the creation of this channel. A node is considered "first" if the +// lexicographical ordering the its serialized public key is "smaller" than +// that of the other node involved in channel creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { + if c.nodeKey1 != nil { + return c.nodeKey1, nil + } + + key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.nodeKey1 = key + + return key, nil +} + +// NodeKey2 is the identity public key of the "second" node that was +// involved in the creation of this channel. A node is considered +// "second" if the lexicographical ordering the its serialized public +// key is "larger" than that of the other node involved in channel +// creation. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { + if c.nodeKey2 != nil { + return c.nodeKey2, nil + } + + key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.nodeKey2 = key + + return key, nil +} + +// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first +// node, that was involved in the funding transaction that originally +// created the channel that this struct represents. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { + if c.bitcoinKey1 != nil { + return c.bitcoinKey1, nil + } + + key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.bitcoinKey1 = key + + return key, nil +} + +// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second +// node, that was involved in the funding transaction that originally +// created the channel that this struct represents. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the pubkey if absolutely necessary. +func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { + if c.bitcoinKey2 != nil { + return c.bitcoinKey2, nil + } + + key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:], btcec.S256()) + if err != nil { + return nil, err + } + c.bitcoinKey2 = key + + return key, nil +} + +// OtherNodeKeyBytes returns the node key bytes of the other end of +// the channel. +func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( + [33]byte, error) { + + switch { + case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): + return c.NodeKey2Bytes, nil + case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): + return c.NodeKey1Bytes, nil + default: + return [33]byte{}, fmt.Errorf("node not participating in this channel") + } +} + +// FetchOtherNode attempts to fetch the full LightningNode that's opposite of +// the target node in the channel. This is useful when one knows the pubkey of +// one of the nodes, and wishes to obtain the full LightningNode for the other +// end of the channel. +func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*LightningNode, error) { + + // Ensure that the node passed in is actually a member of the channel. + var targetNodeBytes [33]byte + switch { + case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): + targetNodeBytes = c.NodeKey2Bytes + case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): + targetNodeBytes = c.NodeKey1Bytes + default: + return nil, fmt.Errorf("node not participating in this channel") + } + + var targetNode *LightningNode + fetchNodeFunc := func(tx *bbolt.Tx) error { + // First grab the nodes bucket which stores the mapping from + // pubKey to node information. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + node, err := fetchLightningNode(nodes, targetNodeBytes[:]) + if err != nil { + return err + } + node.db = c.db + + targetNode = &node + + return nil + } + + // If the transaction is nil, then we'll need to create a new one, + // otherwise we can use the existing db transaction. + var err error + if tx == nil { + err = c.db.View(fetchNodeFunc) + } else { + err = fetchNodeFunc(tx) + } + + return targetNode, err +} + +// ChannelAuthProof is the authentication proof (the signature portion) for a +// channel. Using the four signatures contained in the struct, and some +// auxiliary knowledge (the funding script, node identities, and outpoint) nodes +// on the network are able to validate the authenticity and existence of a +// channel. Each of these signatures signs the following digest: chanID || +// nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || +// features. +type ChannelAuthProof struct { + // nodeSig1 is a cached instance of the first node signature. + nodeSig1 *btcec.Signature + + // NodeSig1Bytes are the raw bytes of the first node signature encoded + // in DER format. + NodeSig1Bytes []byte + + // nodeSig2 is a cached instance of the second node signature. + nodeSig2 *btcec.Signature + + // NodeSig2Bytes are the raw bytes of the second node signature + // encoded in DER format. + NodeSig2Bytes []byte + + // bitcoinSig1 is a cached instance of the first bitcoin signature. + bitcoinSig1 *btcec.Signature + + // BitcoinSig1Bytes are the raw bytes of the first bitcoin signature + // encoded in DER format. + BitcoinSig1Bytes []byte + + // bitcoinSig2 is a cached instance of the second bitcoin signature. + bitcoinSig2 *btcec.Signature + + // BitcoinSig2Bytes are the raw bytes of the second bitcoin signature + // encoded in DER format. + BitcoinSig2Bytes []byte +} + +// Node1Sig is the signature using the identity key of the node that is first +// in a lexicographical ordering of the serialized public keys of the two nodes +// that created the channel. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) Node1Sig() (*btcec.Signature, error) { + if c.nodeSig1 != nil { + return c.nodeSig1, nil + } + + sig, err := btcec.ParseSignature(c.NodeSig1Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.nodeSig1 = sig + + return sig, nil +} + +// Node2Sig is the signature using the identity key of the node that is second +// in a lexicographical ordering of the serialized public keys of the two nodes +// that created the channel. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) Node2Sig() (*btcec.Signature, error) { + if c.nodeSig2 != nil { + return c.nodeSig2, nil + } + + sig, err := btcec.ParseSignature(c.NodeSig2Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.nodeSig2 = sig + + return sig, nil +} + +// BitcoinSig1 is the signature using the public key of the first node that was +// used in the channel's multi-sig output. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) BitcoinSig1() (*btcec.Signature, error) { + if c.bitcoinSig1 != nil { + return c.bitcoinSig1, nil + } + + sig, err := btcec.ParseSignature(c.BitcoinSig1Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.bitcoinSig1 = sig + + return sig, nil +} + +// BitcoinSig2 is the signature using the public key of the second node that +// was used in the channel's multi-sig output. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelAuthProof) BitcoinSig2() (*btcec.Signature, error) { + if c.bitcoinSig2 != nil { + return c.bitcoinSig2, nil + } + + sig, err := btcec.ParseSignature(c.BitcoinSig2Bytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.bitcoinSig2 = sig + + return sig, nil +} + +// IsEmpty check is the authentication proof is empty Proof is empty if at +// least one of the signatures are equal to nil. +func (c *ChannelAuthProof) IsEmpty() bool { + return len(c.NodeSig1Bytes) == 0 || + len(c.NodeSig2Bytes) == 0 || + len(c.BitcoinSig1Bytes) == 0 || + len(c.BitcoinSig2Bytes) == 0 +} + +// ChannelEdgePolicy represents a *directed* edge within the channel graph. For +// each channel in the database, there are two distinct edges: one for each +// possible direction of travel along the channel. The edges themselves hold +// information concerning fees, and minimum time-lock information which is +// utilized during path finding. +type ChannelEdgePolicy struct { + // SigBytes is the raw bytes of the signature of the channel edge + // policy. We'll only parse these if the caller needs to access the + // signature for validation purposes. Do not set SigBytes directly, but + // use SetSigBytes instead to make sure that the cache is invalidated. + SigBytes []byte + + // sig is a cached fully parsed signature. + sig *btcec.Signature + + // ChannelID is the unique channel ID for the channel. The first 3 + // bytes are the block height, the next 3 the index within the block, + // and the last 2 bytes are the output index for the channel. + ChannelID uint64 + + // LastUpdate is the last time an authenticated edge for this channel + // was received. + LastUpdate time.Time + + // MessageFlags is a bitfield which indicates the presence of optional + // fields (like max_htlc) in the policy. + MessageFlags lnwire.ChanUpdateMsgFlags + + // ChannelFlags is a bitfield which signals the capabilities of the + // channel as well as the directed edge this update applies to. + ChannelFlags lnwire.ChanUpdateChanFlags + + // TimeLockDelta is the number of blocks this node will subtract from + // the expiry of an incoming HTLC. This value expresses the time buffer + // the node would like to HTLC exchanges. + TimeLockDelta uint16 + + // MinHTLC is the smallest value HTLC this node will accept, expressed + // in millisatoshi. + MinHTLC lnwire.MilliSatoshi + + // MaxHTLC is the largest value HTLC this node will accept, expressed + // in millisatoshi. + MaxHTLC lnwire.MilliSatoshi + + // FeeBaseMSat is the base HTLC fee that will be charged for forwarding + // ANY HTLC, expressed in mSAT's. + FeeBaseMSat lnwire.MilliSatoshi + + // FeeProportionalMillionths is the rate that the node will charge for + // HTLCs for each millionth of a satoshi forwarded. + FeeProportionalMillionths lnwire.MilliSatoshi + + // Node is the LightningNode that this directed edge leads to. Using + // this pointer the channel graph can further be traversed. + Node *LightningNode + + // ExtraOpaqueData is the set of data that was appended to this + // message, some of which we may not actually know how to iterate or + // parse. By holding onto this data, we ensure that we're able to + // properly validate the set of signatures that cover these new fields, + // and ensure we're able to make upgrades to the network in a forwards + // compatible manner. + ExtraOpaqueData []byte + + db *DB +} + +// Signature is a channel announcement signature, which is needed for proper +// edge policy announcement. +// +// NOTE: By having this method to access an attribute, we ensure we only need +// to fully deserialize the signature if absolutely necessary. +func (c *ChannelEdgePolicy) Signature() (*btcec.Signature, error) { + if c.sig != nil { + return c.sig, nil + } + + sig, err := btcec.ParseSignature(c.SigBytes, btcec.S256()) + if err != nil { + return nil, err + } + + c.sig = sig + + return sig, nil +} + +// SetSigBytes updates the signature and invalidates the cached parsed +// signature. +func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { + c.SigBytes = sig + c.sig = nil +} + +// IsDisabled determines whether the edge has the disabled bit set. +func (c *ChannelEdgePolicy) IsDisabled() bool { + return c.ChannelFlags&lnwire.ChanUpdateDisabled == + lnwire.ChanUpdateDisabled +} + +// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over +// the passed active payment channel. This value is currently computed as +// specified in BOLT07, but will likely change in the near future. +func (c *ChannelEdgePolicy) ComputeFee( + amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts +} + +// divideCeil divides dividend by factor and rounds the result up. +func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { + return (dividend + factor - 1) / factor +} + +// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming +// amount. +func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( + incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { + + return incomingAmt - divideCeil( + feeRateParts*(incomingAmt-c.FeeBaseMSat), + feeRateParts+c.FeeProportionalMillionths, + ) +} + +// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for +// the channel identified by the funding outpoint. If the channel can't be +// found, then ErrEdgeNotFound is returned. A struct which houses the general +// information for the channel itself is returned as well as two structs that +// contain the routing policies for the channel in either direction. +func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, +) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + var ( + edgeInfo *ChannelEdgeInfo + policy1 *ChannelEdgePolicy + policy2 *ChannelEdgePolicy + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // If the channel's outpoint doesn't exist within the outpoint + // index, then the edge does not exist. + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrGraphNoEdgesFound + } + var b bytes.Buffer + if err := writeOutpoint(&b, op); err != nil { + return err + } + chanID := chanIndex.Get(b.Bytes()) + if chanID == nil { + return ErrEdgeNotFound + } + + // If the channel is found to exists, then we'll first retrieve + // the general information for the channel. + edge, err := fetchChanEdgeInfo(edgeIndex, chanID) + if err != nil { + return err + } + edgeInfo = &edge + edgeInfo.db = c.db + + // Once we have the information about the channels' parameters, + // we'll fetch the routing policies for each for the directed + // edges. + e1, e2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, chanID, c.db, + ) + if err != nil { + return err + } + + policy1 = e1 + policy2 = e2 + return nil + }) + if err != nil { + return nil, nil, nil, err + } + + return edgeInfo, policy1, policy2, nil +} + +// FetchChannelEdgesByID attempts to lookup the two directed edges for the +// channel identified by the channel ID. If the channel can't be found, then +// ErrEdgeNotFound is returned. A struct which houses the general information +// for the channel itself is returned as well as two structs that contain the +// routing policies for the channel in either direction. +// +// ErrZombieEdge an be returned if the edge is currently marked as a zombie +// within the database. In this case, the ChannelEdgePolicy's will be nil, and +// the ChannelEdgeInfo will only include the public keys of each node. +func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, +) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + var ( + edgeInfo *ChannelEdgeInfo + policy1 *ChannelEdgePolicy + policy2 *ChannelEdgePolicy + channelID [8]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + // First, grab the node bucket. This will be used to populate + // the Node pointers in each edge read from disk. + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + // Next, grab the edge bucket which stores the edges, and also + // the index itself so we can group the directed edges together + // logically. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + byteOrder.PutUint64(channelID[:], chanID) + + // Now, attempt to fetch edge. + edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:]) + + // If it doesn't exist, we'll quickly check our zombie index to + // see if we've previously marked it as so. + if err == ErrEdgeNotFound { + // If the zombie index doesn't exist, or the edge is not + // marked as a zombie within it, then we'll return the + // original ErrEdgeNotFound error. + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return ErrEdgeNotFound + } + + isZombie, pubKey1, pubKey2 := isZombieEdge( + zombieIndex, chanID, + ) + if !isZombie { + return ErrEdgeNotFound + } + + // Otherwise, the edge is marked as a zombie, so we'll + // populate the edge info with the public keys of each + // party as this is the only information we have about + // it and return an error signaling so. + edgeInfo = &ChannelEdgeInfo{ + NodeKey1Bytes: pubKey1, + NodeKey2Bytes: pubKey2, + } + return ErrZombieEdge + } + + // Otherwise, we'll just return the error if any. + if err != nil { + return err + } + + edgeInfo = &edge + edgeInfo.db = c.db + + // Then we'll attempt to fetch the accompanying policies of this + // edge. + e1, e2, err := fetchChanEdgePolicies( + edgeIndex, edges, nodes, channelID[:], c.db, + ) + if err != nil { + return err + } + + policy1 = e1 + policy2 = e2 + return nil + }) + if err == ErrZombieEdge { + return edgeInfo, nil, nil, err + } + if err != nil { + return nil, nil, nil, err + } + + return edgeInfo, policy1, policy2, nil +} + +// IsPublicNode is a helper method that determines whether the node with the +// given public key is seen as a public node in the graph from the graph's +// source node's point of view. +func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { + var nodeIsPublic bool + err := c.db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNodesNotFound + } + ourPubKey := nodes.Get(sourceKey) + if ourPubKey == nil { + return ErrSourceNodeNotSet + } + node, err := fetchLightningNode(nodes, pubKey[:]) + if err != nil { + return err + } + + nodeIsPublic, err = node.isPublic(tx, ourPubKey) + return err + }) + if err != nil { + return false, err + } + + return nodeIsPublic, nil +} + +// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. +func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { + if len(aPub) != 33 || len(bPub) != 33 { + return nil, fmt.Errorf("Pubkey size error. Compressed " + + "pubkeys only") + } + + // Swap to sort pubkeys if needed. Keys are sorted in lexicographical + // order. The signatures within the scriptSig must also adhere to the + // order, ensuring that the signatures for each public key appears in + // the proper order on the stack. + if bytes.Compare(aPub, bPub) == 1 { + aPub, bPub = bPub, aPub + } + + // First, we'll generate the witness script for the multi-sig. + bldr := txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_2) + bldr.AddData(aPub) // Add both pubkeys (sorted). + bldr.AddData(bPub) + bldr.AddOp(txscript.OP_2) + bldr.AddOp(txscript.OP_CHECKMULTISIG) + witnessScript, err := bldr.Script() + if err != nil { + return nil, err + } + + // With the witness script generated, we'll now turn it into a p2sh + // script: + // * OP_0 + bldr = txscript.NewScriptBuilder() + bldr.AddOp(txscript.OP_0) + scriptHash := sha256.Sum256(witnessScript) + bldr.AddData(scriptHash[:]) + + return bldr.Script() +} + +// EdgePoint couples the outpoint of a channel with the funding script that it +// creates. The FilteredChainView will use this to watch for spends of this +// edge point on chain. We require both of these values as depending on the +// concrete implementation, either the pkScript, or the out point will be used. +type EdgePoint struct { + // FundingPkScript is the p2wsh multi-sig script of the target channel. + FundingPkScript []byte + + // OutPoint is the outpoint of the target channel. + OutPoint wire.OutPoint +} + +// String returns a human readable version of the target EdgePoint. We return +// the outpoint directly as it is enough to uniquely identify the edge point. +func (e *EdgePoint) String() string { + return e.OutPoint.String() +} + +// ChannelView returns the verifiable edge information for each active channel +// within the known channel graph. The set of UTXO's (along with their scripts) +// returned are the ones that need to be watched on chain to detect channel +// closes on the resident blockchain. +func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { + var edgePoints []EdgePoint + if err := c.db.View(func(tx *bbolt.Tx) error { + // We're going to iterate over the entire channel index, so + // we'll need to fetch the edgeBucket to get to the index as + // it's a sub-bucket. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + chanIndex := edges.Bucket(channelPointBucket) + if chanIndex == nil { + return ErrGraphNoEdgesFound + } + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrGraphNoEdgesFound + } + + // Once we have the proper bucket, we'll range over each key + // (which is the channel point for the channel) and decode it, + // accumulating each entry. + return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { + chanPointReader := bytes.NewReader(chanPointBytes) + + var chanPoint wire.OutPoint + err := readOutpoint(chanPointReader, &chanPoint) + if err != nil { + return err + } + + edgeInfo, err := fetchChanEdgeInfo( + edgeIndex, chanID, + ) + if err != nil { + return err + } + + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], + edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + return err + } + + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: chanPoint, + }) + + return nil + }) + }); err != nil { + return nil, err + } + + return edgePoints, nil +} + +// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. +func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { + return &ChannelEdgePolicy{db: c.db} +} + +// markEdgeZombie marks an edge as a zombie within our zombie index. The public +// keys should represent the node public keys of the two parties involved in the +// edge. +func markEdgeZombie(zombieIndex *bbolt.Bucket, chanID uint64, pubKey1, + pubKey2 [33]byte) error { + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + + var v [66]byte + copy(v[:33], pubKey1[:]) + copy(v[33:], pubKey2[:]) + + return zombieIndex.Put(k[:], v[:]) +} + +// MarkEdgeLive clears an edge from our zombie index, deeming it as live. +func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + + err := c.db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + return zombieIndex.Delete(k[:]) + }) + if err != nil { + return err + } + + c.rejectCache.remove(chanID) + c.chanCache.remove(chanID) + + return nil +} + +// IsZombieEdge returns whether the edge is considered zombie. If it is a +// zombie, then the two node public keys corresponding to this edge are also +// returned. +func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { + var ( + isZombie bool + pubKey1, pubKey2 [33]byte + ) + + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) + return nil + }) + if err != nil { + return false, [33]byte{}, [33]byte{} + } + + return isZombie, pubKey1, pubKey2 +} + +// isZombieEdge returns whether an entry exists for the given channel in the +// zombie index. If an entry exists, then the two node public keys corresponding +// to this edge are also returned. +func isZombieEdge(zombieIndex *bbolt.Bucket, + chanID uint64) (bool, [33]byte, [33]byte) { + + var k [8]byte + byteOrder.PutUint64(k[:], chanID) + + v := zombieIndex.Get(k[:]) + if v == nil { + return false, [33]byte{}, [33]byte{} + } + + var pubKey1, pubKey2 [33]byte + copy(pubKey1[:], v[:33]) + copy(pubKey2[:], v[33:]) + + return true, pubKey1, pubKey2 +} + +// NumZombies returns the current number of zombie channels in the graph. +func (c *ChannelGraph) NumZombies() (uint64, error) { + var numZombies uint64 + err := c.db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + zombieIndex := edges.Bucket(zombieBucket) + if zombieIndex == nil { + return nil + } + + return zombieIndex.ForEach(func(_, _ []byte) error { + numZombies++ + return nil + }) + }) + if err != nil { + return 0, err + } + + return numZombies, nil +} + +func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket, + updateIndex *bbolt.Bucket, node *LightningNode) error { + + var ( + scratch [16]byte + b bytes.Buffer + ) + + pub, err := node.PubKey() + if err != nil { + return err + } + nodePub := pub.SerializeCompressed() + + // If the node has the update time set, write it, else write 0. + updateUnix := uint64(0) + if node.LastUpdate.Unix() > 0 { + updateUnix = uint64(node.LastUpdate.Unix()) + } + + byteOrder.PutUint64(scratch[:8], updateUnix) + if _, err := b.Write(scratch[:8]); err != nil { + return err + } + + if _, err := b.Write(nodePub); err != nil { + return err + } + + // If we got a node announcement for this node, we will have the rest + // of the data available. If not we don't have more data to write. + if !node.HaveNodeAnnouncement { + // Write HaveNodeAnnouncement=0. + byteOrder.PutUint16(scratch[:2], 0) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) + } + + // Write HaveNodeAnnouncement=1. + byteOrder.PutUint16(scratch[:2], 1) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + if err := binary.Write(&b, byteOrder, node.Color.R); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, node.Color.G); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, node.Color.B); err != nil { + return err + } + + if err := wire.WriteVarString(&b, 0, node.Alias); err != nil { + return err + } + + if err := node.Features.Encode(&b); err != nil { + return err + } + + numAddresses := uint16(len(node.Addresses)) + byteOrder.PutUint16(scratch[:2], numAddresses) + if _, err := b.Write(scratch[:2]); err != nil { + return err + } + + for _, address := range node.Addresses { + if err := serializeAddr(&b, address); err != nil { + return err + } + } + + sigLen := len(node.AuthSigBytes) + if sigLen > 80 { + return fmt.Errorf("max sig len allowed is 80, had %v", + sigLen) + } + + err = wire.WriteVarBytes(&b, 0, node.AuthSigBytes) + if err != nil { + return err + } + + if len(node.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(node.ExtraOpaqueData)) + } + err = wire.WriteVarBytes(&b, 0, node.ExtraOpaqueData) + if err != nil { + return err + } + + if err := aliasBucket.Put(nodePub, []byte(node.Alias)); err != nil { + return err + } + + // With the alias bucket updated, we'll now update the index that + // tracks the time series of node updates. + var indexKey [8 + 33]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + copy(indexKey[8:], nodePub) + + // If there was already an old index entry for this node, then we'll + // delete the old one before we write the new entry. + if nodeBytes := nodeBucket.Get(nodePub); nodeBytes != nil { + // Extract out the old update time to we can reconstruct the + // prior index key to delete it from the index. + oldUpdateTime := nodeBytes[:8] + + var oldIndexKey [8 + 33]byte + copy(oldIndexKey[:8], oldUpdateTime) + copy(oldIndexKey[8:], nodePub) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return nodeBucket.Put(nodePub, b.Bytes()) +} + +func fetchLightningNode(nodeBucket *bbolt.Bucket, + nodePub []byte) (LightningNode, error) { + + nodeBytes := nodeBucket.Get(nodePub) + if nodeBytes == nil { + return LightningNode{}, ErrGraphNodeNotFound + } + + nodeReader := bytes.NewReader(nodeBytes) + return deserializeLightningNode(nodeReader) +} + +func deserializeLightningNode(r io.Reader) (LightningNode, error) { + var ( + node LightningNode + scratch [8]byte + err error + ) + + if _, err := r.Read(scratch[:]); err != nil { + return LightningNode{}, err + } + + unix := int64(byteOrder.Uint64(scratch[:])) + node.LastUpdate = time.Unix(unix, 0) + + if _, err := io.ReadFull(r, node.PubKeyBytes[:]); err != nil { + return LightningNode{}, err + } + + if _, err := r.Read(scratch[:2]); err != nil { + return LightningNode{}, err + } + + hasNodeAnn := byteOrder.Uint16(scratch[:2]) + if hasNodeAnn == 1 { + node.HaveNodeAnnouncement = true + } else { + node.HaveNodeAnnouncement = false + } + + // The rest of the data is optional, and will only be there if we got a node + // announcement for this node. + if !node.HaveNodeAnnouncement { + return node, nil + } + + // We did get a node announcement for this node, so we'll have the rest + // of the data available. + if err := binary.Read(r, byteOrder, &node.Color.R); err != nil { + return LightningNode{}, err + } + if err := binary.Read(r, byteOrder, &node.Color.G); err != nil { + return LightningNode{}, err + } + if err := binary.Read(r, byteOrder, &node.Color.B); err != nil { + return LightningNode{}, err + } + + node.Alias, err = wire.ReadVarString(r, 0) + if err != nil { + return LightningNode{}, err + } + + fv := lnwire.NewFeatureVector(nil, lnwire.GlobalFeatures) + err = fv.Decode(r) + if err != nil { + return LightningNode{}, err + } + node.Features = fv + + if _, err := r.Read(scratch[:2]); err != nil { + return LightningNode{}, err + } + numAddresses := int(byteOrder.Uint16(scratch[:2])) + + var addresses []net.Addr + for i := 0; i < numAddresses; i++ { + address, err := deserializeAddr(r) + if err != nil { + return LightningNode{}, err + } + addresses = append(addresses, address) + } + node.Addresses = addresses + + node.AuthSigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") + if err != nil { + return LightningNode{}, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the node as is. + node.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return LightningNode{}, err + } + + return node, nil +} + +func putChanEdgeInfo(edgeIndex *bbolt.Bucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error { + var b bytes.Buffer + + if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return err + } + + if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { + return err + } + + authProof := edgeInfo.AuthProof + var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte + if authProof != nil { + nodeSig1 = authProof.NodeSig1Bytes + nodeSig2 = authProof.NodeSig2Bytes + bitcoinSig1 = authProof.BitcoinSig1Bytes + bitcoinSig2 = authProof.BitcoinSig2Bytes + } + + if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { + return err + } + if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { + return err + } + + if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { + return err + } + if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { + return err + } + if _, err := b.Write(chanID[:]); err != nil { + return err + } + if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { + return err + } + + if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) + } + err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) + if err != nil { + return err + } + + return edgeIndex.Put(chanID[:], b.Bytes()) +} + +func fetchChanEdgeInfo(edgeIndex *bbolt.Bucket, + chanID []byte) (ChannelEdgeInfo, error) { + + edgeInfoBytes := edgeIndex.Get(chanID) + if edgeInfoBytes == nil { + return ChannelEdgeInfo{}, ErrEdgeNotFound + } + + edgeInfoReader := bytes.NewReader(edgeInfoBytes) + return deserializeChanEdgeInfo(edgeInfoReader) +} + +func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { + var ( + err error + edgeInfo ChannelEdgeInfo + ) + + if _, err := io.ReadFull(r, edgeInfo.NodeKey1Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.NodeKey2Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey1Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + if _, err := io.ReadFull(r, edgeInfo.BitcoinKey2Bytes[:]); err != nil { + return ChannelEdgeInfo{}, err + } + + edgeInfo.Features, err = wire.ReadVarBytes(r, 0, 900, "features") + if err != nil { + return ChannelEdgeInfo{}, err + } + + proof := &ChannelAuthProof{} + + proof.NodeSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.NodeSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.BitcoinSig1Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + proof.BitcoinSig2Bytes, err = wire.ReadVarBytes(r, 0, 80, "sigs") + if err != nil { + return ChannelEdgeInfo{}, err + } + + if !proof.IsEmpty() { + edgeInfo.AuthProof = proof + } + + edgeInfo.ChannelPoint = wire.OutPoint{} + if err := readOutpoint(r, &edgeInfo.ChannelPoint); err != nil { + return ChannelEdgeInfo{}, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.Capacity); err != nil { + return ChannelEdgeInfo{}, err + } + if err := binary.Read(r, byteOrder, &edgeInfo.ChannelID); err != nil { + return ChannelEdgeInfo{}, err + } + + if _, err := io.ReadFull(r, edgeInfo.ChainHash[:]); err != nil { + return ChannelEdgeInfo{}, err + } + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edgeInfo.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return ChannelEdgeInfo{}, err + } + + return edgeInfo, nil +} + +func putChanEdgePolicy(edges, nodes *bbolt.Bucket, edge *ChannelEdgePolicy, + from, to []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], edge.ChannelID) + + var b bytes.Buffer + if err := serializeChanEdgePolicy(&b, edge, to); err != nil { + return err + } + + // Before we write out the new edge, we'll create a new entry in the + // update index in order to keep it fresh. + updateUnix := uint64(edge.LastUpdate.Unix()) + var indexKey [8 + 8]byte + byteOrder.PutUint64(indexKey[:8], updateUnix) + byteOrder.PutUint64(indexKey[8:], edge.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + // If there was already an entry for this edge, then we'll need to + // delete the old one to ensure we don't leave around any after-images. + // An unknown policy value does not have a update time recorded, so + // it also does not need to be removed. + if edgeBytes := edges.Get(edgeKey[:]); edgeBytes != nil && + !bytes.Equal(edgeBytes[:], unknownPolicy) { + + // In order to delete the old entry, we'll need to obtain the + // *prior* update time in order to delete it. To do this, we'll + // need to deserialize the existing policy within the database + // (now outdated by the new one), and delete its corresponding + // entry within the update index. We'll ignore any + // ErrEdgePolicyOptionalFieldNotFound error, as we only need + // the channel ID and update time to delete the entry. + // TODO(halseth): get rid of these invalid policies in a + // migration. + oldEdgePolicy, err := deserializeChanEdgePolicy( + bytes.NewReader(edgeBytes), nodes, + ) + if err != nil && err != ErrEdgePolicyOptionalFieldNotFound { + return err + } + + oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) + + var oldIndexKey [8 + 8]byte + byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) + byteOrder.PutUint64(oldIndexKey[8:], edge.ChannelID) + + if err := updateIndex.Delete(oldIndexKey[:]); err != nil { + return err + } + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + updateEdgePolicyDisabledIndex( + edges, edge.ChannelID, + edge.ChannelFlags&lnwire.ChanUpdateDirection > 0, + edge.IsDisabled(), + ) + + return edges.Put(edgeKey[:], b.Bytes()[:]) +} + +// updateEdgePolicyDisabledIndex is used to update the disabledEdgePolicyIndex +// bucket by either add a new disabled ChannelEdgePolicy or remove an existing +// one. +// The direction represents the direction of the edge and disabled is used for +// deciding whether to remove or add an entry to the bucket. +// In general a channel is disabled if two entries for the same chanID exist +// in this bucket. +// Maintaining the bucket this way allows a fast retrieval of disabled +// channels, for example when prune is needed. +func updateEdgePolicyDisabledIndex(edges *bbolt.Bucket, chanID uint64, + direction bool, disabled bool) error { + + var disabledEdgeKey [8 + 1]byte + byteOrder.PutUint64(disabledEdgeKey[0:], chanID) + if direction { + disabledEdgeKey[8] = 1 + } + + disabledEdgePolicyIndex, err := edges.CreateBucketIfNotExists( + disabledEdgePolicyBucket, + ) + if err != nil { + return err + } + + if disabled { + return disabledEdgePolicyIndex.Put(disabledEdgeKey[:], []byte{}) + } + + return disabledEdgePolicyIndex.Delete(disabledEdgeKey[:]) +} + +// putChanEdgePolicyUnknown marks the edge policy as unknown +// in the edges bucket. +func putChanEdgePolicyUnknown(edges *bbolt.Bucket, channelID uint64, + from []byte) error { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], channelID) + + if edges.Get(edgeKey[:]) != nil { + return fmt.Errorf("Cannot write unknown policy for channel %v "+ + " when there is already a policy present", channelID) + } + + return edges.Put(edgeKey[:], unknownPolicy) +} + +func fetchChanEdgePolicy(edges *bbolt.Bucket, chanID []byte, + nodePub []byte, nodes *bbolt.Bucket) (*ChannelEdgePolicy, error) { + + var edgeKey [33 + 8]byte + copy(edgeKey[:], nodePub) + copy(edgeKey[33:], chanID[:]) + + edgeBytes := edges.Get(edgeKey[:]) + if edgeBytes == nil { + return nil, ErrEdgeNotFound + } + + // No need to deserialize unknown policy. + if bytes.Equal(edgeBytes[:], unknownPolicy) { + return nil, nil + } + + edgeReader := bytes.NewReader(edgeBytes) + + ep, err := deserializeChanEdgePolicy(edgeReader, nodes) + switch { + // If the db policy was missing an expected optional field, we return + // nil as if the policy was unknown. + case err == ErrEdgePolicyOptionalFieldNotFound: + return nil, nil + + case err != nil: + return nil, err + } + + return ep, nil +} + +func fetchChanEdgePolicies(edgeIndex *bbolt.Bucket, edges *bbolt.Bucket, + nodes *bbolt.Bucket, chanID []byte, + db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { + + edgeInfo := edgeIndex.Get(chanID) + if edgeInfo == nil { + return nil, nil, ErrEdgeNotFound + } + + // The first node is contained within the first half of the edge + // information. We only propagate the error here and below if it's + // something other than edge non-existence. + node1Pub := edgeInfo[:33] + edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) + if err != nil { + return nil, nil, err + } + + // As we may have a single direction of the edge but not the other, + // only fill in the database pointers if the edge is found. + if edge1 != nil { + edge1.db = db + edge1.Node.db = db + } + + // Similarly, the second node is contained within the latter + // half of the edge information. + node2Pub := edgeInfo[33:66] + edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) + if err != nil { + return nil, nil, err + } + + if edge2 != nil { + edge2.db = db + edge2.Node.db = db + } + + return edge1, edge2, nil +} + +func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, + to []byte) error { + + err := wire.WriteVarBytes(w, 0, edge.SigBytes) + if err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.ChannelID); err != nil { + return err + } + + var scratch [8]byte + updateUnix := uint64(edge.LastUpdate.Unix()) + byteOrder.PutUint64(scratch[:], updateUnix) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, edge.MessageFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.ChannelFlags); err != nil { + return err + } + if err := binary.Write(w, byteOrder, edge.TimeLockDelta); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.MinHTLC)); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.FeeBaseMSat)); err != nil { + return err + } + if err := binary.Write(w, byteOrder, uint64(edge.FeeProportionalMillionths)); err != nil { + return err + } + + if _, err := w.Write(to); err != nil { + return err + } + + // If the max_htlc field is present, we write it. To be compatible with + // older versions that wasn't aware of this field, we write it as part + // of the opaque data. + // TODO(halseth): clean up when moving to TLV. + var opaqueBuf bytes.Buffer + if edge.MessageFlags.HasMaxHtlc() { + err := binary.Write(&opaqueBuf, byteOrder, uint64(edge.MaxHTLC)) + if err != nil { + return err + } + } + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + if _, err := opaqueBuf.Write(edge.ExtraOpaqueData); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, opaqueBuf.Bytes()); err != nil { + return err + } + return nil +} + +func deserializeChanEdgePolicy(r io.Reader, + nodes *bbolt.Bucket) (*ChannelEdgePolicy, error) { + + edge := &ChannelEdgePolicy{} + + var err error + edge.SigBytes, err = wire.ReadVarBytes(r, 0, 80, "sig") + if err != nil { + return nil, err + } + + if err := binary.Read(r, byteOrder, &edge.ChannelID); err != nil { + return nil, err + } + + var scratch [8]byte + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + unix := int64(byteOrder.Uint64(scratch[:])) + edge.LastUpdate = time.Unix(unix, 0) + + if err := binary.Read(r, byteOrder, &edge.MessageFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.ChannelFlags); err != nil { + return nil, err + } + if err := binary.Read(r, byteOrder, &edge.TimeLockDelta); err != nil { + return nil, err + } + + var n uint64 + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.MinHTLC = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeBaseMSat = lnwire.MilliSatoshi(n) + + if err := binary.Read(r, byteOrder, &n); err != nil { + return nil, err + } + edge.FeeProportionalMillionths = lnwire.MilliSatoshi(n) + + var pub [33]byte + if _, err := r.Read(pub[:]); err != nil { + return nil, err + } + + node, err := fetchLightningNode(nodes, pub[:]) + if err != nil { + return nil, fmt.Errorf("unable to fetch node: %x, %v", + pub[:], err) + } + edge.Node = &node + + // We'll try and see if there are any opaque bytes left, if not, then + // we'll ignore the EOF error and return the edge as is. + edge.ExtraOpaqueData, err = wire.ReadVarBytes( + r, 0, MaxAllowedExtraOpaqueBytes, "blob", + ) + switch { + case err == io.ErrUnexpectedEOF: + case err == io.EOF: + case err != nil: + return nil, err + } + + // See if optional fields are present. + if edge.MessageFlags.HasMaxHtlc() { + // The max_htlc field should be at the beginning of the opaque + // bytes. + opq := edge.ExtraOpaqueData + + // If the max_htlc field is not present, it might be old data + // stored before this field was validated. We'll return the + // edge along with an error. + if len(opq) < 8 { + return edge, ErrEdgePolicyOptionalFieldNotFound + } + + maxHtlc := byteOrder.Uint64(opq[:8]) + edge.MaxHTLC = lnwire.MilliSatoshi(maxHtlc) + + // Exclude the parsed field from the rest of the opaque data. + edge.ExtraOpaqueData = opq[8:] + } + + return edge, nil +} diff --git a/channeldb/migration_01_to_11/graph_test.go b/channeldb/migration_01_to_11/graph_test.go new file mode 100644 index 00000000..00a8a000 --- /dev/null +++ b/channeldb/migration_01_to_11/graph_test.go @@ -0,0 +1,3197 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "fmt" + "image/color" + "math" + "math/big" + prand "math/rand" + "net" + "reflect" + "runtime" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + testAddr = &net.TCPAddr{IP: (net.IP)([]byte{0xA, 0x0, 0x0, 0x1}), + Port: 9000} + anotherAddr, _ = net.ResolveTCPAddr("tcp", + "[2001:db8:85a3:0:0:8a2e:370:7334]:80") + testAddrs = []net.Addr{testAddr, anotherAddr} + + testSig = &btcec.Signature{ + R: new(big.Int), + S: new(big.Int), + } + _, _ = testSig.R.SetString("63724406601629180062774974542967536251589935445068131219452686511677818569431", 10) + _, _ = testSig.S.SetString("18801056069249825825291287104931333862866033135609736119018462340006816851118", 10) + + testFeatures = lnwire.NewFeatureVector(nil, lnwire.GlobalFeatures) +) + +func createLightningNode(db *DB, priv *btcec.PrivateKey) (*LightningNode, error) { + updateTime := prand.Int63() + + pub := priv.PubKey().SerializeCompressed() + n := &LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: time.Unix(updateTime, 0), + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek" + string(pub[:]), + Features: testFeatures, + Addresses: testAddrs, + db: db, + } + copy(n.PubKeyBytes[:], priv.PubKey().SerializeCompressed()) + + return n, nil +} + +func createTestVertex(db *DB) (*LightningNode, error) { + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + return nil, err + } + + return createLightningNode(db, priv) +} + +func TestNodeInsertionAndDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test basic insertion/deletion for vertexes from the + // graph, so we'll create a test vertex to start with. + _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + node := &LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: testSig.Serialize(), + LastUpdate: time.Unix(1232342, 0), + Color: color.RGBA{1, 2, 3, 0}, + Alias: "kek", + Features: testFeatures, + Addresses: testAddrs, + ExtraOpaqueData: []byte("extra new data"), + db: db, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + // First, insert the node into the graph DB. This should succeed + // without any errors. + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, fetch the node from the database to ensure everything was + // serialized properly. + dbNode, err := graph.FetchLightningNode(testPub) + if err != nil { + t.Fatalf("unable to locate node: %v", err) + } + + if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + t.Fatalf("unable to query for node: %v", err) + } else if !exists { + t.Fatalf("node should be found but wasn't") + } + + // The two nodes should match exactly! + if err := compareNodes(node, dbNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // Next, delete the node from the graph, this should purge all data + // related to the node. + if err := graph.DeleteLightningNode(testPub); err != nil { + t.Fatalf("unable to delete node; %v", err) + } + + // Finally, attempt to fetch the node again. This should fail as the + // node should have been deleted from the database. + _, err = graph.FetchLightningNode(testPub) + if err != ErrGraphNodeNotFound { + t.Fatalf("fetch after delete should fail!") + } +} + +// TestPartialNode checks that we can add and retrieve a LightningNode where +// where only the pubkey is known to the database. +func TestPartialNode(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We want to be able to insert nodes into the graph that only has the + // PubKey set. + _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + node := &LightningNode{ + HaveNodeAnnouncement: false, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, fetch the node from the database to ensure everything was + // serialized properly. + dbNode, err := graph.FetchLightningNode(testPub) + if err != nil { + t.Fatalf("unable to locate node: %v", err) + } + + if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { + t.Fatalf("unable to query for node: %v", err) + } else if !exists { + t.Fatalf("node should be found but wasn't") + } + + // The two nodes should match exactly! (with default values for + // LastUpdate and db set to satisfy compareNodes()) + node = &LightningNode{ + HaveNodeAnnouncement: false, + LastUpdate: time.Unix(0, 0), + db: db, + } + copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) + + if err := compareNodes(node, dbNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // Next, delete the node from the graph, this should purge all data + // related to the node. + if err := graph.DeleteLightningNode(testPub); err != nil { + t.Fatalf("unable to delete node: %v", err) + } + + // Finally, attempt to fetch the node again. This should fail as the + // node should have been deleted from the database. + _, err = graph.FetchLightningNode(testPub) + if err != ErrGraphNodeNotFound { + t.Fatalf("fetch after delete should fail!") + } +} + +func TestAliasLookup(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the alias index within the database, so first + // create a new test node. + testNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // Add the node to the graph's database, this should also insert an + // entry into the alias index for this node. + if err := graph.AddLightningNode(testNode); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, attempt to lookup the alias. The alias should exactly match + // the one which the test node was assigned. + nodePub, err := testNode.PubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + } + dbAlias, err := graph.LookupAlias(nodePub) + if err != nil { + t.Fatalf("unable to find alias: %v", err) + } + if dbAlias != testNode.Alias { + t.Fatalf("aliases don't match, expected %v got %v", + testNode.Alias, dbAlias) + } + + // Ensure that looking up a non-existent alias results in an error. + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + nodePub, err = node.PubKey() + if err != nil { + t.Fatalf("unable to generate pubkey: %v", err) + } + _, err = graph.LookupAlias(nodePub) + if err != ErrNodeAliasNotFound { + t.Fatalf("alias lookup should fail for non-existent pubkey") + } +} + +func TestSourceNode(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the setting/getting of the source node, so we + // first create a fake node to use within the test. + testNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // Attempt to fetch the source node, this should return an error as the + // source node hasn't yet been set. + if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet { + t.Fatalf("source node shouldn't be set in new graph") + } + + // Set the source the source node, this should insert the node into the + // database in a special way indicating it's the source node. + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // Retrieve the source node from the database, it should exactly match + // the one we set above. + sourceNode, err := graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + if err := compareNodes(testNode, sourceNode); err != nil { + t.Fatalf("nodes don't match: %v", err) + } +} + +func TestEdgeInsertionDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the insertion/deletion of edges, so we create two + // vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + chanID := uint64(prand.Int63()) + outpoint := wire.OutPoint{ + Hash: rev, + Index: 9, + } + + // Add the new edge to the database, this should proceed without any + // errors. + node1Pub, err := node1.PubKey() + if err != nil { + t.Fatalf("unable to generate node key: %v", err) + } + node2Pub, err := node2.PubKey() + if err != nil { + t.Fatalf("unable to generate node key: %v", err) + } + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure that both policies are returned as unknown (nil). + _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel edge") + } + if e1 != nil || e2 != nil { + t.Fatalf("channel edges not unknown") + } + + // Next, attempt to delete the edge from the database, again this + // should proceed without any issues. + if err := graph.DeleteChannelEdges(chanID); err != nil { + t.Fatalf("unable to delete edge: %v", err) + } + + // Ensure that any query attempts to lookup the delete channel edge are + // properly deleted. + if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil { + t.Fatalf("channel edge not deleted") + } + if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil { + t.Fatalf("channel edge not deleted") + } + isZombie, _, _ := graph.IsZombieEdge(chanID) + if !isZombie { + t.Fatal("channel edge not marked as zombie") + } + + // Finally, attempt to delete a (now) non-existent edge within the + // database, this should result in an error. + err = graph.DeleteChannelEdges(chanID) + if err != ErrEdgeNotFound { + t.Fatalf("deleting a non-existent edge should fail!") + } +} + +func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, + node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { + + shortChanID := lnwire.ShortChannelID{ + BlockHeight: height, + TxIndex: txIndex, + TxPosition: txPosition, + } + outpoint := wire.OutPoint{ + Hash: rev, + Index: outPointIndex, + } + + node1Pub, _ := node1.PubKey() + node2Pub, _ := node2.PubKey() + edgeInfo := ChannelEdgeInfo{ + ChannelID: shortChanID.ToUint64(), + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 9000, + } + + copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) + copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) + + return edgeInfo, shortChanID +} + +// TestDisconnectBlockAtHeight checks that the pruned state of the channel +// database is what we expect after calling DisconnectBlockAtHeight. +func TestDisconnectBlockAtHeight(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // We'd like to test the insertion/deletion of edges, so we create two + // vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + var spendOutputs []*wire.OutPoint + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) + + // Prune the graph a few times to make sure we have entries in the + // prune log. + _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + var blockHash2 chainhash.Hash + copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) + + _, err = graph.PruneGraph(spendOutputs, &blockHash2, 156) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // We'll create 3 almost identical edges, so first create a helper + // method containing all logic for doing so. + + // Create an edge which has its block height at 156. + height := uint32(156) + edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) + + // Create an edge with block height 157. We give it + // maximum values for tx index and position, to make + // sure our database range scan get edges from the + // entire range. + edgeInfo2, _ := createEdge( + height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, + node1, node2, + ) + + // Create a third edge, this with a block height of 155. + edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) + + // Now add all these new edges to the database. + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + if err := graph.AddChannelEdge(&edgeInfo3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Call DisconnectBlockAtHeight, which should prune every channel + // that has a funding height of 'height' or greater. + removed, err := graph.DisconnectBlockAtHeight(uint32(height)) + if err != nil { + t.Fatalf("unable to prune %v", err) + } + + // The two edges should have been removed. + if len(removed) != 2 { + t.Fatalf("expected two edges to be removed from graph, "+ + "only %d were", len(removed)) + } + if removed[0].ChannelID != edgeInfo.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + if removed[1].ChannelID != edgeInfo2.ChannelID { + t.Fatalf("expected edge to be removed from graph") + } + + // The two first edges should be removed from the db. + _, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge1 was not pruned from the graph") + } + if isZombie { + t.Fatal("reorged edge1 should not be marked as zombie") + } + _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if has { + t.Fatalf("edge2 was not pruned from the graph") + } + if isZombie { + t.Fatal("reorged edge2 should not be marked as zombie") + } + + // Edge 3 should not be removed. + _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if !has { + t.Fatalf("edge3 was pruned from the graph") + } + if isZombie { + t.Fatal("edge3 was marked as zombie") + } + + // PruneTip should be set to the blockHash we specified for the block + // at height 155. + hash, h, err := graph.PruneTip() + if err != nil { + t.Fatalf("unable to get prune tip: %v", err) + } + if !blockHash.IsEqual(hash) { + t.Fatalf("expected best block to be %x, was %x", blockHash, hash) + } + if h != height-1 { + t.Fatalf("expected best block height to be %d, was %d", height-1, h) + } +} + +func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, + e2 *ChannelEdgeInfo) { + + if e1.ChannelID != e2.ChannelID { + t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, + e2.ChannelID) + } + + if e1.ChainHash != e2.ChainHash { + t.Fatalf("chain hashes don't match: %v vs %v", e1.ChainHash, + e2.ChainHash) + } + + if !bytes.Equal(e1.NodeKey1Bytes[:], e2.NodeKey1Bytes[:]) { + t.Fatalf("nodekey1 doesn't match") + } + if !bytes.Equal(e1.NodeKey2Bytes[:], e2.NodeKey2Bytes[:]) { + t.Fatalf("nodekey2 doesn't match") + } + if !bytes.Equal(e1.BitcoinKey1Bytes[:], e2.BitcoinKey1Bytes[:]) { + t.Fatalf("bitcoinkey1 doesn't match") + } + if !bytes.Equal(e1.BitcoinKey2Bytes[:], e2.BitcoinKey2Bytes[:]) { + t.Fatalf("bitcoinkey2 doesn't match") + } + + if !bytes.Equal(e1.Features, e2.Features) { + t.Fatalf("features doesn't match: %x vs %x", e1.Features, + e2.Features) + } + + if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) { + t.Fatalf("nodesig1 doesn't match: %v vs %v", + spew.Sdump(e1.AuthProof.NodeSig1Bytes), + spew.Sdump(e2.AuthProof.NodeSig1Bytes)) + } + if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) { + t.Fatalf("nodesig2 doesn't match") + } + if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) { + t.Fatalf("bitcoinsig1 doesn't match") + } + if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) { + t.Fatalf("bitcoinsig2 doesn't match") + } + + if e1.ChannelPoint != e2.ChannelPoint { + t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint, + e2.ChannelPoint) + } + + if e1.Capacity != e2.Capacity { + t.Fatalf("capacity doesn't match: %v vs %v", e1.Capacity, + e2.Capacity) + } + + if !bytes.Equal(e1.ExtraOpaqueData, e2.ExtraOpaqueData) { + t.Fatalf("extra data doesn't match: %v vs %v", + e2.ExtraOpaqueData, e2.ExtraOpaqueData) + } +} + +func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, + *ChannelEdgePolicy, *ChannelEdgePolicy) { + + var ( + firstNode *LightningNode + secondNode *LightningNode + ) + if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { + firstNode = node1 + secondNode = node2 + } else { + firstNode = node2 + secondNode = node1 + } + + // In addition to the fake vertexes we create some fake channel + // identifiers. + chanID := uint64(prand.Int63()) + outpoint := wire.OutPoint{ + Hash: rev, + Index: 9, + } + + // Add the new edge to the database, this should proceed without any + // errors. + edgeInfo := &ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: outpoint, + Capacity: 1000, + ExtraOpaqueData: []byte("new unknown feature"), + } + copy(edgeInfo.NodeKey1Bytes[:], firstNode.PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], secondNode.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) + + edge1 := &ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(433453, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 3452352, + Node: secondNode, + ExtraOpaqueData: []byte("new unknown feature2"), + db: db, + } + edge2 := &ChannelEdgePolicy{ + SigBytes: testSig.Serialize(), + ChannelID: chanID, + LastUpdate: time.Unix(124234, 0), + MessageFlags: 1, + ChannelFlags: 1, + TimeLockDelta: 99, + MinHTLC: 2342135, + MaxHTLC: 13928598, + FeeBaseMSat: 4352345, + FeeProportionalMillionths: 90392423, + Node: firstNode, + ExtraOpaqueData: []byte("new unknown feature1"), + db: db, + } + + return edgeInfo, edge1, edge2 +} + +func TestEdgeInfoUpdates(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the update of edges inserted into the database, so + // we create two vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create an edge and add it to the db. + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + + // Make sure inserting the policy at this point, before the edge info + // is added, will fail. + if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { + t.Fatalf("expected ErrEdgeNotFound, got: %v", err) + } + + // Add the edge info. + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanID := edgeInfo.ChannelID + outpoint := edgeInfo.ChannelPoint + + // Next, insert both edge policies into the database, they should both + // be inserted without any issues. + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Check for existence of the edge within the database, it should be + // found. + _, _, found, isZombie, err := graph.HasChannelEdge(chanID) + if err != nil { + t.Fatalf("unable to query for edge: %v", err) + } + if !found { + t.Fatalf("graph should have of inserted edge") + } + if isZombie { + t.Fatal("live edge should not be marked as zombie") + } + + // We should also be able to retrieve the channelID only knowing the + // channel point of the channel. + dbChanID, err := graph.ChannelID(&outpoint) + if err != nil { + t.Fatalf("unable to retrieve channel ID: %v", err) + } + if dbChanID != chanID { + t.Fatalf("chan ID's mismatch, expected %v got %v", dbChanID, + chanID) + } + + // With the edges inserted, perform some queries to ensure that they've + // been inserted properly. + dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) + + // Next, attempt to query the channel edges according to the outpoint + // of the channel. + dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) +} + +func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { + update := prand.Int63() + + return newEdgePolicy(chanID, op, db, update) +} + +func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, + updateTime int64) *ChannelEdgePolicy { + + return &ChannelEdgePolicy{ + ChannelID: chanID, + LastUpdate: time.Unix(updateTime, 0), + MessageFlags: 1, + ChannelFlags: 0, + TimeLockDelta: uint16(prand.Int63()), + MinHTLC: lnwire.MilliSatoshi(prand.Int63()), + MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), + FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), + FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), + db: db, + } +} + +func TestGraphTraversal(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test some of the graph traversal capabilities within + // the DB, so we'll create a series of fake nodes to insert into the + // graph. + const numNodes = 20 + nodes := make([]*LightningNode, numNodes) + nodeIndex := map[string]struct{}{} + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + nodes[i] = node + nodeIndex[node.Alias] = struct{}{} + } + + // Add each of the nodes into the graph, they should be inserted + // without error. + for _, node := range nodes { + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + } + + // Iterate over each node as returned by the graph, if all nodes are + // reached, then the map created above should be empty. + err = graph.ForEachNode(nil, func(_ *bbolt.Tx, node *LightningNode) error { + delete(nodeIndex, node.Alias) + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if len(nodeIndex) != 0 { + t.Fatalf("all nodes not reached within ForEach") + } + + // Determine which node is "smaller", we'll need this in order to + // properly create the edges for the graph. + var firstNode, secondNode *LightningNode + if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { + firstNode = nodes[0] + secondNode = nodes[1] + } else { + firstNode = nodes[0] + secondNode = nodes[1] + } + + // Create 5 channels between the first two nodes we generated above. + const numChannels = 5 + chanIndex := map[uint64]struct{}{} + for i := 0; i < numChannels; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64(i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) + err := graph.AddChannelEdge(&edgeInfo) + if err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create and add an edge with random data that points from + // node1 -> node2. + edge := randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 0 + edge.Node = secondNode + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Create another random edge that points from node2 -> node1 + // this time. + edge = randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 1 + edge.Node = firstNode + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + chanIndex[chanID] = struct{}{} + } + + // Iterate through all the known channels within the graph DB, once + // again if the map is empty that indicates that all edges have + // properly been reached. + err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy, + _ *ChannelEdgePolicy) error { + + delete(chanIndex, ei.ChannelID) + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if len(chanIndex) != 0 { + t.Fatalf("all edges not reached within ForEach") + } + + // Finally, we want to test the ability to iterate over all the + // outgoing channels for a particular node. + numNodeChans := 0 + err = firstNode.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, + outEdge, inEdge *ChannelEdgePolicy) error { + + // All channels between first and second node should have fully + // (both sides) specified policies. + if inEdge == nil || outEdge == nil { + return fmt.Errorf("channel policy not present") + } + + // Each should indicate that it's outgoing (pointed + // towards the second node). + if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } + + // The incoming edge should also indicate that it's pointing to + // the origin node. + if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { + return fmt.Errorf("wrong outgoing edge") + } + + numNodeChans++ + return nil + }) + if err != nil { + t.Fatalf("for each failure: %v", err) + } + if numNodeChans != numChannels { + t.Fatalf("all edges for node not reached within ForEach: "+ + "expected %v, got %v", numChannels, numNodeChans) + } +} + +func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, + blockHeight uint32) { + + pruneHash, pruneHeight, err := graph.PruneTip() + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to fetch prune tip: %v", line, err) + } + if !bytes.Equal(blockHash[:], pruneHash[:]) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line: %v, prune tips don't match, expected %x got %x", + line, blockHash, pruneHash) + } + if pruneHeight != blockHeight { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: prune heights don't match, expected %v "+ + "got %v", line, blockHeight, pruneHeight) + } +} + +func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { + numChans := 0 + if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy, + *ChannelEdgePolicy) error { + + numChans++ + return nil + }); err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to scan channels: %v", line, err) + } + if numChans != n { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: expected %v chans instead have %v", line, + n, numChans) + } +} + +func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { + numNodes := 0 + err := graph.ForEachNode(nil, func(_ *bbolt.Tx, _ *LightningNode) error { + numNodes++ + return nil + }) + if err != nil { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: unable to scan nodes: %v", line, err) + } + + if numNodes != n { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) + } +} + +func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views don't match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op.OutPoint] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[op.OutPoint]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) + } + } +} + +func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { + if len(a) != len(b) { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chan views don't match", line) + } + + chanViewSet := make(map[wire.OutPoint]struct{}) + for _, op := range a { + chanViewSet[op.OutPoint] = struct{}{} + } + + for _, op := range b { + if _, ok := chanViewSet[*op]; !ok { + _, _, line, _ := runtime.Caller(1) + t.Fatalf("line %v: chanPoint(%v) not found in first "+ + "view", line, op) + } + } +} + +func TestGraphPruning(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // As initial set up for the test, we'll create a graph with 5 vertexes + // and enough edges to create a fully connected graph. The graph will + // be rather simple, representing a straight line. + const numNodes = 5 + graphNodes := make([]*LightningNode, numNodes) + for i := 0; i < numNodes; i++ { + node, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + graphNodes[i] = node + } + + // With the vertexes created, we'll next create a series of channels + // between them. + channelPoints := make([]*wire.OutPoint, 0, numNodes-1) + edgePoints := make([]EdgePoint, 0, numNodes-1) + for i := 0; i < numNodes-1; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + chanID := uint64(i + 1) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channelPoints = append(channelPoints, &op) + + edgeInfo := ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: key, + AuthProof: &ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + }, + ChannelPoint: op, + Capacity: 1000, + } + copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) + copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) + copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + pkScript, err := genMultiSigP2WSH( + edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], + ) + if err != nil { + t.Fatalf("unable to gen multi-sig p2wsh: %v", err) + } + edgePoints = append(edgePoints, EdgePoint{ + FundingPkScript: pkScript, + OutPoint: op, + }) + + // Create and add an edge with random data that points from + // node_i -> node_i+1 + edge := randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 0 + edge.Node = graphNodes[i] + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Create another random edge that points from node_i+1 -> + // node_i this time. + edge = randEdgePolicy(chanID, op, db) + edge.ChannelFlags = 1 + edge.Node = graphNodes[i] + edge.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + } + + // With all the channel points added, we'll consult the graph to ensure + // it has the same channel view as the one we just constructed. + channelView, err := graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqual(t, channelView, edgePoints) + + // Now with our test graph created, we can test the pruning + // capabilities of the channel graph. + + // First we create a mock block that ends up closing the first two + // channels. + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) + blockHeight := uint32(1) + block := channelPoints[:2] + prunedChans, err := graph.PruneGraph(block, &blockHash, blockHeight) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + if len(prunedChans) != 2 { + t.Fatalf("incorrect number of channels pruned: "+ + "expected %v, got %v", 2, prunedChans) + } + + // Now ensure that the prune tip has been updated. + assertPruneTip(t, graph, &blockHash, blockHeight) + + // Count up the number of channels known within the graph, only 2 + // should be remaining. + assertNumChans(t, graph, 2) + + // Those channels should also be missing from the channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + assertChanViewEqualChanPoints(t, channelView, channelPoints[2:]) + + // Next we'll create a block that doesn't close any channels within the + // graph to test the negative error case. + fakeHash := sha256.Sum256([]byte("test prune")) + nonChannel := &wire.OutPoint{ + Hash: fakeHash, + Index: 9, + } + blockHash = sha256.Sum256(blockHash[:]) + blockHeight = 2 + prunedChans, err = graph.PruneGraph( + []*wire.OutPoint{nonChannel}, &blockHash, blockHeight, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // No channels should have been detected as pruned. + if len(prunedChans) != 0 { + t.Fatalf("channels were pruned but shouldn't have been") + } + + // Once again, the prune tip should have been updated. We should still + // see both channels and their participants, along with the source node. + assertPruneTip(t, graph, &blockHash, blockHeight) + assertNumChans(t, graph, 2) + assertNumNodes(t, graph, 4) + + // Finally, create a block that prunes the remainder of the channels + // from the graph. + blockHash = sha256.Sum256(blockHash[:]) + blockHeight = 3 + prunedChans, err = graph.PruneGraph( + channelPoints[2:], &blockHash, blockHeight, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // The remainder of the channels should have been pruned from the + // graph. + if len(prunedChans) != 2 { + t.Fatalf("incorrect number of channels pruned: "+ + "expected %v, got %v", 2, len(prunedChans)) + } + + // The prune tip should be updated, no channels should be found, and + // only the source node should remain within the current graph. + assertPruneTip(t, graph, &blockHash, blockHeight) + assertNumChans(t, graph, 0) + assertNumNodes(t, graph, 1) + + // Finally, the channel view at this point in the graph should now be + // completely empty. Those channels should also be missing from the + // channel view. + channelView, err = graph.ChannelView() + if err != nil { + t.Fatalf("unable to get graph channel view: %v", err) + } + if len(channelView) != 0 { + t.Fatalf("channel view should be empty, instead have: %v", + channelView) + } +} + +// TestHighestChanID tests that we're able to properly retrieve the highest +// known channel ID in the database. +func TestHighestChanID(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we don't yet have any channels in the database, then we should + // get a channel ID of zero if we ask for the highest channel ID. + bestID, err := graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + if bestID != 0 { + t.Fatalf("best ID w/ no chan should be zero, is instead: %v", + bestID) + } + + // Next, we'll insert two channels into the database, with each channel + // connecting the same two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // The first channel with be at height 10, while the other will be at + // height 100. + edge1, _ := createEdge(10, 0, 0, 0, node1, node2) + edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) + + if err := graph.AddChannelEdge(&edge1); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + if err := graph.AddChannelEdge(&edge2); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Now that the edges has been inserted, we'll query for the highest + // known channel ID in the database. + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID2.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID2.ToUint64(), bestID) + } + + // If we add another edge, then the current best chan ID should be + // updated as well. + edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edge3); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + bestID, err = graph.HighestChanID() + if err != nil { + t.Fatalf("unable to get highest ID: %v", err) + } + + if bestID != chanID3.ToUint64() { + t.Fatalf("expected %v got %v for best chan ID: ", + chanID3.ToUint64(), bestID) + } +} + +// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known +// channel updates within a specific time horizon. It also tests that upon +// insertion of a new edge, the edge update index is updated properly. +func TestChanUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we issue an arbitrary query before any channel updates are + // inserted in the database, we should get zero results. + chanUpdates, err := graph.ChanUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to updates for updates: %v", err) + } + if len(chanUpdates) != 0 { + t.Fatalf("expected 0 chan updates, instead got %v", + len(chanUpdates)) + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now create 10 channels between the two nodes, with update + // times 10 seconds after each other. + const numChans = 10 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + edge1UpdateTime := endTime + edge2UpdateTime := edge1UpdateTime.Add(time.Second) + endTime = endTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), + ) + edge1.ChannelFlags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), + ) + edge2.ChannelFlags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + } + + // With our channels loaded, we'll now start our series of queries. + queryCases := []struct { + start time.Time + end time.Time + + resp []ChannelEdge + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we query for the start time, and 10 seconds directly + // after it, we should only get a single update, that first + // one. + { + start: time.Unix(1234, 0), + end: startTime.Add(time.Second * 10), + + resp: []ChannelEdge{edges[0]}, + }, + + // If we add 10 seconds past the first update, and then + // subtract 10 from the last update, then we should only get + // the 8 edges in the middle. + { + start: startTime.Add(time.Second * 10), + end: endTime.Add(-time.Second * 10), + + resp: edges[1:9], + }, + + // If we use the start and end time as is, we should get the + // entire range. + { + start: startTime, + end: endTime, + + resp: edges, + }, + } + for _, queryCase := range queryCases { + resp, err := graph.ChanUpdatesInHorizon( + queryCase.start, queryCase.end, + ) + if err != nil { + t.Fatalf("unable to query for updates: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v chans, got %v chans", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + chanExp := queryCase.resp[i] + chanRet := resp[i] + + assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) + + err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) + if err != nil { + t.Fatal(err) + } + compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve +// the most recent node updates within a particular time horizon. +func TestNodeUpdatesInHorizon(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + startTime := time.Unix(1234, 0) + endTime := startTime + + // If we issue an arbitrary query before we insert any nodes into the + // database, then we shouldn't get any results back. + nodeUpdates, err := graph.NodeUpdatesInHorizon( + time.Unix(999, 0), time.Unix(9999, 0), + ) + if err != nil { + t.Fatalf("unable to query for node updates: %v", err) + } + if len(nodeUpdates) != 0 { + t.Fatalf("expected 0 node updates, instead got %v", + len(nodeUpdates)) + } + + // We'll create 10 node announcements, each with an update timestamp 10 + // seconds after the other. + const numNodes = 10 + nodeAnns := make([]LightningNode, 0, numNodes) + for i := 0; i < numNodes; i++ { + nodeAnn, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // The node ann will use the current end time as its last + // update them, then we'll add 10 seconds in order to create + // the proper update time for the next node announcement. + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + nodeAnn.LastUpdate = updateTime + + nodeAnns = append(nodeAnns, *nodeAnn) + + if err := graph.AddLightningNode(nodeAnn); err != nil { + t.Fatalf("unable to add lightning node: %v", err) + } + } + + queryCases := []struct { + start time.Time + end time.Time + + resp []LightningNode + }{ + // If we query for a time range that's strictly below our set + // of updates, then we'll get an empty result back. + { + start: time.Unix(100, 0), + end: time.Unix(200, 0), + }, + + // If we query for a time range that's well beyond our set of + // updates, we should get an empty set of results back. + { + start: time.Unix(99999, 0), + end: time.Unix(999999, 0), + }, + + // If we skip he first time epoch with out start time, then we + // should get back every now but the first. + { + start: startTime.Add(time.Second * 10), + end: endTime, + + resp: nodeAnns[1:], + }, + + // If we query for the range as is, we should get all 10 + // announcements back. + { + start: startTime, + end: endTime, + + resp: nodeAnns, + }, + + // If we reduce the ending time by 10 seconds, then we should + // get all but the last node we inserted. + { + start: startTime, + end: endTime.Add(-time.Second * 10), + + resp: nodeAnns[:9], + }, + } + for _, queryCase := range queryCases { + resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) + if err != nil { + t.Fatalf("unable to query for nodes: %v", err) + } + + if len(resp) != len(queryCase.resp) { + t.Fatalf("expected %v nodes, got %v nodes", + len(queryCase.resp), len(resp)) + + } + + for i := 0; i < len(resp); i++ { + err := compareNodes(&queryCase.resp[i], &resp[i]) + if err != nil { + t.Fatal(err) + } + } + } +} + +// TestFilterKnownChanIDs tests that we're able to properly perform the set +// differences of an incoming set of channel ID's, and those that we already +// know of on disk. +func TestFilterKnownChanIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // If we try to filter out a set of channel ID's before we even know of + // any channels, then we should get the entire set back. + preChanIDs := []uint64{1, 2, 3, 4} + filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + if !reflect.DeepEqual(preChanIDs, filteredIDs) { + t.Fatalf("chan IDs shouldn't have been filtered!") + } + + // We'll start by creating two nodes which will seed our test graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Next, we'll add 5 channel ID's to the graph, each of them having a + // block height 10 blocks after the previous. + const numChans = 5 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + } + + const numZombies = 5 + zombieIDs := make([]uint64, 0, numZombies) + for i := 0; i < numZombies; i++ { + channel, chanID := createEdge( + uint32(i*10+1), 0, 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + err := graph.DeleteChannelEdges(channel.ChannelID) + if err != nil { + t.Fatalf("unable to mark edge zombie: %v", err) + } + + zombieIDs = append(zombieIDs, chanID.ToUint64()) + } + + queryCases := []struct { + queryIDs []uint64 + + resp []uint64 + }{ + // If we attempt to filter out all chanIDs we know of, the + // response should be the empty set. + { + queryIDs: chanIDs, + }, + // If we attempt to filter out all zombies that we know of, the + // response should be the empty set. + { + queryIDs: zombieIDs, + }, + + // If we query for a set of ID's that we didn't insert, we + // should get the same set back. + { + queryIDs: []uint64{99, 100}, + resp: []uint64{99, 100}, + }, + + // If we query for a super-set of our the chan ID's inserted, + // we should only get those new chanIDs back. + { + queryIDs: append(chanIDs, []uint64{99, 101}...), + resp: []uint64{99, 101}, + }, + } + + for _, queryCase := range queryCases { + resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) + if err != nil { + t.Fatalf("unable to filter chan IDs: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), + spew.Sdump(resp)) + } + } +} + +// TestFilterChannelRange tests that we're able to properly retrieve the full +// set of short channel ID's for a given block range. +func TestFilterChannelRange(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // If we try to filter a channel range before we have any channels + // inserted, we should get an empty slice of results. + resp, err := graph.FilterChannelRange(10, 100) + if err != nil { + t.Fatalf("unable to filter channels: %v", err) + } + if len(resp) != 0 { + t.Fatalf("expected zero chans, instead got %v", len(resp)) + } + + // To start, we'll create a set of channels, each mined in a block 10 + // blocks after the prior one. + startHeight := uint32(100) + endHeight := startHeight + const numChans = 10 + chanIDs := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + chanHeight := endHeight + channel, chanID := createEdge( + uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanIDs = append(chanIDs, chanID.ToUint64()) + + endHeight += 10 + } + + // With our channels inserted, we'll construct a series of queries that + // we'll execute below in order to exercise the features of the + // FilterKnownChanIDs method. + queryCases := []struct { + startHeight uint32 + endHeight uint32 + + resp []uint64 + }{ + // If we query for the entire range, then we should get the same + // set of short channel IDs back. + { + startHeight: startHeight, + endHeight: endHeight, + + resp: chanIDs, + }, + + // If we query for a range of channels right before our range, we + // shouldn't get any results back. + { + startHeight: 0, + endHeight: 10, + }, + + // If we only query for the last height (range wise), we should + // only get that last channel. + { + startHeight: endHeight - 10, + endHeight: endHeight - 10, + + resp: chanIDs[9:], + }, + + // If we query for just the first height, we should only get a + // single channel back (the first one). + { + startHeight: startHeight, + endHeight: startHeight, + + resp: chanIDs[:1], + }, + } + for i, queryCase := range queryCases { + resp, err := graph.FilterChannelRange( + queryCase.startHeight, queryCase.endHeight, + ) + if err != nil { + t.Fatalf("unable to issue range query: %v", err) + } + + if !reflect.DeepEqual(resp, queryCase.resp) { + t.Fatalf("case #%v: expected %v, got %v", i, + queryCase.resp, resp) + } + } +} + +// TestFetchChanInfos tests that we're able to properly retrieve the full set +// of ChannelEdge structs for a given set of short channel ID's. +func TestFetchChanInfos(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll make 5 test channels, ensuring we keep track of which channel + // ID corresponds to a particular ChannelEdge. + const numChans = 5 + startTime := time.Unix(1234, 0) + endTime := startTime + edges := make([]ChannelEdge, 0, numChans) + edgeQuery := make([]uint64, 0, numChans) + for i := 0; i < numChans; i++ { + txHash := sha256.Sum256([]byte{byte(i)}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(i*10), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + updateTime := endTime + endTime = updateTime.Add(time.Second * 10) + + edge1 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge1.ChannelFlags = 0 + edge1.Node = node2 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edge2.ChannelFlags = 1 + edge2.Node = node1 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edges = append(edges, ChannelEdge{ + Info: &channel, + Policy1: edge1, + Policy2: edge2, + }) + + edgeQuery = append(edgeQuery, chanID.ToUint64()) + } + + // Add an additional edge that does not exist. The query should skip + // this channel and return only infos for the edges that exist. + edgeQuery = append(edgeQuery, 500) + + // Add an another edge to the query that has been marked as a zombie + // edge. The query should also skip this channel. + zombieChan, zombieChanID := createEdge( + 666, 0, 0, 0, node1, node2, + ) + if err := graph.AddChannelEdge(&zombieChan); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + err = graph.DeleteChannelEdges(zombieChan.ChannelID) + if err != nil { + t.Fatalf("unable to delete and mark edge zombie: %v", err) + } + edgeQuery = append(edgeQuery, zombieChanID.ToUint64()) + + // We'll now attempt to query for the range of channel ID's we just + // inserted into the database. We should get the exact same set of + // edges back. + resp, err := graph.FetchChanInfos(edgeQuery) + if err != nil { + t.Fatalf("unable to fetch chan edges: %v", err) + } + if len(resp) != len(edges) { + t.Fatalf("expected %v edges, instead got %v", len(edges), + len(resp)) + } + + for i := 0; i < len(resp); i++ { + err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) + if err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) + } +} + +// TestIncompleteChannelPolicies tests that a channel that only has a policy +// specified on one end is properly returned in ForEachChannel calls from +// both sides. +func TestIncompleteChannelPolicies(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // Create two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create channel between nodes. + txHash := sha256.Sum256([]byte{0}) + op := wire.OutPoint{ + Hash: txHash, + Index: 0, + } + + channel, chanID := createEdge( + uint32(0), 0, 0, 0, node1, node2, + ) + + if err := graph.AddChannelEdge(&channel); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure that channel is reported with unknown policies. + + checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { + calls := 0 + node.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, + outEdge, inEdge *ChannelEdgePolicy) error { + + if !expectedOut && outEdge != nil { + t.Fatalf("Expected no outgoing policy") + } + + if expectedOut && outEdge == nil { + t.Fatalf("Expected an outgoing policy") + } + + if !expectedIn && inEdge != nil { + t.Fatalf("Expected no incoming policy") + } + + if expectedIn && inEdge == nil { + t.Fatalf("Expected an incoming policy") + } + + calls++ + + return nil + }) + + if calls != 1 { + t.Fatalf("Expected only one callback call") + } + } + + checkPolicies(node2, false, false) + + // Only create an edge policy for node1 and leave the policy for node2 + // unknown. + updateTime := time.Unix(1234, 0) + + edgePolicy := newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edgePolicy.ChannelFlags = 0 + edgePolicy.Node = node2 + edgePolicy.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + checkPolicies(node1, false, true) + checkPolicies(node2, true, false) + + // Create second policy and assert that both policies are reported + // as present. + edgePolicy = newEdgePolicy( + chanID.ToUint64(), op, db, updateTime.Unix(), + ) + edgePolicy.ChannelFlags = 1 + edgePolicy.Node = node1 + edgePolicy.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + checkPolicies(node1, true, true) + checkPolicies(node2, true, true) +} + +// TestChannelEdgePruningUpdateIndexDeletion tests that once edges are deleted +// from the graph, then their entries within the update index are also cleaned +// up. +func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // We'll first populate our graph with two nodes. All channels created + // below will be made between these two nodes. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // With the two nodes created, we'll now create a random channel, as + // well as two edges in the database with distinct update times. + edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1.ChannelFlags = 0 + edge1.Node = node1 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge2.ChannelFlags = 1 + edge2.Node = node2 + edge2.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // checkIndexTimestamps is a helper function that checks the edge update + // index only includes the given timestamps. + checkIndexTimestamps := func(timestamps ...uint64) { + timestampSet := make(map[uint64]struct{}) + for _, t := range timestamps { + timestampSet[t] = struct{}{} + } + + err := db.View(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrGraphNoEdgesFound + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return ErrGraphNoEdgesFound + } + + numEntries := edgeUpdateIndex.Stats().KeyN + expectedEntries := len(timestampSet) + if numEntries != expectedEntries { + return fmt.Errorf("expected %v entries in the "+ + "update index, got %v", expectedEntries, + numEntries) + } + + return edgeUpdateIndex.ForEach(func(k, _ []byte) error { + t := byteOrder.Uint64(k[:8]) + if _, ok := timestampSet[t]; !ok { + return fmt.Errorf("found unexpected "+ + "timestamp "+"%d", t) + } + + return nil + }) + }) + if err != nil { + t.Fatal(err) + } + } + + // With both edges policies added, we'll make sure to check they exist + // within the edge update index. + checkIndexTimestamps( + uint64(edge1.LastUpdate.Unix()), + uint64(edge2.LastUpdate.Unix()), + ) + + // Now, we'll update the edge policies to ensure the old timestamps are + // removed from the update index. + edge1.ChannelFlags = 2 + edge1.LastUpdate = time.Now() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + edge2.ChannelFlags = 3 + edge2.LastUpdate = edge1.LastUpdate.Add(time.Hour) + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // With the policies updated, we should now be able to find their + // updated entries within the update index. + checkIndexTimestamps( + uint64(edge1.LastUpdate.Unix()), + uint64(edge2.LastUpdate.Unix()), + ) + + // Now we'll prune the graph, removing the edges, and also the update + // index entries from the database all together. + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) + _, err = graph.PruneGraph( + []*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101, + ) + if err != nil { + t.Fatalf("unable to prune graph: %v", err) + } + + // Finally, we'll check the database state one last time to conclude + // that we should no longer be able to locate _any_ entries within the + // edge update index. + checkIndexTimestamps() +} + +// TestPruneGraphNodes tests that unconnected vertexes are pruned via the +// PruneSyncState method. +func TestPruneGraphNodes(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + // We'll start off by inserting our source node, to ensure that it's + // the only node left after we prune the graph. + graph := db.ChannelGraph() + sourceNode, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create source node: %v", err) + } + if err := graph.SetSourceNode(sourceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + // With the source node inserted, we'll now add three nodes to the + // channel graph, at the end of the scenario, only two of these nodes + // should still be in the graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node3, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node3); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll now add a new edge to the graph, but only actually advertise + // the edge of *one* of the nodes. + edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + // We'll now insert an advertised edge, but it'll only be the edge that + // points from the first to the second node. + edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) + edge1.ChannelFlags = 0 + edge1.Node = node1 + edge1.SigBytes = testSig.Serialize() + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // We'll now initiate a around of graph pruning. + if err := graph.PruneGraphNodes(); err != nil { + t.Fatalf("unable to prune graph nodes: %v", err) + } + + // At this point, there should be 3 nodes left in the graph still: the + // source node (which can't be pruned), and node 1+2. Nodes 1 and two + // should still be left in the graph as there's half of an advertised + // edge between them. + assertNumNodes(t, graph, 3) + + // Finally, we'll ensure that node3, the only fully unconnected node as + // properly deleted from the graph and not another node in its place. + node3Pub, err := node3.PubKey() + if err != nil { + t.Fatalf("unable to fetch the pubkey of node3: %v", err) + } + if _, err := graph.FetchLightningNode(node3Pub); err == nil { + t.Fatalf("node 3 should have been deleted!") + } +} + +// TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge +// to the graph, one or both of the nodes the edge involves aren't found in the +// database, then shell edges are created for each node if needed. +func TestAddChannelEdgeShellNodes(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // To start, we'll create two nodes, and only add one of them to the + // channel graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + // We'll now create an edge between the two nodes, as a result, node2 + // should be inserted into the database as a shell node. + edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2) + if err := graph.AddChannelEdge(&edgeInfo); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + + node1Pub, err := node1.PubKey() + if err != nil { + t.Fatalf("unable to parse node 1 pub: %v", err) + } + node2Pub, err := node2.PubKey() + if err != nil { + t.Fatalf("unable to parse node 2 pub: %v", err) + } + + // Ensure that node1 was inserted as a full node, while node2 only has + // a shell node present. + node1, err = graph.FetchLightningNode(node1Pub) + if err != nil { + t.Fatalf("unable to fetch node1: %v", err) + } + if !node1.HaveNodeAnnouncement { + t.Fatalf("have shell announcement for node1, shouldn't") + } + + node2, err = graph.FetchLightningNode(node2Pub) + if err != nil { + t.Fatalf("unable to fetch node2: %v", err) + } + if node2.HaveNodeAnnouncement { + t.Fatalf("should have shell announcement for node2, but is full") + } +} + +// TestNodePruningUpdateIndexDeletion tests that once a node has been removed +// from the channel graph, we also remove the entry from the update index as +// well. +func TestNodePruningUpdateIndexDeletion(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'll first populate our graph with a single node that will be + // removed shortly. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // We'll confirm that we can retrieve the node using + // NodeUpdatesInHorizon, using a time that's slightly beyond the last + // update time of our test node. + startTime := time.Unix(9, 0) + endTime := node1.LastUpdate.Add(time.Minute) + nodesInHorizon, err := graph.NodeUpdatesInHorizon(startTime, endTime) + if err != nil { + t.Fatalf("unable to fetch nodes in horizon: %v", err) + } + + // We should only have a single node, and that node should exactly + // match the node we just inserted. + if len(nodesInHorizon) != 1 { + t.Fatalf("should have 1 nodes instead have: %v", + len(nodesInHorizon)) + } + if err := compareNodes(node1, &nodesInHorizon[0]); err != nil { + t.Fatalf("nodes don't match: %v", err) + } + + // We'll now delete the node from the graph, this should result in it + // being removed from the update index as well. + nodePub, _ := node1.PubKey() + if err := graph.DeleteLightningNode(nodePub); err != nil { + t.Fatalf("unable to delete node: %v", err) + } + + // Now that the node has been deleted, we'll again query the nodes in + // the horizon. This time we should have no nodes at all. + nodesInHorizon, err = graph.NodeUpdatesInHorizon(startTime, endTime) + if err != nil { + t.Fatalf("unable to fetch nodes in horizon: %v", err) + } + + if len(nodesInHorizon) != 0 { + t.Fatalf("should have zero nodes instead have: %v", + len(nodesInHorizon)) + } +} + +// TestNodeIsPublic ensures that we properly detect nodes that are seen as +// public within the network graph. +func TestNodeIsPublic(t *testing.T) { + t.Parallel() + + // We'll start off the test by creating a small network of 3 + // participants with the following graph: + // + // Alice <-> Bob <-> Carol + // + // We'll need to create a separate database and channel graph for each + // participant to replicate real-world scenarios (private edges being in + // some graphs but not others, etc.). + aliceDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + aliceNode, err := createTestVertex(aliceDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + aliceGraph := aliceDB.ChannelGraph() + if err := aliceGraph.SetSourceNode(aliceNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + bobDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + bobNode, err := createTestVertex(bobDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + bobGraph := bobDB.ChannelGraph() + if err := bobGraph.SetSourceNode(bobNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + carolDB, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + carolNode, err := createTestVertex(carolDB) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + carolGraph := carolDB.ChannelGraph() + if err := carolGraph.SetSourceNode(carolNode); err != nil { + t.Fatalf("unable to set source node: %v", err) + } + + aliceBobEdge, _ := createEdge(10, 0, 0, 0, aliceNode, bobNode) + bobCarolEdge, _ := createEdge(10, 1, 0, 1, bobNode, carolNode) + + // After creating all of our nodes and edges, we'll add them to each + // participant's graph. + nodes := []*LightningNode{aliceNode, bobNode, carolNode} + edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} + dbs := []*DB{aliceDB, bobDB, carolDB} + graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} + for i, graph := range graphs { + for _, node := range nodes { + node.db = dbs[i] + if err := graph.AddLightningNode(node); err != nil { + t.Fatalf("unable to add node: %v", err) + } + } + for _, edge := range edges { + edge.db = dbs[i] + if err := graph.AddChannelEdge(edge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + } + } + + // checkNodes is a helper closure that will be used to assert that the + // given nodes are seen as public/private within the given graphs. + checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph, + public bool) { + + t.Helper() + + for _, node := range nodes { + for _, graph := range graphs { + isPublic, err := graph.IsPublicNode(node.PubKeyBytes) + if err != nil { + t.Fatalf("unable to determine if pivot "+ + "is public: %v", err) + } + + switch { + case isPublic && !public: + t.Fatalf("expected %x to be private", + node.PubKeyBytes) + case !isPublic && public: + t.Fatalf("expected %x to be public", + node.PubKeyBytes) + } + } + } + } + + // Due to the way the edges were set up above, we'll make sure each node + // can correctly determine that every other node is public. + checkNodes(nodes, graphs, true) + + // Now, we'll remove the edge between Alice and Bob from everyone's + // graph. This will make Alice be seen as a private node as it no longer + // has any advertised edges. + for _, graph := range graphs { + err := graph.DeleteChannelEdges(aliceBobEdge.ChannelID) + if err != nil { + t.Fatalf("unable to remove edge: %v", err) + } + } + checkNodes( + []*LightningNode{aliceNode}, + []*ChannelGraph{bobGraph, carolGraph}, + false, + ) + + // We'll also make the edge between Bob and Carol private. Within Bob's + // and Carol's graph, the edge will exist, but it will not have a proof + // that allows it to be advertised. Within Alice's graph, we'll + // completely remove the edge as it is not possible for her to know of + // it without it being advertised. + for i, graph := range graphs { + err := graph.DeleteChannelEdges(bobCarolEdge.ChannelID) + if err != nil { + t.Fatalf("unable to remove edge: %v", err) + } + + if graph == aliceGraph { + continue + } + + bobCarolEdge.AuthProof = nil + bobCarolEdge.db = dbs[i] + if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { + t.Fatalf("unable to add edge: %v", err) + } + } + + // With the modifications above, Bob should now be seen as a private + // node from both Alice's and Carol's perspective. + checkNodes( + []*LightningNode{bobNode}, + []*ChannelGraph{aliceGraph, carolGraph}, + false, + ) +} + +// TestDisabledChannelIDs ensures that the disabled channels within the +// disabledEdgePolicyBucket are managed properly and the list returned from +// DisabledChannelIDs is correct. +func TestDisabledChannelIDs(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + graph := db.ChannelGraph() + + // Create first node and add it to the graph. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Create second node and add it to the graph. + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + // Adding a new channel edge to the graph. + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Ensure no disabled channels exist in the bucket on start. + disabledChanIds, err := graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } + + // Add one disabled policy and ensure the channel is still not in the + // disabled list. + edge1.ChannelFlags |= lnwire.ChanUpdateDisabled + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } + + // Add second disabled policy and ensure the channel is now in the + // disabled list. + edge2.ChannelFlags |= lnwire.ChanUpdateDisabled + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID { + t.Fatalf("expected disabled channel with id %v, "+ + "got %v", edgeInfo.ChannelID, disabledChanIds) + } + + // Delete the channel edge and ensure it is removed from the disabled list. + if err = graph.DeleteChannelEdges(edgeInfo.ChannelID); err != nil { + t.Fatalf("unable to delete channel edge: %v", err) + } + disabledChanIds, err = graph.DisabledChannelIDs() + if err != nil { + t.Fatalf("unable to get disabled channel ids: %v", err) + } + if len(disabledChanIds) > 0 { + t.Fatalf("expected empty disabled channels, got %v disabled channels", + len(disabledChanIds)) + } +} + +// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in +// the DB that indicates that it should support the htlc_maximum_value_msat +// field, but it is not part of the opaque data, then we'll handle it as it is +// unknown. It also checks that we are correctly able to overwrite it when we +// receive the proper update. +func TestEdgePolicyMissingMaxHtcl(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + + graph := db.ChannelGraph() + + // We'd like to test the update of edges inserted into the database, so + // we create two vertexes to connect. + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + if err := graph.AddLightningNode(node1); err != nil { + t.Fatalf("unable to add node: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test node: %v", err) + } + + edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) + if err := graph.AddLightningNode(node2); err != nil { + t.Fatalf("unable to add node: %v", err) + } + if err := graph.AddChannelEdge(edgeInfo); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + chanID := edgeInfo.ChannelID + from := edge2.Node.PubKeyBytes[:] + to := edge1.Node.PubKeyBytes[:] + + // We'll remove the no max_htlc field from the first edge policy, and + // all other opaque data, and serialize it. + edge1.MessageFlags = 0 + edge1.ExtraOpaqueData = nil + + var b bytes.Buffer + err = serializeChanEdgePolicy(&b, edge1, to) + if err != nil { + t.Fatalf("unable to serialize policy") + } + + // Set the max_htlc field. The extra bytes added to the serialization + // will be the opaque data containing the serialized field. + edge1.MessageFlags = lnwire.ChanUpdateOptionMaxHtlc + edge1.MaxHTLC = 13928598 + var b2 bytes.Buffer + err = serializeChanEdgePolicy(&b2, edge1, to) + if err != nil { + t.Fatalf("unable to serialize policy") + } + + withMaxHtlc := b2.Bytes() + + // Remove the opaque data from the serialization. + stripped := withMaxHtlc[:len(b.Bytes())] + + // Attempting to deserialize these bytes should return an error. + r := bytes.NewReader(stripped) + err = db.View(func(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return ErrGraphNotFound + } + + _, err = deserializeChanEdgePolicy(r, nodes) + if err != ErrEdgePolicyOptionalFieldNotFound { + t.Fatalf("expected "+ + "ErrEdgePolicyOptionalFieldNotFound, got %v", + err) + } + + return nil + }) + if err != nil { + t.Fatalf("error reading db: %v", err) + } + + // Put the stripped bytes in the DB. + err = db.Update(func(tx *bbolt.Tx) error { + edges := tx.Bucket(edgeBucket) + if edges == nil { + return ErrEdgeNotFound + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return ErrEdgeNotFound + } + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID) + + var scratch [8]byte + var indexKey [8 + 8]byte + copy(indexKey[:], scratch[:]) + byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) + + updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) + if err != nil { + return err + } + + if err := updateIndex.Put(indexKey[:], nil); err != nil { + return err + } + + return edges.Put(edgeKey[:], stripped) + }) + if err != nil { + t.Fatalf("error writing db: %v", err) + } + + // And add the second, unmodified edge. + if err := graph.UpdateEdgePolicy(edge2); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + // Attempt to fetch the edge and policies from the DB. Since the policy + // we added is invalid according to the new format, it should be as we + // are not aware of the policy (indicated by the policy returned being + // nil) + dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + + // The first edge should have a nil-policy returned + if dbEdge1 != nil { + t.Fatalf("expected db edge to be nil") + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) + + // Now add the original, unmodified edge policy, and make sure the edge + // policies then become fully populated. + if err := graph.UpdateEdgePolicy(edge1); err != nil { + t.Fatalf("unable to update edge: %v", err) + } + + dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByID(chanID) + if err != nil { + t.Fatalf("unable to fetch channel by ID: %v", err) + } + if err := compareEdgePolicies(dbEdge1, edge1); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + if err := compareEdgePolicies(dbEdge2, edge2); err != nil { + t.Fatalf("edge doesn't match: %v", err) + } + assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) +} + +// assertNumZombies queries the provided ChannelGraph for NumZombies, and +// asserts that the returned number is equal to expZombies. +func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { + t.Helper() + + numZombies, err := graph.NumZombies() + if err != nil { + t.Fatalf("unable to query number of zombies: %v", err) + } + + if numZombies != expZombies { + t.Fatalf("expected %d zombies, found %d", + expZombies, numZombies) + } +} + +// TestGraphZombieIndex ensures that we can mark edges correctly as zombie/live. +func TestGraphZombieIndex(t *testing.T) { + t.Parallel() + + // We'll start by creating our test graph along with a test edge. + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to create test database: %v", err) + } + graph := db.ChannelGraph() + + node1, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + node2, err := createTestVertex(db) + if err != nil { + t.Fatalf("unable to create test vertex: %v", err) + } + + // Swap the nodes if the second's pubkey is smaller than the first. + // Without this, the comparisons at the end will fail probabilistically. + if bytes.Compare(node2.PubKeyBytes[:], node1.PubKeyBytes[:]) < 0 { + node1, node2 = node2, node1 + } + + edge, _, _ := createChannelEdge(db, node1, node2) + if err := graph.AddChannelEdge(edge); err != nil { + t.Fatalf("unable to create channel edge: %v", err) + } + + // Since the edge is known the graph and it isn't a zombie, IsZombieEdge + // should not report the channel as a zombie. + isZombie, _, _ := graph.IsZombieEdge(edge.ChannelID) + if isZombie { + t.Fatal("expected edge to not be marked as zombie") + } + assertNumZombies(t, graph, 0) + + // If we delete the edge and mark it as a zombie, then we should expect + // to see it within the index. + err = graph.DeleteChannelEdges(edge.ChannelID) + if err != nil { + t.Fatalf("unable to mark edge as zombie: %v", err) + } + isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID) + if !isZombie { + t.Fatal("expected edge to be marked as zombie") + } + if pubKey1 != node1.PubKeyBytes { + t.Fatalf("expected pubKey1 %x, got %x", node1.PubKeyBytes, + pubKey1) + } + if pubKey2 != node2.PubKeyBytes { + t.Fatalf("expected pubKey2 %x, got %x", node2.PubKeyBytes, + pubKey2) + } + assertNumZombies(t, graph, 1) + + // Similarly, if we mark the same edge as live, we should no longer see + // it within the index. + if err := graph.MarkEdgeLive(edge.ChannelID); err != nil { + t.Fatalf("unable to mark edge as live: %v", err) + } + isZombie, _, _ = graph.IsZombieEdge(edge.ChannelID) + if isZombie { + t.Fatal("expected edge to not be marked as zombie") + } + assertNumZombies(t, graph, 0) +} + +// compareNodes is used to compare two LightningNodes while excluding the +// Features struct, which cannot be compared as the semantics for reserializing +// the featuresMap have not been defined. +func compareNodes(a, b *LightningNode) error { + if a.LastUpdate != b.LastUpdate { + return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ + "got %v", a.LastUpdate, b.LastUpdate) + } + if !reflect.DeepEqual(a.Addresses, b.Addresses) { + return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ + "got %#v", a.Addresses, b.Addresses) + } + if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) { + return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+ + "got %#v", a.PubKeyBytes, b.PubKeyBytes) + } + if !reflect.DeepEqual(a.Color, b.Color) { + return fmt.Errorf("Color doesn't match: expected %#v, \n "+ + "got %#v", a.Color, b.Color) + } + if !reflect.DeepEqual(a.Alias, b.Alias) { + return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ + "got %#v", a.Alias, b.Alias) + } + if !reflect.DeepEqual(a.db, b.db) { + return fmt.Errorf("db doesn't match: expected %#v, \n "+ + "got %#v", a.db, b.db) + } + if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { + return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ + "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) + } + if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { + return fmt.Errorf("extra data doesn't match: %v vs %v", + a.ExtraOpaqueData, b.ExtraOpaqueData) + } + + return nil +} + +// compareEdgePolicies is used to compare two ChannelEdgePolices using +// compareNodes, so as to exclude comparisons of the Nodes' Features struct. +func compareEdgePolicies(a, b *ChannelEdgePolicy) error { + if a.ChannelID != b.ChannelID { + return fmt.Errorf("ChannelID doesn't match: expected %v, "+ + "got %v", a.ChannelID, b.ChannelID) + } + if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { + return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ + "got %#v", a.LastUpdate, b.LastUpdate) + } + if a.MessageFlags != b.MessageFlags { + return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ + "got %v", a.MessageFlags, b.MessageFlags) + } + if a.ChannelFlags != b.ChannelFlags { + return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ + "got %v", a.ChannelFlags, b.ChannelFlags) + } + if a.TimeLockDelta != b.TimeLockDelta { + return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+ + "got %v", a.TimeLockDelta, b.TimeLockDelta) + } + if a.MinHTLC != b.MinHTLC { + return fmt.Errorf("MinHTLC doesn't match: expected %v, "+ + "got %v", a.MinHTLC, b.MinHTLC) + } + if a.MaxHTLC != b.MaxHTLC { + return fmt.Errorf("MaxHTLC doesn't match: expected %v, "+ + "got %v", a.MaxHTLC, b.MaxHTLC) + } + if a.FeeBaseMSat != b.FeeBaseMSat { + return fmt.Errorf("FeeBaseMSat doesn't match: expected %v, "+ + "got %v", a.FeeBaseMSat, b.FeeBaseMSat) + } + if a.FeeProportionalMillionths != b.FeeProportionalMillionths { + return fmt.Errorf("FeeProportionalMillionths doesn't match: "+ + "expected %v, got %v", a.FeeProportionalMillionths, + b.FeeProportionalMillionths) + } + if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { + return fmt.Errorf("extra data doesn't match: %v vs %v", + a.ExtraOpaqueData, b.ExtraOpaqueData) + } + if err := compareNodes(a.Node, b.Node); err != nil { + return err + } + if !reflect.DeepEqual(a.db, b.db) { + return fmt.Errorf("db doesn't match: expected %#v, \n "+ + "got %#v", a.db, b.db) + } + return nil +} + +// TestLightningNodeSigVerifcation checks that we can use the LightningNode's +// pubkey to verify signatures. +func TestLightningNodeSigVerification(t *testing.T) { + t.Parallel() + + // Create some dummy data to sign. + var data [32]byte + if _, err := prand.Read(data[:]); err != nil { + t.Fatalf("unable to read prand: %v", err) + } + + // Create private key and sign the data with it. + priv, err := btcec.NewPrivateKey(btcec.S256()) + if err != nil { + t.Fatalf("unable to crete priv key: %v", err) + } + + sign, err := priv.Sign(data[:]) + if err != nil { + t.Fatalf("unable to sign: %v", err) + } + + // Sanity check that the signature checks out. + if !sign.Verify(data[:], priv.PubKey()) { + t.Fatalf("signature doesn't check out") + } + + // Create a LightningNode from the same private key. + db, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + node, err := createLightningNode(db, priv) + if err != nil { + t.Fatalf("unable to create node: %v", err) + } + + // And finally check that we can verify the same signature from the + // pubkey returned from the lightning node. + nodePub, err := node.PubKey() + if err != nil { + t.Fatalf("unable to get pubkey: %v", err) + } + + if !sign.Verify(data[:], nodePub) { + t.Fatalf("unable to verify sig") + } +} + +// TestComputeFee tests fee calculation based on both in- and outgoing amt. +func TestComputeFee(t *testing.T) { + var ( + policy = ChannelEdgePolicy{ + FeeBaseMSat: 10000, + FeeProportionalMillionths: 30000, + } + outgoingAmt = lnwire.MilliSatoshi(1000000) + expectedFee = lnwire.MilliSatoshi(40000) + ) + + fee := policy.ComputeFee(outgoingAmt) + if fee != expectedFee { + t.Fatalf("expected fee %v, got %v", expectedFee, fee) + } + + fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee) + if fwdFee != expectedFee { + t.Fatalf("expected fee %v, but got %v", fee, fwdFee) + } +} diff --git a/channeldb/migration_01_to_11/invoice_test.go b/channeldb/migration_01_to_11/invoice_test.go new file mode 100644 index 00000000..795fe493 --- /dev/null +++ b/channeldb/migration_01_to_11/invoice_test.go @@ -0,0 +1,694 @@ +package migration_01_to_11 + +import ( + "crypto/rand" + "reflect" + "testing" + "time" + + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lnwire" +) + +func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { + var pre [32]byte + if _, err := rand.Read(pre[:]); err != nil { + return nil, err + } + + i := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Terms: ContractTerm{ + PaymentPreimage: pre, + Value: value, + }, + Htlcs: map[CircuitKey]*InvoiceHTLC{}, + Expiry: 4000, + } + i.Memo = []byte("memo") + i.Receipt = []byte("receipt") + + // Create a random byte slice of MaxPaymentRequestSize bytes to be used + // as a dummy paymentrequest, and determine if it should be set based + // on one of the random bytes. + var r [MaxPaymentRequestSize]byte + if _, err := rand.Read(r[:]); err != nil { + return nil, err + } + if r[0]&1 == 0 { + i.PaymentRequest = r[:] + } else { + i.PaymentRequest = []byte("") + } + + return i, nil +} + +func TestInvoiceWorkflow(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // Create a fake invoice which we'll use several times in the tests + // below. + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Htlcs: map[CircuitKey]*InvoiceHTLC{}, + } + fakeInvoice.Memo = []byte("memo") + fakeInvoice.Receipt = []byte("receipt") + fakeInvoice.PaymentRequest = []byte("") + copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) + fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) + + paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() + + // Add the invoice to the database, this should succeed as there aren't + // any existing invoices within the database with the same payment + // hash. + if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { + t.Fatalf("unable to find invoice: %v", err) + } + + // Attempt to retrieve the invoice which was just added to the + // database. It should be found, and the invoice returned should be + // identical to the one created above. + dbInvoice, err := db.LookupInvoice(paymentHash) + if err != nil { + t.Fatalf("unable to find invoice: %v", err) + } + if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { + t.Fatalf("invoice fetched from db doesn't match original %v vs %v", + spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice)) + } + + // The add index of the invoice retrieved from the database should now + // be fully populated. As this is the first index written to the DB, + // the addIndex should be 1. + if dbInvoice.AddIndex != 1 { + t.Fatalf("wrong add index: expected %v, got %v", 1, + dbInvoice.AddIndex) + } + + // Settle the invoice, the version retrieved from the database should + // now have the settled bit toggle to true and a non-default + // SettledDate + payAmt := fakeInvoice.Terms.Value * 2 + _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + dbInvoice2, err := db.LookupInvoice(paymentHash) + if err != nil { + t.Fatalf("unable to fetch invoice: %v", err) + } + if dbInvoice2.Terms.State != ContractSettled { + t.Fatalf("invoice should now be settled but isn't") + } + if dbInvoice2.SettleDate.IsZero() { + t.Fatalf("invoice should have non-zero SettledDate but isn't") + } + + // Our 2x payment should be reflected, and also the settle index of 1 + // should also have been committed for this index. + if dbInvoice2.AmtPaid != payAmt { + t.Fatalf("wrong amt paid: expected %v, got %v", payAmt, + dbInvoice2.AmtPaid) + } + if dbInvoice2.SettleIndex != 1 { + t.Fatalf("wrong settle index: expected %v, got %v", 1, + dbInvoice2.SettleIndex) + } + + // Attempt to insert generated above again, this should fail as + // duplicates are rejected by the processing logic. + if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { + t.Fatalf("invoice insertion should fail due to duplication, "+ + "instead %v", err) + } + + // Attempt to look up a non-existent invoice, this should also fail but + // with a "not found" error. + var fakeHash [32]byte + if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { + t.Fatalf("lookup should have failed, instead %v", err) + } + + // Add 10 random invoices. + const numInvoices = 10 + amt := lnwire.NewMSatFromSatoshis(1000) + invoices := make([]*Invoice, numInvoices+1) + invoices[0] = &dbInvoice2 + for i := 1; i < len(invoices)-1; i++ { + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + hash := invoice.Terms.PaymentPreimage.Hash() + if _, err := db.AddInvoice(invoice, hash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + invoices[i] = invoice + } + + // Perform a scan to collect all the active invoices. + dbInvoices, err := db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch all invoices: %v", err) + } + + // The retrieve list of invoices should be identical as since we're + // using big endian, the invoices should be retrieved in ascending + // order (and the primary key should be incremented with each + // insertion). + for i := 0; i < len(invoices)-1; i++ { + if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { + t.Fatalf("retrieved invoices don't match %v vs %v", + spew.Sdump(invoices[i]), + spew.Sdump(dbInvoices[i])) + } + } +} + +// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as +// settled invoices are added to the database are properly placed in the add +// add or settle index which serves as an event time series. +func TestInvoiceAddTimeSeries(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // We'll start off by creating 20 random invoices, and inserting them + // into the database. + const numInvoices = 20 + amt := lnwire.NewMSatFromSatoshis(1000) + invoices := make([]Invoice, numInvoices) + for i := 0; i < len(invoices); i++ { + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + invoices[i] = *invoice + } + + // With the invoices constructed, we'll now create a series of queries + // that we'll use to assert expected return values of + // InvoicesAddedSince. + addQueries := []struct { + sinceAddIndex uint64 + + resp []Invoice + }{ + // If we specify a value of zero, we shouldn't get any invoices + // back. + { + sinceAddIndex: 0, + }, + + // If we specify a value well beyond the number of inserted + // invoices, we shouldn't get any invoices back. + { + sinceAddIndex: 99999999, + }, + + // Using an index of 1 should result in all values, but the + // first one being returned. + { + sinceAddIndex: 1, + resp: invoices[1:], + }, + + // If we use an index of 10, then we should retrieve the + // reaming 10 invoices. + { + sinceAddIndex: 10, + resp: invoices[10:], + }, + } + + for i, query := range addQueries { + resp, err := db.InvoicesAddedSince(query.sinceAddIndex) + if err != nil { + t.Fatalf("unable to query: %v", err) + } + + if !reflect.DeepEqual(query.resp, resp) { + t.Fatalf("test #%v: expected %v, got %v", i, + spew.Sdump(query.resp), spew.Sdump(resp)) + } + } + + // We'll now only settle the latter half of each of those invoices. + for i := 10; i < len(invoices); i++ { + invoice := &invoices[i] + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(0), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + } + + invoices, err = db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch invoices: %v", err) + } + + // We'll slice off the first 10 invoices, as we only settled the last + // 10. + invoices = invoices[10:] + + // We'll now prepare an additional set of queries to ensure the settle + // time series has properly been maintained in the database. + settleQueries := []struct { + sinceSettleIndex uint64 + + resp []Invoice + }{ + // If we specify a value of zero, we shouldn't get any settled + // invoices back. + { + sinceSettleIndex: 0, + }, + + // If we specify a value well beyond the number of settled + // invoices, we shouldn't get any invoices back. + { + sinceSettleIndex: 99999999, + }, + + // Using an index of 1 should result in the final 10 invoices + // being returned, as we only settled those. + { + sinceSettleIndex: 1, + resp: invoices[1:], + }, + } + + for i, query := range settleQueries { + resp, err := db.InvoicesSettledSince(query.sinceSettleIndex) + if err != nil { + t.Fatalf("unable to query: %v", err) + } + + if !reflect.DeepEqual(query.resp, resp) { + t.Fatalf("test #%v: expected %v, got %v", i, + spew.Sdump(query.resp), spew.Sdump(resp)) + } + } +} + +// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it +// twice, then the second time we also receive the invoice that we settled as a +// return argument. +func TestDuplicateSettleInvoice(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + db.now = func() time.Time { return time.Unix(1, 0) } + + // We'll start out by creating an invoice and writing it to the DB. + amt := lnwire.NewMSatFromSatoshis(1000) + invoice, err := randInvoice(amt) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + payHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, payHash); err != nil { + t.Fatalf("unable to add invoice %v", err) + } + + // With the invoice in the DB, we'll now attempt to settle the invoice. + dbInvoice, err := db.UpdateInvoice( + payHash, getUpdateInvoice(amt), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + + // We'll update what we expect the settle invoice to be so that our + // comparison below has the correct assumption. + invoice.SettleIndex = 1 + invoice.Terms.State = ContractSettled + invoice.AmtPaid = amt + invoice.SettleDate = dbInvoice.SettleDate + invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ + {}: { + Amt: amt, + AcceptTime: time.Unix(1, 0), + ResolveTime: time.Unix(1, 0), + State: HtlcStateSettled, + }, + } + + // We should get back the exact same invoice that we just inserted. + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } + + // If we try to settle the invoice again, then we should get the very + // same invoice back, but with an error this time. + dbInvoice, err = db.UpdateInvoice( + payHash, getUpdateInvoice(amt), + ) + if err != ErrInvoiceAlreadySettled { + t.Fatalf("expected ErrInvoiceAlreadySettled") + } + + if dbInvoice == nil { + t.Fatalf("invoice from db is nil after settle!") + } + + invoice.SettleDate = dbInvoice.SettleDate + if !reflect.DeepEqual(dbInvoice, invoice) { + t.Fatalf("wrong invoice after second settle, expected %v got %v", + spew.Sdump(invoice), spew.Sdump(dbInvoice)) + } +} + +// TestQueryInvoices ensures that we can properly query the invoice database for +// invoices using different types of queries. +func TestQueryInvoices(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatalf("unable to make test db: %v", err) + } + + // To begin the test, we'll add 50 invoices to the database. We'll + // assume that the index of the invoice within the database is the same + // as the amount of the invoice itself. + const numInvoices = 50 + for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { + invoice, err := randInvoice(i) + if err != nil { + t.Fatalf("unable to create invoice: %v", err) + } + + paymentHash := invoice.Terms.PaymentPreimage.Hash() + + if _, err := db.AddInvoice(invoice, paymentHash); err != nil { + t.Fatalf("unable to add invoice: %v", err) + } + + // We'll only settle half of all invoices created. + if i%2 == 0 { + _, err := db.UpdateInvoice( + paymentHash, getUpdateInvoice(i), + ) + if err != nil { + t.Fatalf("unable to settle invoice: %v", err) + } + } + } + + // We'll then retrieve the set of all invoices and pending invoices. + // This will serve useful when comparing the expected responses of the + // query with the actual ones. + invoices, err := db.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to retrieve invoices: %v", err) + } + pendingInvoices, err := db.FetchAllInvoices(true) + if err != nil { + t.Fatalf("unable to retrieve pending invoices: %v", err) + } + + // The test will consist of several queries along with their respective + // expected response. Each query response should match its expected one. + testCases := []struct { + query InvoiceQuery + expected []Invoice + }{ + // Fetch all invoices with a single query. + { + query: InvoiceQuery{ + NumMaxInvoices: numInvoices, + }, + expected: invoices, + }, + // Fetch all invoices with a single query, reversed. + { + query: InvoiceQuery{ + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: invoices, + }, + // Fetch the first 25 invoices. + { + query: InvoiceQuery{ + NumMaxInvoices: numInvoices / 2, + }, + expected: invoices[:numInvoices/2], + }, + // Fetch the first 10 invoices, but this time iterating + // backwards. + { + query: InvoiceQuery{ + IndexOffset: 11, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: invoices[:10], + }, + // Fetch the last 40 invoices. + { + query: InvoiceQuery{ + IndexOffset: 10, + NumMaxInvoices: numInvoices, + }, + expected: invoices[10:], + }, + // Fetch all but the first invoice. + { + query: InvoiceQuery{ + IndexOffset: 1, + NumMaxInvoices: numInvoices, + }, + expected: invoices[1:], + }, + // Fetch one invoice, reversed, with index offset 3. This + // should give us the second invoice in the array. + { + query: InvoiceQuery{ + IndexOffset: 3, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[1:2], + }, + // Same as above, at index 2. + { + query: InvoiceQuery{ + IndexOffset: 2, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[0:1], + }, + // Fetch one invoice, at index 1, reversed. Since invoice#1 is + // the very first, there won't be any left in a reverse search, + // so we expect no invoices to be returned. + { + query: InvoiceQuery{ + IndexOffset: 1, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: nil, + }, + // Same as above, but don't restrict the number of invoices to + // 1. + { + query: InvoiceQuery{ + IndexOffset: 1, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + expected: nil, + }, + // Fetch one invoice, reversed, with no offset set. We expect + // the last invoice in the response. + { + query: InvoiceQuery{ + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Fetch one invoice, reversed, the offset set at numInvoices+1. + // We expect this to return the last invoice. + { + query: InvoiceQuery{ + IndexOffset: numInvoices + 1, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Same as above, at offset numInvoices. + { + query: InvoiceQuery{ + IndexOffset: numInvoices, + Reversed: true, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-2 : numInvoices-1], + }, + // Fetch one invoice, at no offset (same as offset 0). We + // expect the first invoice only in the response. + { + query: InvoiceQuery{ + NumMaxInvoices: 1, + }, + expected: invoices[:1], + }, + // Same as above, at offset 1. + { + query: InvoiceQuery{ + IndexOffset: 1, + NumMaxInvoices: 1, + }, + expected: invoices[1:2], + }, + // Same as above, at offset 2. + { + query: InvoiceQuery{ + IndexOffset: 2, + NumMaxInvoices: 1, + }, + expected: invoices[2:3], + }, + // Same as above, at offset numInvoices-1. Expect the last + // invoice to be returned. + { + query: InvoiceQuery{ + IndexOffset: numInvoices - 1, + NumMaxInvoices: 1, + }, + expected: invoices[numInvoices-1:], + }, + // Same as above, at offset numInvoices. No invoices should be + // returned, as there are no invoices after this offset. + { + query: InvoiceQuery{ + IndexOffset: numInvoices, + NumMaxInvoices: 1, + }, + expected: nil, + }, + // Fetch all pending invoices with a single query. + { + query: InvoiceQuery{ + PendingOnly: true, + NumMaxInvoices: numInvoices, + }, + expected: pendingInvoices, + }, + // Fetch the first 12 pending invoices. + { + query: InvoiceQuery{ + PendingOnly: true, + NumMaxInvoices: numInvoices / 4, + }, + expected: pendingInvoices[:len(pendingInvoices)/2], + }, + // Fetch the first 5 pending invoices, but this time iterating + // backwards. + { + query: InvoiceQuery{ + IndexOffset: 10, + PendingOnly: true, + Reversed: true, + NumMaxInvoices: numInvoices, + }, + // Since we seek to the invoice with index 10 and + // iterate backwards, there should only be 5 pending + // invoices before it as every other invoice within the + // index is settled. + expected: pendingInvoices[:5], + }, + // Fetch the last 15 invoices. + { + query: InvoiceQuery{ + IndexOffset: 20, + PendingOnly: true, + NumMaxInvoices: numInvoices, + }, + // Since we seek to the invoice with index 20, there are + // 30 invoices left. From these 30, only 15 of them are + // still pending. + expected: pendingInvoices[len(pendingInvoices)-15:], + }, + } + + for i, testCase := range testCases { + response, err := db.QueryInvoices(testCase.query) + if err != nil { + t.Fatalf("unable to query invoice database: %v", err) + } + + if !reflect.DeepEqual(response.Invoices, testCase.expected) { + t.Fatalf("test #%d: query returned incorrect set of "+ + "invoices: expcted %v, got %v", i, + spew.Sdump(response.Invoices), + spew.Sdump(testCase.expected)) + } + } +} + +// getUpdateInvoice returns an invoice update callback that, when called, +// settles the invoice with the given amount. +func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { + return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + if invoice.Terms.State == ContractSettled { + return nil, ErrInvoiceAlreadySettled + } + + update := &InvoiceUpdateDesc{ + Preimage: invoice.Terms.PaymentPreimage, + State: ContractSettled, + Htlcs: map[CircuitKey]*HtlcAcceptDesc{ + {}: { + Amt: amt, + }, + }, + } + + return update, nil + } +} diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go new file mode 100644 index 00000000..5f40454a --- /dev/null +++ b/channeldb/migration_01_to_11/invoices.go @@ -0,0 +1,1320 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // UnknownPreimage is an all-zeroes preimage that indicates that the + // preimage for this invoice is not yet known. + UnknownPreimage lntypes.Preimage + + // invoiceBucket is the name of the bucket within the database that + // stores all data related to invoices no matter their final state. + // Within the invoice bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint32. + invoiceBucket = []byte("invoices") + + // paymentHashIndexBucket is the name of the sub-bucket within the + // invoiceBucket which indexes all invoices by their payment hash. The + // payment hash is the sha256 of the invoice's payment preimage. This + // index is used to detect duplicates, and also to provide a fast path + // for looking up incoming HTLCs to determine if we're able to settle + // them fully. + // + // maps: payHash => invoiceKey + invoiceIndexBucket = []byte("paymenthashes") + + // numInvoicesKey is the name of key which houses the auto-incrementing + // invoice ID which is essentially used as a primary key. With each + // invoice inserted, the primary key is incremented by one. This key is + // stored within the invoiceIndexBucket. Within the invoiceBucket + // invoices are uniquely identified by the invoice ID. + numInvoicesKey = []byte("nik") + + // addIndexBucket is an index bucket that we'll use to create a + // monotonically increasing set of add indexes. Each time we add a new + // invoice, this sequence number will be incremented and then populated + // within the new invoice. + // + // In addition to this sequence number, we map: + // + // addIndexNo => invoiceKey + addIndexBucket = []byte("invoice-add-index") + + // settleIndexBucket is an index bucket that we'll use to create a + // monotonically increasing integer for tracking a "settle index". Each + // time an invoice is settled, this sequence number will be incremented + // as populate within the newly settled invoice. + // + // In addition to this sequence number, we map: + // + // settleIndexNo => invoiceKey + settleIndexBucket = []byte("invoice-settle-index") + + // ErrInvoiceAlreadySettled is returned when the invoice is already + // settled. + ErrInvoiceAlreadySettled = errors.New("invoice already settled") + + // ErrInvoiceAlreadyCanceled is returned when the invoice is already + // canceled. + ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") + + // ErrInvoiceAlreadyAccepted is returned when the invoice is already + // accepted. + ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") + + // ErrInvoiceStillOpen is returned when the invoice is still open. + ErrInvoiceStillOpen = errors.New("invoice still open") +) + +const ( + // MaxMemoSize is maximum size of the memo field within invoices stored + // in the database. + MaxMemoSize = 1024 + + // MaxReceiptSize is the maximum size of the payment receipt stored + // within the database along side incoming/outgoing invoices. + MaxReceiptSize = 1024 + + // MaxPaymentRequestSize is the max size of a payment request for + // this invoice. + // TODO(halseth): determine the max length payment request when field + // lengths are final. + MaxPaymentRequestSize = 4096 + + // A set of tlv type definitions used to serialize invoice htlcs to the + // database. + chanIDType tlv.Type = 1 + htlcIDType tlv.Type = 3 + amtType tlv.Type = 5 + acceptHeightType tlv.Type = 7 + acceptTimeType tlv.Type = 9 + resolveTimeType tlv.Type = 11 + expiryHeightType tlv.Type = 13 + stateType tlv.Type = 15 +) + +// ContractState describes the state the invoice is in. +type ContractState uint8 + +const ( + // ContractOpen means the invoice has only been created. + ContractOpen ContractState = 0 + + // ContractSettled means the htlc is settled and the invoice has been + // paid. + ContractSettled ContractState = 1 + + // ContractCanceled means the invoice has been canceled. + ContractCanceled ContractState = 2 + + // ContractAccepted means the HTLC has been accepted but not settled + // yet. + ContractAccepted ContractState = 3 +) + +// String returns a human readable identifier for the ContractState type. +func (c ContractState) String() string { + switch c { + case ContractOpen: + return "Open" + case ContractSettled: + return "Settled" + case ContractCanceled: + return "Canceled" + case ContractAccepted: + return "Accepted" + } + + return "Unknown" +} + +// ContractTerm is a companion struct to the Invoice struct. This struct houses +// the necessary conditions required before the invoice can be considered fully +// settled by the payee. +type ContractTerm struct { + // PaymentPreimage is the preimage which is to be revealed in the + // occasion that an HTLC paying to the hash of this preimage is + // extended. + PaymentPreimage lntypes.Preimage + + // Value is the expected amount of milli-satoshis to be paid to an HTLC + // which can be satisfied by the above preimage. + Value lnwire.MilliSatoshi + + // State describes the state the invoice is in. + State ContractState +} + +// Invoice is a payment invoice generated by a payee in order to request +// payment for some good or service. The inclusion of invoices within Lightning +// creates a payment work flow for merchants very similar to that of the +// existing financial system within PayPal, etc. Invoices are added to the +// database when a payment is requested, then can be settled manually once the +// payment is received at the upper layer. For record keeping purposes, +// invoices are never deleted from the database, instead a bit is toggled +// denoting the invoice has been fully settled. Within the database, all +// invoices must have a unique payment hash which is generated by taking the +// sha256 of the payment preimage. +type Invoice struct { + // Memo is an optional memo to be stored along side an invoice. The + // memo may contain further details pertaining to the invoice itself, + // or any other message which fits within the size constraints. + Memo []byte + + // Receipt is an optional field dedicated for storing a + // cryptographically binding receipt of payment. + // + // TODO(roasbeef): document scheme. + Receipt []byte + + // PaymentRequest is an optional field where a payment request created + // for this invoice can be stored. + PaymentRequest []byte + + // FinalCltvDelta is the minimum required number of blocks before htlc + // expiry when the invoice is accepted. + FinalCltvDelta int32 + + // Expiry defines how long after creation this invoice should expire. + Expiry time.Duration + + // CreationDate is the exact time the invoice was created. + CreationDate time.Time + + // SettleDate is the exact time the invoice was settled. + SettleDate time.Time + + // Terms are the contractual payment terms of the invoice. Once all the + // terms have been satisfied by the payer, then the invoice can be + // considered fully fulfilled. + // + // TODO(roasbeef): later allow for multiple terms to fulfill the final + // invoice: payment fragmentation, etc. + Terms ContractTerm + + // AddIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all invoices created. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been added before they re-connected. + // + // NOTE: This index starts at 1. + AddIndex uint64 + + // SettleIndex is an auto-incrementing integer that acts as a + // monotonically increasing sequence number for all settled invoices. + // Clients can then use this field as a "checkpoint" of sorts when + // implementing a streaming RPC to notify consumers of instances where + // an invoice has been settled before they re-connected. + // + // NOTE: This index starts at 1. + SettleIndex uint64 + + // AmtPaid is the final amount that we ultimately accepted for pay for + // this invoice. We specify this value independently as it's possible + // that the invoice originally didn't specify an amount, or the sender + // overpaid. + AmtPaid lnwire.MilliSatoshi + + // Htlcs records all htlcs that paid to this invoice. Some of these + // htlcs may have been marked as canceled. + Htlcs map[CircuitKey]*InvoiceHTLC +} + +// HtlcState defines the states an htlc paying to an invoice can be in. +type HtlcState uint8 + +const ( + // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. + HtlcStateAccepted HtlcState = iota + + // HtlcStateCanceled indicates the htlc is canceled back to the + // sender. + HtlcStateCanceled + + // HtlcStateSettled indicates the htlc is settled. + HtlcStateSettled +) + +// InvoiceHTLC contains details about an htlc paying to this invoice. +type InvoiceHTLC struct { + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // AcceptHeight is the block height at which the invoice registry + // decided to accept this htlc as a payment to the invoice. At this + // height, the invoice cltv delay must have been met. + AcceptHeight uint32 + + // AcceptTime is the wall clock time at which the invoice registry + // decided to accept the htlc. + AcceptTime time.Time + + // ResolveTime is the wall clock time at which the invoice registry + // decided to settle the htlc. + ResolveTime time.Time + + // Expiry is the expiry height of this htlc. + Expiry uint32 + + // State indicates the state the invoice htlc is currently in. A + // canceled htlc isn't just removed from the invoice htlcs map, because + // we need AcceptHeight to properly cancel the htlc back. + State HtlcState +} + +// HtlcAcceptDesc describes the details of a newly accepted htlc. +type HtlcAcceptDesc struct { + // AcceptHeight is the block height at which this htlc was accepted. + AcceptHeight int32 + + // Amt is the amount that is carried by this htlc. + Amt lnwire.MilliSatoshi + + // Expiry is the expiry height of this htlc. + Expiry uint32 +} + +// InvoiceUpdateDesc describes the changes that should be applied to the +// invoice. +type InvoiceUpdateDesc struct { + // State is the new state that this invoice should progress to. + State ContractState + + // Htlcs describes the changes that need to be made to the invoice htlcs + // in the database. Htlc map entries with their value set should be + // added. If the map value is nil, the htlc should be canceled. + Htlcs map[CircuitKey]*HtlcAcceptDesc + + // Preimage must be set to the preimage when state is settled. + Preimage lntypes.Preimage +} + +// InvoiceUpdateCallback is a callback used in the db transaction to update the +// invoice. +type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) + +func validateInvoice(i *Invoice) error { + if len(i.Memo) > MaxMemoSize { + return fmt.Errorf("max length a memo is %v, and invoice "+ + "of length %v was provided", MaxMemoSize, len(i.Memo)) + } + if len(i.Receipt) > MaxReceiptSize { + return fmt.Errorf("max length a receipt is %v, and invoice "+ + "of length %v was provided", MaxReceiptSize, + len(i.Receipt)) + } + if len(i.PaymentRequest) > MaxPaymentRequestSize { + return fmt.Errorf("max length of payment request is %v, length "+ + "provided was %v", MaxPaymentRequestSize, + len(i.PaymentRequest)) + } + return nil +} + +// AddInvoice inserts the targeted invoice into the database. If the invoice has +// *any* payment hashes which already exists within the database, then the +// insertion will be aborted and rejected due to the strict policy banning any +// duplicate payment hashes. A side effect of this function is that it sets +// AddIndex on newInvoice. +func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( + uint64, error) { + + if err := validateInvoice(newInvoice); err != nil { + return 0, err + } + + var invoiceAddIndex uint64 + err := d.Update(func(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + + invoiceIndex, err := invoices.CreateBucketIfNotExists( + invoiceIndexBucket, + ) + if err != nil { + return err + } + addIndex, err := invoices.CreateBucketIfNotExists( + addIndexBucket, + ) + if err != nil { + return err + } + + // Ensure that an invoice an identical payment hash doesn't + // already exist within the index. + if invoiceIndex.Get(paymentHash[:]) != nil { + return ErrDuplicateInvoice + } + + // If the current running payment ID counter hasn't yet been + // created, then create it now. + var invoiceNum uint32 + invoiceCounter := invoiceIndex.Get(numInvoicesKey) + if invoiceCounter == nil { + var scratch [4]byte + byteOrder.PutUint32(scratch[:], invoiceNum) + err := invoiceIndex.Put(numInvoicesKey, scratch[:]) + if err != nil { + return err + } + } else { + invoiceNum = byteOrder.Uint32(invoiceCounter) + } + + newIndex, err := putInvoice( + invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, + paymentHash, + ) + if err != nil { + return err + } + + invoiceAddIndex = newIndex + return nil + }) + if err != nil { + return 0, err + } + + return invoiceAddIndex, err +} + +// InvoicesAddedSince can be used by callers to seek into the event time series +// of all the invoices added in the database. The specified sinceAddIndex +// should be the highest add index that the caller knows of. This method will +// return all invoices with an add index greater than the specified +// sinceAddIndex. +// +// NOTE: The index starts from 1, as a result. We enforce that specifying a +// value below the starting index value is a noop. +func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { + var newInvoices []Invoice + + // If an index of zero was specified, then in order to maintain + // backwards compat, we won't send out any new invoices. + if sinceAddIndex == 0 { + return newInvoices, nil + } + + var startIndex [8]byte + byteOrder.PutUint64(startIndex[:], sinceAddIndex) + + err := d.DB.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + addIndex := invoices.Bucket(addIndexBucket) + if addIndex == nil { + return ErrNoInvoicesCreated + } + + // We'll now run through each entry in the add index starting + // at our starting index. We'll continue until we reach the + // very end of the current key space. + invoiceCursor := addIndex.Cursor() + + // We'll seek to the starting index, then manually advance the + // cursor in order to skip the entry with the since add index. + invoiceCursor.Seek(startIndex[:]) + addSeqNo, invoiceKey := invoiceCursor.Next() + + for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() { + + // For each key found, we'll look up the actual + // invoice, then accumulate it into our return value. + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + newInvoices = append(newInvoices, invoice) + } + + return nil + }) + switch { + // If no invoices have been created, then we'll return the empty set of + // invoices. + case err == ErrNoInvoicesCreated: + + case err != nil: + return nil, err + } + + return newInvoices, nil +} + +// LookupInvoice attempts to look up an invoice according to its 32 byte +// payment hash. If an invoice which can settle the HTLC identified by the +// passed payment hash isn't found, then an error is returned. Otherwise, the +// full invoice is returned. Before setting the incoming HTLC, the values +// SHOULD be checked to ensure the payer meets the agreed upon contractual +// terms of the payment. +func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { + var invoice Invoice + err := d.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + invoiceIndex := invoices.Bucket(invoiceIndexBucket) + if invoiceIndex == nil { + return ErrNoInvoicesCreated + } + + // Check the invoice index to see if an invoice paying to this + // hash exists within the DB. + invoiceNum := invoiceIndex.Get(paymentHash[:]) + if invoiceNum == nil { + return ErrInvoiceNotFound + } + + // An invoice matching the payment hash has been found, so + // retrieve the record of the invoice itself. + i, err := fetchInvoice(invoiceNum, invoices) + if err != nil { + return err + } + invoice = i + + return nil + }) + if err != nil { + return invoice, err + } + + return invoice, nil +} + +// FetchAllInvoices returns all invoices currently stored within the database. +// If the pendingOnly param is true, then only unsettled invoices will be +// returned, skipping all invoices that are fully settled. +func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { + var invoices []Invoice + + err := d.View(func(tx *bbolt.Tx) error { + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return ErrNoInvoicesCreated + } + + // Iterate through the entire key space of the top-level + // invoice bucket. If key with a non-nil value stores the next + // invoice ID which maps to the corresponding invoice. + return invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoice(invoiceReader) + if err != nil { + return err + } + + if pendingOnly && + invoice.Terms.State == ContractSettled { + + return nil + } + + invoices = append(invoices, invoice) + + return nil + }) + }) + if err != nil { + return nil, err + } + + return invoices, nil +} + +// InvoiceQuery represents a query to the invoice database. The query allows a +// caller to retrieve all invoices starting from a particular add index and +// limit the number of results returned. +type InvoiceQuery struct { + // IndexOffset is the offset within the add indices to start at. This + // can be used to start the response at a particular invoice. + IndexOffset uint64 + + // NumMaxInvoices is the maximum number of invoices that should be + // starting from the add index. + NumMaxInvoices uint64 + + // PendingOnly, if set, returns unsettled invoices starting from the + // add index. + PendingOnly bool + + // Reversed, if set, indicates that the invoices returned should start + // from the IndexOffset and go backwards. + Reversed bool +} + +// InvoiceSlice is the response to a invoice query. It includes the original +// query, the set of invoices that match the query, and an integer which +// represents the offset index of the last item in the set of returned invoices. +// This integer allows callers to resume their query using this offset in the +// event that the query's response exceeds the maximum number of returnable +// invoices. +type InvoiceSlice struct { + InvoiceQuery + + // Invoices is the set of invoices that matched the query above. + Invoices []Invoice + + // FirstIndexOffset is the index of the first element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + FirstIndexOffset uint64 + + // LastIndexOffset is the index of the last element in the set of + // returned Invoices above. Callers can use this to resume their query + // in the event that the slice has too many events to fit into a single + // response. + LastIndexOffset uint64 +} + +// QueryInvoices allows a caller to query the invoice database for invoices +// within the specified add index range. +func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { + resp := InvoiceSlice{ + InvoiceQuery: q, + } + + err := d.View(func(tx *bbolt.Tx) error { + // If the bucket wasn't found, then there aren't any invoices + // within the database yet, so we can simply exit. + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + invoiceAddIndex := invoices.Bucket(addIndexBucket) + if invoiceAddIndex == nil { + return ErrNoInvoicesCreated + } + + // keyForIndex is a helper closure that retrieves the invoice + // key for the given add index of an invoice. + keyForIndex := func(c *bbolt.Cursor, index uint64) []byte { + var keyIndex [8]byte + byteOrder.PutUint64(keyIndex[:], index) + _, invoiceKey := c.Seek(keyIndex[:]) + return invoiceKey + } + + // nextKey is a helper closure to determine what the next + // invoice key is when iterating over the invoice add index. + nextKey := func(c *bbolt.Cursor) ([]byte, []byte) { + if q.Reversed { + return c.Prev() + } + return c.Next() + } + + // We'll be using a cursor to seek into the database and return + // a slice of invoices. We'll need to determine where to start + // our cursor depending on the parameters set within the query. + c := invoiceAddIndex.Cursor() + invoiceKey := keyForIndex(c, q.IndexOffset+1) + + // If the query is specifying reverse iteration, then we must + // handle a few offset cases. + if q.Reversed { + switch q.IndexOffset { + + // This indicates the default case, where no offset was + // specified. In that case we just start from the last + // invoice. + case 0: + _, invoiceKey = c.Last() + + // This indicates the offset being set to the very + // first invoice. Since there are no invoices before + // this offset, and the direction is reversed, we can + // return without adding any invoices to the response. + case 1: + return nil + + // Otherwise we start iteration at the invoice prior to + // the offset. + default: + invoiceKey = keyForIndex(c, q.IndexOffset-1) + } + } + + // If we know that a set of invoices exists, then we'll begin + // our seek through the bucket in order to satisfy the query. + // We'll continue until either we reach the end of the range, or + // reach our max number of invoices. + for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { + // If our current return payload exceeds the max number + // of invoices, then we'll exit now. + if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { + break + } + + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + // Skip any settled invoices if the caller is only + // interested in unsettled. + if q.PendingOnly && + invoice.Terms.State == ContractSettled { + + continue + } + + // At this point, we've exhausted the offset, so we'll + // begin collecting invoices found within the range. + resp.Invoices = append(resp.Invoices, invoice) + } + + // If we iterated through the add index in reverse order, then + // we'll need to reverse the slice of invoices to return them in + // forward order. + if q.Reversed { + numInvoices := len(resp.Invoices) + for i := 0; i < numInvoices/2; i++ { + opposite := numInvoices - i - 1 + resp.Invoices[i], resp.Invoices[opposite] = + resp.Invoices[opposite], resp.Invoices[i] + } + } + + return nil + }) + if err != nil && err != ErrNoInvoicesCreated { + return resp, err + } + + // Finally, record the indexes of the first and last invoices returned + // so that the caller can resume from this point later on. + if len(resp.Invoices) > 0 { + resp.FirstIndexOffset = resp.Invoices[0].AddIndex + resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex + } + + return resp, nil +} + +// UpdateInvoice attempts to update an invoice corresponding to the passed +// payment hash. If an invoice matching the passed payment hash doesn't exist +// within the database, then the action will fail with a "not found" error. +// +// The update is performed inside the same database transaction that fetches the +// invoice and is therefore atomic. The fields to update are controlled by the +// supplied callback. +func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, + callback InvoiceUpdateCallback) (*Invoice, error) { + + var updatedInvoice *Invoice + err := d.Update(func(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + invoiceIndex, err := invoices.CreateBucketIfNotExists( + invoiceIndexBucket, + ) + if err != nil { + return err + } + settleIndex, err := invoices.CreateBucketIfNotExists( + settleIndexBucket, + ) + if err != nil { + return err + } + + // Check the invoice index to see if an invoice paying to this + // hash exists within the DB. + invoiceNum := invoiceIndex.Get(paymentHash[:]) + if invoiceNum == nil { + return ErrInvoiceNotFound + } + + updatedInvoice, err = d.updateInvoice( + paymentHash, invoices, settleIndex, invoiceNum, + callback, + ) + + return err + }) + + return updatedInvoice, err +} + +// InvoicesSettledSince can be used by callers to catch up any settled invoices +// they missed within the settled invoice time series. We'll return all known +// settled invoice that have a settle index higher than the passed +// sinceSettleIndex. +// +// NOTE: The index starts from 1, as a result. We enforce that specifying a +// value below the starting index value is a noop. +func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { + var settledInvoices []Invoice + + // If an index of zero was specified, then in order to maintain + // backwards compat, we won't send out any new invoices. + if sinceSettleIndex == 0 { + return settledInvoices, nil + } + + var startIndex [8]byte + byteOrder.PutUint64(startIndex[:], sinceSettleIndex) + + err := d.DB.View(func(tx *bbolt.Tx) error { + invoices := tx.Bucket(invoiceBucket) + if invoices == nil { + return ErrNoInvoicesCreated + } + + settleIndex := invoices.Bucket(settleIndexBucket) + if settleIndex == nil { + return ErrNoInvoicesCreated + } + + // We'll now run through each entry in the add index starting + // at our starting index. We'll continue until we reach the + // very end of the current key space. + invoiceCursor := settleIndex.Cursor() + + // We'll seek to the starting index, then manually advance the + // cursor in order to skip the entry with the since add index. + invoiceCursor.Seek(startIndex[:]) + seqNo, invoiceKey := invoiceCursor.Next() + + for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() { + + // For each key found, we'll look up the actual + // invoice, then accumulate it into our return value. + invoice, err := fetchInvoice(invoiceKey, invoices) + if err != nil { + return err + } + + settledInvoices = append(settledInvoices, invoice) + } + + return nil + }) + if err != nil { + return nil, err + } + + return settledInvoices, nil +} + +func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, + i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( + uint64, error) { + + // Create the invoice key which is just the big-endian representation + // of the invoice number. + var invoiceKey [4]byte + byteOrder.PutUint32(invoiceKey[:], invoiceNum) + + // Increment the num invoice counter index so the next invoice bares + // the proper ID. + var scratch [4]byte + invoiceCounter := invoiceNum + 1 + byteOrder.PutUint32(scratch[:], invoiceCounter) + if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { + return 0, err + } + + // Add the payment hash to the invoice index. This will let us quickly + // identify if we can settle an incoming payment, and also to possibly + // allow a single invoice to have multiple payment installations. + err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) + if err != nil { + return 0, err + } + + // Next, we'll obtain the next add invoice index (sequence + // number), so we can properly place this invoice within this + // event stream. + nextAddSeqNo, err := addIndex.NextSequence() + if err != nil { + return 0, err + } + + // With the next sequence obtained, we'll updating the event series in + // the add index bucket to map this current add counter to the index of + // this new invoice. + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) + if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { + return 0, err + } + + i.AddIndex = nextAddSeqNo + + // Finally, serialize the invoice itself to be written to the disk. + var buf bytes.Buffer + if err := serializeInvoice(&buf, i); err != nil { + return 0, err + } + + if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil { + return 0, err + } + + return nextAddSeqNo, nil +} + +// serializeInvoice serializes an invoice to a writer. +// +// Note: this function is in use for a migration. Before making changes that +// would modify the on disk format, make a copy of the original code and store +// it with the migration. +func serializeInvoice(w io.Writer, i *Invoice) error { + if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.FinalCltvDelta); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, int64(i.Expiry)); err != nil { + return err + } + + birthBytes, err := i.CreationDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { + return err + } + + settleBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + return err + } + + if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + return err + } + + var scratch [8]byte + byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { + return err + } + + if err := serializeHtlcs(w, i.Htlcs); err != nil { + return err + } + + return nil +} + +// serializeHtlcs serializes a map containing circuit keys and invoice htlcs to +// a writer. +func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { + for key, htlc := range htlcs { + // Encode the htlc in a tlv stream. + chanID := key.ChanID.ToUint64() + amt := uint64(htlc.Amt) + acceptTime := uint64(htlc.AcceptTime.UnixNano()) + resolveTime := uint64(htlc.ResolveTime.UnixNano()) + state := uint8(htlc.State) + + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return err + } + + var b bytes.Buffer + if err := tlvStream.Encode(&b); err != nil { + return err + } + + // Write the length of the tlv stream followed by the stream + // bytes. + err = binary.Write(w, byteOrder, uint64(b.Len())) + if err != nil { + return err + } + + if _, err := w.Write(b.Bytes()); err != nil { + return err + } + } + + return nil +} + +func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) { + invoiceBytes := invoices.Get(invoiceNum) + if invoiceBytes == nil { + return Invoice{}, ErrInvoiceNotFound + } + + invoiceReader := bytes.NewReader(invoiceBytes) + + return deserializeInvoice(invoiceReader) +} + +func deserializeInvoice(r io.Reader) (Invoice, error) { + var err error + invoice := Invoice{} + + // TODO(roasbeef): use read full everywhere + invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + if err != nil { + return invoice, err + } + invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + if err != nil { + return invoice, err + } + + invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + if err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.FinalCltvDelta); err != nil { + return invoice, err + } + + var expiry int64 + if err := binary.Read(r, byteOrder, &expiry); err != nil { + return invoice, err + } + invoice.Expiry = time.Duration(expiry) + + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + if err != nil { + return invoice, err + } + if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { + return invoice, err + } + + settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + if err != nil { + return invoice, err + } + if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { + return invoice, err + } + + if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { + return invoice, err + } + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return invoice, err + } + invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { + return invoice, err + } + + invoice.Htlcs, err = deserializeHtlcs(r) + if err != nil { + return Invoice{}, err + } + + return invoice, nil +} + +// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it +// as a map. +func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { + htlcs := make(map[CircuitKey]*InvoiceHTLC, 0) + + for { + // Read the length of the tlv stream for this htlc. + var streamLen uint64 + if err := binary.Read(r, byteOrder, &streamLen); err != nil { + if err == io.EOF { + break + } + + return nil, err + } + + streamBytes := make([]byte, streamLen) + if _, err := r.Read(streamBytes); err != nil { + return nil, err + } + streamReader := bytes.NewReader(streamBytes) + + // Decode the contents into the htlc fields. + var ( + htlc InvoiceHTLC + key CircuitKey + chanID uint64 + state uint8 + acceptTime, resolveTime uint64 + amt uint64 + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord(chanIDType, &chanID), + tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), + tlv.MakePrimitiveRecord(amtType, &amt), + tlv.MakePrimitiveRecord( + acceptHeightType, &htlc.AcceptHeight, + ), + tlv.MakePrimitiveRecord(acceptTimeType, &acceptTime), + tlv.MakePrimitiveRecord(resolveTimeType, &resolveTime), + tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), + tlv.MakePrimitiveRecord(stateType, &state), + ) + if err != nil { + return nil, err + } + + if err := tlvStream.Decode(streamReader); err != nil { + return nil, err + } + + key.ChanID = lnwire.NewShortChanIDFromInt(chanID) + htlc.AcceptTime = time.Unix(0, int64(acceptTime)) + htlc.ResolveTime = time.Unix(0, int64(resolveTime)) + htlc.State = HtlcState(state) + htlc.Amt = lnwire.MilliSatoshi(amt) + + htlcs[key] = &htlc + } + + return htlcs, nil +} + +// copySlice allocates a new slice and copies the source into it. +func copySlice(src []byte) []byte { + dest := make([]byte, len(src)) + copy(dest, src) + return dest +} + +// copyInvoice makes a deep copy of the supplied invoice. +func copyInvoice(src *Invoice) *Invoice { + dest := Invoice{ + Memo: copySlice(src.Memo), + Receipt: copySlice(src.Receipt), + PaymentRequest: copySlice(src.PaymentRequest), + FinalCltvDelta: src.FinalCltvDelta, + CreationDate: src.CreationDate, + SettleDate: src.SettleDate, + Terms: src.Terms, + AddIndex: src.AddIndex, + SettleIndex: src.SettleIndex, + AmtPaid: src.AmtPaid, + Htlcs: make( + map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), + ), + } + + for k, v := range src.Htlcs { + dest.Htlcs[k] = v + } + + return &dest +} + +// updateInvoice fetches the invoice, obtains the update descriptor from the +// callback and applies the updates in a single db transaction. +func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket, + invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { + + invoice, err := fetchInvoice(invoiceNum, invoices) + if err != nil { + return nil, err + } + + preUpdateState := invoice.Terms.State + + // Create deep copy to prevent any accidental modification in the + // callback. + copy := copyInvoice(&invoice) + + // Call the callback and obtain the update descriptor. + update, err := callback(copy) + if err != nil { + return &invoice, err + } + + // Update invoice state. + invoice.Terms.State = update.State + + now := d.now() + + // Update htlc set. + for key, htlcUpdate := range update.Htlcs { + htlc, ok := invoice.Htlcs[key] + + // No update means the htlc needs to be canceled. + if htlcUpdate == nil { + if !ok { + return nil, fmt.Errorf("unknown htlc %v", key) + } + if htlc.State != HtlcStateAccepted { + return nil, fmt.Errorf("can only cancel " + + "accepted htlcs") + } + + htlc.State = HtlcStateCanceled + htlc.ResolveTime = now + invoice.AmtPaid -= htlc.Amt + + continue + } + + // Add new htlc paying to the invoice. + if ok { + return nil, fmt.Errorf("htlc %v already exists", key) + } + htlc = &InvoiceHTLC{ + Amt: htlcUpdate.Amt, + Expiry: htlcUpdate.Expiry, + AcceptHeight: uint32(htlcUpdate.AcceptHeight), + AcceptTime: now, + } + if preUpdateState == ContractSettled { + htlc.State = HtlcStateSettled + htlc.ResolveTime = now + } else { + htlc.State = HtlcStateAccepted + } + + invoice.Htlcs[key] = htlc + invoice.AmtPaid += htlc.Amt + } + + // If invoice moved to the settled state, update settle index and settle + // time. + if preUpdateState != invoice.Terms.State && + invoice.Terms.State == ContractSettled { + + if update.Preimage.Hash() != hash { + return nil, fmt.Errorf("preimage does not match") + } + invoice.Terms.PaymentPreimage = update.Preimage + + // Settle all accepted htlcs. + for _, htlc := range invoice.Htlcs { + if htlc.State != HtlcStateAccepted { + continue + } + + htlc.State = HtlcStateSettled + htlc.ResolveTime = now + } + + err := setSettleFields(settleIndex, invoiceNum, &invoice, now) + if err != nil { + return nil, err + } + } + + var buf bytes.Buffer + if err := serializeInvoice(&buf, &invoice); err != nil { + return nil, err + } + + if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil { + return nil, err + } + + return &invoice, nil +} + +func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte, + invoice *Invoice, now time.Time) error { + + // Now that we know the invoice hasn't already been settled, we'll + // update the settle index so we can place this settle event in the + // proper location within our time series. + nextSettleSeqNo, err := settleIndex.NextSequence() + if err != nil { + return err + } + + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) + if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { + return err + } + + invoice.Terms.State = ContractSettled + invoice.SettleDate = now + invoice.SettleIndex = nextSettleSeqNo + + return nil +} diff --git a/channeldb/migration_01_to_11/legacy_serialization.go b/channeldb/migration_01_to_11/legacy_serialization.go new file mode 100644 index 00000000..5d731bff --- /dev/null +++ b/channeldb/migration_01_to_11/legacy_serialization.go @@ -0,0 +1,55 @@ +package migration_01_to_11 + +import ( + "io" +) + +// deserializeCloseChannelSummaryV6 reads the v6 database format for +// ChannelCloseSummary. +// +// NOTE: deprecated, only for migration. +func deserializeCloseChannelSummaryV6(r io.Reader) (*ChannelCloseSummary, error) { + c := &ChannelCloseSummary{} + + err := ReadElements(r, + &c.ChanPoint, &c.ShortChanID, &c.ChainHash, &c.ClosingTXID, + &c.CloseHeight, &c.RemotePub, &c.Capacity, &c.SettledBalance, + &c.TimeLockedBalance, &c.CloseType, &c.IsPending, + ) + if err != nil { + return nil, err + } + + // We'll now check to see if the channel close summary was encoded with + // any of the additional optional fields. + err = ReadElements(r, &c.RemoteCurrentRevocation) + switch { + case err == io.EOF: + return c, nil + + // If we got a non-eof error, then we know there's an actually issue. + // Otherwise, it may have been the case that this summary didn't have + // the set of optional fields. + case err != nil: + return nil, err + } + + if err := readChanConfig(r, &c.LocalChanConfig); err != nil { + return nil, err + } + + // Finally, we'll attempt to read the next unrevoked commitment point + // for the remote party. If we closed the channel before receiving a + // funding locked message, then this can be nil. As a result, we'll use + // the same technique to read the field, only if there's still data + // left in the buffer. + err = ReadElements(r, &c.RemoteNextRevocation) + if err != nil && err != io.EOF { + // If we got a non-eof error, then we know there's an actually + // issue. Otherwise, it may have been the case that this + // summary didn't have the set of optional fields. + return nil, err + } + + return c, nil +} diff --git a/channeldb/migration_01_to_11/log.go b/channeldb/migration_01_to_11/log.go new file mode 100644 index 00000000..17958b19 --- /dev/null +++ b/channeldb/migration_01_to_11/log.go @@ -0,0 +1,28 @@ +package migration_01_to_11 + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +func init() { + UseLogger(build.NewSubLogger("CHDB", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/channeldb/migration_01_to_11/meta.go b/channeldb/migration_01_to_11/meta.go new file mode 100644 index 00000000..fbe7a0e4 --- /dev/null +++ b/channeldb/migration_01_to_11/meta.go @@ -0,0 +1,78 @@ +package migration_01_to_11 + +import "github.com/coreos/bbolt" + +var ( + // metaBucket stores all the meta information concerning the state of + // the database. + metaBucket = []byte("metadata") + + // dbVersionKey is a boltdb key and it's used for storing/retrieving + // current database version. + dbVersionKey = []byte("dbp") +) + +// Meta structure holds the database meta information. +type Meta struct { + // DbVersionNumber is the current schema version of the database. + DbVersionNumber uint32 +} + +// FetchMeta fetches the meta data from boltdb and returns filled meta +// structure. +func (d *DB) FetchMeta(tx *bbolt.Tx) (*Meta, error) { + meta := &Meta{} + + err := d.View(func(tx *bbolt.Tx) error { + return fetchMeta(meta, tx) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +// fetchMeta is an internal helper function used in order to allow callers to +// re-use a database transaction. See the publicly exported FetchMeta method +// for more information. +func fetchMeta(meta *Meta, tx *bbolt.Tx) error { + metaBucket := tx.Bucket(metaBucket) + if metaBucket == nil { + return ErrMetaNotFound + } + + data := metaBucket.Get(dbVersionKey) + if data == nil { + meta.DbVersionNumber = getLatestDBVersion(dbVersions) + } else { + meta.DbVersionNumber = byteOrder.Uint32(data) + } + + return nil +} + +// PutMeta writes the passed instance of the database met-data struct to disk. +func (d *DB) PutMeta(meta *Meta) error { + return d.Update(func(tx *bbolt.Tx) error { + return putMeta(meta, tx) + }) +} + +// putMeta is an internal helper function used in order to allow callers to +// re-use a database transaction. See the publicly exported PutMeta method for +// more information. +func putMeta(meta *Meta, tx *bbolt.Tx) error { + metaBucket, err := tx.CreateBucketIfNotExists(metaBucket) + if err != nil { + return err + } + + return putDbVersion(metaBucket, meta) +} + +func putDbVersion(metaBucket *bbolt.Bucket, meta *Meta) error { + scratch := make([]byte, 4) + byteOrder.PutUint32(scratch, meta.DbVersionNumber) + return metaBucket.Put(dbVersionKey, scratch) +} diff --git a/channeldb/migration_01_to_11/meta_test.go b/channeldb/migration_01_to_11/meta_test.go new file mode 100644 index 00000000..27e9369c --- /dev/null +++ b/channeldb/migration_01_to_11/meta_test.go @@ -0,0 +1,442 @@ +package migration_01_to_11 + +import ( + "bytes" + "io/ioutil" + "testing" + + "github.com/coreos/bbolt" + "github.com/go-errors/errors" +) + +// applyMigration is a helper test function that encapsulates the general steps +// which are needed to properly check the result of applying migration function. +func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), + migrationFunc migration, shouldFail bool) { + + cdb, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + // Create a test node that will be our source node. + testNode, err := createTestVertex(cdb) + if err != nil { + t.Fatal(err) + } + graph := cdb.ChannelGraph() + if err := graph.SetSourceNode(testNode); err != nil { + t.Fatal(err) + } + + // beforeMigration usually used for populating the database + // with test data. + beforeMigration(cdb) + + // Create test meta info with zero database version and put it on disk. + // Than creating the version list pretending that new version was added. + meta := &Meta{DbVersionNumber: 0} + if err := cdb.PutMeta(meta); err != nil { + t.Fatalf("unable to store meta data: %v", err) + } + + versions := []version{ + { + number: 0, + migration: nil, + }, + { + number: 1, + migration: migrationFunc, + }, + } + + defer func() { + if r := recover(); r != nil { + err = errors.New(r) + } + + if err == nil && shouldFail { + t.Fatal("error wasn't received on migration stage") + } else if err != nil && !shouldFail { + t.Fatalf("error was received on migration stage: %v", err) + } + + // afterMigration usually used for checking the database state and + // throwing the error if something went wrong. + afterMigration(cdb) + }() + + // Sync with the latest version - applying migration function. + err = cdb.syncVersions(versions) + if err != nil { + log.Error(err) + } +} + +// TestVersionFetchPut checks the propernces of fetch/put methods +// and also initialization of meta data in case if don't have any in +// database. +func TestVersionFetchPut(t *testing.T) { + t.Parallel() + + db, cleanUp, err := makeTestDB() + defer cleanUp() + if err != nil { + t.Fatal(err) + } + + meta, err := db.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != getLatestDBVersion(dbVersions) { + t.Fatal("initialization of meta information wasn't performed") + } + + newVersion := getLatestDBVersion(dbVersions) + 1 + meta.DbVersionNumber = newVersion + + if err := db.PutMeta(meta); err != nil { + t.Fatalf("update of meta failed %v", err) + } + + meta, err = db.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != newVersion { + t.Fatal("update of meta information wasn't performed") + } +} + +// TestOrderOfMigrations checks that migrations are applied in proper order. +func TestOrderOfMigrations(t *testing.T) { + t.Parallel() + + appliedMigration := -1 + versions := []version{ + {0, nil}, + {1, nil}, + {2, func(tx *bbolt.Tx) error { + appliedMigration = 2 + return nil + }}, + {3, func(tx *bbolt.Tx) error { + appliedMigration = 3 + return nil + }}, + } + + // Retrieve the migration that should be applied to db, as far as + // current version is 1, we skip zero and first versions. + migrations, _ := getMigrationsToApply(versions, 1) + + if len(migrations) != 2 { + t.Fatal("incorrect number of migrations to apply") + } + + // Apply first migration. + migrations[0](nil) + + // Check that first migration corresponds to the second version. + if appliedMigration != 2 { + t.Fatal("incorrect order of applying migrations") + } + + // Apply second migration. + migrations[1](nil) + + // Check that second migration corresponds to the third version. + if appliedMigration != 3 { + t.Fatal("incorrect order of applying migrations") + } +} + +// TestGlobalVersionList checks that there is no mistake in global version list +// in terms of version ordering. +func TestGlobalVersionList(t *testing.T) { + t.Parallel() + + if dbVersions == nil { + t.Fatal("can't find versions list") + } + + if len(dbVersions) == 0 { + t.Fatal("db versions list is empty") + } + + prev := dbVersions[0].number + for i := 1; i < len(dbVersions); i++ { + version := dbVersions[i].number + + if version == prev { + t.Fatal("duplicates db versions") + } + if version < prev { + t.Fatal("order of db versions is wrong") + } + + prev = version + } +} + +// TestMigrationWithPanic asserts that if migration logic panics, we will return +// to the original state unaltered. +func TestMigrationWithPanic(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + beforeMigrationFunc := func(d *DB) { + // Insert data in database and in order then make sure that the + // key isn't changes in case of panic or fail. + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data and + // throw the panic, in this case we pretending that something goes. + migrationWithPanic := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + panic("panic!") + } + + // Check that version of database and data wasn't changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 0 { + t.Fatal("migration panicked but version is changed") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, beforeMigration) { + return errors.New("migration failed but data is " + + "changed") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithPanic, + true) +} + +// TestMigrationWithFatal asserts that migrations which fail do not modify the +// database. +func TestMigrationWithFatal(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + beforeMigrationFunc := func(d *DB) { + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data and + // return the error, in this case we pretending that something goes + // wrong. + migrationWithFatal := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + return errors.New("some error") + } + + // Check that version of database and initial data wasn't changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 0 { + t.Fatal("migration failed but version is changed") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, beforeMigration) { + return errors.New("migration failed but data is " + + "changed") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithFatal, + true) +} + +// TestMigrationWithoutErrors asserts that a successful migration has its +// changes applied to the database. +func TestMigrationWithoutErrors(t *testing.T) { + t.Parallel() + + bucketPrefix := []byte("somebucket") + keyPrefix := []byte("someprefix") + beforeMigration := []byte("beforemigration") + afterMigration := []byte("aftermigration") + + // Populate database with initial data. + beforeMigrationFunc := func(d *DB) { + d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, beforeMigration) + return nil + }) + } + + // Create migration function which changes the initially created data. + migrationWithoutErrors := func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + bucket.Put(keyPrefix, afterMigration) + return nil + } + + // Check that version of database and data was properly changed. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("version number isn't changed after " + + "successfully applied migration") + } + + err = d.Update(func(tx *bbolt.Tx) error { + bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) + if err != nil { + return err + } + + value := bucket.Get(keyPrefix) + if !bytes.Equal(value, afterMigration) { + return errors.New("migration wasn't applied " + + "properly") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrationWithoutErrors, + false) +} + +// TestMigrationReversion tests after performing a migration to a higher +// database version, opening the database with a lower latest db version returns +// ErrDBReversion. +func TestMigrationReversion(t *testing.T) { + t.Parallel() + + tempDirName, err := ioutil.TempDir("", "channeldb") + if err != nil { + t.Fatalf("unable to create temp dir: %v", err) + } + + cdb, err := Open(tempDirName) + if err != nil { + t.Fatalf("unable to open channeldb: %v", err) + } + + // Update the database metadata to point to one more than the highest + // known version. + err = cdb.Update(func(tx *bbolt.Tx) error { + newMeta := &Meta{ + DbVersionNumber: getLatestDBVersion(dbVersions) + 1, + } + + return putMeta(newMeta, tx) + }) + + // Close the database. Even if we succeeded, our next step is to reopen. + cdb.Close() + + if err != nil { + t.Fatalf("unable to increase db version: %v", err) + } + + _, err = Open(tempDirName) + if err != ErrDBReversion { + t.Fatalf("unexpected error when opening channeldb, "+ + "want: %v, got: %v", ErrDBReversion, err) + } +} diff --git a/channeldb/migration_01_to_11/migration_09_legacy_serialization.go b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go new file mode 100644 index 00000000..cc6614c9 --- /dev/null +++ b/channeldb/migration_01_to_11/migration_09_legacy_serialization.go @@ -0,0 +1,497 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "sort" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // paymentBucket is the name of the bucket within the database that + // stores all data related to payments. + // + // Within the payments bucket, each invoice is keyed by its invoice ID + // which is a monotonically increasing uint64. BoltDB's sequence + // feature is used for generating monotonically increasing id. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentBucket = []byte("payments") + + // paymentStatusBucket is the name of the bucket within the database + // that stores the status of a payment indexed by the payment's + // preimage. + // + // NOTE: Deprecated. Kept around for migration purposes. + paymentStatusBucket = []byte("payment-status") +) + +// outgoingPayment represents a successful payment between the daemon and a +// remote node. Details such as the total fee paid, and the time of the payment +// are stored. +// +// NOTE: Deprecated. Kept around for migration purposes. +type outgoingPayment struct { + Invoice + + // Fee is the total fee paid for the payment in milli-satoshis. + Fee lnwire.MilliSatoshi + + // TotalTimeLock is the total cumulative time-lock in the HTLC extended + // from the second-to-last hop to the destination. + TimeLockLength uint32 + + // Path encodes the path the payment took through the network. The path + // excludes the outgoing node and consists of the hex-encoded + // compressed public key of each of the nodes involved in the payment. + Path [][33]byte + + // PaymentPreimage is the preImage of a successful payment. This is used + // to calculate the PaymentHash as well as serve as a proof of payment. + PaymentPreimage [32]byte +} + +// addPayment saves a successful payment to the database. It is assumed that +// all payment are sent using unique payment hashes. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) addPayment(payment *outgoingPayment) error { + // Validate the field of the inner voice within the outgoing payment, + // these must also adhere to the same constraints as regular invoices. + if err := validateInvoice(&payment.Invoice); err != nil { + return err + } + + // We first serialize the payment before starting the database + // transaction so we can avoid creating a DB payment in the case of a + // serialization error. + var b bytes.Buffer + if err := serializeOutgoingPayment(&b, payment); err != nil { + return err + } + paymentBytes := b.Bytes() + + return db.Batch(func(tx *bbolt.Tx) error { + payments, err := tx.CreateBucketIfNotExists(paymentBucket) + if err != nil { + return err + } + + // Obtain the new unique sequence number for this payment. + paymentID, err := payments.NextSequence() + if err != nil { + return err + } + + // We use BigEndian for keys as it orders keys in + // ascending order. This allows bucket scans to order payments + // in the order in which they were created. + paymentIDBytes := make([]byte, 8) + binary.BigEndian.PutUint64(paymentIDBytes, paymentID) + + return payments.Put(paymentIDBytes, paymentBytes) + }) +} + +// fetchAllPayments returns all outgoing payments in DB. +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { + var payments []*outgoingPayment + + err := db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return ErrNoPaymentsCreated + } + + return bucket.ForEach(func(k, v []byte) error { + // If the value is nil, then we ignore it as it may be + // a sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + payments = append(payments, payment) + return nil + }) + }) + if err != nil { + return nil, err + } + + return payments, nil +} + +// fetchPaymentStatus returns the payment status for outgoing payment. +// If status of the payment isn't found, it will default to "StatusUnknown". +// +// NOTE: Deprecated. Kept around for migration purposes. +func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + var err error + paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) + return err + }) + if err != nil { + return StatusUnknown, err + } + + return paymentStatus, nil +} + +// fetchPaymentStatusTx is a helper method that returns the payment status for +// outgoing payment. If status of the payment isn't found, it will default to +// "StatusUnknown". It accepts the boltdb transactions such that this method +// can be composed into other atomic operations. +// +// NOTE: Deprecated. Kept around for migration purposes. +func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { + // The default status for all payments that aren't recorded in database. + var paymentStatus = StatusUnknown + + bucket := tx.Bucket(paymentStatusBucket) + if bucket == nil { + return paymentStatus, nil + } + + paymentStatusBytes := bucket.Get(paymentHash[:]) + if paymentStatusBytes == nil { + return paymentStatus, nil + } + + paymentStatus.FromBytes(paymentStatusBytes) + + return paymentStatus, nil +} + +func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error { + var scratch [8]byte + + if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(p.Fee)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + // First write out the length of the bytes to prefix the value. + pathLen := uint32(len(p.Path)) + byteOrder.PutUint32(scratch[:4], pathLen) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + // Then with the path written, we write out the series of public keys + // involved in the path. + for _, hop := range p.Path { + if _, err := w.Write(hop[:]); err != nil { + return err + } + } + + byteOrder.PutUint32(scratch[:4], p.TimeLockLength) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + if _, err := w.Write(p.PaymentPreimage[:]); err != nil { + return err + } + + return nil +} + +func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) { + var scratch [8]byte + + p := &outgoingPayment{} + + inv, err := deserializeInvoiceLegacy(r) + if err != nil { + return nil, err + } + p.Invoice = inv + + if _, err := r.Read(scratch[:]); err != nil { + return nil, err + } + p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + pathLen := byteOrder.Uint32(scratch[:4]) + + path := make([][33]byte, pathLen) + for i := uint32(0); i < pathLen; i++ { + if _, err := r.Read(path[i][:]); err != nil { + return nil, err + } + } + p.Path = path + + if _, err = r.Read(scratch[:4]); err != nil { + return nil, err + } + p.TimeLockLength = byteOrder.Uint32(scratch[:4]) + + if _, err := r.Read(p.PaymentPreimage[:]); err != nil { + return nil, err + } + + return p, nil +} + +// serializePaymentAttemptInfoMigration9 is the serializePaymentAttemptInfo +// version as existed when migration #9 was created. We keep this around, along +// with the methods below to ensure that clients that upgrade will use the +// correct version of this method. +func serializePaymentAttemptInfoMigration9(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := serializeRouteMigration9(w, a.Route); err != nil { + return err + } + + return nil +} + +func serializeHopMigration9(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + return nil +} + +func serializeRouteMigration9(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHopMigration9(w, h); err != nil { + return err + } + } + + return nil +} + +func deserializePaymentAttemptInfoMigration9(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = deserializeRouteMigration9(r) + if err != nil { + return nil, err + } + return a, nil +} + +func deserializeRouteMigration9(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHopMigration9(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} + +func deserializeHopMigration9(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + return h, nil +} + +// fetchPaymentsMigration9 returns all sent payments found in the DB using the +// payment attempt info format that was present as of migration #9. We need +// this as otherwise, the current FetchPayments version will use the latest +// decoding format. Note that we only need this for the +// TestOutgoingPaymentsMigration migration test case. +func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) { + var payments []*Payment + + err := db.View(func(tx *bbolt.Tx) error { + paymentsBucket := tx.Bucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + return paymentsBucket.ForEach(func(k, v []byte) error { + bucket := paymentsBucket.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + p, err := fetchPaymentMigration9(bucket) + if err != nil { + return err + } + + payments = append(payments, p) + + // For older versions of lnd, duplicate payments to a + // payment has was possible. These will be found in a + // sub-bucket indexed by their sequence number if + // available. + dup := bucket.Bucket(paymentDuplicateBucket) + if dup == nil { + return nil + } + + return dup.ForEach(func(k, v []byte) error { + subBucket := dup.Bucket(k) + if subBucket == nil { + // We one bucket for each duplicate to + // be found. + return fmt.Errorf("non bucket element" + + "in duplicate bucket") + } + + p, err := fetchPaymentMigration9(subBucket) + if err != nil { + return err + } + + payments = append(payments, p) + return nil + }) + }) + }) + if err != nil { + return nil, err + } + + // Before returning, sort the payments by their sequence number. + sort.Slice(payments, func(i, j int) bool { + return payments[i].sequenceNum < payments[j].sequenceNum + }) + + return payments, nil +} + +func fetchPaymentMigration9(bucket *bbolt.Bucket) (*Payment, error) { + var ( + err error + p = &Payment{} + ) + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + p.sequenceNum = binary.BigEndian.Uint64(seqBytes) + + // Get the payment status. + p.Status = fetchPaymentStatus(bucket) + + // Get the PaymentCreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return nil, fmt.Errorf("creation info not found") + } + + r := bytes.NewReader(b) + p.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return nil, err + + } + + // Get the PaymentAttemptInfo. This can be unset. + b = bucket.Get(paymentAttemptInfoKey) + if b != nil { + r = bytes.NewReader(b) + p.Attempt, err = deserializePaymentAttemptInfoMigration9(r) + if err != nil { + return nil, err + } + } + + // Get the payment preimage. This is only found for + // completed payments. + b = bucket.Get(paymentSettleInfoKey) + if b != nil { + var preimg lntypes.Preimage + copy(preimg[:], b[:]) + p.PaymentPreimage = &preimg + } + + // Get failure reason if available. + b = bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + p.Failure = &reason + } + + return p, nil +} diff --git a/channeldb/migration_01_to_11/migration_10_route_tlv_records.go b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go new file mode 100644 index 00000000..a8478cda --- /dev/null +++ b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go @@ -0,0 +1,236 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/routing/route" +) + +// migrateRouteSerialization migrates the way we serialize routes across the +// entire database. At the time of writing of this migration, this includes our +// payment attempts, as well as the payment results in mission control. +func migrateRouteSerialization(tx *bbolt.Tx) error { + // First, we'll do all the payment attempts. + rootPaymentBucket := tx.Bucket(paymentsRootBucket) + if rootPaymentBucket == nil { + return nil + } + + // As we can't mutate a bucket while we're iterating over it with + // ForEach, we'll need to collect all the known payment hashes in + // memory first. + var payHashes [][]byte + err := rootPaymentBucket.ForEach(func(k, v []byte) error { + if v != nil { + return nil + } + + payHashes = append(payHashes, k) + return nil + }) + if err != nil { + return err + } + + // Now that we have all the payment hashes, we can carry out the + // migration itself. + for _, payHash := range payHashes { + payHashBucket := rootPaymentBucket.Bucket(payHash) + + // First, we'll migrate the main (non duplicate) payment to + // this hash. + err := migrateAttemptEncoding(tx, payHashBucket) + if err != nil { + return err + } + + // Now that we've migrated the main payment, we'll also check + // for any duplicate payments to the same payment hash. + dupBucket := payHashBucket.Bucket(paymentDuplicateBucket) + + // If there's no dup bucket, then we can move on to the next + // payment. + if dupBucket == nil { + continue + } + + // Otherwise, we'll now iterate through all the duplicate pay + // hashes and migrate those. + var dupSeqNos [][]byte + err = dupBucket.ForEach(func(k, v []byte) error { + dupSeqNos = append(dupSeqNos, k) + return nil + }) + if err != nil { + return err + } + + // Now in this second pass, we'll re-serialize their duplicate + // payment attempts under the new encoding. + for _, seqNo := range dupSeqNos { + dupPayHashBucket := dupBucket.Bucket(seqNo) + err := migrateAttemptEncoding(tx, dupPayHashBucket) + if err != nil { + return err + } + } + } + + log.Infof("Migration of route/hop serialization complete!") + + log.Infof("Migrating to new mission control store by clearing " + + "existing data") + + resultsKey := []byte("missioncontrol-results") + err = tx.DeleteBucket(resultsKey) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + log.Infof("Migration to new mission control completed!") + + return nil +} + +// migrateAttemptEncoding migrates payment attempts using the legacy format to +// the new format. +func migrateAttemptEncoding(tx *bbolt.Tx, payHashBucket *bbolt.Bucket) error { + payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey) + if payAttemptBytes == nil { + return nil + } + + // For our migration, we'll first read out the existing payment attempt + // using the legacy serialization of the attempt. + payAttemptReader := bytes.NewReader(payAttemptBytes) + payAttempt, err := deserializePaymentAttemptInfoLegacy( + payAttemptReader, + ) + if err != nil { + return err + } + + // Now that we have the old attempts, we'll explicitly mark this as + // needing a legacy payload, since after this migration, the modern + // payload will be the default if signalled. + for _, hop := range payAttempt.Route.Hops { + hop.LegacyPayload = true + } + + // Finally, we'll write out the payment attempt using the new encoding. + var b bytes.Buffer + err = serializePaymentAttemptInfo(&b, payAttempt) + if err != nil { + return err + } + + return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes()) +} + +func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = deserializeRouteLegacy(r) + if err != nil { + return nil, err + } + return a, nil +} + +func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := serializeRouteLegacy(w, a.Route); err != nil { + return err + } + + return nil +} + +func deserializeHopLegacy(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + return h, nil +} + +func serializeHopLegacy(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + return nil +} + +func deserializeRouteLegacy(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHopLegacy(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} + +func serializeRouteLegacy(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHopLegacy(w, h); err != nil { + return err + } + } + + return nil +} diff --git a/channeldb/migration_01_to_11/migration_11_invoices.go b/channeldb/migration_01_to_11/migration_11_invoices.go new file mode 100644 index 00000000..1ae969be --- /dev/null +++ b/channeldb/migration_01_to_11/migration_11_invoices.go @@ -0,0 +1,230 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +// migrateInvoices adds invoice htlcs and a separate cltv delta field to the +// invoices. +func migrateInvoices(tx *bbolt.Tx) error { + log.Infof("Migrating invoices to new invoice format") + + invoiceB := tx.Bucket(invoiceBucket) + if invoiceB == nil { + return nil + } + + // Iterate through the entire key space of the top-level invoice bucket. + // If key with a non-nil value stores the next invoice ID which maps to + // the corresponding invoice. Store those keys first, because it isn't + // safe to modify the bucket inside a ForEach loop. + var invoiceKeys [][]byte + err := invoiceB.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + invoiceKeys = append(invoiceKeys, k) + + return nil + }) + if err != nil { + return err + } + + nets := []*bitcoinCfg.Params{ + &bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams, + &bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params, + } + + ltcNets := []*litecoinCfg.Params{ + &litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams, + &litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params, + } + for _, net := range ltcNets { + var convertedNet bitcoinCfg.Params + convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit + nets = append(nets, &convertedNet) + } + + // Iterate over all stored keys and migrate the invoices. + for _, k := range invoiceKeys { + v := invoiceB.Get(k) + + // Deserialize the invoice with the deserializing function that + // was in use for this version of the database. + invoiceReader := bytes.NewReader(v) + invoice, err := deserializeInvoiceLegacy(invoiceReader) + if err != nil { + return err + } + + if invoice.Terms.State == ContractAccepted { + return fmt.Errorf("cannot upgrade with invoice(s) " + + "in accepted state, see release notes") + } + + // Try to decode the payment request for every possible net to + // avoid passing a the active network to channeldb. This would + // be a layering violation, while this migration is only running + // once and will likely be removed in the future. + var payReq *zpay32.Invoice + for _, net := range nets { + payReq, err = zpay32.Decode( + string(invoice.PaymentRequest), net, + ) + if err == nil { + break + } + } + if payReq == nil { + return fmt.Errorf("cannot decode payreq") + } + invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry()) + invoice.Expiry = payReq.Expiry() + + // Serialize the invoice in the new format and use it to replace + // the old invoice in the database. + var buf bytes.Buffer + if err := serializeInvoice(&buf, &invoice); err != nil { + return err + } + + err = invoiceB.Put(k, buf.Bytes()) + if err != nil { + return err + } + } + + log.Infof("Migration of invoices completed!") + return nil +} + +func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) { + var err error + invoice := Invoice{} + + // TODO(roasbeef): use read full everywhere + invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") + if err != nil { + return invoice, err + } + invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") + if err != nil { + return invoice, err + } + + invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") + if err != nil { + return invoice, err + } + + birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") + if err != nil { + return invoice, err + } + if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { + return invoice, err + } + + settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") + if err != nil { + return invoice, err + } + if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { + return invoice, err + } + + if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { + return invoice, err + } + var scratch [8]byte + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return invoice, err + } + invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { + return invoice, err + } + + if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { + return invoice, err + } + if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { + return invoice, err + } + + return invoice, nil +} + +// serializeInvoiceLegacy serializes an invoice in the format of the previous db +// version. +func serializeInvoiceLegacy(w io.Writer, i *Invoice) error { + if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { + return err + } + if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { + return err + } + + birthBytes, err := i.CreationDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { + return err + } + + settleBytes, err := i.SettleDate.MarshalBinary() + if err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { + return err + } + + if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { + return err + } + + var scratch [8]byte + byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { + return err + } + if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { + return err + } + + return nil +} diff --git a/channeldb/migration_01_to_11/migration_11_invoices_test.go b/channeldb/migration_01_to_11/migration_11_invoices_test.go new file mode 100644 index 00000000..9c0c877a --- /dev/null +++ b/channeldb/migration_01_to_11/migration_11_invoices_test.go @@ -0,0 +1,193 @@ +package migration_01_to_11 + +import ( + "bytes" + "fmt" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + bitcoinCfg "github.com/btcsuite/btcd/chaincfg" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/zpay32" + litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" +) + +var ( + testPrivKeyBytes = []byte{ + 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, + 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, + 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, + 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, + } + + testCltvDelta = int32(50) +) + +// beforeMigrationFuncV11 insert the test invoices in the database. +func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) { + err := d.Update(func(tx *bbolt.Tx) error { + invoicesBucket, err := tx.CreateBucketIfNotExists( + invoiceBucket, + ) + if err != nil { + return err + } + + invoiceNum := uint32(1) + for _, invoice := range invoices { + var invoiceKey [4]byte + byteOrder.PutUint32(invoiceKey[:], invoiceNum) + invoiceNum++ + + var buf bytes.Buffer + err := serializeInvoiceLegacy(&buf, &invoice) // nolint:scopelint + if err != nil { + return err + } + + err = invoicesBucket.Put( + invoiceKey[:], buf.Bytes(), + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +// TestMigrateInvoices checks that invoices are migrated correctly. +func TestMigrateInvoices(t *testing.T) { + t.Parallel() + + payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + var ltcNetParams bitcoinCfg.Params + ltcNetParams.Bech32HRPSegwit = litecoinCfg.MainNetParams.Bech32HRPSegwit + payReqLtc, err := getPayReq(<cNetParams) + if err != nil { + t.Fatal(err) + } + + invoices := []Invoice{ + { + PaymentRequest: []byte(payReqBtc), + }, + { + PaymentRequest: []byte(payReqLtc), + }, + } + + // Verify that all invoices were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'invoices' wasn't applied") + } + + dbInvoices, err := d.FetchAllInvoices(false) + if err != nil { + t.Fatalf("unable to fetch invoices: %v", err) + } + + if len(invoices) != len(dbInvoices) { + t.Fatalf("expected %d invoices, got %d", len(invoices), + len(dbInvoices)) + } + + for _, dbInvoice := range dbInvoices { + if dbInvoice.FinalCltvDelta != testCltvDelta { + t.Fatal("incorrect final cltv delta") + } + if dbInvoice.Expiry != 3600*time.Second { + t.Fatal("incorrect expiry") + } + if len(dbInvoice.Htlcs) != 0 { + t.Fatal("expected no htlcs after migration") + } + } + } + + applyMigration(t, + func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, + afterMigrationFunc, + migrateInvoices, + false) +} + +// TestMigrateInvoicesHodl checks that a hodl invoice in the accepted state +// fails the migration. +func TestMigrateInvoicesHodl(t *testing.T) { + t.Parallel() + + payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) + if err != nil { + t.Fatal(err) + } + + invoices := []Invoice{ + { + PaymentRequest: []byte(payReqBtc), + Terms: ContractTerm{ + State: ContractAccepted, + }, + }, + } + + applyMigration(t, + func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, + func(d *DB) {}, + migrateInvoices, + true) +} + +// signDigestCompact generates a test signature to be used in the generation of +// test payment requests. +func signDigestCompact(hash []byte) ([]byte, error) { + // Should the signature reference a compressed public key or not. + isCompressedKey := true + + privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes) + + // btcec.SignCompact returns a pubkey-recoverable signature + sig, err := btcec.SignCompact( + btcec.S256(), privKey, hash, isCompressedKey, + ) + if err != nil { + return nil, fmt.Errorf("can't sign the hash: %v", err) + } + + return sig, nil +} + +// getPayReq creates a payment request for the given net. +func getPayReq(net *bitcoinCfg.Params) (string, error) { + options := []func(*zpay32.Invoice){ + zpay32.CLTVExpiry(uint64(testCltvDelta)), + zpay32.Description("test"), + } + + payReq, err := zpay32.NewInvoice( + net, [32]byte{}, time.Unix(1, 0), options..., + ) + if err != nil { + return "", err + } + return payReq.Encode( + zpay32.MessageSigner{ + SignCompact: signDigestCompact, + }, + ) +} diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go new file mode 100644 index 00000000..3e296d02 --- /dev/null +++ b/channeldb/migration_01_to_11/migrations.go @@ -0,0 +1,939 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + + "github.com/btcsuite/btcd/btcec" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// migrateNodeAndEdgeUpdateIndex is a migration function that will update the +// database from version 0 to version 1. In version 1, we add two new indexes +// (one for nodes and one for edges) to keep track of the last time a node or +// edge was updated on the network. These new indexes allow us to implement the +// new graph sync protocol added. +func migrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { + // First, we'll populating the node portion of the new index. Before we + // can add new values to the index, we'll first create the new bucket + // where these items will be housed. + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return fmt.Errorf("unable to create node bucket: %v", err) + } + nodeUpdateIndex, err := nodes.CreateBucketIfNotExists( + nodeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create node update index: %v", err) + } + + log.Infof("Populating new node update index bucket") + + // Now that we know the bucket has been created, we'll iterate over the + // entire node bucket so we can add the (updateTime || nodePub) key + // into the node update index. + err = nodes.ForEach(func(nodePub, nodeInfo []byte) error { + if len(nodePub) != 33 { + return nil + } + + log.Tracef("Adding %x to node update index", nodePub) + + // The first 8 bytes of a node's serialize data is the update + // time, so we can extract that without decoding the entire + // structure. + updateTime := nodeInfo[:8] + + // Now that we have the update time, we can construct the key + // to insert into the index. + var indexKey [8 + 33]byte + copy(indexKey[:8], updateTime) + copy(indexKey[8:], nodePub) + + return nodeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update node indexes: %v", err) + } + + log.Infof("Populating new edge update index bucket") + + // With the set of nodes updated, we'll now update all edges to have a + // corresponding entry in the edge update index. + edges, err := tx.CreateBucketIfNotExists(edgeBucket) + if err != nil { + return fmt.Errorf("unable to create edge bucket: %v", err) + } + edgeUpdateIndex, err := edges.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) + if err != nil { + return fmt.Errorf("unable to create edge update index: %v", err) + } + + // We'll now run through each edge policy in the database, and update + // the index to ensure each edge has the proper record. + err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { + if len(edgeKey) != 41 { + return nil + } + + // Now that we know this is the proper record, we'll grab the + // channel ID (last 8 bytes of the key), and then decode the + // edge policy so we can access the update time. + chanID := edgeKey[33:] + edgePolicyReader := bytes.NewReader(edgePolicyBytes) + + edgePolicy, err := deserializeChanEdgePolicy( + edgePolicyReader, nodes, + ) + if err != nil { + return err + } + + log.Tracef("Adding chan_id=%v to edge update index", + edgePolicy.ChannelID) + + // We'll now construct the index key using the channel ID, and + // the last time it was updated: (updateTime || chanID). + var indexKey [8 + 8]byte + byteOrder.PutUint64( + indexKey[:], uint64(edgePolicy.LastUpdate.Unix()), + ) + copy(indexKey[8:], chanID) + + return edgeUpdateIndex.Put(indexKey[:], nil) + }) + if err != nil { + return fmt.Errorf("unable to update edge indexes: %v", err) + } + + log.Infof("Migration to node and edge update indexes complete!") + + return nil +} + +// migrateInvoiceTimeSeries is a database migration that assigns all existing +// invoices an index in the add and/or the settle index. Additionally, all +// existing invoices will have their bytes padded out in order to encode the +// add+settle index as well as the amount paid. +func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { + invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) + if err != nil { + return err + } + + addIndex, err := invoices.CreateBucketIfNotExists( + addIndexBucket, + ) + if err != nil { + return err + } + settleIndex, err := invoices.CreateBucketIfNotExists( + settleIndexBucket, + ) + if err != nil { + return err + } + + log.Infof("Migrating invoice database to new time series format") + + // Now that we have all the buckets we need, we'll run through each + // invoice in the database, and update it to reflect the new format + // expected post migration. + // NOTE: we store the converted invoices and put them back into the + // database after the loop, since modifying the bucket within the + // ForEach loop is not safe. + var invoicesKeys [][]byte + var invoicesValues [][]byte + err = invoices.ForEach(func(invoiceNum, invoiceBytes []byte) error { + // If this is a sub bucket, then we'll skip it. + if invoiceBytes == nil { + return nil + } + + // First, we'll make a copy of the encoded invoice bytes. + invoiceBytesCopy := make([]byte, len(invoiceBytes)) + copy(invoiceBytesCopy, invoiceBytes) + + // With the bytes copied over, we'll append 24 additional + // bytes. We do this so we can decode the invoice under the new + // serialization format. + padding := bytes.Repeat([]byte{0}, 24) + invoiceBytesCopy = append(invoiceBytesCopy, padding...) + + invoiceReader := bytes.NewReader(invoiceBytesCopy) + invoice, err := deserializeInvoiceLegacy(invoiceReader) + if err != nil { + return fmt.Errorf("unable to decode invoice: %v", err) + } + + // Now that we have the fully decoded invoice, we can update + // the various indexes that we're added, and finally the + // invoice itself before re-inserting it. + + // First, we'll get the new sequence in the addIndex in order + // to create the proper mapping. + nextAddSeqNo, err := addIndex.NextSequence() + if err != nil { + return err + } + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) + err = addIndex.Put(seqNoBytes[:], invoiceNum[:]) + if err != nil { + return err + } + + log.Tracef("Adding invoice (preimage=%x, add_index=%v) to add "+ + "time series", invoice.Terms.PaymentPreimage[:], + nextAddSeqNo) + + // Next, we'll check if the invoice has been settled or not. If + // so, then we'll also add it to the settle index. + var nextSettleSeqNo uint64 + if invoice.Terms.State == ContractSettled { + nextSettleSeqNo, err = settleIndex.NextSequence() + if err != nil { + return err + } + + var seqNoBytes [8]byte + byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) + err := settleIndex.Put(seqNoBytes[:], invoiceNum) + if err != nil { + return err + } + + invoice.AmtPaid = invoice.Terms.Value + + log.Tracef("Adding invoice (preimage=%x, "+ + "settle_index=%v) to add time series", + invoice.Terms.PaymentPreimage[:], + nextSettleSeqNo) + } + + // Finally, we'll update the invoice itself with the new + // indexing information as well as the amount paid if it has + // been settled or not. + invoice.AddIndex = nextAddSeqNo + invoice.SettleIndex = nextSettleSeqNo + + // We've fully migrated an invoice, so we'll now update the + // invoice in-place. + var b bytes.Buffer + if err := serializeInvoiceLegacy(&b, &invoice); err != nil { + return err + } + + // Save the key and value pending update for after the ForEach + // is done. + invoicesKeys = append(invoicesKeys, invoiceNum) + invoicesValues = append(invoicesValues, b.Bytes()) + return nil + }) + if err != nil { + return err + } + + // Now put the converted invoices into the DB. + for i := range invoicesKeys { + key := invoicesKeys[i] + value := invoicesValues[i] + if err := invoices.Put(key, value); err != nil { + return err + } + } + + log.Infof("Migration to invoice time series index complete!") + + return nil +} + +// migrateInvoiceTimeSeriesOutgoingPayments is a follow up to the +// migrateInvoiceTimeSeries migration. As at the time of writing, the +// OutgoingPayment struct embeddeds an instance of the Invoice struct. As a +// result, we also need to migrate the internal invoice to the new format. +func migrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { + payBucket := tx.Bucket(paymentBucket) + if payBucket == nil { + return nil + } + + log.Infof("Migrating invoice database to new outgoing payment format") + + // We store the keys and values we want to modify since it is not safe + // to modify them directly within the ForEach loop. + var paymentKeys [][]byte + var paymentValues [][]byte + err := payBucket.ForEach(func(payID, paymentBytes []byte) error { + log.Tracef("Migrating payment %x", payID[:]) + + // The internal invoices for each payment only contain a + // populated contract term, and creation date, as a result, + // most of the bytes will be "empty". + + // We'll calculate the end of the invoice index assuming a + // "minimal" index that's embedded within the greater + // OutgoingPayment. The breakdown is: + // 3 bytes empty var bytes, 16 bytes creation date, 16 bytes + // settled date, 32 bytes payment pre-image, 8 bytes value, 1 + // byte settled. + endOfInvoiceIndex := 1 + 1 + 1 + 16 + 16 + 32 + 8 + 1 + + // We'll now extract the prefix of the pure invoice embedded + // within. + invoiceBytes := paymentBytes[:endOfInvoiceIndex] + + // With the prefix extracted, we'll copy over the invoice, and + // also add padding for the new 24 bytes of fields, and finally + // append the remainder of the outgoing payment. + paymentCopy := make([]byte, len(invoiceBytes)) + copy(paymentCopy[:], invoiceBytes) + + padding := bytes.Repeat([]byte{0}, 24) + paymentCopy = append(paymentCopy, padding...) + paymentCopy = append( + paymentCopy, paymentBytes[endOfInvoiceIndex:]..., + ) + + // At this point, we now have the new format of the outgoing + // payments, we'll attempt to deserialize it to ensure the + // bytes are properly formatted. + paymentReader := bytes.NewReader(paymentCopy) + _, err := deserializeOutgoingPayment(paymentReader) + if err != nil { + return fmt.Errorf("unable to deserialize payment: %v", err) + } + + // Now that we know the modifications was successful, we'll + // store it to our slice of keys and values, and write it back + // to disk in the new format after the ForEach loop is over. + paymentKeys = append(paymentKeys, payID) + paymentValues = append(paymentValues, paymentCopy) + return nil + }) + if err != nil { + return err + } + + // Finally store the updated payments to the bucket. + for i := range paymentKeys { + key := paymentKeys[i] + value := paymentValues[i] + if err := payBucket.Put(key, value); err != nil { + return err + } + } + + log.Infof("Migration to outgoing payment invoices complete!") + + return nil +} + +// migrateEdgePolicies is a migration function that will update the edges +// bucket. It ensure that edges with unknown policies will also have an entry +// in the bucket. After the migration, there will be two edge entries for +// every channel, regardless of whether the policies are known. +func migrateEdgePolicies(tx *bbolt.Tx) error { + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return nil + } + + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + + edgeIndex := edges.Bucket(edgeIndexBucket) + if edgeIndex == nil { + return nil + } + + // checkKey gets the policy from the database with a low-level call + // so that it is still possible to distinguish between unknown and + // not present. + checkKey := func(channelId uint64, keyBytes []byte) error { + var channelID [8]byte + byteOrder.PutUint64(channelID[:], channelId) + + _, err := fetchChanEdgePolicy(edges, + channelID[:], keyBytes, nodes) + + if err == ErrEdgeNotFound { + log.Tracef("Adding unknown edge policy present for node %x, channel %v", + keyBytes, channelId) + + err := putChanEdgePolicyUnknown(edges, channelId, keyBytes) + if err != nil { + return err + } + + return nil + } + + return err + } + + // Iterate over all channels and check both edge policies. + err := edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { + infoReader := bytes.NewReader(edgeInfoBytes) + edgeInfo, err := deserializeChanEdgeInfo(infoReader) + if err != nil { + return err + } + + for _, key := range [][]byte{edgeInfo.NodeKey1Bytes[:], + edgeInfo.NodeKey2Bytes[:]} { + + if err := checkKey(edgeInfo.ChannelID, key); err != nil { + return err + } + } + + return nil + }) + + if err != nil { + return fmt.Errorf("unable to update edge policies: %v", err) + } + + log.Infof("Migration of edge policies complete!") + + return nil +} + +// paymentStatusesMigration is a database migration intended for adding payment +// statuses for each existing payment entity in bucket to be able control +// transitions of statuses and prevent cases such as double payment +func paymentStatusesMigration(tx *bbolt.Tx) error { + // Get the bucket dedicated to storing statuses of payments, + // where a key is payment hash, value is payment status. + paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) + if err != nil { + return err + } + + log.Infof("Migrating database to support payment statuses") + + circuitAddKey := []byte("circuit-adds") + circuits := tx.Bucket(circuitAddKey) + if circuits != nil { + log.Infof("Marking all known circuits with status InFlight") + + err = circuits.ForEach(func(k, v []byte) error { + // Parse the first 8 bytes as the short chan ID for the + // circuit. We'll skip all short chan IDs are not + // locally initiated, which includes all non-zero short + // chan ids. + chanID := binary.BigEndian.Uint64(k[:8]) + if chanID != 0 { + return nil + } + + // The payment hash is the third item in the serialized + // payment circuit. The first two items are an AddRef + // (10 bytes) and the incoming circuit key (16 bytes). + const payHashOffset = 10 + 16 + + paymentHash := v[payHashOffset : payHashOffset+32] + + return paymentStatuses.Put( + paymentHash[:], StatusInFlight.Bytes(), + ) + }) + if err != nil { + return err + } + } + + log.Infof("Marking all existing payments with status Completed") + + // Get the bucket dedicated to storing payments + bucket := tx.Bucket(paymentBucket) + if bucket == nil { + return nil + } + + // For each payment in the bucket, deserialize the payment and mark it + // as completed. + err = bucket.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash for current payment. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Update status for current payment to completed. If it fails, + // the migration is aborted and the payment bucket is returned + // to its previous state. + return paymentStatuses.Put(paymentHash[:], StatusSucceeded.Bytes()) + }) + if err != nil { + return err + } + + log.Infof("Migration of payment statuses complete!") + + return nil +} + +// migratePruneEdgeUpdateIndex is a database migration that attempts to resolve +// some lingering bugs with regards to edge policies and their update index. +// Stale entries within the edge update index were not being properly pruned due +// to a miscalculation on the offset of an edge's policy last update. This +// migration also fixes the case where the public keys within edge policies were +// being serialized with an extra byte, causing an even greater error when +// attempting to perform the offset calculation described earlier. +func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { + // To begin the migration, we'll retrieve the update index bucket. If it + // does not exist, we have nothing left to do so we can simply exit. + edges := tx.Bucket(edgeBucket) + if edges == nil { + return nil + } + edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) + if edgeUpdateIndex == nil { + return nil + } + + // Retrieve some buckets that will be needed later on. These should + // already exist given the assumption that the buckets above do as + // well. + edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) + if err != nil { + return fmt.Errorf("error creating edge index bucket: %s", err) + } + if edgeIndex == nil { + return fmt.Errorf("unable to create/fetch edge index " + + "bucket") + } + nodes, err := tx.CreateBucketIfNotExists(nodeBucket) + if err != nil { + return fmt.Errorf("unable to make node bucket") + } + + log.Info("Migrating database to properly prune edge update index") + + // We'll need to properly prune all the outdated entries within the edge + // update index. To do so, we'll gather all of the existing policies + // within the graph to re-populate them later on. + var edgeKeys [][]byte + err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { + // All valid entries are indexed by a public key (33 bytes) + // followed by a channel ID (8 bytes), so we'll skip any entries + // with keys that do not match this. + if len(edgeKey) != 33+8 { + return nil + } + + edgeKeys = append(edgeKeys, edgeKey) + + return nil + }) + if err != nil { + return fmt.Errorf("unable to gather existing edge policies: %v", + err) + } + + log.Info("Constructing set of edge update entries to purge.") + + // Build the set of keys that we will remove from the edge update index. + // This will include all keys contained within the bucket. + var updateKeysToRemove [][]byte + err = edgeUpdateIndex.ForEach(func(updKey, _ []byte) error { + updateKeysToRemove = append(updateKeysToRemove, updKey) + return nil + }) + if err != nil { + return fmt.Errorf("unable to gather existing edge updates: %v", + err) + } + + log.Infof("Removing %d entries from edge update index.", + len(updateKeysToRemove)) + + // With the set of keys contained in the edge update index constructed, + // we'll proceed in purging all of them from the index. + for _, updKey := range updateKeysToRemove { + if err := edgeUpdateIndex.Delete(updKey); err != nil { + return err + } + } + + log.Infof("Repopulating edge update index with %d valid entries.", + len(edgeKeys)) + + // For each edge key, we'll retrieve the policy, deserialize it, and + // re-add it to the different buckets. By doing so, we'll ensure that + // all existing edge policies are serialized correctly within their + // respective buckets and that the correct entries are populated within + // the edge update index. + for _, edgeKey := range edgeKeys { + edgePolicyBytes := edges.Get(edgeKey) + + // Skip any entries with unknown policies as there will not be + // any entries for them in the edge update index. + if bytes.Equal(edgePolicyBytes[:], unknownPolicy) { + continue + } + + edgePolicy, err := deserializeChanEdgePolicy( + bytes.NewReader(edgePolicyBytes), nodes, + ) + if err != nil { + return err + } + + _, err = updateEdgePolicy(tx, edgePolicy) + if err != nil { + return err + } + } + + log.Info("Migration to properly prune edge update index complete!") + + return nil +} + +// migrateOptionalChannelCloseSummaryFields migrates the serialized format of +// ChannelCloseSummary to a format where optional fields' presence is indicated +// with boolean markers. +func migrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { + closedChanBucket := tx.Bucket(closedChannelBucket) + if closedChanBucket == nil { + return nil + } + + log.Info("Migrating to new closed channel format...") + + // We store the converted keys and values and put them back into the + // database after the loop, since modifying the bucket within the + // ForEach loop is not safe. + var closedChansKeys [][]byte + var closedChansValues [][]byte + err := closedChanBucket.ForEach(func(chanID, summary []byte) error { + r := bytes.NewReader(summary) + + // Read the old (v6) format from the database. + c, err := deserializeCloseChannelSummaryV6(r) + if err != nil { + return err + } + + // Serialize using the new format, and put back into the + // bucket. + var b bytes.Buffer + if err := serializeChannelCloseSummary(&b, c); err != nil { + return err + } + + // Now that we know the modifications was successful, we'll + // Store the key and value to our slices, and write it back to + // disk in the new format after the ForEach loop is over. + closedChansKeys = append(closedChansKeys, chanID) + closedChansValues = append(closedChansValues, b.Bytes()) + return nil + }) + if err != nil { + return fmt.Errorf("unable to update closed channels: %v", err) + } + + // Now put the new format back into the DB. + for i := range closedChansKeys { + key := closedChansKeys[i] + value := closedChansValues[i] + if err := closedChanBucket.Put(key, value); err != nil { + return err + } + } + + log.Info("Migration to new closed channel format complete!") + + return nil +} + +var messageStoreBucket = []byte("message-store") + +// migrateGossipMessageStoreKeys migrates the key format for gossip messages +// found in the message store to a new one that takes into consideration the of +// the message being stored. +func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { + // We'll start by retrieving the bucket in which these messages are + // stored within. If there isn't one, there's nothing left for us to do + // so we can avoid the migration. + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return nil + } + + log.Info("Migrating to the gossip message store new key format") + + // Otherwise we'll proceed with the migration. We'll start by coalescing + // all the current messages within the store, which are indexed by the + // public key of the peer which they should be sent to, followed by the + // short channel ID of the channel for which the message belongs to. We + // should only expect to find channel announcement signatures as that + // was the only support message type previously. + msgs := make(map[[33 + 8]byte]*lnwire.AnnounceSignatures) + err := messageStore.ForEach(func(k, v []byte) error { + var msgKey [33 + 8]byte + copy(msgKey[:], k) + + msg := &lnwire.AnnounceSignatures{} + if err := msg.Decode(bytes.NewReader(v), 0); err != nil { + return err + } + + msgs[msgKey] = msg + + return nil + + }) + if err != nil { + return err + } + + // Then, we'll go over all of our messages, remove their previous entry, + // and add another with the new key format. Once we've done this for + // every message, we can consider the migration complete. + for oldMsgKey, msg := range msgs { + if err := messageStore.Delete(oldMsgKey[:]); err != nil { + return err + } + + // Construct the new key for which we'll find this message with + // in the store. It'll be the same as the old, but we'll also + // include the message type. + var msgType [2]byte + binary.BigEndian.PutUint16(msgType[:], uint16(msg.MsgType())) + newMsgKey := append(oldMsgKey[:], msgType[:]...) + + // Serialize the message with its wire encoding. + var b bytes.Buffer + if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { + return err + } + + if err := messageStore.Put(newMsgKey, b.Bytes()); err != nil { + return err + } + } + + log.Info("Migration to the gossip message store new key format complete!") + + return nil +} + +// migrateOutgoingPayments moves the OutgoingPayments into a new bucket format +// where they all reside in a top-level bucket indexed by the payment hash. In +// this sub-bucket we store information relevant to this payment, such as the +// payment status. +// +// Since the router cannot handle resumed payments that have the status +// InFlight (we have no PaymentAttemptInfo available for pre-migration +// payments) we delete those statuses, so only Completed payments remain in the +// new bucket structure. +func migrateOutgoingPayments(tx *bbolt.Tx) error { + log.Infof("Migrating outgoing payments to new bucket structure") + + oldPayments := tx.Bucket(paymentBucket) + + // Return early if there are no payments to migrate. + if oldPayments == nil { + log.Infof("No outgoing payments found, nothing to migrate.") + return nil + } + + newPayments, err := tx.CreateBucket(paymentsRootBucket) + if err != nil { + return err + } + + // Helper method to get the source pubkey. We define it such that we + // only attempt to fetch it if needed. + sourcePub := func() ([33]byte, error) { + var pub [33]byte + nodes := tx.Bucket(nodeBucket) + if nodes == nil { + return pub, ErrGraphNotFound + } + + selfPub := nodes.Get(sourceKey) + if selfPub == nil { + return pub, ErrSourceNodeNotSet + } + copy(pub[:], selfPub[:]) + return pub, nil + } + + err = oldPayments.ForEach(func(k, v []byte) error { + // Ignores if it is sub-bucket. + if v == nil { + return nil + } + + // Read the old payment format. + r := bytes.NewReader(v) + payment, err := deserializeOutgoingPayment(r) + if err != nil { + return err + } + + // Calculate payment hash from the payment preimage. + paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) + + // Now create and add a PaymentCreationInfo to the bucket. + c := &PaymentCreationInfo{ + PaymentHash: paymentHash, + Value: payment.Terms.Value, + CreationDate: payment.CreationDate, + PaymentRequest: payment.PaymentRequest, + } + + var infoBuf bytes.Buffer + if err := serializePaymentCreationInfo(&infoBuf, c); err != nil { + return err + } + + sourcePubKey, err := sourcePub() + if err != nil { + return err + } + + // Do the same for the PaymentAttemptInfo. + totalAmt := payment.Terms.Value + payment.Fee + rt := route.Route{ + TotalTimeLock: payment.TimeLockLength, + TotalAmount: totalAmt, + SourcePubKey: sourcePubKey, + Hops: []*route.Hop{}, + } + for _, hop := range payment.Path { + rt.Hops = append(rt.Hops, &route.Hop{ + PubKeyBytes: hop, + AmtToForward: totalAmt, + }) + } + + // Since the old format didn't store the fee for individual + // hops, we let the last hop eat the whole fee for the total to + // add up. + if len(rt.Hops) > 0 { + rt.Hops[len(rt.Hops)-1].AmtToForward = payment.Terms.Value + } + + // Since we don't have the session key for old payments, we + // create a random one to be able to serialize the attempt + // info. + priv, _ := btcec.NewPrivateKey(btcec.S256()) + s := &PaymentAttemptInfo{ + PaymentID: 0, // unknown. + SessionKey: priv, // unknown. + Route: rt, + } + + var attemptBuf bytes.Buffer + if err := serializePaymentAttemptInfoMigration9(&attemptBuf, s); err != nil { + return err + } + + // Reuse the existing payment sequence number. + var seqNum [8]byte + copy(seqNum[:], k) + + // Create a bucket indexed by the payment hash. + bucket, err := newPayments.CreateBucket(paymentHash[:]) + + // If the bucket already exists, it means that we are migrating + // from a database containing duplicate payments to a payment + // hash. To keep this information, we store such duplicate + // payments in a sub-bucket. + if err == bbolt.ErrBucketExists { + pHashBucket := newPayments.Bucket(paymentHash[:]) + + // Create a bucket for duplicate payments within this + // payment hash's bucket. + dup, err := pHashBucket.CreateBucketIfNotExists( + paymentDuplicateBucket, + ) + if err != nil { + return err + } + + // Each duplicate will get its own sub-bucket within + // this bucket, so use their sequence number to index + // them by. + bucket, err = dup.CreateBucket(seqNum[:]) + if err != nil { + return err + } + + } else if err != nil { + return err + } + + // Store the payment's information to the bucket. + err = bucket.Put(paymentSequenceKey, seqNum[:]) + if err != nil { + return err + } + + err = bucket.Put(paymentCreationInfoKey, infoBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentAttemptInfoKey, attemptBuf.Bytes()) + if err != nil { + return err + } + + err = bucket.Put(paymentSettleInfoKey, payment.PaymentPreimage[:]) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + // To continue producing unique sequence numbers, we set the sequence + // of the new bucket to that of the old one. + seq := oldPayments.Sequence() + if err := newPayments.SetSequence(seq); err != nil { + return err + } + + // Now we delete the old buckets. Deleting the payment status buckets + // deletes all payment statuses other than Complete. + err = tx.DeleteBucket(paymentStatusBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + // Finally delete the old payment bucket. + err = tx.DeleteBucket(paymentBucket) + if err != nil && err != bbolt.ErrBucketNotFound { + return err + } + + log.Infof("Migration of outgoing payment bucket structure completed!") + return nil +} diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go new file mode 100644 index 00000000..8a9076fb --- /dev/null +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -0,0 +1,952 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcutil" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// TestPaymentStatusesMigration checks that already completed payments will have +// their payment statuses set to Completed after the migration. +func TestPaymentStatusesMigration(t *testing.T) { + t.Parallel() + + fakePayment := makeFakePayment() + paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:]) + + // Add fake payment to test database, verifying that it was created, + // that we have only one payment, and its status is not "Completed". + beforeMigrationFunc := func(d *DB) { + if err := d.addPayment(fakePayment); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + payments, err := d.fetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != 1 { + t.Fatalf("wrong qty of paymets: expected 1, got %v", + len(payments)) + } + + paymentStatus, err := d.fetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + // We should receive default status if we have any in database. + if paymentStatus != StatusUnknown { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusUnknown.String(), paymentStatus.String()) + } + + // Lastly, we'll add a locally-sourced circuit and + // non-locally-sourced circuit to the circuit map. The + // locally-sourced payment should end up with an InFlight + // status, while the other should remain unchanged, which + // defaults to Grounded. + err = d.Update(func(tx *bbolt.Tx) error { + circuits, err := tx.CreateBucketIfNotExists( + []byte("circuit-adds"), + ) + if err != nil { + return err + } + + groundedKey := make([]byte, 16) + binary.BigEndian.PutUint64(groundedKey[:8], 1) + binary.BigEndian.PutUint64(groundedKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is the case for locally-sourced + // payments. No payment status should end up being set + // for this circuit, since the short channel id of the + // key is non-zero (e.g., a forwarded circuit). This + // will default it to Grounded. + groundedCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + err = circuits.Put(groundedKey, groundedCircuit) + if err != nil { + return err + } + + inFlightKey := make([]byte, 16) + binary.BigEndian.PutUint64(inFlightKey[:8], 0) + binary.BigEndian.PutUint64(inFlightKey[8:], 1) + + // Generated using TestHalfCircuitSerialization with nil + // ErrorEncrypter, which is not the case for forwarded + // payments, but should have no impact on the + // correctness of the test. The payment status for this + // circuit should be set to InFlight, since the short + // channel id in the key is 0 (sourceHop). + inFlightCircuit := []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + // start payment hash + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // end payment hash + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, + 0x42, 0x40, 0x00, + } + + return circuits.Put(inFlightKey, inFlightCircuit) + }) + if err != nil { + t.Fatalf("unable to add circuit map entry: %v", err) + } + } + + // Verify that the created payment status is "Completed" for our one + // fake payment. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + // Check that our completed payments were migrated. + paymentStatus, err := d.fetchPaymentStatus(paymentHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusSucceeded { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusSucceeded.String(), paymentStatus.String()) + } + + inFlightHash := [32]byte{ + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that the locally sourced payment was transitioned to + // InFlight. + paymentStatus, err = d.fetchPaymentStatus(inFlightHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusInFlight { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusInFlight.String(), paymentStatus.String()) + } + + groundedHash := [32]byte{ + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + + // Check that non-locally sourced payments remain in the default + // Grounded state. + paymentStatus, err = d.fetchPaymentStatus(groundedHash) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != StatusUnknown { + t.Fatalf("wrong payment status: expected %v, got %v", + StatusUnknown.String(), paymentStatus.String()) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + paymentStatusesMigration, + false) +} + +// TestMigrateOptionalChannelCloseSummaryFields properly converts a +// ChannelCloseSummary to the v7 format, where optional fields have their +// presence indicated with boolean markers. +func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { + t.Parallel() + + chanState, err := createTestChannelState(nil) + if err != nil { + t.Fatalf("unable to create channel state: %v", err) + } + + var chanPointBuf bytes.Buffer + err = writeOutpoint(&chanPointBuf, &chanState.FundingOutpoint) + if err != nil { + t.Fatalf("unable to write outpoint: %v", err) + } + + chanID := chanPointBuf.Bytes() + + testCases := []struct { + closeSummary *ChannelCloseSummary + oldSerialization func(c *ChannelCloseSummary) []byte + }{ + { + // A close summary where none of the new fields are + // set. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + + // The last fields will be unset. + RemoteCurrentRevocation: nil, + LocalChanConfig: ChannelConfig{}, + RemoteNextRevocation: nil, + }, + + // In the old format the last field written is the + // IsPendingField. It should be converted by adding an + // extra boolean marker at the end to indicate that the + // remaining fields are not there. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + // For the old format, these are all the fields + // that are written. + return buf.Bytes() + }, + }, + { + // A close summary where the new fields are present, + // but the optional RemoteNextRevocation field is not + // set. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, + LocalChanConfig: chanState.LocalChanCfg, + + // RemoteNextRevocation is optional, and here + // it is not set. + RemoteNextRevocation: nil, + }, + + // In the old format the last field written is the + // LocalChanConfig. This indicates that the optional + // RemoteNextRevocation field is not present. It should + // be converted by adding boolean markers for all these + // fields. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteCurrentRevocation) + if err != nil { + t.Fatal(err) + } + + err = writeChanConfig(&buf, &cs.LocalChanConfig) + if err != nil { + t.Fatal(err) + } + + // RemoteNextRevocation is not written. + return buf.Bytes() + }, + }, + { + // A close summary where all fields are present. + closeSummary: &ChannelCloseSummary{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + ChainHash: chanState.ChainHash, + ClosingTXID: testTx.TxHash(), + CloseHeight: 100, + RemotePub: chanState.IdentityPub, + Capacity: chanState.Capacity, + SettledBalance: btcutil.Amount(50000), + CloseType: RemoteForceClose, + IsPending: true, + RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, + LocalChanConfig: chanState.LocalChanCfg, + + // RemoteNextRevocation is optional, and in + // this case we set it. + RemoteNextRevocation: chanState.RemoteNextRevocation, + }, + + // In the old format all the fields are written. It + // should be converted by adding boolean markers for + // all these fields. + oldSerialization: func(cs *ChannelCloseSummary) []byte { + var buf bytes.Buffer + err := WriteElements(&buf, cs.ChanPoint, + cs.ShortChanID, cs.ChainHash, + cs.ClosingTXID, cs.CloseHeight, + cs.RemotePub, cs.Capacity, + cs.SettledBalance, cs.TimeLockedBalance, + cs.CloseType, cs.IsPending, + ) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteCurrentRevocation) + if err != nil { + t.Fatal(err) + } + + err = writeChanConfig(&buf, &cs.LocalChanConfig) + if err != nil { + t.Fatal(err) + } + + err = WriteElements(&buf, cs.RemoteNextRevocation) + if err != nil { + t.Fatal(err) + } + + return buf.Bytes() + }, + }, + } + + for _, test := range testCases { + + // Before the migration we must add the old format to the DB. + beforeMigrationFunc := func(d *DB) { + + // Get the old serialization format for this test's + // close summary, and it to the closed channel bucket. + old := test.oldSerialization(test.closeSummary) + err = d.Update(func(tx *bbolt.Tx) error { + closedChanBucket, err := tx.CreateBucketIfNotExists( + closedChannelBucket, + ) + if err != nil { + return err + } + return closedChanBucket.Put(chanID, old) + }) + if err != nil { + t.Fatalf("unable to add old serialization: %v", + err) + } + } + + // After the migration it should be found in the new format. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration wasn't applied") + } + + // We generate the new serialized version, to check + // against what is found in the DB. + var b bytes.Buffer + err = serializeChannelCloseSummary(&b, test.closeSummary) + if err != nil { + t.Fatalf("unable to serialize: %v", err) + } + newSerialization := b.Bytes() + + var dbSummary []byte + err = d.View(func(tx *bbolt.Tx) error { + closedChanBucket := tx.Bucket(closedChannelBucket) + if closedChanBucket == nil { + return errors.New("unable to find bucket") + } + + // Get the serialized verision from the DB and + // make sure it matches what we expected. + dbSummary = closedChanBucket.Get(chanID) + if !bytes.Equal(dbSummary, newSerialization) { + return fmt.Errorf("unexpected new " + + "serialization") + } + return nil + }) + if err != nil { + t.Fatalf("unable to view DB: %v", err) + } + + // Finally we fetch the deserialized summary from the + // DB and check that it is equal to our original one. + dbChannels, err := d.FetchClosedChannels(false) + if err != nil { + t.Fatalf("unable to fetch closed channels: %v", + err) + } + + if len(dbChannels) != 1 { + t.Fatalf("expected 1 closed channels, found %v", + len(dbChannels)) + } + + dbChan := dbChannels[0] + if !reflect.DeepEqual(dbChan, test.closeSummary) { + dbChan.RemotePub.Curve = nil + test.closeSummary.RemotePub.Curve = nil + t.Fatalf("not equal: %v vs %v", + spew.Sdump(dbChan), + spew.Sdump(test.closeSummary)) + } + + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateOptionalChannelCloseSummaryFields, + false) + } +} + +// TestMigrateGossipMessageStoreKeys ensures that the migration to the new +// gossip message store key format is successful/unsuccessful under various +// scenarios. +func TestMigrateGossipMessageStoreKeys(t *testing.T) { + t.Parallel() + + // Construct the message which we'll use to test the migration, along + // with its old and new key formats. + shortChanID := lnwire.ShortChannelID{BlockHeight: 10} + msg := &lnwire.AnnounceSignatures{ShortChannelID: shortChanID} + + var oldMsgKey [33 + 8]byte + copy(oldMsgKey[:33], pubKey.SerializeCompressed()) + binary.BigEndian.PutUint64(oldMsgKey[33:41], shortChanID.ToUint64()) + + var newMsgKey [33 + 8 + 2]byte + copy(newMsgKey[:41], oldMsgKey[:]) + binary.BigEndian.PutUint16(newMsgKey[41:43], uint16(msg.MsgType())) + + // Before the migration, we'll create the bucket where the messages + // should live and insert them. + beforeMigration := func(db *DB) { + var b bytes.Buffer + if err := msg.Encode(&b, 0); err != nil { + t.Fatalf("unable to serialize message: %v", err) + } + + err := db.Update(func(tx *bbolt.Tx) error { + messageStore, err := tx.CreateBucketIfNotExists( + messageStoreBucket, + ) + if err != nil { + return err + } + + return messageStore.Put(oldMsgKey[:], b.Bytes()) + }) + if err != nil { + t.Fatal(err) + } + } + + // After the migration, we'll make sure that: + // 1. We cannot find the message under its old key. + // 2. We can find the message under its new key. + // 3. The message matches the original. + afterMigration := func(db *DB) { + meta, err := db.FetchMeta(nil) + if err != nil { + t.Fatalf("unable to fetch db version: %v", err) + } + if meta.DbVersionNumber != 1 { + t.Fatalf("migration should have succeeded but didn't") + } + + var rawMsg []byte + err = db.View(func(tx *bbolt.Tx) error { + messageStore := tx.Bucket(messageStoreBucket) + if messageStore == nil { + return errors.New("message store bucket not " + + "found") + } + rawMsg = messageStore.Get(oldMsgKey[:]) + if rawMsg != nil { + t.Fatal("expected to not find message under " + + "old key, but did") + } + rawMsg = messageStore.Get(newMsgKey[:]) + if rawMsg == nil { + return fmt.Errorf("expected to find message " + + "under new key, but didn't") + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + + gotMsg, err := lnwire.ReadMessage(bytes.NewReader(rawMsg), 0) + if err != nil { + t.Fatalf("unable to deserialize raw message: %v", err) + } + if !reflect.DeepEqual(msg, gotMsg) { + t.Fatalf("expected message: %v\ngot message: %v", + spew.Sdump(msg), spew.Sdump(gotMsg)) + } + } + + applyMigration( + t, beforeMigration, afterMigration, + migrateGossipMessageStoreKeys, false, + ) +} + +// TestOutgoingPaymentsMigration checks that OutgoingPayments are migrated to a +// new bucket structure after the migration. +func TestOutgoingPaymentsMigration(t *testing.T) { + t.Parallel() + + const numPayments = 4 + var oldPayments []*outgoingPayment + + // Add fake payments to test database, verifying that it was created. + beforeMigrationFunc := func(d *DB) { + for i := 0; i < numPayments; i++ { + var p *outgoingPayment + var err error + + // We fill the database with random payments. For the + // very last one we'll use a duplicate of the first, to + // ensure we are able to handle migration from a + // database that has copies. + if i < numPayments-1 { + p, err = makeRandomFakePayment() + if err != nil { + t.Fatalf("unable to create payment: %v", + err) + } + } else { + p = oldPayments[0] + } + + if err := d.addPayment(p); err != nil { + t.Fatalf("unable to add payment: %v", err) + } + + oldPayments = append(oldPayments, p) + } + + payments, err := d.fetchAllPayments() + if err != nil { + t.Fatalf("unable to fetch payments: %v", err) + } + + if len(payments) != numPayments { + t.Fatalf("wrong qty of paymets: expected %d got %v", + numPayments, len(payments)) + } + } + + // Verify that all payments were migrated. + afterMigrationFunc := func(d *DB) { + meta, err := d.FetchMeta(nil) + if err != nil { + t.Fatal(err) + } + + if meta.DbVersionNumber != 1 { + t.Fatal("migration 'paymentStatusesMigration' wasn't applied") + } + + sentPayments, err := d.fetchPaymentsMigration9() + if err != nil { + t.Fatalf("unable to fetch sent payments: %v", err) + } + + if len(sentPayments) != numPayments { + t.Fatalf("expected %d payments, got %d", numPayments, + len(sentPayments)) + } + + graph := d.ChannelGraph() + sourceNode, err := graph.SourceNode() + if err != nil { + t.Fatalf("unable to fetch source node: %v", err) + } + + for i, p := range sentPayments { + // The payment status should be Completed. + if p.Status != StatusSucceeded { + t.Fatalf("expected Completed, got %v", p.Status) + } + + // Check that the sequence number is preserved. They + // start counting at 1. + if p.sequenceNum != uint64(i+1) { + t.Fatalf("expected seqnum %d, got %d", i, + p.sequenceNum) + } + + // Order of payments should be be preserved. + old := oldPayments[i] + + // Check the individial fields. + if p.Info.Value != old.Terms.Value { + t.Fatalf("value mismatch") + } + + if p.Info.CreationDate != old.CreationDate { + t.Fatalf("date mismatch") + } + + if !bytes.Equal(p.Info.PaymentRequest, old.PaymentRequest) { + t.Fatalf("payreq mismatch") + } + + if *p.PaymentPreimage != old.PaymentPreimage { + t.Fatalf("preimage mismatch") + } + + if p.Attempt.Route.TotalFees() != old.Fee { + t.Fatalf("Fee mismatch") + } + + if p.Attempt.Route.TotalAmount != old.Fee+old.Terms.Value { + t.Fatalf("Total amount mismatch") + } + + if p.Attempt.Route.TotalTimeLock != old.TimeLockLength { + t.Fatalf("timelock mismatch") + } + + if p.Attempt.Route.SourcePubKey != sourceNode.PubKeyBytes { + t.Fatalf("source mismatch: %x vs %x", + p.Attempt.Route.SourcePubKey[:], + sourceNode.PubKeyBytes[:]) + } + + for i, hop := range old.Path { + if hop != p.Attempt.Route.Hops[i].PubKeyBytes { + t.Fatalf("path mismatch") + } + } + } + + // Finally, check that the payment sequence number is updated + // to reflect the migrated payments. + err = d.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return fmt.Errorf("payments bucket not found") + } + + seq := payments.Sequence() + if seq != numPayments { + return fmt.Errorf("expected sequence to be "+ + "%d, got %d", numPayments, seq) + } + + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateOutgoingPayments, + false) +} + +func makeRandPaymentCreationInfo() (*PaymentCreationInfo, error) { + var payHash lntypes.Hash + if _, err := rand.Read(payHash[:]); err != nil { + return nil, err + } + + return &PaymentCreationInfo{ + PaymentHash: payHash, + Value: lnwire.MilliSatoshi(rand.Int63()), + CreationDate: time.Now(), + PaymentRequest: []byte("test"), + }, nil +} + +// TestPaymentRouteSerialization tests that we're able to properly migrate +// existing payments on disk that contain the traversed routes to the new +// routing format which supports the TLV payloads. We also test that the +// migration is able to handle duplicate payment attempts. +func TestPaymentRouteSerialization(t *testing.T) { + t.Parallel() + + legacyHop1 := &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + LegacyPayload: true, + AmtToForward: 555, + } + legacyHop2 := &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + LegacyPayload: true, + AmtToForward: 555, + } + legacyRoute := route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{legacyHop1, legacyHop2}, + } + + const numPayments = 4 + var oldPayments []*Payment + + sharedPayAttempt := PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: legacyRoute, + } + + // We'll first add a series of fake payments, using the existing legacy + // serialization format. + beforeMigrationFunc := func(d *DB) { + err := d.Update(func(tx *bbolt.Tx) error { + paymentsBucket, err := tx.CreateBucket( + paymentsRootBucket, + ) + if err != nil { + t.Fatalf("unable to create new payments "+ + "bucket: %v", err) + } + + for i := 0; i < numPayments; i++ { + var seqNum [8]byte + byteOrder.PutUint64(seqNum[:], uint64(i)) + + // All payments will be randomly generated, + // other than the final payment. We'll force + // the final payment to re-use an existing + // payment hash so we can insert it into the + // duplicate payment hash bucket. + var payInfo *PaymentCreationInfo + if i < numPayments-1 { + payInfo, err = makeRandPaymentCreationInfo() + if err != nil { + t.Fatalf("unable to create "+ + "payment: %v", err) + } + } else { + payInfo = oldPayments[0].Info + } + + // Next, legacy encoded when needed, we'll + // serialize the info and the attempt. + var payInfoBytes bytes.Buffer + err = serializePaymentCreationInfo( + &payInfoBytes, payInfo, + ) + if err != nil { + t.Fatalf("unable to encode pay "+ + "info: %v", err) + } + var payAttemptBytes bytes.Buffer + err = serializePaymentAttemptInfoLegacy( + &payAttemptBytes, &sharedPayAttempt, + ) + if err != nil { + t.Fatalf("unable to encode payment attempt: "+ + "%v", err) + } + + // Before we write to disk, we'll need to fetch + // the proper bucket. If this is the duplicate + // payment, then we'll grab the dup bucket, + // otherwise, we'll use the top level bucket. + var payHashBucket *bbolt.Bucket + if i < numPayments-1 { + payHashBucket, err = paymentsBucket.CreateBucket( + payInfo.PaymentHash[:], + ) + if err != nil { + t.Fatalf("unable to create payments bucket: %v", err) + } + } else { + payHashBucket = paymentsBucket.Bucket( + payInfo.PaymentHash[:], + ) + dupPayBucket, err := payHashBucket.CreateBucket( + paymentDuplicateBucket, + ) + if err != nil { + t.Fatalf("unable to create "+ + "dup hash bucket: %v", err) + } + + payHashBucket, err = dupPayBucket.CreateBucket( + seqNum[:], + ) + if err != nil { + t.Fatalf("unable to make dup "+ + "bucket: %v", err) + } + } + + err = payHashBucket.Put(paymentSequenceKey, seqNum[:]) + if err != nil { + t.Fatalf("unable to write seqno: %v", err) + } + + err = payHashBucket.Put( + paymentCreationInfoKey, payInfoBytes.Bytes(), + ) + if err != nil { + t.Fatalf("unable to write creation "+ + "info: %v", err) + } + + err = payHashBucket.Put( + paymentAttemptInfoKey, payAttemptBytes.Bytes(), + ) + if err != nil { + t.Fatalf("unable to write attempt "+ + "info: %v", err) + } + + oldPayments = append(oldPayments, &Payment{ + Info: payInfo, + Attempt: &sharedPayAttempt, + }) + } + + return nil + }) + if err != nil { + t.Fatalf("unable to create test payments: %v", err) + } + } + + afterMigrationFunc := func(d *DB) { + newPayments, err := d.FetchPayments() + if err != nil { + t.Fatalf("unable to fetch new payments: %v", err) + } + + if len(newPayments) != numPayments { + t.Fatalf("expected %d payments, got %d", numPayments, + len(newPayments)) + } + + for i, p := range newPayments { + // Order of payments should be be preserved. + old := oldPayments[i] + + if p.Attempt.PaymentID != old.Attempt.PaymentID { + t.Fatalf("wrong pay ID: expected %v, got %v", + p.Attempt.PaymentID, + old.Attempt.PaymentID) + } + + if p.Attempt.Route.TotalFees() != old.Attempt.Route.TotalFees() { + t.Fatalf("Fee mismatch") + } + + if p.Attempt.Route.TotalAmount != old.Attempt.Route.TotalAmount { + t.Fatalf("Total amount mismatch") + } + + if p.Attempt.Route.TotalTimeLock != old.Attempt.Route.TotalTimeLock { + t.Fatalf("timelock mismatch") + } + + if p.Attempt.Route.SourcePubKey != old.Attempt.Route.SourcePubKey { + t.Fatalf("source mismatch: %x vs %x", + p.Attempt.Route.SourcePubKey[:], + old.Attempt.Route.SourcePubKey[:]) + } + + for i, hop := range p.Attempt.Route.Hops { + if !reflect.DeepEqual(hop, legacyRoute.Hops[i]) { + t.Fatalf("hop mismatch") + } + } + } + } + + applyMigration(t, + beforeMigrationFunc, + afterMigrationFunc, + migrateRouteSerialization, + false) +} diff --git a/channeldb/migration_01_to_11/nodes.go b/channeldb/migration_01_to_11/nodes.go new file mode 100644 index 00000000..f40359e8 --- /dev/null +++ b/channeldb/migration_01_to_11/nodes.go @@ -0,0 +1,316 @@ +package migration_01_to_11 + +import ( + "bytes" + "io" + "net" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" +) + +var ( + // nodeInfoBucket stores metadata pertaining to nodes that we've had + // direct channel-based correspondence with. This bucket allows one to + // query for all open channels pertaining to the node by exploring each + // node's sub-bucket within the openChanBucket. + nodeInfoBucket = []byte("nib") +) + +// LinkNode stores metadata related to node's that we have/had a direct +// channel open with. Information such as the Bitcoin network the node +// advertised, and its identity public key are also stored. Additionally, this +// struct and the bucket its stored within have store data similar to that of +// Bitcoin's addrmanager. The TCP address information stored within the struct +// can be used to establish persistent connections will all channel +// counterparties on daemon startup. +// +// TODO(roasbeef): also add current OnionKey plus rotation schedule? +// TODO(roasbeef): add bitfield for supported services +// * possibly add a wire.NetAddress type, type +type LinkNode struct { + // Network indicates the Bitcoin network that the LinkNode advertises + // for incoming channel creation. + Network wire.BitcoinNet + + // IdentityPub is the node's current identity public key. Any + // channel/topology related information received by this node MUST be + // signed by this public key. + IdentityPub *btcec.PublicKey + + // LastSeen tracks the last time this node was seen within the network. + // A node should be marked as seen if the daemon either is able to + // establish an outgoing connection to the node or receives a new + // incoming connection from the node. This timestamp (stored in unix + // epoch) may be used within a heuristic which aims to determine when a + // channel should be unilaterally closed due to inactivity. + // + // TODO(roasbeef): replace with block hash/height? + // * possibly add a time-value metric into the heuristic? + LastSeen time.Time + + // Addresses is a list of IP address in which either we were able to + // reach the node over in the past, OR we received an incoming + // authenticated connection for the stored identity public key. + Addresses []net.Addr + + db *DB +} + +// NewLinkNode creates a new LinkNode from the provided parameters, which is +// backed by an instance of channeldb. +func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, + addrs ...net.Addr) *LinkNode { + + return &LinkNode{ + Network: bitNet, + IdentityPub: pub, + LastSeen: time.Now(), + Addresses: addrs, + db: db, + } +} + +// UpdateLastSeen updates the last time this node was directly encountered on +// the Lightning Network. +func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error { + l.LastSeen = lastSeen + + return l.Sync() +} + +// AddAddress appends the specified TCP address to the list of known addresses +// this node is/was known to be reachable at. +func (l *LinkNode) AddAddress(addr net.Addr) error { + for _, a := range l.Addresses { + if a.String() == addr.String() { + return nil + } + } + + l.Addresses = append(l.Addresses, addr) + + return l.Sync() +} + +// Sync performs a full database sync which writes the current up-to-date data +// within the struct to the database. +func (l *LinkNode) Sync() error { + + // Finally update the database by storing the link node and updating + // any relevant indexes. + return l.db.Update(func(tx *bbolt.Tx) error { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return ErrLinkNodesNotFound + } + + return putLinkNode(nodeMetaBucket, l) + }) +} + +// putLinkNode serializes then writes the encoded version of the passed link +// node into the nodeMetaBucket. This function is provided in order to allow +// the ability to re-use a database transaction across many operations. +func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error { + // First serialize the LinkNode into its raw-bytes encoding. + var b bytes.Buffer + if err := serializeLinkNode(&b, l); err != nil { + return err + } + + // Finally insert the link-node into the node metadata bucket keyed + // according to the its pubkey serialized in compressed form. + nodePub := l.IdentityPub.SerializeCompressed() + return nodeMetaBucket.Put(nodePub, b.Bytes()) +} + +// DeleteLinkNode removes the link node with the given identity from the +// database. +func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error { + return db.Update(func(tx *bbolt.Tx) error { + return db.deleteLinkNode(tx, identity) + }) +} + +func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return ErrLinkNodesNotFound + } + + pubKey := identity.SerializeCompressed() + return nodeMetaBucket.Delete(pubKey) +} + +// FetchLinkNode attempts to lookup the data for a LinkNode based on a target +// identity public key. If a particular LinkNode for the passed identity public +// key cannot be found, then ErrNodeNotFound if returned. +func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { + var linkNode *LinkNode + err := db.View(func(tx *bbolt.Tx) error { + node, err := fetchLinkNode(tx, identity) + if err != nil { + return err + } + + linkNode = node + return nil + }) + + return linkNode, err +} + +func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) { + // First fetch the bucket for storing node metadata, bailing out early + // if it hasn't been created yet. + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + // If a link node for that particular public key cannot be located, + // then exit early with an ErrNodeNotFound. + pubKey := targetPub.SerializeCompressed() + nodeBytes := nodeMetaBucket.Get(pubKey) + if nodeBytes == nil { + return nil, ErrNodeNotFound + } + + // Finally, decode and allocate a fresh LinkNode object to be returned + // to the caller. + nodeReader := bytes.NewReader(nodeBytes) + return deserializeLinkNode(nodeReader) +} + +// TODO(roasbeef): update link node addrs in server upon connection + +// FetchAllLinkNodes starts a new database transaction to fetch all nodes with +// whom we have active channels with. +func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { + var linkNodes []*LinkNode + err := db.View(func(tx *bbolt.Tx) error { + nodes, err := db.fetchAllLinkNodes(tx) + if err != nil { + return err + } + + linkNodes = nodes + return nil + }) + if err != nil { + return nil, err + } + + return linkNodes, nil +} + +// fetchAllLinkNodes uses an existing database transaction to fetch all nodes +// with whom we have active channels with. +func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) { + nodeMetaBucket := tx.Bucket(nodeInfoBucket) + if nodeMetaBucket == nil { + return nil, ErrLinkNodesNotFound + } + + var linkNodes []*LinkNode + err := nodeMetaBucket.ForEach(func(k, v []byte) error { + if v == nil { + return nil + } + + nodeReader := bytes.NewReader(v) + linkNode, err := deserializeLinkNode(nodeReader) + if err != nil { + return err + } + + linkNodes = append(linkNodes, linkNode) + return nil + }) + if err != nil { + return nil, err + } + + return linkNodes, nil +} + +func serializeLinkNode(w io.Writer, l *LinkNode) error { + var buf [8]byte + + byteOrder.PutUint32(buf[:4], uint32(l.Network)) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + serializedID := l.IdentityPub.SerializeCompressed() + if _, err := w.Write(serializedID); err != nil { + return err + } + + seenUnix := uint64(l.LastSeen.Unix()) + byteOrder.PutUint64(buf[:], seenUnix) + if _, err := w.Write(buf[:]); err != nil { + return err + } + + numAddrs := uint32(len(l.Addresses)) + byteOrder.PutUint32(buf[:4], numAddrs) + if _, err := w.Write(buf[:4]); err != nil { + return err + } + + for _, addr := range l.Addresses { + if err := serializeAddr(w, addr); err != nil { + return err + } + } + + return nil +} + +func deserializeLinkNode(r io.Reader) (*LinkNode, error) { + var ( + err error + buf [8]byte + ) + + node := &LinkNode{} + + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return nil, err + } + node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4])) + + var pub [33]byte + if _, err := io.ReadFull(r, pub[:]); err != nil { + return nil, err + } + node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256()) + if err != nil { + return nil, err + } + + if _, err := io.ReadFull(r, buf[:]); err != nil { + return nil, err + } + node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0) + + if _, err := io.ReadFull(r, buf[:4]); err != nil { + return nil, err + } + numAddrs := byteOrder.Uint32(buf[:4]) + + node.Addresses = make([]net.Addr, numAddrs) + for i := uint32(0); i < numAddrs; i++ { + addr, err := deserializeAddr(r) + if err != nil { + return nil, err + } + node.Addresses[i] = addr + } + + return node, nil +} diff --git a/channeldb/migration_01_to_11/nodes_test.go b/channeldb/migration_01_to_11/nodes_test.go new file mode 100644 index 00000000..481dc5bd --- /dev/null +++ b/channeldb/migration_01_to_11/nodes_test.go @@ -0,0 +1,140 @@ +package migration_01_to_11 + +import ( + "bytes" + "net" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" +) + +func TestLinkNodeEncodeDecode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + // First we'll create some initial data to use for populating our test + // LinkNode instances. + _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + _, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:]) + addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000") + if err != nil { + t.Fatalf("unable to create test addr: %v", err) + } + addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000") + if err != nil { + t.Fatalf("unable to create test addr: %v", err) + } + + // Create two fresh link node instances with the above dummy data, then + // fully sync both instances to disk. + node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) + node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) + if err := node1.Sync(); err != nil { + t.Fatalf("unable to sync node: %v", err) + } + if err := node2.Sync(); err != nil { + t.Fatalf("unable to sync node: %v", err) + } + + // Fetch all current link nodes from the database, they should exactly + // match the two created above. + originalNodes := []*LinkNode{node2, node1} + linkNodes, err := cdb.FetchAllLinkNodes() + if err != nil { + t.Fatalf("unable to fetch nodes: %v", err) + } + for i, node := range linkNodes { + if originalNodes[i].Network != node.Network { + t.Fatalf("node networks don't match: expected %v, got %v", + originalNodes[i].Network, node.Network) + } + + originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed() + dbPubkey := node.IdentityPub.SerializeCompressed() + if !bytes.Equal(originalPubkey, dbPubkey) { + t.Fatalf("node pubkeys don't match: expected %x, got %x", + originalPubkey, dbPubkey) + } + if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() { + t.Fatalf("last seen timestamps don't match: expected %v got %v", + originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix()) + } + if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() { + t.Fatalf("addresses don't match: expected %v, got %v", + originalNodes[i].Addresses, node.Addresses) + } + } + + // Next, we'll exercise the methods to append additional IP + // addresses, and also to update the last seen time. + if err := node1.UpdateLastSeen(time.Now()); err != nil { + t.Fatalf("unable to update last seen: %v", err) + } + if err := node1.AddAddress(addr2); err != nil { + t.Fatalf("unable to update addr: %v", err) + } + + // Fetch the same node from the database according to its public key. + node1DB, err := cdb.FetchLinkNode(pub1) + if err != nil { + t.Fatalf("unable to find node: %v", err) + } + + // Both the last seen timestamp and the list of reachable addresses for + // the node should be updated. + if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() { + t.Fatalf("last seen timestamps don't match: expected %v got %v", + node1.LastSeen.Unix(), node1DB.LastSeen.Unix()) + } + if len(node1DB.Addresses) != 2 { + t.Fatalf("wrong length for node1 addresses: expected %v, got %v", + 2, len(node1DB.Addresses)) + } + if node1DB.Addresses[0].String() != addr1.String() { + t.Fatalf("wrong address for node: expected %v, got %v", + addr1.String(), node1DB.Addresses[0].String()) + } + if node1DB.Addresses[1].String() != addr2.String() { + t.Fatalf("wrong address for node: expected %v, got %v", + addr2.String(), node1DB.Addresses[1].String()) + } +} + +func TestDeleteLinkNode(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) + addr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1337, + } + linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) + if err := linkNode.Sync(); err != nil { + t.Fatalf("unable to write link node to db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err != nil { + t.Fatalf("unable to find link node: %v", err) + } + + if err := cdb.DeleteLinkNode(pubKey); err != nil { + t.Fatalf("unable to delete link node from db: %v", err) + } + + if _, err := cdb.FetchLinkNode(pubKey); err == nil { + t.Fatal("should not have found link node in db, but did") + } +} diff --git a/channeldb/migration_01_to_11/options.go b/channeldb/migration_01_to_11/options.go new file mode 100644 index 00000000..c3cc2c4a --- /dev/null +++ b/channeldb/migration_01_to_11/options.go @@ -0,0 +1,62 @@ +package migration_01_to_11 + +const ( + // DefaultRejectCacheSize is the default number of rejectCacheEntries to + // cache for use in the rejection cache of incoming gossip traffic. This + // produces a cache size of around 1MB. + DefaultRejectCacheSize = 50000 + + // DefaultChannelCacheSize is the default number of ChannelEdges cached + // in order to reply to gossip queries. This produces a cache size of + // around 40MB. + DefaultChannelCacheSize = 20000 +) + +// Options holds parameters for tuning and customizing a channeldb.DB. +type Options struct { + // RejectCacheSize is the maximum number of rejectCacheEntries to hold + // in the rejection cache. + RejectCacheSize int + + // ChannelCacheSize is the maximum number of ChannelEdges to hold in the + // channel cache. + ChannelCacheSize int + + // NoFreelistSync, if true, prevents the database from syncing its + // freelist to disk, resulting in improved performance at the expense of + // increased startup time. + NoFreelistSync bool +} + +// DefaultOptions returns an Options populated with default values. +func DefaultOptions() Options { + return Options{ + RejectCacheSize: DefaultRejectCacheSize, + ChannelCacheSize: DefaultChannelCacheSize, + NoFreelistSync: true, + } +} + +// OptionModifier is a function signature for modifying the default Options. +type OptionModifier func(*Options) + +// OptionSetRejectCacheSize sets the RejectCacheSize to n. +func OptionSetRejectCacheSize(n int) OptionModifier { + return func(o *Options) { + o.RejectCacheSize = n + } +} + +// OptionSetChannelCacheSize sets the ChannelCacheSize to n. +func OptionSetChannelCacheSize(n int) OptionModifier { + return func(o *Options) { + o.ChannelCacheSize = n + } +} + +// OptionSetSyncFreelist allows the database to sync its freelist. +func OptionSetSyncFreelist(b bool) OptionModifier { + return func(o *Options) { + o.NoFreelistSync = !b + } +} diff --git a/channeldb/migration_01_to_11/payment_control.go b/channeldb/migration_01_to_11/payment_control.go new file mode 100644 index 00000000..83b1649a --- /dev/null +++ b/channeldb/migration_01_to_11/payment_control.go @@ -0,0 +1,497 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" +) + +var ( + // ErrAlreadyPaid signals we have already paid this payment hash. + ErrAlreadyPaid = errors.New("invoice is already paid") + + // ErrPaymentInFlight signals that payment for this payment hash is + // already "in flight" on the network. + ErrPaymentInFlight = errors.New("payment is in transition") + + // ErrPaymentNotInitiated is returned if payment wasn't initiated in + // switch. + ErrPaymentNotInitiated = errors.New("payment isn't initiated") + + // ErrPaymentAlreadySucceeded is returned in the event we attempt to + // change the status of a payment already succeeded. + ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded") + + // ErrPaymentAlreadyFailed is returned in the event we attempt to + // re-fail a failed payment. + ErrPaymentAlreadyFailed = errors.New("payment has already failed") + + // ErrUnknownPaymentStatus is returned when we do not recognize the + // existing state of a payment. + ErrUnknownPaymentStatus = errors.New("unknown payment status") + + // errNoAttemptInfo is returned when no attempt info is stored yet. + errNoAttemptInfo = errors.New("unable to find attempt info for " + + "inflight payment") +) + +// PaymentControl implements persistence for payments and payment attempts. +type PaymentControl struct { + db *DB +} + +// NewPaymentControl creates a new instance of the PaymentControl. +func NewPaymentControl(db *DB) *PaymentControl { + return &PaymentControl{ + db: db, + } +} + +// InitPayment checks or records the given PaymentCreationInfo with the DB, +// making sure it does not already exist as an in-flight payment. Then this +// method returns successfully, the payment is guranteeed to be in the InFlight +// state. +func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, + info *PaymentCreationInfo) error { + + var b bytes.Buffer + if err := serializePaymentCreationInfo(&b, info); err != nil { + return err + } + infoBytes := b.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := createPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + // Get the existing status of this payment, if any. + paymentStatus := fetchPaymentStatus(bucket) + + switch paymentStatus { + + // We allow retrying failed payments. + case StatusFailed: + + // This is a new payment that is being initialized for the + // first time. + case StatusUnknown: + + // We already have an InFlight payment on the network. We will + // disallow any new payments. + case StatusInFlight: + updateErr = ErrPaymentInFlight + return nil + + // We've already succeeded a payment to this payment hash, + // forbid the switch from sending another. + case StatusSucceeded: + updateErr = ErrAlreadyPaid + return nil + + default: + updateErr = ErrUnknownPaymentStatus + return nil + } + + // Obtain a new sequence number for this payment. This is used + // to sort the payments in order of creation, and also acts as + // a unique identifier for each payment. + sequenceNum, err := nextPaymentSequence(tx) + if err != nil { + return err + } + + err = bucket.Put(paymentSequenceKey, sequenceNum) + if err != nil { + return err + } + + // Add the payment info to the bucket, which contains the + // static information for this payment + err = bucket.Put(paymentCreationInfoKey, infoBytes) + if err != nil { + return err + } + + // We'll delete any lingering attempt info to start with, in + // case we are initializing a payment that was attempted + // earlier, but left in a state where we could retry. + err = bucket.Delete(paymentAttemptInfoKey) + if err != nil { + return err + } + + // Also delete any lingering failure info now that we are + // re-attempting. + return bucket.Delete(paymentFailInfoKey) + }) + if err != nil { + return err + } + + return updateErr +} + +// RegisterAttempt atomically records the provided PaymentAttemptInfo to the +// DB. +func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, + attempt *PaymentAttemptInfo) error { + + // Serialize the information before opening the db transaction. + var a bytes.Buffer + if err := serializePaymentAttemptInfo(&a, attempt); err != nil { + return err + } + attemptBytes := a.Bytes() + + var updateErr error + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only register attempts for payments that are + // in-flight. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Add the payment attempt to the payments bucket. + return bucket.Put(paymentAttemptInfoKey, attemptBytes) + }) + if err != nil { + return err + } + + return updateErr +} + +// Success transitions a payment into the Succeeded state. After invoking this +// method, InitPayment should always return an error to prevent us from making +// duplicate payments to the same payment hash. The provided preimage is +// atomically saved to the DB for record keeping. +func (p *PaymentControl) Success(paymentHash lntypes.Hash, + preimage lntypes.Preimage) (*route.Route, error) { + + var ( + updateErr error + route *route.Route + ) + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as succeeded. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Record the successful payment info atomically to the + // payments record. + err = bucket.Put(paymentSettleInfoKey, preimage[:]) + if err != nil { + return err + } + + // Retrieve attempt info for the notification. + attempt, err := fetchPaymentAttempt(bucket) + if err != nil { + return err + } + + route = &attempt.Route + + return nil + }) + if err != nil { + return nil, err + } + + return route, updateErr +} + +// Fail transitions a payment into the Failed state, and records the reason the +// payment failed. After invoking this method, InitPayment should return nil on +// its next call for this payment hash, allowing the switch to make a +// subsequent payment. +func (p *PaymentControl) Fail(paymentHash lntypes.Hash, + reason FailureReason) (*route.Route, error) { + + var ( + updateErr error + route *route.Route + ) + err := p.db.Batch(func(tx *bbolt.Tx) error { + // Reset the update error, to avoid carrying over an error + // from a previous execution of the batched db transaction. + updateErr = nil + + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err == ErrPaymentNotInitiated { + updateErr = ErrPaymentNotInitiated + return nil + } else if err != nil { + return err + } + + // We can only mark in-flight payments as failed. + if err := ensureInFlight(bucket); err != nil { + updateErr = err + return nil + } + + // Put the failure reason in the bucket for record keeping. + v := []byte{byte(reason)} + err = bucket.Put(paymentFailInfoKey, v) + if err != nil { + return err + } + + // Retrieve attempt info for the notification, if available. + attempt, err := fetchPaymentAttempt(bucket) + if err != nil && err != errNoAttemptInfo { + return err + } + if err != errNoAttemptInfo { + route = &attempt.Route + } + + return nil + }) + if err != nil { + return nil, err + } + + return route, updateErr +} + +// FetchPayment returns information about a payment from the database. +func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( + *Payment, error) { + + var payment *Payment + err := p.db.View(func(tx *bbolt.Tx) error { + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + payment, err = fetchPayment(bucket) + + return err + }) + if err != nil { + return nil, err + } + + return payment, nil +} + +// createPaymentBucket creates or fetches the sub-bucket assigned to this +// payment hash. +func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + return payments.CreateBucketIfNotExists(paymentHash[:]) +} + +// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If +// the bucket does not exist, it returns ErrPaymentNotInitiated. +func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( + *bbolt.Bucket, error) { + + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil, ErrPaymentNotInitiated + } + + bucket := payments.Bucket(paymentHash[:]) + if bucket == nil { + return nil, ErrPaymentNotInitiated + } + + return bucket, nil + +} + +// nextPaymentSequence returns the next sequence number to store for a new +// payment. +func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) { + payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) + if err != nil { + return nil, err + } + + seq, err := payments.NextSequence() + if err != nil { + return nil, err + } + + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, seq) + return b, nil +} + +// fetchPaymentStatus fetches the payment status of the payment. If the payment +// isn't found, it will default to "StatusUnknown". +func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { + if bucket.Get(paymentSettleInfoKey) != nil { + return StatusSucceeded + } + + if bucket.Get(paymentFailInfoKey) != nil { + return StatusFailed + } + + if bucket.Get(paymentCreationInfoKey) != nil { + return StatusInFlight + } + + return StatusUnknown +} + +// ensureInFlight checks whether the payment found in the given bucket has +// status InFlight, and returns an error otherwise. This should be used to +// ensure we only mark in-flight payments as succeeded or failed. +func ensureInFlight(bucket *bbolt.Bucket) error { + paymentStatus := fetchPaymentStatus(bucket) + + switch { + + // The payment was indeed InFlight, return. + case paymentStatus == StatusInFlight: + return nil + + // Our records show the payment as unknown, meaning it never + // should have left the switch. + case paymentStatus == StatusUnknown: + return ErrPaymentNotInitiated + + // The payment succeeded previously. + case paymentStatus == StatusSucceeded: + return ErrPaymentAlreadySucceeded + + // The payment was already failed. + case paymentStatus == StatusFailed: + return ErrPaymentAlreadyFailed + + default: + return ErrUnknownPaymentStatus + } +} + +// fetchPaymentAttempt fetches the payment attempt from the bucket. +func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) { + attemptData := bucket.Get(paymentAttemptInfoKey) + if attemptData == nil { + return nil, errNoAttemptInfo + } + + r := bytes.NewReader(attemptData) + return deserializePaymentAttemptInfo(r) +} + +// InFlightPayment is a wrapper around a payment that has status InFlight. +type InFlightPayment struct { + // Info is the PaymentCreationInfo of the in-flight payment. + Info *PaymentCreationInfo + + // Attempt contains information about the last payment attempt that was + // made to this payment hash. + // + // NOTE: Might be nil. + Attempt *PaymentAttemptInfo +} + +// FetchInFlightPayments returns all payments with status InFlight. +func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { + var inFlights []*InFlightPayment + err := p.db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + return payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + return fmt.Errorf("non bucket element") + } + + // If the status is not InFlight, we can return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus != StatusInFlight { + return nil + } + + var ( + inFlight = &InFlightPayment{} + err error + ) + + // Get the CreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return fmt.Errorf("unable to find creation " + + "info for inflight payment") + } + + r := bytes.NewReader(b) + inFlight.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return err + } + + // Now get the attempt info. It could be that there is + // no attempt info yet. + inFlight.Attempt, err = fetchPaymentAttempt(bucket) + if err != nil && err != errNoAttemptInfo { + return err + } + + inFlights = append(inFlights, inFlight) + return nil + }) + }) + if err != nil { + return nil, err + } + + return inFlights, nil +} diff --git a/channeldb/migration_01_to_11/payment_control_test.go b/channeldb/migration_01_to_11/payment_control_test.go new file mode 100644 index 00000000..9868475e --- /dev/null +++ b/channeldb/migration_01_to_11/payment_control_test.go @@ -0,0 +1,550 @@ +package migration_01_to_11 + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "io/ioutil" + "reflect" + "testing" + "time" + + "github.com/btcsuite/fastsha256" + "github.com/coreos/bbolt" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" +) + +func initDB() (*DB, error) { + tempPath, err := ioutil.TempDir("", "switchdb") + if err != nil { + return nil, err + } + + db, err := Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } + return preimage, nil +} + +func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, + lntypes.Preimage, error) { + + preimage, err := genPreimage() + if err != nil { + return nil, nil, preimage, fmt.Errorf("unable to "+ + "generate preimage: %v", err) + } + + rhash := fastsha256.Sum256(preimage[:]) + return &PaymentCreationInfo{ + PaymentHash: rhash, + Value: 1, + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + }, + &PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: testRoute, + }, preimage, nil +} + +// TestPaymentControlSwitchFail checks that payment status returns to Failed +// status after failing, and that InitPayment allows another HTLC for the +// same payment hash. +func TestPaymentControlSwitchFail(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + &failReason, + ) + + // Sends the htlc again, which should succeed since the prior payment + // failed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Record a new attempt. + attempt.PaymentID = 2 + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Verifies that status was changed to StatusSucceeded. + var route *route.Route + route, err = pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + err = assertRouteEqual(route, &attempt.Route) + if err != nil { + t.Fatalf("unexpected route returned: %v vs %v: %v", + spew.Sdump(attempt.Route), spew.Sdump(*route), err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + // Attempt a final payment, which should now fail since the prior + // payment succeed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSwitchDoubleSend checks the ability of payment control to +// prevent double sending of htlc message, when message is in StatusInFlight. +func TestPaymentControlSwitchDoubleSend(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate base status and move it to + // StatusInFlight and verifies that it was changed. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, + nil, + ) + + // Try to initiate double sending of htlc message with the same + // payment hash, should result in error indicating that payment has + // already been sent. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // Record an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, + nil, + ) + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrPaymentInFlight { + t.Fatalf("payment control wrong behaviour: " + + "double sending must trigger ErrPaymentInFlight error") + } + + // After settling, the error should be ErrAlreadyPaid. + if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) + + err = pControl.InitPayment(info.PaymentHash, info) + if err != ErrAlreadyPaid { + t.Fatalf("unable to send htlc message: %v", err) + } +} + +// TestPaymentControlSuccessesWithoutInFlight checks that the payment +// control will disallow calls to Success when no payment is in flight. +func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Attempt to complete the payment should fail. + _, err = pControl.Success(info.PaymentHash, preimg) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, + nil, + ) +} + +// TestPaymentControlFailsWithoutInFlight checks that a strict payment +// control will disallow calls to Fail when no payment is in flight. +func TestPaymentControlFailsWithoutInFlight(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + info, _, _, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Calling Fail should return an error. + _, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute) + if err != ErrPaymentNotInitiated { + t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) + assertPaymentInfo( + t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, + ) +} + +// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only +// deletes payments from the database that are not in-flight. +func TestPaymentControlDeleteNonInFligt(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewPaymentControl(db) + + payments := []struct { + failed bool + success bool + }{ + { + failed: true, + success: false, + }, + { + failed: false, + success: true, + }, + { + failed: false, + success: false, + }, + } + + for _, p := range payments { + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatalf("unable to generate htlc message: %v", err) + } + + // Sends base htlc message which initiate StatusInFlight. + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatalf("unable to send htlc message: %v", err) + } + + if p.failed { + // Fail the payment, which should moved it to Failed. + failReason := FailureReasonNoRoute + _, err = pControl.Fail(info.PaymentHash, failReason) + if err != nil { + t.Fatalf("unable to fail payment hash: %v", err) + } + + // Verify the status is indeed Failed. + assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, &failReason, + ) + } else if p.success { + // Verifies that status was changed to StatusSucceeded. + _, err := pControl.Success(info.PaymentHash, preimg) + if err != nil { + t.Fatalf("error shouldn't have been received, got: %v", err) + } + + assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, preimg, nil, + ) + } else { + assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) + assertPaymentInfo( + t, db, info.PaymentHash, info, attempt, + lntypes.Preimage{}, nil, + ) + } + } + + // Delete payments. + if err := db.DeletePayments(); err != nil { + t.Fatal(err) + } + + // This should leave the in-flight payment. + dbPayments, err := db.FetchPayments() + if err != nil { + t.Fatal(err) + } + + if len(dbPayments) != 1 { + t.Fatalf("expected one payment, got %d", len(dbPayments)) + } + + status := dbPayments[0].Status + if status != StatusInFlight { + t.Fatalf("expected in-fligth status, got %v", status) + } +} + +func assertPaymentStatus(t *testing.T, db *DB, + hash [32]byte, expStatus PaymentStatus) { + + t.Helper() + + var paymentStatus = StatusUnknown + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil { + return nil + } + + // Get the existing status of this payment, if any. + paymentStatus = fetchPaymentStatus(bucket) + return nil + }) + if err != nil { + t.Fatalf("unable to fetch payment status: %v", err) + } + + if paymentStatus != expStatus { + t.Fatalf("payment status mismatch: expected %v, got %v", + expStatus, paymentStatus) + } +} + +func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { + b := bucket.Get(paymentCreationInfoKey) + switch { + case b == nil && c == nil: + return nil + case b == nil: + return fmt.Errorf("expected creation info not found") + case c == nil: + return fmt.Errorf("unexpected creation info found") + } + + r := bytes.NewReader(b) + c2, err := deserializePaymentCreationInfo(r) + if err != nil { + return err + } + if !reflect.DeepEqual(c, c2) { + return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", + spew.Sdump(c), spew.Sdump(c2)) + } + + return nil +} + +func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { + b := bucket.Get(paymentAttemptInfoKey) + switch { + case b == nil && a == nil: + return nil + case b == nil: + return fmt.Errorf("expected attempt info not found") + case a == nil: + return fmt.Errorf("unexpected attempt info found") + } + + r := bytes.NewReader(b) + a2, err := deserializePaymentAttemptInfo(r) + if err != nil { + return err + } + + return assertRouteEqual(&a.Route, &a2.Route) +} + +func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { + zero := lntypes.Preimage{} + b := bucket.Get(paymentSettleInfoKey) + switch { + case b == nil && preimg == zero: + return nil + case b == nil: + return fmt.Errorf("expected preimage not found") + case preimg == zero: + return fmt.Errorf("unexpected preimage found") + } + + var pre2 lntypes.Preimage + copy(pre2[:], b[:]) + if preimg != pre2 { + return fmt.Errorf("Preimages don't match: %x vs %x", + preimg, pre2) + } + + return nil +} + +func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { + b := bucket.Get(paymentFailInfoKey) + switch { + case b == nil && failReason == nil: + return nil + case b == nil: + return fmt.Errorf("expected fail info not found") + case failReason == nil: + return fmt.Errorf("unexpected fail info found") + } + + failReason2 := FailureReason(b[0]) + if *failReason != failReason2 { + return fmt.Errorf("Failure infos don't match: %v vs %v", + *failReason, failReason2) + } + + return nil +} + +func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, + c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage, + f *FailureReason) { + + t.Helper() + + err := db.View(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil && c == nil { + return nil + } + if payments == nil { + return fmt.Errorf("sent payments not found") + } + + bucket := payments.Bucket(hash[:]) + if bucket == nil && c == nil { + return nil + } + + if bucket == nil { + return fmt.Errorf("payment not found") + } + + if err := checkPaymentCreationInfo(bucket, c); err != nil { + return err + } + + if err := checkPaymentAttemptInfo(bucket, a); err != nil { + return err + } + + if err := checkSettleInfo(bucket, s); err != nil { + return err + } + + if err := checkFailInfo(bucket, f); err != nil { + return err + } + return nil + }) + if err != nil { + t.Fatalf("assert payment info failed: %v", err) + } + +} diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go new file mode 100644 index 00000000..fd3db5a1 --- /dev/null +++ b/channeldb/migration_01_to_11/payments.go @@ -0,0 +1,669 @@ +package migration_01_to_11 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sort" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // paymentsRootBucket is the name of the top-level bucket within the + // database that stores all data related to payments. Within this + // bucket, each payment hash its own sub-bucket keyed by its payment + // hash. + // + // Bucket hierarchy: + // + // root-bucket + // | + // |-- + // | |--sequence-key: + // | |--creation-info-key: + // | |--attempt-info-key: + // | |--settle-info-key: + // | |--fail-info-key: + // | | + // | |--duplicate-bucket (only for old, completed payments) + // | | + // | |-- + // | | |--sequence-key: + // | | |--creation-info-key: + // | | |--attempt-info-key: + // | | |--settle-info-key: + // | | |--fail-info-key: + // | | + // | |-- + // | | | + // | ... ... + // | + // |-- + // | | + // | ... + // ... + // + paymentsRootBucket = []byte("payments-root-bucket") + + // paymentDublicateBucket is the name of a optional sub-bucket within + // the payment hash bucket, that is used to hold duplicate payments to + // a payment hash. This is needed to support information from earlier + // versions of lnd, where it was possible to pay to a payment hash more + // than once. + paymentDuplicateBucket = []byte("payment-duplicate-bucket") + + // paymentSequenceKey is a key used in the payment's sub-bucket to + // store the sequence number of the payment. + paymentSequenceKey = []byte("payment-sequence-key") + + // paymentCreationInfoKey is a key used in the payment's sub-bucket to + // store the creation info of the payment. + paymentCreationInfoKey = []byte("payment-creation-info") + + // paymentAttemptInfoKey is a key used in the payment's sub-bucket to + // store the info about the latest attempt that was done for the + // payment in question. + paymentAttemptInfoKey = []byte("payment-attempt-info") + + // paymentSettleInfoKey is a key used in the payment's sub-bucket to + // store the settle info of the payment. + paymentSettleInfoKey = []byte("payment-settle-info") + + // paymentFailInfoKey is a key used in the payment's sub-bucket to + // store information about the reason a payment failed. + paymentFailInfoKey = []byte("payment-fail-info") +) + +// FailureReason encodes the reason a payment ultimately failed. +type FailureReason byte + +const ( + // FailureReasonTimeout indicates that the payment did timeout before a + // successful payment attempt was made. + FailureReasonTimeout FailureReason = 0 + + // FailureReasonNoRoute indicates no successful route to the + // destination was found during path finding. + FailureReasonNoRoute FailureReason = 1 + + // FailureReasonError indicates that an unexpected error happened during + // payment. + FailureReasonError FailureReason = 2 + + // FailureReasonIncorrectPaymentDetails indicates that either the hash + // is unknown or the final cltv delta or amount is incorrect. + FailureReasonIncorrectPaymentDetails FailureReason = 3 + + // TODO(halseth): cancel state. + + // TODO(joostjager): Add failure reasons for: + // LocalLiquidityInsufficient, RemoteCapacityInsufficient. +) + +// String returns a human readable FailureReason +func (r FailureReason) String() string { + switch r { + case FailureReasonTimeout: + return "timeout" + case FailureReasonNoRoute: + return "no_route" + case FailureReasonError: + return "error" + case FailureReasonIncorrectPaymentDetails: + return "incorrect_payment_details" + } + + return "unknown" +} + +// PaymentStatus represent current status of payment +type PaymentStatus byte + +const ( + // StatusUnknown is the status where a payment has never been initiated + // and hence is unknown. + StatusUnknown PaymentStatus = 0 + + // StatusInFlight is the status where a payment has been initiated, but + // a response has not been received. + StatusInFlight PaymentStatus = 1 + + // StatusSucceeded is the status where a payment has been initiated and + // the payment was completed successfully. + StatusSucceeded PaymentStatus = 2 + + // StatusFailed is the status where a payment has been initiated and a + // failure result has come back. + StatusFailed PaymentStatus = 3 +) + +// Bytes returns status as slice of bytes. +func (ps PaymentStatus) Bytes() []byte { + return []byte{byte(ps)} +} + +// FromBytes sets status from slice of bytes. +func (ps *PaymentStatus) FromBytes(status []byte) error { + if len(status) != 1 { + return errors.New("payment status is empty") + } + + switch PaymentStatus(status[0]) { + case StatusUnknown, StatusInFlight, StatusSucceeded, StatusFailed: + *ps = PaymentStatus(status[0]) + default: + return errors.New("unknown payment status") + } + + return nil +} + +// String returns readable representation of payment status. +func (ps PaymentStatus) String() string { + switch ps { + case StatusUnknown: + return "Unknown" + case StatusInFlight: + return "In Flight" + case StatusSucceeded: + return "Succeeded" + case StatusFailed: + return "Failed" + default: + return "Unknown" + } +} + +// PaymentCreationInfo is the information necessary to have ready when +// initiating a payment, moving it into state InFlight. +type PaymentCreationInfo struct { + // PaymentHash is the hash this payment is paying to. + PaymentHash lntypes.Hash + + // Value is the amount we are paying. + Value lnwire.MilliSatoshi + + // CreatingDate is the time when this payment was initiated. + CreationDate time.Time + + // PaymentRequest is the full payment request, if any. + PaymentRequest []byte +} + +// PaymentAttemptInfo contains information about a specific payment attempt for +// a given payment. This information is used by the router to handle any errors +// coming back after an attempt is made, and to query the switch about the +// status of a payment. For settled payment this will be the information for +// the succeeding payment attempt. +type PaymentAttemptInfo struct { + // PaymentID is the unique ID used for this attempt. + PaymentID uint64 + + // SessionKey is the ephemeral key used for this payment attempt. + SessionKey *btcec.PrivateKey + + // Route is the route attempted to send the HTLC. + Route route.Route +} + +// Payment is a wrapper around a payment's PaymentCreationInfo, +// PaymentAttemptInfo, and preimage. All payments will have the +// PaymentCreationInfo set, the PaymentAttemptInfo will be set only if at least +// one payment attempt has been made, while only completed payments will have a +// non-zero payment preimage. +type Payment struct { + // sequenceNum is a unique identifier used to sort the payments in + // order of creation. + sequenceNum uint64 + + // Status is the current PaymentStatus of this payment. + Status PaymentStatus + + // Info holds all static information about this payment, and is + // populated when the payment is initiated. + Info *PaymentCreationInfo + + // Attempt is the information about the last payment attempt made. + // + // NOTE: Can be nil if no attempt is yet made. + Attempt *PaymentAttemptInfo + + // PaymentPreimage is the preimage of a successful payment. This serves + // as a proof of payment. It will only be non-nil for settled payments. + // + // NOTE: Can be nil if payment is not settled. + PaymentPreimage *lntypes.Preimage + + // Failure is a failure reason code indicating the reason the payment + // failed. It is only non-nil for failed payments. + // + // NOTE: Can be nil if payment is not failed. + Failure *FailureReason +} + +// FetchPayments returns all sent payments found in the DB. +func (db *DB) FetchPayments() ([]*Payment, error) { + var payments []*Payment + + err := db.View(func(tx *bbolt.Tx) error { + paymentsBucket := tx.Bucket(paymentsRootBucket) + if paymentsBucket == nil { + return nil + } + + return paymentsBucket.ForEach(func(k, v []byte) error { + bucket := paymentsBucket.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + p, err := fetchPayment(bucket) + if err != nil { + return err + } + + payments = append(payments, p) + + // For older versions of lnd, duplicate payments to a + // payment has was possible. These will be found in a + // sub-bucket indexed by their sequence number if + // available. + dup := bucket.Bucket(paymentDuplicateBucket) + if dup == nil { + return nil + } + + return dup.ForEach(func(k, v []byte) error { + subBucket := dup.Bucket(k) + if subBucket == nil { + // We one bucket for each duplicate to + // be found. + return fmt.Errorf("non bucket element" + + "in duplicate bucket") + } + + p, err := fetchPayment(subBucket) + if err != nil { + return err + } + + payments = append(payments, p) + return nil + }) + }) + }) + if err != nil { + return nil, err + } + + // Before returning, sort the payments by their sequence number. + sort.Slice(payments, func(i, j int) bool { + return payments[i].sequenceNum < payments[j].sequenceNum + }) + + return payments, nil +} + +func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { + var ( + err error + p = &Payment{} + ) + + seqBytes := bucket.Get(paymentSequenceKey) + if seqBytes == nil { + return nil, fmt.Errorf("sequence number not found") + } + + p.sequenceNum = binary.BigEndian.Uint64(seqBytes) + + // Get the payment status. + p.Status = fetchPaymentStatus(bucket) + + // Get the PaymentCreationInfo. + b := bucket.Get(paymentCreationInfoKey) + if b == nil { + return nil, fmt.Errorf("creation info not found") + } + + r := bytes.NewReader(b) + p.Info, err = deserializePaymentCreationInfo(r) + if err != nil { + return nil, err + + } + + // Get the PaymentAttemptInfo. This can be unset. + b = bucket.Get(paymentAttemptInfoKey) + if b != nil { + r = bytes.NewReader(b) + p.Attempt, err = deserializePaymentAttemptInfo(r) + if err != nil { + return nil, err + } + } + + // Get the payment preimage. This is only found for + // completed payments. + b = bucket.Get(paymentSettleInfoKey) + if b != nil { + var preimg lntypes.Preimage + copy(preimg[:], b[:]) + p.PaymentPreimage = &preimg + } + + // Get failure reason if available. + b = bucket.Get(paymentFailInfoKey) + if b != nil { + reason := FailureReason(b[0]) + p.Failure = &reason + } + + return p, nil +} + +// DeletePayments deletes all completed and failed payments from the DB. +func (db *DB) DeletePayments() error { + return db.Update(func(tx *bbolt.Tx) error { + payments := tx.Bucket(paymentsRootBucket) + if payments == nil { + return nil + } + + var deleteBuckets [][]byte + err := payments.ForEach(func(k, _ []byte) error { + bucket := payments.Bucket(k) + if bucket == nil { + // We only expect sub-buckets to be found in + // this top-level bucket. + return fmt.Errorf("non bucket element in " + + "payments bucket") + } + + // If the status is InFlight, we cannot safely delete + // the payment information, so we return early. + paymentStatus := fetchPaymentStatus(bucket) + if paymentStatus == StatusInFlight { + return nil + } + + deleteBuckets = append(deleteBuckets, k) + return nil + }) + if err != nil { + return err + } + + for _, k := range deleteBuckets { + if err := payments.DeleteBucket(k); err != nil { + return err + } + } + + return nil + }) +} + +func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { + var scratch [8]byte + + if _, err := w.Write(c.PaymentHash[:]); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(c.Value)) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + byteOrder.PutUint64(scratch[:], uint64(c.CreationDate.Unix())) + if _, err := w.Write(scratch[:]); err != nil { + return err + } + + byteOrder.PutUint32(scratch[:4], uint32(len(c.PaymentRequest))) + if _, err := w.Write(scratch[:4]); err != nil { + return err + } + + if _, err := w.Write(c.PaymentRequest[:]); err != nil { + return err + } + + return nil +} + +func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { + var scratch [8]byte + + c := &PaymentCreationInfo{} + + if _, err := io.ReadFull(r, c.PaymentHash[:]); err != nil { + return nil, err + } + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return nil, err + } + c.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) + + if _, err := io.ReadFull(r, scratch[:]); err != nil { + return nil, err + } + c.CreationDate = time.Unix(int64(byteOrder.Uint64(scratch[:])), 0) + + if _, err := io.ReadFull(r, scratch[:4]); err != nil { + return nil, err + } + + reqLen := uint32(byteOrder.Uint32(scratch[:4])) + payReq := make([]byte, reqLen) + if reqLen > 0 { + if _, err := io.ReadFull(r, payReq[:]); err != nil { + return nil, err + } + } + c.PaymentRequest = payReq + + return c, nil +} + +func serializePaymentAttemptInfo(w io.Writer, a *PaymentAttemptInfo) error { + if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { + return err + } + + if err := SerializeRoute(w, a.Route); err != nil { + return err + } + + return nil +} + +func deserializePaymentAttemptInfo(r io.Reader) (*PaymentAttemptInfo, error) { + a := &PaymentAttemptInfo{} + err := ReadElements(r, &a.PaymentID, &a.SessionKey) + if err != nil { + return nil, err + } + a.Route, err = DeserializeRoute(r) + if err != nil { + return nil, err + } + return a, nil +} + +func serializeHop(w io.Writer, h *route.Hop) error { + if err := WriteElements(w, + h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, + h.AmtToForward, + ); err != nil { + return err + } + + if err := binary.Write(w, byteOrder, h.LegacyPayload); err != nil { + return err + } + + // For legacy payloads, we don't need to write any TLV records, so + // we'll write a zero indicating the our serialized TLV map has no + // records. + if h.LegacyPayload { + return WriteElements(w, uint32(0)) + } + + // Otherwise, we'll transform our slice of records into a map of the + // raw bytes, then serialize them in-line with a length (number of + // elements) prefix. + mapRecords, err := tlv.RecordsToMap(h.TLVRecords) + if err != nil { + return err + } + + numRecords := uint32(len(mapRecords)) + if err := WriteElements(w, numRecords); err != nil { + return err + } + + for recordType, rawBytes := range mapRecords { + if err := WriteElements(w, recordType); err != nil { + return err + } + + if err := wire.WriteVarBytes(w, 0, rawBytes); err != nil { + return err + } + } + + return nil +} + +// maxOnionPayloadSize is the largest Sphinx payload possible, so we don't need +// to read/write a TLV stream larger than this. +const maxOnionPayloadSize = 1300 + +func deserializeHop(r io.Reader) (*route.Hop, error) { + h := &route.Hop{} + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return nil, err + } + copy(h.PubKeyBytes[:], pub) + + if err := ReadElements(r, + &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, + ); err != nil { + return nil, err + } + + // TODO(roasbeef): change field to allow LegacyPayload false to be the + // legacy default? + err := binary.Read(r, byteOrder, &h.LegacyPayload) + if err != nil { + return nil, err + } + + var numElements uint32 + if err := ReadElements(r, &numElements); err != nil { + return nil, err + } + + // If there're no elements, then we can return early. + if numElements == 0 { + return h, nil + } + + tlvMap := make(map[uint64][]byte) + for i := uint32(0); i < numElements; i++ { + var tlvType uint64 + if err := ReadElements(r, &tlvType); err != nil { + return nil, err + } + + rawRecordBytes, err := wire.ReadVarBytes( + r, 0, maxOnionPayloadSize, "tlv", + ) + if err != nil { + return nil, err + } + + tlvMap[tlvType] = rawRecordBytes + } + + tlvRecords, err := tlv.MapToRecords(tlvMap) + if err != nil { + return nil, err + } + + h.TLVRecords = tlvRecords + + return h, nil +} + +// SerializeRoute serializes a route. +func SerializeRoute(w io.Writer, r route.Route) error { + if err := WriteElements(w, + r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], + ); err != nil { + return err + } + + if err := WriteElements(w, uint32(len(r.Hops))); err != nil { + return err + } + + for _, h := range r.Hops { + if err := serializeHop(w, h); err != nil { + return err + } + } + + return nil +} + +// DeserializeRoute deserializes a route. +func DeserializeRoute(r io.Reader) (route.Route, error) { + rt := route.Route{} + if err := ReadElements(r, + &rt.TotalTimeLock, &rt.TotalAmount, + ); err != nil { + return rt, err + } + + var pub []byte + if err := ReadElements(r, &pub); err != nil { + return rt, err + } + copy(rt.SourcePubKey[:], pub) + + var numHops uint32 + if err := ReadElements(r, &numHops); err != nil { + return rt, err + } + + var hops []*route.Hop + for i := uint32(0); i < numHops; i++ { + hop, err := deserializeHop(r) + if err != nil { + return rt, err + } + hops = append(hops, hop) + } + rt.Hops = hops + + return rt, nil +} diff --git a/channeldb/migration_01_to_11/payments_test.go b/channeldb/migration_01_to_11/payments_test.go new file mode 100644 index 00000000..07307941 --- /dev/null +++ b/channeldb/migration_01_to_11/payments_test.go @@ -0,0 +1,324 @@ +package migration_01_to_11 + +import ( + "bytes" + "errors" + "fmt" + "math/rand" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + priv, _ = btcec.NewPrivateKey(btcec.S256()) + pub = priv.PubKey() + + tlvBytes = []byte{1, 2, 3} + tlvEncoder = tlv.StubEncoder(tlvBytes) + testHop1 = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + TLVRecords: []tlv.Record{ + tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), + tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), + }, + } + + testHop2 = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + LegacyPayload: true, + } + + testRoute = route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{ + testHop1, + testHop2, + }, + } +) + +func makeFakePayment() *outgoingPayment { + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + Memo: []byte("fake memo"), + Receipt: []byte("fake receipt"), + PaymentRequest: []byte(""), + } + + copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) + fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) + + fakePath := make([][33]byte, 3) + for i := 0; i < 3; i++ { + copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33)) + } + + fakePayment := &outgoingPayment{ + Invoice: *fakeInvoice, + Fee: 101, + Path: fakePath, + TimeLockLength: 1000, + } + copy(fakePayment.PaymentPreimage[:], rev[:]) + return fakePayment +} + +func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { + var preimg lntypes.Preimage + copy(preimg[:], rev[:]) + + c := &PaymentCreationInfo{ + PaymentHash: preimg.Hash(), + Value: 1000, + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte(""), + } + + a := &PaymentAttemptInfo{ + PaymentID: 44, + SessionKey: priv, + Route: testRoute, + } + return c, a +} + +// randomBytes creates random []byte with length in range [minLen, maxLen) +func randomBytes(minLen, maxLen int) ([]byte, error) { + randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) + + if _, err := rand.Read(randBuf); err != nil { + return nil, fmt.Errorf("Internal error. "+ + "Cannot generate random string: %v", err) + } + + return randBuf, nil +} + +func makeRandomFakePayment() (*outgoingPayment, error) { + var err error + fakeInvoice := &Invoice{ + // Use single second precision to avoid false positive test + // failures due to the monotonic time component. + CreationDate: time.Unix(time.Now().Unix(), 0), + } + + fakeInvoice.Memo, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + fakeInvoice.Receipt, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + fakeInvoice.PaymentRequest, err = randomBytes(1, 50) + if err != nil { + return nil, err + } + + preImg, err := randomBytes(32, 33) + if err != nil { + return nil, err + } + copy(fakeInvoice.Terms.PaymentPreimage[:], preImg) + + fakeInvoice.Terms.Value = lnwire.MilliSatoshi(rand.Intn(10000)) + + fakePathLen := 1 + rand.Intn(5) + fakePath := make([][33]byte, fakePathLen) + for i := 0; i < fakePathLen; i++ { + b, err := randomBytes(33, 34) + if err != nil { + return nil, err + } + copy(fakePath[i][:], b) + } + + fakePayment := &outgoingPayment{ + Invoice: *fakeInvoice, + Fee: lnwire.MilliSatoshi(rand.Intn(1001)), + Path: fakePath, + TimeLockLength: uint32(rand.Intn(10000)), + } + copy(fakePayment.PaymentPreimage[:], fakeInvoice.Terms.PaymentPreimage[:]) + + return fakePayment, nil +} + +func TestSentPaymentSerialization(t *testing.T) { + t.Parallel() + + c, s := makeFakeInfo() + + var b bytes.Buffer + if err := serializePaymentCreationInfo(&b, c); err != nil { + t.Fatalf("unable to serialize creation info: %v", err) + } + + newCreationInfo, err := deserializePaymentCreationInfo(&b) + if err != nil { + t.Fatalf("unable to deserialize creation info: %v", err) + } + + if !reflect.DeepEqual(c, newCreationInfo) { + t.Fatalf("Payments do not match after "+ + "serialization/deserialization %v vs %v", + spew.Sdump(c), spew.Sdump(newCreationInfo), + ) + } + + b.Reset() + if err := serializePaymentAttemptInfo(&b, s); err != nil { + t.Fatalf("unable to serialize info: %v", err) + } + + newAttemptInfo, err := deserializePaymentAttemptInfo(&b) + if err != nil { + t.Fatalf("unable to deserialize info: %v", err) + } + + // First we verify all the records match up porperly, as they aren't + // able to be properly compared using reflect.DeepEqual. + err = assertRouteEqual(&s.Route, &newAttemptInfo.Route) + if err != nil { + t.Fatalf("Routes do not match after "+ + "serialization/deserialization: %v", err) + } + + // Clear routes to allow DeepEqual to compare the remaining fields. + newAttemptInfo.Route = route.Route{} + s.Route = route.Route{} + + if !reflect.DeepEqual(s, newAttemptInfo) { + s.SessionKey.Curve = nil + newAttemptInfo.SessionKey.Curve = nil + t.Fatalf("Payments do not match after "+ + "serialization/deserialization %v vs %v", + spew.Sdump(s), spew.Sdump(newAttemptInfo), + ) + } +} + +// assertRouteEquals compares to routes for equality and returns an error if +// they are not equal. +func assertRouteEqual(a, b *route.Route) error { + err := assertRouteHopRecordsEqual(a, b) + if err != nil { + return err + } + + // TLV records have already been compared and need to be cleared to + // properly compare the remaining fields using DeepEqual. + copyRouteNoHops := func(r *route.Route) *route.Route { + copy := *r + copy.Hops = make([]*route.Hop, len(r.Hops)) + for i, hop := range r.Hops { + hopCopy := *hop + hopCopy.TLVRecords = nil + copy.Hops[i] = &hopCopy + } + return © + } + + if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { + return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", + spew.Sdump(a), spew.Sdump(b)) + } + + return nil +} + +func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { + if len(r1.Hops) != len(r2.Hops) { + return errors.New("route hop count mismatch") + } + + for i := 0; i < len(r1.Hops); i++ { + records1 := r1.Hops[i].TLVRecords + records2 := r2.Hops[i].TLVRecords + if len(records1) != len(records2) { + return fmt.Errorf("route record count for hop %v "+ + "mismatch", i) + } + + for j := 0; j < len(records1); j++ { + expectedRecord := records1[j] + newRecord := records2[j] + + err := assertHopRecordsEqual(expectedRecord, newRecord) + if err != nil { + return fmt.Errorf("route record mismatch: %v", err) + } + } + } + + return nil +} + +func assertHopRecordsEqual(h1, h2 tlv.Record) error { + if h1.Type() != h2.Type() { + return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(), + h2.Type()) + } + + var b bytes.Buffer + if err := h2.Encode(&b); err != nil { + return fmt.Errorf("unable to encode record: %v", err) + } + + if !bytes.Equal(b.Bytes(), tlvBytes) { + return fmt.Errorf("wrong raw record: expected %x, got %x", + tlvBytes, b.Bytes()) + } + + if h1.Size() != h2.Size() { + return fmt.Errorf("wrong size: expected %v, "+ + "got %v", h1.Size(), h2.Size()) + } + + return nil +} + +func TestRouteSerialization(t *testing.T) { + t.Parallel() + + var b bytes.Buffer + if err := SerializeRoute(&b, testRoute); err != nil { + t.Fatal(err) + } + + r := bytes.NewReader(b.Bytes()) + route2, err := DeserializeRoute(r) + if err != nil { + t.Fatal(err) + } + + // First we verify all the records match up porperly, as they aren't + // able to be properly compared using reflect.DeepEqual. + err = assertRouteEqual(&testRoute, &route2) + if err != nil { + t.Fatalf("routes not equal: \n%v vs \n%v", + spew.Sdump(testRoute), spew.Sdump(route2)) + } +} diff --git a/channeldb/migration_01_to_11/reject_cache.go b/channeldb/migration_01_to_11/reject_cache.go new file mode 100644 index 00000000..c54d78a8 --- /dev/null +++ b/channeldb/migration_01_to_11/reject_cache.go @@ -0,0 +1,95 @@ +package migration_01_to_11 + +// rejectFlags is a compact representation of various metadata stored by the +// reject cache about a particular channel. +type rejectFlags uint8 + +const ( + // rejectFlagExists is a flag indicating whether the channel exists, + // i.e. the channel is open and has a recent channel update. If this + // flag is not set, the channel is either a zombie or unknown. + rejectFlagExists rejectFlags = 1 << iota + + // rejectFlagZombie is a flag indicating whether the channel is a + // zombie, i.e. the channel is open but has no recent channel updates. + rejectFlagZombie +) + +// packRejectFlags computes the rejectFlags corresponding to the passed boolean +// values indicating whether the edge exists or is a zombie. +func packRejectFlags(exists, isZombie bool) rejectFlags { + var flags rejectFlags + if exists { + flags |= rejectFlagExists + } + if isZombie { + flags |= rejectFlagZombie + } + + return flags +} + +// unpack returns the booleans packed into the rejectFlags. The first indicates +// if the edge exists in our graph, the second indicates if the edge is a +// zombie. +func (f rejectFlags) unpack() (bool, bool) { + return f&rejectFlagExists == rejectFlagExists, + f&rejectFlagZombie == rejectFlagZombie +} + +// rejectCacheEntry caches frequently accessed information about a channel, +// including the timestamps of its latest edge policies and whether or not the +// channel exists in the graph. +type rejectCacheEntry struct { + upd1Time int64 + upd2Time int64 + flags rejectFlags +} + +// rejectCache is an in-memory cache used to improve the performance of +// HasChannelEdge. It caches information about the whether or channel exists, as +// well as the most recent timestamps for each policy (if they exists). +type rejectCache struct { + n int + edges map[uint64]rejectCacheEntry +} + +// newRejectCache creates a new rejectCache with maximum capacity of n entries. +func newRejectCache(n int) *rejectCache { + return &rejectCache{ + n: n, + edges: make(map[uint64]rejectCacheEntry, n), + } +} + +// get returns the entry from the cache for chanid, if it exists. +func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { + entry, ok := c.edges[chanid] + return entry, ok +} + +// insert adds the entry to the reject cache. If an entry for chanid already +// exists, it will be replaced with the new entry. If the entry doesn't exists, +// it will be inserted to the cache, performing a random eviction if the cache +// is at capacity. +func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { + // If entry exists, replace it. + if _, ok := c.edges[chanid]; ok { + c.edges[chanid] = entry + return + } + + // Otherwise, evict an entry at random and insert. + if len(c.edges) == c.n { + for id := range c.edges { + delete(c.edges, id) + break + } + } + c.edges[chanid] = entry +} + +// remove deletes an entry for chanid from the cache, if it exists. +func (c *rejectCache) remove(chanid uint64) { + delete(c.edges, chanid) +} diff --git a/channeldb/migration_01_to_11/reject_cache_test.go b/channeldb/migration_01_to_11/reject_cache_test.go new file mode 100644 index 00000000..e15e0a10 --- /dev/null +++ b/channeldb/migration_01_to_11/reject_cache_test.go @@ -0,0 +1,107 @@ +package migration_01_to_11 + +import ( + "reflect" + "testing" +) + +// TestRejectCache checks the behavior of the rejectCache with respect to insertion, +// eviction, and removal of cache entries. +func TestRejectCache(t *testing.T) { + const cacheSize = 100 + + // Create a new reject cache with the configured max size. + c := newRejectCache(cacheSize) + + // As a sanity check, assert that querying the empty cache does not + // return an entry. + _, ok := c.get(0) + if ok { + t.Fatalf("reject cache should be empty") + } + + // Now, fill up the cache entirely. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, entryForInt(i)) + } + + // Assert that the cache has all of the entries just inserted, since no + // eviction should occur until we try to surpass the max size. + assertHasEntries(t, c, 0, cacheSize) + + // Now, insert a new element that causes the cache to evict an element. + c.insert(cacheSize, entryForInt(cacheSize)) + + // Assert that the cache has this last entry, as the cache should evict + // some prior element and not the newly inserted one. + assertHasEntries(t, c, cacheSize, cacheSize) + + // Iterate over all inserted elements and construct a set of the evicted + // elements. + evicted := make(map[uint64]struct{}) + for i := uint64(0); i < cacheSize+1; i++ { + _, ok := c.get(i) + if !ok { + evicted[i] = struct{}{} + } + } + + // Assert that exactly one element has been evicted. + numEvicted := len(evicted) + if numEvicted != 1 { + t.Fatalf("expected one evicted entry, got: %d", numEvicted) + } + + // Remove the highest item which initially caused the eviction and + // reinsert the element that was evicted prior. + c.remove(cacheSize) + for i := range evicted { + c.insert(i, entryForInt(i)) + } + + // Since the removal created an extra slot, the last insertion should + // not have caused an eviction and the entries for all channels in the + // original set that filled the cache should be present. + assertHasEntries(t, c, 0, cacheSize) + + // Finally, reinsert the existing set back into the cache and test that + // the cache still has all the entries. If the randomized eviction were + // happening on inserts for existing cache items, we expect this to fail + // with high probability. + for i := uint64(0); i < cacheSize; i++ { + c.insert(i, entryForInt(i)) + } + assertHasEntries(t, c, 0, cacheSize) + +} + +// assertHasEntries queries the reject cache for all channels in the range [start, +// end), asserting that they exist and their value matches the entry produced by +// entryForInt. +func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) { + t.Helper() + + for i := start; i < end; i++ { + entry, ok := c.get(i) + if !ok { + t.Fatalf("reject cache should contain chan %d", i) + } + + expEntry := entryForInt(i) + if !reflect.DeepEqual(entry, expEntry) { + t.Fatalf("entry mismatch, want: %v, got: %v", + expEntry, entry) + } + } +} + +// entryForInt generates a unique rejectCacheEntry given an integer. +func entryForInt(i uint64) rejectCacheEntry { + exists := i%2 == 0 + isZombie := i%3 == 0 + return rejectCacheEntry{ + upd1Time: int64(2 * i), + upd2Time: int64(2*i + 1), + flags: packRejectFlags(exists, isZombie), + } +} diff --git a/channeldb/migration_01_to_11/waitingproof.go b/channeldb/migration_01_to_11/waitingproof.go new file mode 100644 index 00000000..64729116 --- /dev/null +++ b/channeldb/migration_01_to_11/waitingproof.go @@ -0,0 +1,251 @@ +package migration_01_to_11 + +import ( + "encoding/binary" + "sync" + + "io" + + "bytes" + + "github.com/coreos/bbolt" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // waitingProofsBucketKey byte string name of the waiting proofs store. + waitingProofsBucketKey = []byte("waitingproofs") + + // ErrWaitingProofNotFound is returned if waiting proofs haven't been + // found by db. + ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " + + "found") + + // ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been + // found by db. + ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " + + "key already exist") +) + +// WaitingProofStore is the bold db map-like storage for half announcement +// signatures. The one responsibility of this storage is to be able to +// retrieve waiting proofs after client restart. +type WaitingProofStore struct { + // cache is used in order to reduce the number of redundant get + // calls, when object isn't stored in it. + cache map[WaitingProofKey]struct{} + db *DB + mu sync.RWMutex +} + +// NewWaitingProofStore creates new instance of proofs storage. +func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { + s := &WaitingProofStore{ + db: db, + cache: make(map[WaitingProofKey]struct{}), + } + + if err := s.ForAll(func(proof *WaitingProof) error { + s.cache[proof.Key()] = struct{}{} + return nil + }); err != nil && err != ErrWaitingProofNotFound { + return nil, err + } + + return s, nil +} + +// Add adds new waiting proof in the storage. +func (s *WaitingProofStore) Add(proof *WaitingProof) error { + s.mu.Lock() + defer s.mu.Unlock() + + err := s.db.Update(func(tx *bbolt.Tx) error { + var err error + var b bytes.Buffer + + // Get or create the bucket. + bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey) + if err != nil { + return err + } + + // Encode the objects and place it in the bucket. + if err := proof.Encode(&b); err != nil { + return err + } + + key := proof.Key() + + return bucket.Put(key[:], b.Bytes()) + }) + if err != nil { + return err + } + + // Knowing that the write succeeded, we can now update the in-memory + // cache with the proof's key. + s.cache[proof.Key()] = struct{}{} + + return nil +} + +// Remove removes the proof from storage by its key. +func (s *WaitingProofStore) Remove(key WaitingProofKey) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.cache[key]; !ok { + return ErrWaitingProofNotFound + } + + err := s.db.Update(func(tx *bbolt.Tx) error { + // Get or create the top bucket. + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + return bucket.Delete(key[:]) + }) + if err != nil { + return err + } + + // Since the proof was successfully deleted from the store, we can now + // remove it from the in-memory cache. + delete(s.cache, key) + + return nil +} + +// ForAll iterates thought all waiting proofs and passing the waiting proof +// in the given callback. +func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { + return s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + // Iterate over objects buckets. + return bucket.ForEach(func(k, v []byte) error { + // Skip buckets fields. + if v == nil { + return nil + } + + r := bytes.NewReader(v) + proof := &WaitingProof{} + if err := proof.Decode(r); err != nil { + return err + } + + return cb(proof) + }) + }) +} + +// Get returns the object which corresponds to the given index. +func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { + proof := &WaitingProof{} + + s.mu.RLock() + defer s.mu.RUnlock() + + if _, ok := s.cache[key]; !ok { + return nil, ErrWaitingProofNotFound + } + + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(waitingProofsBucketKey) + if bucket == nil { + return ErrWaitingProofNotFound + } + + // Iterate over objects buckets. + v := bucket.Get(key[:]) + if v == nil { + return ErrWaitingProofNotFound + } + + r := bytes.NewReader(v) + return proof.Decode(r) + }) + + return proof, err +} + +// WaitingProofKey is the proof key which uniquely identifies the waiting +// proof object. The goal of this key is distinguish the local and remote +// proof for the same channel id. +type WaitingProofKey [9]byte + +// WaitingProof is the storable object, which encapsulate the half proof and +// the information about from which side this proof came. This structure is +// needed to make channel proof exchange persistent, so that after client +// restart we may receive remote/local half proof and process it. +type WaitingProof struct { + *lnwire.AnnounceSignatures + isRemote bool +} + +// NewWaitingProof constructs a new waiting prof instance. +func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof { + return &WaitingProof{ + AnnounceSignatures: proof, + isRemote: isRemote, + } +} + +// OppositeKey returns the key which uniquely identifies opposite waiting proof. +func (p *WaitingProof) OppositeKey() WaitingProofKey { + var key [9]byte + binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) + + if !p.isRemote { + key[8] = 1 + } + return key +} + +// Key returns the key which uniquely identifies waiting proof. +func (p *WaitingProof) Key() WaitingProofKey { + var key [9]byte + binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) + + if p.isRemote { + key[8] = 1 + } + return key +} + +// Encode writes the internal representation of waiting proof in byte stream. +func (p *WaitingProof) Encode(w io.Writer) error { + if err := binary.Write(w, byteOrder, p.isRemote); err != nil { + return err + } + + if err := p.AnnounceSignatures.Encode(w, 0); err != nil { + return err + } + + return nil +} + +// Decode reads the data from the byte stream and initializes the +// waiting proof object with it. +func (p *WaitingProof) Decode(r io.Reader) error { + if err := binary.Read(r, byteOrder, &p.isRemote); err != nil { + return err + } + + msg := &lnwire.AnnounceSignatures{} + if err := msg.Decode(r, 0); err != nil { + return err + } + + (*p).AnnounceSignatures = msg + return nil +} diff --git a/channeldb/migration_01_to_11/waitingproof_test.go b/channeldb/migration_01_to_11/waitingproof_test.go new file mode 100644 index 00000000..968f1157 --- /dev/null +++ b/channeldb/migration_01_to_11/waitingproof_test.go @@ -0,0 +1,59 @@ +package migration_01_to_11 + +import ( + "testing" + + "reflect" + + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/lnwire" +) + +// TestWaitingProofStore tests add/get/remove functions of the waiting proof +// storage. +func TestWaitingProofStore(t *testing.T) { + t.Parallel() + + db, cleanup, err := makeTestDB() + if err != nil { + t.Fatalf("failed to make test database: %s", err) + } + defer cleanup() + + proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ + NodeSignature: wireSig, + BitcoinSignature: wireSig, + }) + + store, err := NewWaitingProofStore(db) + if err != nil { + t.Fatalf("unable to create the waiting proofs storage: %v", + err) + } + + if err := store.Add(proof1); err != nil { + t.Fatalf("unable add proof to storage: %v", err) + } + + proof2, err := store.Get(proof1.Key()) + if err != nil { + t.Fatalf("unable retrieve proof from storage: %v", err) + } + if !reflect.DeepEqual(proof1, proof2) { + t.Fatal("wrong proof retrieved") + } + + if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { + t.Fatalf("proof shouldn't be found: %v", err) + } + + if err := store.Remove(proof1.Key()); err != nil { + t.Fatalf("unable remove proof from storage: %v", err) + } + + if err := store.ForAll(func(proof *WaitingProof) error { + return errors.New("storage should be empty") + }); err != nil && err != ErrWaitingProofNotFound { + t.Fatal(err) + } +} diff --git a/channeldb/migration_01_to_11/witness_cache.go b/channeldb/migration_01_to_11/witness_cache.go new file mode 100644 index 00000000..69de1054 --- /dev/null +++ b/channeldb/migration_01_to_11/witness_cache.go @@ -0,0 +1,229 @@ +package migration_01_to_11 + +import ( + "fmt" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + // ErrNoWitnesses is an error that's returned when no new witnesses have + // been added to the WitnessCache. + ErrNoWitnesses = fmt.Errorf("no witnesses") + + // ErrUnknownWitnessType is returned if a caller attempts to + ErrUnknownWitnessType = fmt.Errorf("unknown witness type") +) + +// WitnessType is enum that denotes what "type" of witness is being +// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce +// any structure on added witnesses, we use this type to partition the +// witnesses on disk, and also to know how to map a witness to its look up key. +type WitnessType uint8 + +var ( + // Sha256HashWitness is a witness that is simply the pre image to a + // hash image. In order to map to its key, we'll use sha256. + Sha256HashWitness WitnessType = 1 +) + +// toDBKey is a helper method that maps a witness type to the key that we'll +// use to store it within the database. +func (w WitnessType) toDBKey() ([]byte, error) { + switch w { + + case Sha256HashWitness: + return []byte{byte(w)}, nil + + default: + return nil, ErrUnknownWitnessType + } +} + +var ( + // witnessBucketKey is the name of the bucket that we use to store all + // witnesses encountered. Within this bucket, we'll create a sub-bucket for + // each witness type. + witnessBucketKey = []byte("byte") +) + +// WitnessCache is a persistent cache of all witnesses we've encountered on the +// network. In the case of multi-hop, multi-step contracts, a cache of all +// witnesses can be useful in the case of partial contract resolution. If +// negotiations break down, we may be forced to locate the witness for a +// portion of the contract on-chain. In this case, we'll then add that witness +// to the cache so the incoming contract can fully resolve witness. +// Additionally, as one MUST always use a unique witness on the network, we may +// use this cache to detect duplicate witnesses. +// +// TODO(roasbeef): need expiry policy? +// * encrypt? +type WitnessCache struct { + db *DB +} + +// NewWitnessCache returns a new instance of the witness cache. +func (d *DB) NewWitnessCache() *WitnessCache { + return &WitnessCache{ + db: d, + } +} + +// witnessEntry is a key-value struct that holds each key -> witness pair, used +// when inserting records into the cache. +type witnessEntry struct { + key []byte + witness []byte +} + +// AddSha256Witnesses adds a batch of new sha256 preimages into the witness +// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the +// preimages' witness type. +func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error { + // Optimistically compute the preimages' hashes before attempting to + // start the db transaction. + entries := make([]witnessEntry, 0, len(preimages)) + for i := range preimages { + hash := preimages[i].Hash() + entries = append(entries, witnessEntry{ + key: hash[:], + witness: preimages[i][:], + }) + } + + return w.addWitnessEntries(Sha256HashWitness, entries) +} + +// addWitnessEntries inserts the witnessEntry key-value pairs into the cache, +// using the appropriate witness type to segment the namespace of possible +// witness types. +func (w *WitnessCache) addWitnessEntries(wType WitnessType, + entries []witnessEntry) error { + + // Exit early if there are no witnesses to add. + if len(entries) == 0 { + return nil + } + + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( + witnessTypeBucketKey, + ) + if err != nil { + return err + } + + for _, entry := range entries { + err = witnessTypeBucket.Put(entry.key, entry.witness) + if err != nil { + return err + } + } + + return nil + }) +} + +// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If +// the witness isn't found, ErrNoWitnesses will be returned. +func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) { + witness, err := w.lookupWitness(Sha256HashWitness, hash[:]) + if err != nil { + return lntypes.Preimage{}, err + } + + return lntypes.MakePreimage(witness) +} + +// lookupWitness attempts to lookup a witness according to its type and also +// its witness key. In the case that the witness isn't found, ErrNoWitnesses +// will be returned. +func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) { + var witness []byte + err := w.db.View(func(tx *bbolt.Tx) error { + witnessBucket := tx.Bucket(witnessBucketKey) + if witnessBucket == nil { + return ErrNoWitnesses + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey) + if witnessTypeBucket == nil { + return ErrNoWitnesses + } + + dbWitness := witnessTypeBucket.Get(witnessKey) + if dbWitness == nil { + return ErrNoWitnesses + } + + witness = make([]byte, len(dbWitness)) + copy(witness[:], dbWitness) + + return nil + }) + if err != nil { + return nil, err + } + + return witness, nil +} + +// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash. +func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error { + return w.deleteWitness(Sha256HashWitness, hash[:]) +} + +// deleteWitness attempts to delete a particular witness from the database. +func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error { + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( + witnessTypeBucketKey, + ) + if err != nil { + return err + } + + return witnessTypeBucket.Delete(witnessKey) + }) +} + +// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After +// this function return with a non-nil error, +func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error { + return w.db.Batch(func(tx *bbolt.Tx) error { + witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) + if err != nil { + return err + } + + witnessTypeBucketKey, err := wType.toDBKey() + if err != nil { + return err + } + + return witnessBucket.DeleteBucket(witnessTypeBucketKey) + }) +} diff --git a/channeldb/migration_01_to_11/witness_cache_test.go b/channeldb/migration_01_to_11/witness_cache_test.go new file mode 100644 index 00000000..92836abe --- /dev/null +++ b/channeldb/migration_01_to_11/witness_cache_test.go @@ -0,0 +1,238 @@ +package migration_01_to_11 + +import ( + "crypto/sha256" + "testing" + + "github.com/lightningnetwork/lnd/lntypes" +) + +// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new +// sha256 preimages to the witness cache. +func TestWitnessCacheSha256Retrieval(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll be attempting to add then lookup two simple sha256 preimages + // within this test. + preimage1 := lntypes.Preimage(rev) + preimage2 := lntypes.Preimage(key) + + preimages := []lntypes.Preimage{preimage1, preimage2} + hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()} + + // First, we'll attempt to add the preimages to the database. + err = wCache.AddSha256Witnesses(preimages...) + if err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + // With the preimages stored, we'll now attempt to look them up. + for i, hash := range hashes { + preimage := preimages[i] + + // We should get back the *exact* same preimage as we originally + // stored. + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to look up witness: %v", err) + } + + if preimage != dbPreimage { + t.Fatalf("witnesses don't match: expected %x, got %x", + preimage[:], dbPreimage[:]) + } + } +} + +// TestWitnessCacheSha256Deletion tests that we're able to delete a single +// sha256 preimage, and also a class of witnesses from the cache. +func TestWitnessCacheSha256Deletion(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll start by adding two preimages to the cache. + preimage1 := lntypes.Preimage(key) + hash1 := preimage1.Hash() + + preimage2 := lntypes.Preimage(rev) + hash2 := preimage2.Hash() + + if err := wCache.AddSha256Witnesses(preimage1); err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + if err := wCache.AddSha256Witnesses(preimage2); err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + // We'll now delete the first preimage. If we attempt to look it up, we + // should get ErrNoWitnesses. + err = wCache.DeleteSha256Witness(hash1) + if err != nil { + t.Fatalf("unable to delete witness: %v", err) + } + _, err = wCache.LookupSha256Witness(hash1) + if err != ErrNoWitnesses { + t.Fatalf("expected ErrNoWitnesses instead got: %v", err) + } + + // Next, we'll attempt to delete the entire witness class itself. When + // we try to lookup the second preimage, we should again get + // ErrNoWitnesses. + if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil { + t.Fatalf("unable to delete witness class: %v", err) + } + _, err = wCache.LookupSha256Witness(hash2) + if err != ErrNoWitnesses { + t.Fatalf("expected ErrNoWitnesses instead got: %v", err) + } +} + +// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to +// query/add/delete an unknown witness. +func TestWitnessCacheUnknownWitness(t *testing.T) { + t.Parallel() + + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll attempt to add a new, undefined witness type to the database. + // We should get an error. + err = wCache.legacyAddWitnesses(234, key[:]) + if err != ErrUnknownWitnessType { + t.Fatalf("expected ErrUnknownWitnessType, got %v", err) + } +} + +// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves +// identically to the insertion via the generalized interface. +func TestAddSha256Witnesses(t *testing.T) { + cdb, cleanUp, err := makeTestDB() + if err != nil { + t.Fatalf("unable to make test database: %v", err) + } + defer cleanUp() + + wCache := cdb.NewWitnessCache() + + // We'll start by adding a witnesses to the cache using the generic + // AddWitnesses method. + witness1 := rev[:] + preimage1 := lntypes.Preimage(rev) + hash1 := preimage1.Hash() + + witness2 := key[:] + preimage2 := lntypes.Preimage(key) + hash2 := preimage2.Hash() + + var ( + witnesses = [][]byte{witness1, witness2} + preimages = []lntypes.Preimage{preimage1, preimage2} + hashes = []lntypes.Hash{hash1, hash2} + ) + + err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...) + if err != nil { + t.Fatalf("unable to add witness: %v", err) + } + + for i, hash := range hashes { + preimage := preimages[i] + + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to lookup witness: %v", err) + } + + // Assert that the retrieved witness matches the original. + if dbPreimage != preimage { + t.Fatalf("retrieved witness mismatch, want: %x, "+ + "got: %x", preimage, dbPreimage) + } + + // We'll now delete the witness, as we'll be reinserting it + // using the specialized AddSha256Witnesses method. + err = wCache.DeleteSha256Witness(hash) + if err != nil { + t.Fatalf("unable to delete witness: %v", err) + } + } + + // Now, add the same witnesses using the type-safe interface for + // lntypes.Preimages.. + err = wCache.AddSha256Witnesses(preimages...) + if err != nil { + t.Fatalf("unable to add sha256 preimage: %v", err) + } + + // Finally, iterate over the keys and assert that the returned witnesses + // match the original witnesses. This asserts that the specialized + // insertion method behaves identically to the generalized interface. + for i, hash := range hashes { + preimage := preimages[i] + + dbPreimage, err := wCache.LookupSha256Witness(hash) + if err != nil { + t.Fatalf("unable to lookup witness: %v", err) + } + + // Assert that the retrieved witness matches the original. + if dbPreimage != preimage { + t.Fatalf("retrieved witness mismatch, want: %x, "+ + "got: %x", preimage, dbPreimage) + } + } +} + +// legacyAddWitnesses adds a batch of new witnesses of wType to the witness +// cache. The type of the witness will be used to map each witness to the key +// that will be used to look it up. All witnesses should be of the same +// WitnessType. +// +// NOTE: Previously this method exposed a generic interface for adding +// witnesses, which has since been deprecated in favor of a strongly typed +// interface for each witness class. We keep this method around to assert the +// correctness of specialized witness adding methods. +func (w *WitnessCache) legacyAddWitnesses(wType WitnessType, + witnesses ...[]byte) error { + + // Optimistically compute the witness keys before attempting to start + // the db transaction. + entries := make([]witnessEntry, 0, len(witnesses)) + for _, witness := range witnesses { + // Map each witness to its key by applying the appropriate + // transformation for the given witness type. + switch wType { + case Sha256HashWitness: + key := sha256.Sum256(witness) + entries = append(entries, witnessEntry{ + key: key[:], + witness: witness, + }) + default: + return ErrUnknownWitnessType + } + } + + return w.addWitnessEntries(wType, entries) +} From 6913cd64b680f409cd76b97836d24b4438224448 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 24 Oct 2019 12:25:28 +0200 Subject: [PATCH 2/6] channeldb: reference migrations in package This commit removes the migrations from channeldb and references those in the migrations_01_to_11 package. This creates a one-way dependency on the migrations. Future changes to channeldb won't be able to break migrations anymore. --- channeldb/db.go | 27 +- channeldb/migration_01_to_11/db.go | 93 +- channeldb/migration_01_to_11/meta.go | 2 +- channeldb/migration_01_to_11/meta_test.go | 369 ------- .../migration_10_route_tlv_records.go | 4 +- .../migration_11_invoices.go | 4 +- .../migration_11_invoices_test.go | 4 +- channeldb/migration_01_to_11/migrations.go | 36 +- .../migration_01_to_11/migrations_test.go | 10 +- .../migration_09_legacy_serialization.go | 497 --------- channeldb/migration_10_route_tlv_records.go | 236 ----- channeldb/migration_11_invoices.go | 230 ----- channeldb/migration_11_invoices_test.go | 193 ---- channeldb/migrations.go | 939 ----------------- channeldb/migrations_test.go | 952 ------------------ channeldb/payments_test.go | 81 -- 16 files changed, 43 insertions(+), 3634 deletions(-) delete mode 100644 channeldb/migration_09_legacy_serialization.go delete mode 100644 channeldb/migration_10_route_tlv_records.go delete mode 100644 channeldb/migration_11_invoices.go delete mode 100644 channeldb/migration_11_invoices_test.go delete mode 100644 channeldb/migrations.go delete mode 100644 channeldb/migrations_test.go diff --git a/channeldb/db.go b/channeldb/db.go index 7e8f9479..a67ccd2f 100644 --- a/channeldb/db.go +++ b/channeldb/db.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" "github.com/lightningnetwork/lnd/lnwire" ) @@ -47,19 +48,19 @@ var ( // for the update time of node and channel updates were // added. number: 1, - migration: migrateNodeAndEdgeUpdateIndex, + migration: migration_01_to_11.MigrateNodeAndEdgeUpdateIndex, }, { // The DB version that added the invoice event time // series. number: 2, - migration: migrateInvoiceTimeSeries, + migration: migration_01_to_11.MigrateInvoiceTimeSeries, }, { // The DB version that updated the embedded invoice in // outgoing payments to match the new format. number: 3, - migration: migrateInvoiceTimeSeriesOutgoingPayments, + migration: migration_01_to_11.MigrateInvoiceTimeSeriesOutgoingPayments, }, { // The version of the database where every channel @@ -67,53 +68,53 @@ var ( // a policy is unknown, this will be represented // by a special byte sequence. number: 4, - migration: migrateEdgePolicies, + migration: migration_01_to_11.MigrateEdgePolicies, }, { // The DB version where we persist each attempt to send // an HTLC to a payment hash, and track whether the // payment is in-flight, succeeded, or failed. number: 5, - migration: paymentStatusesMigration, + migration: migration_01_to_11.PaymentStatusesMigration, }, { // The DB version that properly prunes stale entries // from the edge update index. number: 6, - migration: migratePruneEdgeUpdateIndex, + migration: migration_01_to_11.MigratePruneEdgeUpdateIndex, }, { // The DB version that migrates the ChannelCloseSummary // to a format where optional fields are indicated with // boolean flags. number: 7, - migration: migrateOptionalChannelCloseSummaryFields, + migration: migration_01_to_11.MigrateOptionalChannelCloseSummaryFields, }, { // The DB version that changes the gossiper's message // store keys to account for the message's type and // ShortChannelID. number: 8, - migration: migrateGossipMessageStoreKeys, + migration: migration_01_to_11.MigrateGossipMessageStoreKeys, }, { // The DB version where the payments and payment // statuses are moved to being stored in a combined // bucket. number: 9, - migration: migrateOutgoingPayments, + migration: migration_01_to_11.MigrateOutgoingPayments, }, { // The DB version where we started to store legacy // payload information for all routes, as well as the // optional TLV records. number: 10, - migration: migrateRouteSerialization, + migration: migration_01_to_11.MigrateRouteSerialization, }, { // Add invoice htlc and cltv delta fields. number: 11, - migration: migrateInvoices, + migration: migration_01_to_11.MigrateInvoices, }, } @@ -266,10 +267,6 @@ func createChannelDB(dbPath string) error { return err } - if _, err := tx.CreateBucket(paymentBucket); err != nil { - return err - } - if _, err := tx.CreateBucket(nodeInfoBucket); err != nil { return err } diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index c4306400..e1057d65 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -32,91 +32,6 @@ type version struct { } var ( - // dbVersions is storing all versions of database. If current version - // of database don't match with latest version this list will be used - // for retrieving all migration function that are need to apply to the - // current db. - dbVersions = []version{ - { - // The base DB version requires no migration. - number: 0, - migration: nil, - }, - { - // The version of the database where two new indexes - // for the update time of node and channel updates were - // added. - number: 1, - migration: migrateNodeAndEdgeUpdateIndex, - }, - { - // The DB version that added the invoice event time - // series. - number: 2, - migration: migrateInvoiceTimeSeries, - }, - { - // The DB version that updated the embedded invoice in - // outgoing payments to match the new format. - number: 3, - migration: migrateInvoiceTimeSeriesOutgoingPayments, - }, - { - // The version of the database where every channel - // always has two entries in the edges bucket. If - // a policy is unknown, this will be represented - // by a special byte sequence. - number: 4, - migration: migrateEdgePolicies, - }, - { - // The DB version where we persist each attempt to send - // an HTLC to a payment hash, and track whether the - // payment is in-flight, succeeded, or failed. - number: 5, - migration: paymentStatusesMigration, - }, - { - // The DB version that properly prunes stale entries - // from the edge update index. - number: 6, - migration: migratePruneEdgeUpdateIndex, - }, - { - // The DB version that migrates the ChannelCloseSummary - // to a format where optional fields are indicated with - // boolean flags. - number: 7, - migration: migrateOptionalChannelCloseSummaryFields, - }, - { - // The DB version that changes the gossiper's message - // store keys to account for the message's type and - // ShortChannelID. - number: 8, - migration: migrateGossipMessageStoreKeys, - }, - { - // The DB version where the payments and payment - // statuses are moved to being stored in a combined - // bucket. - number: 9, - migration: migrateOutgoingPayments, - }, - { - // The DB version where we started to store legacy - // payload information for all routes, as well as the - // optional TLV records. - number: 10, - migration: migrateRouteSerialization, - }, - { - // Add invoice htlc and cltv delta fields. - number: 11, - migration: migrateInvoices, - }, - } - // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. byteOrder = binary.BigEndian @@ -169,12 +84,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { chanDB, opts.RejectCacheSize, opts.ChannelCacheSize, ) - // Synchronize the version of database and apply migrations if needed. - if err := chanDB.syncVersions(dbVersions); err != nil { - bdb.Close() - return nil, err - } - return chanDB, nil } @@ -318,7 +227,7 @@ func createChannelDB(dbPath string) error { } meta := &Meta{ - DbVersionNumber: getLatestDBVersion(dbVersions), + DbVersionNumber: 0, } return putMeta(meta, tx) }) diff --git a/channeldb/migration_01_to_11/meta.go b/channeldb/migration_01_to_11/meta.go index fbe7a0e4..a8f9bd41 100644 --- a/channeldb/migration_01_to_11/meta.go +++ b/channeldb/migration_01_to_11/meta.go @@ -44,7 +44,7 @@ func fetchMeta(meta *Meta, tx *bbolt.Tx) error { data := metaBucket.Get(dbVersionKey) if data == nil { - meta.DbVersionNumber = getLatestDBVersion(dbVersions) + meta.DbVersionNumber = 0 } else { meta.DbVersionNumber = byteOrder.Uint32(data) } diff --git a/channeldb/migration_01_to_11/meta_test.go b/channeldb/migration_01_to_11/meta_test.go index 27e9369c..be1af2f9 100644 --- a/channeldb/migration_01_to_11/meta_test.go +++ b/channeldb/migration_01_to_11/meta_test.go @@ -1,11 +1,8 @@ package migration_01_to_11 import ( - "bytes" - "io/ioutil" "testing" - "github.com/coreos/bbolt" "github.com/go-errors/errors" ) @@ -74,369 +71,3 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), log.Error(err) } } - -// TestVersionFetchPut checks the propernces of fetch/put methods -// and also initialization of meta data in case if don't have any in -// database. -func TestVersionFetchPut(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatal(err) - } - - meta, err := db.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != getLatestDBVersion(dbVersions) { - t.Fatal("initialization of meta information wasn't performed") - } - - newVersion := getLatestDBVersion(dbVersions) + 1 - meta.DbVersionNumber = newVersion - - if err := db.PutMeta(meta); err != nil { - t.Fatalf("update of meta failed %v", err) - } - - meta, err = db.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != newVersion { - t.Fatal("update of meta information wasn't performed") - } -} - -// TestOrderOfMigrations checks that migrations are applied in proper order. -func TestOrderOfMigrations(t *testing.T) { - t.Parallel() - - appliedMigration := -1 - versions := []version{ - {0, nil}, - {1, nil}, - {2, func(tx *bbolt.Tx) error { - appliedMigration = 2 - return nil - }}, - {3, func(tx *bbolt.Tx) error { - appliedMigration = 3 - return nil - }}, - } - - // Retrieve the migration that should be applied to db, as far as - // current version is 1, we skip zero and first versions. - migrations, _ := getMigrationsToApply(versions, 1) - - if len(migrations) != 2 { - t.Fatal("incorrect number of migrations to apply") - } - - // Apply first migration. - migrations[0](nil) - - // Check that first migration corresponds to the second version. - if appliedMigration != 2 { - t.Fatal("incorrect order of applying migrations") - } - - // Apply second migration. - migrations[1](nil) - - // Check that second migration corresponds to the third version. - if appliedMigration != 3 { - t.Fatal("incorrect order of applying migrations") - } -} - -// TestGlobalVersionList checks that there is no mistake in global version list -// in terms of version ordering. -func TestGlobalVersionList(t *testing.T) { - t.Parallel() - - if dbVersions == nil { - t.Fatal("can't find versions list") - } - - if len(dbVersions) == 0 { - t.Fatal("db versions list is empty") - } - - prev := dbVersions[0].number - for i := 1; i < len(dbVersions); i++ { - version := dbVersions[i].number - - if version == prev { - t.Fatal("duplicates db versions") - } - if version < prev { - t.Fatal("order of db versions is wrong") - } - - prev = version - } -} - -// TestMigrationWithPanic asserts that if migration logic panics, we will return -// to the original state unaltered. -func TestMigrationWithPanic(t *testing.T) { - t.Parallel() - - bucketPrefix := []byte("somebucket") - keyPrefix := []byte("someprefix") - beforeMigration := []byte("beforemigration") - afterMigration := []byte("aftermigration") - - beforeMigrationFunc := func(d *DB) { - // Insert data in database and in order then make sure that the - // key isn't changes in case of panic or fail. - d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, beforeMigration) - return nil - }) - } - - // Create migration function which changes the initially created data and - // throw the panic, in this case we pretending that something goes. - migrationWithPanic := func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, afterMigration) - panic("panic!") - } - - // Check that version of database and data wasn't changed. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 0 { - t.Fatal("migration panicked but version is changed") - } - - err = d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - value := bucket.Get(keyPrefix) - if !bytes.Equal(value, beforeMigration) { - return errors.New("migration failed but data is " + - "changed") - } - - return nil - }) - if err != nil { - t.Fatal(err) - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrationWithPanic, - true) -} - -// TestMigrationWithFatal asserts that migrations which fail do not modify the -// database. -func TestMigrationWithFatal(t *testing.T) { - t.Parallel() - - bucketPrefix := []byte("somebucket") - keyPrefix := []byte("someprefix") - beforeMigration := []byte("beforemigration") - afterMigration := []byte("aftermigration") - - beforeMigrationFunc := func(d *DB) { - d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, beforeMigration) - return nil - }) - } - - // Create migration function which changes the initially created data and - // return the error, in this case we pretending that something goes - // wrong. - migrationWithFatal := func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, afterMigration) - return errors.New("some error") - } - - // Check that version of database and initial data wasn't changed. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 0 { - t.Fatal("migration failed but version is changed") - } - - err = d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - value := bucket.Get(keyPrefix) - if !bytes.Equal(value, beforeMigration) { - return errors.New("migration failed but data is " + - "changed") - } - - return nil - }) - if err != nil { - t.Fatal(err) - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrationWithFatal, - true) -} - -// TestMigrationWithoutErrors asserts that a successful migration has its -// changes applied to the database. -func TestMigrationWithoutErrors(t *testing.T) { - t.Parallel() - - bucketPrefix := []byte("somebucket") - keyPrefix := []byte("someprefix") - beforeMigration := []byte("beforemigration") - afterMigration := []byte("aftermigration") - - // Populate database with initial data. - beforeMigrationFunc := func(d *DB) { - d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, beforeMigration) - return nil - }) - } - - // Create migration function which changes the initially created data. - migrationWithoutErrors := func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - bucket.Put(keyPrefix, afterMigration) - return nil - } - - // Check that version of database and data was properly changed. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("version number isn't changed after " + - "successfully applied migration") - } - - err = d.Update(func(tx *bbolt.Tx) error { - bucket, err := tx.CreateBucketIfNotExists(bucketPrefix) - if err != nil { - return err - } - - value := bucket.Get(keyPrefix) - if !bytes.Equal(value, afterMigration) { - return errors.New("migration wasn't applied " + - "properly") - } - - return nil - }) - if err != nil { - t.Fatal(err) - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrationWithoutErrors, - false) -} - -// TestMigrationReversion tests after performing a migration to a higher -// database version, opening the database with a lower latest db version returns -// ErrDBReversion. -func TestMigrationReversion(t *testing.T) { - t.Parallel() - - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - - cdb, err := Open(tempDirName) - if err != nil { - t.Fatalf("unable to open channeldb: %v", err) - } - - // Update the database metadata to point to one more than the highest - // known version. - err = cdb.Update(func(tx *bbolt.Tx) error { - newMeta := &Meta{ - DbVersionNumber: getLatestDBVersion(dbVersions) + 1, - } - - return putMeta(newMeta, tx) - }) - - // Close the database. Even if we succeeded, our next step is to reopen. - cdb.Close() - - if err != nil { - t.Fatalf("unable to increase db version: %v", err) - } - - _, err = Open(tempDirName) - if err != ErrDBReversion { - t.Fatalf("unexpected error when opening channeldb, "+ - "want: %v, got: %v", ErrDBReversion, err) - } -} diff --git a/channeldb/migration_01_to_11/migration_10_route_tlv_records.go b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go index a8478cda..648d85ad 100644 --- a/channeldb/migration_01_to_11/migration_10_route_tlv_records.go +++ b/channeldb/migration_01_to_11/migration_10_route_tlv_records.go @@ -8,10 +8,10 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) -// migrateRouteSerialization migrates the way we serialize routes across the +// MigrateRouteSerialization migrates the way we serialize routes across the // entire database. At the time of writing of this migration, this includes our // payment attempts, as well as the payment results in mission control. -func migrateRouteSerialization(tx *bbolt.Tx) error { +func MigrateRouteSerialization(tx *bbolt.Tx) error { // First, we'll do all the payment attempts. rootPaymentBucket := tx.Bucket(paymentsRootBucket) if rootPaymentBucket == nil { diff --git a/channeldb/migration_01_to_11/migration_11_invoices.go b/channeldb/migration_01_to_11/migration_11_invoices.go index 1ae969be..449e9b5d 100644 --- a/channeldb/migration_01_to_11/migration_11_invoices.go +++ b/channeldb/migration_01_to_11/migration_11_invoices.go @@ -14,9 +14,9 @@ import ( litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" ) -// migrateInvoices adds invoice htlcs and a separate cltv delta field to the +// MigrateInvoices adds invoice htlcs and a separate cltv delta field to the // invoices. -func migrateInvoices(tx *bbolt.Tx) error { +func MigrateInvoices(tx *bbolt.Tx) error { log.Infof("Migrating invoices to new invoice format") invoiceB := tx.Bucket(invoiceBucket) diff --git a/channeldb/migration_01_to_11/migration_11_invoices_test.go b/channeldb/migration_01_to_11/migration_11_invoices_test.go index 9c0c877a..31cfe48f 100644 --- a/channeldb/migration_01_to_11/migration_11_invoices_test.go +++ b/channeldb/migration_01_to_11/migration_11_invoices_test.go @@ -123,7 +123,7 @@ func TestMigrateInvoices(t *testing.T) { applyMigration(t, func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, afterMigrationFunc, - migrateInvoices, + MigrateInvoices, false) } @@ -149,7 +149,7 @@ func TestMigrateInvoicesHodl(t *testing.T) { applyMigration(t, func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, func(d *DB) {}, - migrateInvoices, + MigrateInvoices, true) } diff --git a/channeldb/migration_01_to_11/migrations.go b/channeldb/migration_01_to_11/migrations.go index 3e296d02..3f841009 100644 --- a/channeldb/migration_01_to_11/migrations.go +++ b/channeldb/migration_01_to_11/migrations.go @@ -12,12 +12,12 @@ import ( "github.com/lightningnetwork/lnd/routing/route" ) -// migrateNodeAndEdgeUpdateIndex is a migration function that will update the +// MigrateNodeAndEdgeUpdateIndex is a migration function that will update the // database from version 0 to version 1. In version 1, we add two new indexes // (one for nodes and one for edges) to keep track of the last time a node or // edge was updated on the network. These new indexes allow us to implement the // new graph sync protocol added. -func migrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { +func MigrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { // First, we'll populating the node portion of the new index. Before we // can add new values to the index, we'll first create the new bucket // where these items will be housed. @@ -118,11 +118,11 @@ func migrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { return nil } -// migrateInvoiceTimeSeries is a database migration that assigns all existing +// MigrateInvoiceTimeSeries is a database migration that assigns all existing // invoices an index in the add and/or the settle index. Additionally, all // existing invoices will have their bytes padded out in order to encode the // add+settle index as well as the amount paid. -func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { +func MigrateInvoiceTimeSeries(tx *bbolt.Tx) error { invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) if err != nil { return err @@ -255,11 +255,11 @@ func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { return nil } -// migrateInvoiceTimeSeriesOutgoingPayments is a follow up to the +// MigrateInvoiceTimeSeriesOutgoingPayments is a follow up to the // migrateInvoiceTimeSeries migration. As at the time of writing, the // OutgoingPayment struct embeddeds an instance of the Invoice struct. As a // result, we also need to migrate the internal invoice to the new format. -func migrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { +func MigrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { payBucket := tx.Bucket(paymentBucket) if payBucket == nil { return nil @@ -336,11 +336,11 @@ func migrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { return nil } -// migrateEdgePolicies is a migration function that will update the edges +// MigrateEdgePolicies is a migration function that will update the edges // bucket. It ensure that edges with unknown policies will also have an entry // in the bucket. After the migration, there will be two edge entries for // every channel, regardless of whether the policies are known. -func migrateEdgePolicies(tx *bbolt.Tx) error { +func MigrateEdgePolicies(tx *bbolt.Tx) error { nodes := tx.Bucket(nodeBucket) if nodes == nil { return nil @@ -409,10 +409,10 @@ func migrateEdgePolicies(tx *bbolt.Tx) error { return nil } -// paymentStatusesMigration is a database migration intended for adding payment +// PaymentStatusesMigration is a database migration intended for adding payment // statuses for each existing payment entity in bucket to be able control // transitions of statuses and prevent cases such as double payment -func paymentStatusesMigration(tx *bbolt.Tx) error { +func PaymentStatusesMigration(tx *bbolt.Tx) error { // Get the bucket dedicated to storing statuses of payments, // where a key is payment hash, value is payment status. paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) @@ -492,14 +492,14 @@ func paymentStatusesMigration(tx *bbolt.Tx) error { return nil } -// migratePruneEdgeUpdateIndex is a database migration that attempts to resolve +// MigratePruneEdgeUpdateIndex is a database migration that attempts to resolve // some lingering bugs with regards to edge policies and their update index. // Stale entries within the edge update index were not being properly pruned due // to a miscalculation on the offset of an edge's policy last update. This // migration also fixes the case where the public keys within edge policies were // being serialized with an extra byte, causing an even greater error when // attempting to perform the offset calculation described earlier. -func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { +func MigratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { // To begin the migration, we'll retrieve the update index bucket. If it // does not exist, we have nothing left to do so we can simply exit. edges := tx.Bucket(edgeBucket) @@ -610,10 +610,10 @@ func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { return nil } -// migrateOptionalChannelCloseSummaryFields migrates the serialized format of +// MigrateOptionalChannelCloseSummaryFields migrates the serialized format of // ChannelCloseSummary to a format where optional fields' presence is indicated // with boolean markers. -func migrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { +func MigrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { closedChanBucket := tx.Bucket(closedChannelBucket) if closedChanBucket == nil { return nil @@ -669,10 +669,10 @@ func migrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { var messageStoreBucket = []byte("message-store") -// migrateGossipMessageStoreKeys migrates the key format for gossip messages +// MigrateGossipMessageStoreKeys migrates the key format for gossip messages // found in the message store to a new one that takes into consideration the of // the message being stored. -func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { +func MigrateGossipMessageStoreKeys(tx *bbolt.Tx) error { // We'll start by retrieving the bucket in which these messages are // stored within. If there isn't one, there's nothing left for us to do // so we can avoid the migration. @@ -739,7 +739,7 @@ func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { return nil } -// migrateOutgoingPayments moves the OutgoingPayments into a new bucket format +// MigrateOutgoingPayments moves the OutgoingPayments into a new bucket format // where they all reside in a top-level bucket indexed by the payment hash. In // this sub-bucket we store information relevant to this payment, such as the // payment status. @@ -748,7 +748,7 @@ func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { // InFlight (we have no PaymentAttemptInfo available for pre-migration // payments) we delete those statuses, so only Completed payments remain in the // new bucket structure. -func migrateOutgoingPayments(tx *bbolt.Tx) error { +func MigrateOutgoingPayments(tx *bbolt.Tx) error { log.Infof("Migrating outgoing payments to new bucket structure") oldPayments := tx.Bucket(paymentBucket) diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 8a9076fb..cdaef57f 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -197,7 +197,7 @@ func TestPaymentStatusesMigration(t *testing.T) { applyMigration(t, beforeMigrationFunc, afterMigrationFunc, - paymentStatusesMigration, + PaymentStatusesMigration, false) } @@ -469,7 +469,7 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { applyMigration(t, beforeMigrationFunc, afterMigrationFunc, - migrateOptionalChannelCloseSummaryFields, + MigrateOptionalChannelCloseSummaryFields, false) } } @@ -565,7 +565,7 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { applyMigration( t, beforeMigration, afterMigration, - migrateGossipMessageStoreKeys, false, + MigrateGossipMessageStoreKeys, false, ) } @@ -724,7 +724,7 @@ func TestOutgoingPaymentsMigration(t *testing.T) { applyMigration(t, beforeMigrationFunc, afterMigrationFunc, - migrateOutgoingPayments, + MigrateOutgoingPayments, false) } @@ -947,6 +947,6 @@ func TestPaymentRouteSerialization(t *testing.T) { applyMigration(t, beforeMigrationFunc, afterMigrationFunc, - migrateRouteSerialization, + MigrateRouteSerialization, false) } diff --git a/channeldb/migration_09_legacy_serialization.go b/channeldb/migration_09_legacy_serialization.go deleted file mode 100644 index 1205cf9b..00000000 --- a/channeldb/migration_09_legacy_serialization.go +++ /dev/null @@ -1,497 +0,0 @@ -package channeldb - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - "sort" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" -) - -var ( - // paymentBucket is the name of the bucket within the database that - // stores all data related to payments. - // - // Within the payments bucket, each invoice is keyed by its invoice ID - // which is a monotonically increasing uint64. BoltDB's sequence - // feature is used for generating monotonically increasing id. - // - // NOTE: Deprecated. Kept around for migration purposes. - paymentBucket = []byte("payments") - - // paymentStatusBucket is the name of the bucket within the database - // that stores the status of a payment indexed by the payment's - // preimage. - // - // NOTE: Deprecated. Kept around for migration purposes. - paymentStatusBucket = []byte("payment-status") -) - -// outgoingPayment represents a successful payment between the daemon and a -// remote node. Details such as the total fee paid, and the time of the payment -// are stored. -// -// NOTE: Deprecated. Kept around for migration purposes. -type outgoingPayment struct { - Invoice - - // Fee is the total fee paid for the payment in milli-satoshis. - Fee lnwire.MilliSatoshi - - // TotalTimeLock is the total cumulative time-lock in the HTLC extended - // from the second-to-last hop to the destination. - TimeLockLength uint32 - - // Path encodes the path the payment took through the network. The path - // excludes the outgoing node and consists of the hex-encoded - // compressed public key of each of the nodes involved in the payment. - Path [][33]byte - - // PaymentPreimage is the preImage of a successful payment. This is used - // to calculate the PaymentHash as well as serve as a proof of payment. - PaymentPreimage [32]byte -} - -// addPayment saves a successful payment to the database. It is assumed that -// all payment are sent using unique payment hashes. -// -// NOTE: Deprecated. Kept around for migration purposes. -func (db *DB) addPayment(payment *outgoingPayment) error { - // Validate the field of the inner voice within the outgoing payment, - // these must also adhere to the same constraints as regular invoices. - if err := validateInvoice(&payment.Invoice); err != nil { - return err - } - - // We first serialize the payment before starting the database - // transaction so we can avoid creating a DB payment in the case of a - // serialization error. - var b bytes.Buffer - if err := serializeOutgoingPayment(&b, payment); err != nil { - return err - } - paymentBytes := b.Bytes() - - return db.Batch(func(tx *bbolt.Tx) error { - payments, err := tx.CreateBucketIfNotExists(paymentBucket) - if err != nil { - return err - } - - // Obtain the new unique sequence number for this payment. - paymentID, err := payments.NextSequence() - if err != nil { - return err - } - - // We use BigEndian for keys as it orders keys in - // ascending order. This allows bucket scans to order payments - // in the order in which they were created. - paymentIDBytes := make([]byte, 8) - binary.BigEndian.PutUint64(paymentIDBytes, paymentID) - - return payments.Put(paymentIDBytes, paymentBytes) - }) -} - -// fetchAllPayments returns all outgoing payments in DB. -// -// NOTE: Deprecated. Kept around for migration purposes. -func (db *DB) fetchAllPayments() ([]*outgoingPayment, error) { - var payments []*outgoingPayment - - err := db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(paymentBucket) - if bucket == nil { - return ErrNoPaymentsCreated - } - - return bucket.ForEach(func(k, v []byte) error { - // If the value is nil, then we ignore it as it may be - // a sub-bucket. - if v == nil { - return nil - } - - r := bytes.NewReader(v) - payment, err := deserializeOutgoingPayment(r) - if err != nil { - return err - } - - payments = append(payments, payment) - return nil - }) - }) - if err != nil { - return nil, err - } - - return payments, nil -} - -// fetchPaymentStatus returns the payment status for outgoing payment. -// If status of the payment isn't found, it will default to "StatusUnknown". -// -// NOTE: Deprecated. Kept around for migration purposes. -func (db *DB) fetchPaymentStatus(paymentHash [32]byte) (PaymentStatus, error) { - var paymentStatus = StatusUnknown - err := db.View(func(tx *bbolt.Tx) error { - var err error - paymentStatus, err = fetchPaymentStatusTx(tx, paymentHash) - return err - }) - if err != nil { - return StatusUnknown, err - } - - return paymentStatus, nil -} - -// fetchPaymentStatusTx is a helper method that returns the payment status for -// outgoing payment. If status of the payment isn't found, it will default to -// "StatusUnknown". It accepts the boltdb transactions such that this method -// can be composed into other atomic operations. -// -// NOTE: Deprecated. Kept around for migration purposes. -func fetchPaymentStatusTx(tx *bbolt.Tx, paymentHash [32]byte) (PaymentStatus, error) { - // The default status for all payments that aren't recorded in database. - var paymentStatus = StatusUnknown - - bucket := tx.Bucket(paymentStatusBucket) - if bucket == nil { - return paymentStatus, nil - } - - paymentStatusBytes := bucket.Get(paymentHash[:]) - if paymentStatusBytes == nil { - return paymentStatus, nil - } - - paymentStatus.FromBytes(paymentStatusBytes) - - return paymentStatus, nil -} - -func serializeOutgoingPayment(w io.Writer, p *outgoingPayment) error { - var scratch [8]byte - - if err := serializeInvoiceLegacy(w, &p.Invoice); err != nil { - return err - } - - byteOrder.PutUint64(scratch[:], uint64(p.Fee)) - if _, err := w.Write(scratch[:]); err != nil { - return err - } - - // First write out the length of the bytes to prefix the value. - pathLen := uint32(len(p.Path)) - byteOrder.PutUint32(scratch[:4], pathLen) - if _, err := w.Write(scratch[:4]); err != nil { - return err - } - - // Then with the path written, we write out the series of public keys - // involved in the path. - for _, hop := range p.Path { - if _, err := w.Write(hop[:]); err != nil { - return err - } - } - - byteOrder.PutUint32(scratch[:4], p.TimeLockLength) - if _, err := w.Write(scratch[:4]); err != nil { - return err - } - - if _, err := w.Write(p.PaymentPreimage[:]); err != nil { - return err - } - - return nil -} - -func deserializeOutgoingPayment(r io.Reader) (*outgoingPayment, error) { - var scratch [8]byte - - p := &outgoingPayment{} - - inv, err := deserializeInvoiceLegacy(r) - if err != nil { - return nil, err - } - p.Invoice = inv - - if _, err := r.Read(scratch[:]); err != nil { - return nil, err - } - p.Fee = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) - - if _, err = r.Read(scratch[:4]); err != nil { - return nil, err - } - pathLen := byteOrder.Uint32(scratch[:4]) - - path := make([][33]byte, pathLen) - for i := uint32(0); i < pathLen; i++ { - if _, err := r.Read(path[i][:]); err != nil { - return nil, err - } - } - p.Path = path - - if _, err = r.Read(scratch[:4]); err != nil { - return nil, err - } - p.TimeLockLength = byteOrder.Uint32(scratch[:4]) - - if _, err := r.Read(p.PaymentPreimage[:]); err != nil { - return nil, err - } - - return p, nil -} - -// serializePaymentAttemptInfoMigration9 is the serializePaymentAttemptInfo -// version as existed when migration #9 was created. We keep this around, along -// with the methods below to ensure that clients that upgrade will use the -// correct version of this method. -func serializePaymentAttemptInfoMigration9(w io.Writer, a *PaymentAttemptInfo) error { - if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { - return err - } - - if err := serializeRouteMigration9(w, a.Route); err != nil { - return err - } - - return nil -} - -func serializeHopMigration9(w io.Writer, h *route.Hop) error { - if err := WriteElements(w, - h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, - h.AmtToForward, - ); err != nil { - return err - } - - return nil -} - -func serializeRouteMigration9(w io.Writer, r route.Route) error { - if err := WriteElements(w, - r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], - ); err != nil { - return err - } - - if err := WriteElements(w, uint32(len(r.Hops))); err != nil { - return err - } - - for _, h := range r.Hops { - if err := serializeHopMigration9(w, h); err != nil { - return err - } - } - - return nil -} - -func deserializePaymentAttemptInfoMigration9(r io.Reader) (*PaymentAttemptInfo, error) { - a := &PaymentAttemptInfo{} - err := ReadElements(r, &a.PaymentID, &a.SessionKey) - if err != nil { - return nil, err - } - a.Route, err = deserializeRouteMigration9(r) - if err != nil { - return nil, err - } - return a, nil -} - -func deserializeRouteMigration9(r io.Reader) (route.Route, error) { - rt := route.Route{} - if err := ReadElements(r, - &rt.TotalTimeLock, &rt.TotalAmount, - ); err != nil { - return rt, err - } - - var pub []byte - if err := ReadElements(r, &pub); err != nil { - return rt, err - } - copy(rt.SourcePubKey[:], pub) - - var numHops uint32 - if err := ReadElements(r, &numHops); err != nil { - return rt, err - } - - var hops []*route.Hop - for i := uint32(0); i < numHops; i++ { - hop, err := deserializeHopMigration9(r) - if err != nil { - return rt, err - } - hops = append(hops, hop) - } - rt.Hops = hops - - return rt, nil -} - -func deserializeHopMigration9(r io.Reader) (*route.Hop, error) { - h := &route.Hop{} - - var pub []byte - if err := ReadElements(r, &pub); err != nil { - return nil, err - } - copy(h.PubKeyBytes[:], pub) - - if err := ReadElements(r, - &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, - ); err != nil { - return nil, err - } - - return h, nil -} - -// fetchPaymentsMigration9 returns all sent payments found in the DB using the -// payment attempt info format that was present as of migration #9. We need -// this as otherwise, the current FetchPayments version will use the latest -// decoding format. Note that we only need this for the -// TestOutgoingPaymentsMigration migration test case. -func (db *DB) fetchPaymentsMigration9() ([]*Payment, error) { - var payments []*Payment - - err := db.View(func(tx *bbolt.Tx) error { - paymentsBucket := tx.Bucket(paymentsRootBucket) - if paymentsBucket == nil { - return nil - } - - return paymentsBucket.ForEach(func(k, v []byte) error { - bucket := paymentsBucket.Bucket(k) - if bucket == nil { - // We only expect sub-buckets to be found in - // this top-level bucket. - return fmt.Errorf("non bucket element in " + - "payments bucket") - } - - p, err := fetchPaymentMigration9(bucket) - if err != nil { - return err - } - - payments = append(payments, p) - - // For older versions of lnd, duplicate payments to a - // payment has was possible. These will be found in a - // sub-bucket indexed by their sequence number if - // available. - dup := bucket.Bucket(paymentDuplicateBucket) - if dup == nil { - return nil - } - - return dup.ForEach(func(k, v []byte) error { - subBucket := dup.Bucket(k) - if subBucket == nil { - // We one bucket for each duplicate to - // be found. - return fmt.Errorf("non bucket element" + - "in duplicate bucket") - } - - p, err := fetchPaymentMigration9(subBucket) - if err != nil { - return err - } - - payments = append(payments, p) - return nil - }) - }) - }) - if err != nil { - return nil, err - } - - // Before returning, sort the payments by their sequence number. - sort.Slice(payments, func(i, j int) bool { - return payments[i].sequenceNum < payments[j].sequenceNum - }) - - return payments, nil -} - -func fetchPaymentMigration9(bucket *bbolt.Bucket) (*Payment, error) { - var ( - err error - p = &Payment{} - ) - - seqBytes := bucket.Get(paymentSequenceKey) - if seqBytes == nil { - return nil, fmt.Errorf("sequence number not found") - } - - p.sequenceNum = binary.BigEndian.Uint64(seqBytes) - - // Get the payment status. - p.Status = fetchPaymentStatus(bucket) - - // Get the PaymentCreationInfo. - b := bucket.Get(paymentCreationInfoKey) - if b == nil { - return nil, fmt.Errorf("creation info not found") - } - - r := bytes.NewReader(b) - p.Info, err = deserializePaymentCreationInfo(r) - if err != nil { - return nil, err - - } - - // Get the PaymentAttemptInfo. This can be unset. - b = bucket.Get(paymentAttemptInfoKey) - if b != nil { - r = bytes.NewReader(b) - p.Attempt, err = deserializePaymentAttemptInfoMigration9(r) - if err != nil { - return nil, err - } - } - - // Get the payment preimage. This is only found for - // completed payments. - b = bucket.Get(paymentSettleInfoKey) - if b != nil { - var preimg lntypes.Preimage - copy(preimg[:], b[:]) - p.PaymentPreimage = &preimg - } - - // Get failure reason if available. - b = bucket.Get(paymentFailInfoKey) - if b != nil { - reason := FailureReason(b[0]) - p.Failure = &reason - } - - return p, nil -} diff --git a/channeldb/migration_10_route_tlv_records.go b/channeldb/migration_10_route_tlv_records.go deleted file mode 100644 index 2659c4a7..00000000 --- a/channeldb/migration_10_route_tlv_records.go +++ /dev/null @@ -1,236 +0,0 @@ -package channeldb - -import ( - "bytes" - "io" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/routing/route" -) - -// migrateRouteSerialization migrates the way we serialize routes across the -// entire database. At the time of writing of this migration, this includes our -// payment attempts, as well as the payment results in mission control. -func migrateRouteSerialization(tx *bbolt.Tx) error { - // First, we'll do all the payment attempts. - rootPaymentBucket := tx.Bucket(paymentsRootBucket) - if rootPaymentBucket == nil { - return nil - } - - // As we can't mutate a bucket while we're iterating over it with - // ForEach, we'll need to collect all the known payment hashes in - // memory first. - var payHashes [][]byte - err := rootPaymentBucket.ForEach(func(k, v []byte) error { - if v != nil { - return nil - } - - payHashes = append(payHashes, k) - return nil - }) - if err != nil { - return err - } - - // Now that we have all the payment hashes, we can carry out the - // migration itself. - for _, payHash := range payHashes { - payHashBucket := rootPaymentBucket.Bucket(payHash) - - // First, we'll migrate the main (non duplicate) payment to - // this hash. - err := migrateAttemptEncoding(tx, payHashBucket) - if err != nil { - return err - } - - // Now that we've migrated the main payment, we'll also check - // for any duplicate payments to the same payment hash. - dupBucket := payHashBucket.Bucket(paymentDuplicateBucket) - - // If there's no dup bucket, then we can move on to the next - // payment. - if dupBucket == nil { - continue - } - - // Otherwise, we'll now iterate through all the duplicate pay - // hashes and migrate those. - var dupSeqNos [][]byte - err = dupBucket.ForEach(func(k, v []byte) error { - dupSeqNos = append(dupSeqNos, k) - return nil - }) - if err != nil { - return err - } - - // Now in this second pass, we'll re-serialize their duplicate - // payment attempts under the new encoding. - for _, seqNo := range dupSeqNos { - dupPayHashBucket := dupBucket.Bucket(seqNo) - err := migrateAttemptEncoding(tx, dupPayHashBucket) - if err != nil { - return err - } - } - } - - log.Infof("Migration of route/hop serialization complete!") - - log.Infof("Migrating to new mission control store by clearing " + - "existing data") - - resultsKey := []byte("missioncontrol-results") - err = tx.DeleteBucket(resultsKey) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - log.Infof("Migration to new mission control completed!") - - return nil -} - -// migrateAttemptEncoding migrates payment attempts using the legacy format to -// the new format. -func migrateAttemptEncoding(tx *bbolt.Tx, payHashBucket *bbolt.Bucket) error { - payAttemptBytes := payHashBucket.Get(paymentAttemptInfoKey) - if payAttemptBytes == nil { - return nil - } - - // For our migration, we'll first read out the existing payment attempt - // using the legacy serialization of the attempt. - payAttemptReader := bytes.NewReader(payAttemptBytes) - payAttempt, err := deserializePaymentAttemptInfoLegacy( - payAttemptReader, - ) - if err != nil { - return err - } - - // Now that we have the old attempts, we'll explicitly mark this as - // needing a legacy payload, since after this migration, the modern - // payload will be the default if signalled. - for _, hop := range payAttempt.Route.Hops { - hop.LegacyPayload = true - } - - // Finally, we'll write out the payment attempt using the new encoding. - var b bytes.Buffer - err = serializePaymentAttemptInfo(&b, payAttempt) - if err != nil { - return err - } - - return payHashBucket.Put(paymentAttemptInfoKey, b.Bytes()) -} - -func deserializePaymentAttemptInfoLegacy(r io.Reader) (*PaymentAttemptInfo, error) { - a := &PaymentAttemptInfo{} - err := ReadElements(r, &a.PaymentID, &a.SessionKey) - if err != nil { - return nil, err - } - a.Route, err = deserializeRouteLegacy(r) - if err != nil { - return nil, err - } - return a, nil -} - -func serializePaymentAttemptInfoLegacy(w io.Writer, a *PaymentAttemptInfo) error { - if err := WriteElements(w, a.PaymentID, a.SessionKey); err != nil { - return err - } - - if err := serializeRouteLegacy(w, a.Route); err != nil { - return err - } - - return nil -} - -func deserializeHopLegacy(r io.Reader) (*route.Hop, error) { - h := &route.Hop{} - - var pub []byte - if err := ReadElements(r, &pub); err != nil { - return nil, err - } - copy(h.PubKeyBytes[:], pub) - - if err := ReadElements(r, - &h.ChannelID, &h.OutgoingTimeLock, &h.AmtToForward, - ); err != nil { - return nil, err - } - - return h, nil -} - -func serializeHopLegacy(w io.Writer, h *route.Hop) error { - if err := WriteElements(w, - h.PubKeyBytes[:], h.ChannelID, h.OutgoingTimeLock, - h.AmtToForward, - ); err != nil { - return err - } - - return nil -} - -func deserializeRouteLegacy(r io.Reader) (route.Route, error) { - rt := route.Route{} - if err := ReadElements(r, - &rt.TotalTimeLock, &rt.TotalAmount, - ); err != nil { - return rt, err - } - - var pub []byte - if err := ReadElements(r, &pub); err != nil { - return rt, err - } - copy(rt.SourcePubKey[:], pub) - - var numHops uint32 - if err := ReadElements(r, &numHops); err != nil { - return rt, err - } - - var hops []*route.Hop - for i := uint32(0); i < numHops; i++ { - hop, err := deserializeHopLegacy(r) - if err != nil { - return rt, err - } - hops = append(hops, hop) - } - rt.Hops = hops - - return rt, nil -} - -func serializeRouteLegacy(w io.Writer, r route.Route) error { - if err := WriteElements(w, - r.TotalTimeLock, r.TotalAmount, r.SourcePubKey[:], - ); err != nil { - return err - } - - if err := WriteElements(w, uint32(len(r.Hops))); err != nil { - return err - } - - for _, h := range r.Hops { - if err := serializeHopLegacy(w, h); err != nil { - return err - } - } - - return nil -} diff --git a/channeldb/migration_11_invoices.go b/channeldb/migration_11_invoices.go deleted file mode 100644 index b4e60733..00000000 --- a/channeldb/migration_11_invoices.go +++ /dev/null @@ -1,230 +0,0 @@ -package channeldb - -import ( - "bytes" - "encoding/binary" - "fmt" - "io" - - bitcoinCfg "github.com/btcsuite/btcd/chaincfg" - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/zpay32" - litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" -) - -// migrateInvoices adds invoice htlcs and a separate cltv delta field to the -// invoices. -func migrateInvoices(tx *bbolt.Tx) error { - log.Infof("Migrating invoices to new invoice format") - - invoiceB := tx.Bucket(invoiceBucket) - if invoiceB == nil { - return nil - } - - // Iterate through the entire key space of the top-level invoice bucket. - // If key with a non-nil value stores the next invoice ID which maps to - // the corresponding invoice. Store those keys first, because it isn't - // safe to modify the bucket inside a ForEach loop. - var invoiceKeys [][]byte - err := invoiceB.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } - - invoiceKeys = append(invoiceKeys, k) - - return nil - }) - if err != nil { - return err - } - - nets := []*bitcoinCfg.Params{ - &bitcoinCfg.MainNetParams, &bitcoinCfg.SimNetParams, - &bitcoinCfg.RegressionNetParams, &bitcoinCfg.TestNet3Params, - } - - ltcNets := []*litecoinCfg.Params{ - &litecoinCfg.MainNetParams, &litecoinCfg.SimNetParams, - &litecoinCfg.RegressionNetParams, &litecoinCfg.TestNet4Params, - } - for _, net := range ltcNets { - var convertedNet bitcoinCfg.Params - convertedNet.Bech32HRPSegwit = net.Bech32HRPSegwit - nets = append(nets, &convertedNet) - } - - // Iterate over all stored keys and migrate the invoices. - for _, k := range invoiceKeys { - v := invoiceB.Get(k) - - // Deserialize the invoice with the deserializing function that - // was in use for this version of the database. - invoiceReader := bytes.NewReader(v) - invoice, err := deserializeInvoiceLegacy(invoiceReader) - if err != nil { - return err - } - - if invoice.Terms.State == ContractAccepted { - return fmt.Errorf("cannot upgrade with invoice(s) " + - "in accepted state, see release notes") - } - - // Try to decode the payment request for every possible net to - // avoid passing a the active network to channeldb. This would - // be a layering violation, while this migration is only running - // once and will likely be removed in the future. - var payReq *zpay32.Invoice - for _, net := range nets { - payReq, err = zpay32.Decode( - string(invoice.PaymentRequest), net, - ) - if err == nil { - break - } - } - if payReq == nil { - return fmt.Errorf("cannot decode payreq") - } - invoice.FinalCltvDelta = int32(payReq.MinFinalCLTVExpiry()) - invoice.Expiry = payReq.Expiry() - - // Serialize the invoice in the new format and use it to replace - // the old invoice in the database. - var buf bytes.Buffer - if err := serializeInvoice(&buf, &invoice); err != nil { - return err - } - - err = invoiceB.Put(k, buf.Bytes()) - if err != nil { - return err - } - } - - log.Infof("Migration of invoices completed!") - return nil -} - -func deserializeInvoiceLegacy(r io.Reader) (Invoice, error) { - var err error - invoice := Invoice{} - - // TODO(roasbeef): use read full everywhere - invoice.Memo, err = wire.ReadVarBytes(r, 0, MaxMemoSize, "") - if err != nil { - return invoice, err - } - invoice.Receipt, err = wire.ReadVarBytes(r, 0, MaxReceiptSize, "") - if err != nil { - return invoice, err - } - - invoice.PaymentRequest, err = wire.ReadVarBytes(r, 0, MaxPaymentRequestSize, "") - if err != nil { - return invoice, err - } - - birthBytes, err := wire.ReadVarBytes(r, 0, 300, "birth") - if err != nil { - return invoice, err - } - if err := invoice.CreationDate.UnmarshalBinary(birthBytes); err != nil { - return invoice, err - } - - settledBytes, err := wire.ReadVarBytes(r, 0, 300, "settled") - if err != nil { - return invoice, err - } - if err := invoice.SettleDate.UnmarshalBinary(settledBytes); err != nil { - return invoice, err - } - - if _, err := io.ReadFull(r, invoice.Terms.PaymentPreimage[:]); err != nil { - return invoice, err - } - var scratch [8]byte - if _, err := io.ReadFull(r, scratch[:]); err != nil { - return invoice, err - } - invoice.Terms.Value = lnwire.MilliSatoshi(byteOrder.Uint64(scratch[:])) - - if err := binary.Read(r, byteOrder, &invoice.Terms.State); err != nil { - return invoice, err - } - - if err := binary.Read(r, byteOrder, &invoice.AddIndex); err != nil { - return invoice, err - } - if err := binary.Read(r, byteOrder, &invoice.SettleIndex); err != nil { - return invoice, err - } - if err := binary.Read(r, byteOrder, &invoice.AmtPaid); err != nil { - return invoice, err - } - - return invoice, nil -} - -// serializeInvoiceLegacy serializes an invoice in the format of the previous db -// version. -func serializeInvoiceLegacy(w io.Writer, i *Invoice) error { - if err := wire.WriteVarBytes(w, 0, i.Memo[:]); err != nil { - return err - } - if err := wire.WriteVarBytes(w, 0, i.Receipt[:]); err != nil { - return err - } - if err := wire.WriteVarBytes(w, 0, i.PaymentRequest[:]); err != nil { - return err - } - - birthBytes, err := i.CreationDate.MarshalBinary() - if err != nil { - return err - } - - if err := wire.WriteVarBytes(w, 0, birthBytes); err != nil { - return err - } - - settleBytes, err := i.SettleDate.MarshalBinary() - if err != nil { - return err - } - - if err := wire.WriteVarBytes(w, 0, settleBytes); err != nil { - return err - } - - if _, err := w.Write(i.Terms.PaymentPreimage[:]); err != nil { - return err - } - - var scratch [8]byte - byteOrder.PutUint64(scratch[:], uint64(i.Terms.Value)) - if _, err := w.Write(scratch[:]); err != nil { - return err - } - - if err := binary.Write(w, byteOrder, i.Terms.State); err != nil { - return err - } - - if err := binary.Write(w, byteOrder, i.AddIndex); err != nil { - return err - } - if err := binary.Write(w, byteOrder, i.SettleIndex); err != nil { - return err - } - if err := binary.Write(w, byteOrder, int64(i.AmtPaid)); err != nil { - return err - } - - return nil -} diff --git a/channeldb/migration_11_invoices_test.go b/channeldb/migration_11_invoices_test.go deleted file mode 100644 index 34cb1a92..00000000 --- a/channeldb/migration_11_invoices_test.go +++ /dev/null @@ -1,193 +0,0 @@ -package channeldb - -import ( - "bytes" - "fmt" - "testing" - "time" - - "github.com/btcsuite/btcd/btcec" - bitcoinCfg "github.com/btcsuite/btcd/chaincfg" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/zpay32" - litecoinCfg "github.com/ltcsuite/ltcd/chaincfg" -) - -var ( - testPrivKeyBytes = []byte{ - 0x2b, 0xd8, 0x06, 0xc9, 0x7f, 0x0e, 0x00, 0xaf, - 0x1a, 0x1f, 0xc3, 0x32, 0x8f, 0xa7, 0x63, 0xa9, - 0x26, 0x97, 0x23, 0xc8, 0xdb, 0x8f, 0xac, 0x4f, - 0x93, 0xaf, 0x71, 0xdb, 0x18, 0x6d, 0x6e, 0x90, - } - - testCltvDelta = int32(50) -) - -// beforeMigrationFuncV11 insert the test invoices in the database. -func beforeMigrationFuncV11(t *testing.T, d *DB, invoices []Invoice) { - err := d.Update(func(tx *bbolt.Tx) error { - invoicesBucket, err := tx.CreateBucketIfNotExists( - invoiceBucket, - ) - if err != nil { - return err - } - - invoiceNum := uint32(1) - for _, invoice := range invoices { - var invoiceKey [4]byte - byteOrder.PutUint32(invoiceKey[:], invoiceNum) - invoiceNum++ - - var buf bytes.Buffer - err := serializeInvoiceLegacy(&buf, &invoice) // nolint:scopelint - if err != nil { - return err - } - - err = invoicesBucket.Put( - invoiceKey[:], buf.Bytes(), - ) - if err != nil { - return err - } - } - - return nil - }) - if err != nil { - t.Fatal(err) - } -} - -// TestMigrateInvoices checks that invoices are migrated correctly. -func TestMigrateInvoices(t *testing.T) { - t.Parallel() - - payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) - if err != nil { - t.Fatal(err) - } - - var ltcNetParams bitcoinCfg.Params - ltcNetParams.Bech32HRPSegwit = litecoinCfg.MainNetParams.Bech32HRPSegwit - payReqLtc, err := getPayReq(<cNetParams) - if err != nil { - t.Fatal(err) - } - - invoices := []Invoice{ - { - PaymentRequest: []byte(payReqBtc), - }, - { - PaymentRequest: []byte(payReqLtc), - }, - } - - // Verify that all invoices were migrated. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'invoices' wasn't applied") - } - - dbInvoices, err := d.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch invoices: %v", err) - } - - if len(invoices) != len(dbInvoices) { - t.Fatalf("expected %d invoices, got %d", len(invoices), - len(dbInvoices)) - } - - for _, dbInvoice := range dbInvoices { - if dbInvoice.FinalCltvDelta != testCltvDelta { - t.Fatal("incorrect final cltv delta") - } - if dbInvoice.Expiry != 3600*time.Second { - t.Fatal("incorrect expiry") - } - if len(dbInvoice.Htlcs) != 0 { - t.Fatal("expected no htlcs after migration") - } - } - } - - applyMigration(t, - func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, - afterMigrationFunc, - migrateInvoices, - false) -} - -// TestMigrateInvoicesHodl checks that a hodl invoice in the accepted state -// fails the migration. -func TestMigrateInvoicesHodl(t *testing.T) { - t.Parallel() - - payReqBtc, err := getPayReq(&bitcoinCfg.MainNetParams) - if err != nil { - t.Fatal(err) - } - - invoices := []Invoice{ - { - PaymentRequest: []byte(payReqBtc), - Terms: ContractTerm{ - State: ContractAccepted, - }, - }, - } - - applyMigration(t, - func(d *DB) { beforeMigrationFuncV11(t, d, invoices) }, - func(d *DB) {}, - migrateInvoices, - true) -} - -// signDigestCompact generates a test signature to be used in the generation of -// test payment requests. -func signDigestCompact(hash []byte) ([]byte, error) { - // Should the signature reference a compressed public key or not. - isCompressedKey := true - - privKey, _ := btcec.PrivKeyFromBytes(btcec.S256(), testPrivKeyBytes) - - // btcec.SignCompact returns a pubkey-recoverable signature - sig, err := btcec.SignCompact( - btcec.S256(), privKey, hash, isCompressedKey, - ) - if err != nil { - return nil, fmt.Errorf("can't sign the hash: %v", err) - } - - return sig, nil -} - -// getPayReq creates a payment request for the given net. -func getPayReq(net *bitcoinCfg.Params) (string, error) { - options := []func(*zpay32.Invoice){ - zpay32.CLTVExpiry(uint64(testCltvDelta)), - zpay32.Description("test"), - } - - payReq, err := zpay32.NewInvoice( - net, [32]byte{}, time.Unix(1, 0), options..., - ) - if err != nil { - return "", err - } - return payReq.Encode( - zpay32.MessageSigner{ - SignCompact: signDigestCompact, - }, - ) -} diff --git a/channeldb/migrations.go b/channeldb/migrations.go deleted file mode 100644 index a78d1314..00000000 --- a/channeldb/migrations.go +++ /dev/null @@ -1,939 +0,0 @@ -package channeldb - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "fmt" - - "github.com/btcsuite/btcd/btcec" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" -) - -// migrateNodeAndEdgeUpdateIndex is a migration function that will update the -// database from version 0 to version 1. In version 1, we add two new indexes -// (one for nodes and one for edges) to keep track of the last time a node or -// edge was updated on the network. These new indexes allow us to implement the -// new graph sync protocol added. -func migrateNodeAndEdgeUpdateIndex(tx *bbolt.Tx) error { - // First, we'll populating the node portion of the new index. Before we - // can add new values to the index, we'll first create the new bucket - // where these items will be housed. - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return fmt.Errorf("unable to create node bucket: %v", err) - } - nodeUpdateIndex, err := nodes.CreateBucketIfNotExists( - nodeUpdateIndexBucket, - ) - if err != nil { - return fmt.Errorf("unable to create node update index: %v", err) - } - - log.Infof("Populating new node update index bucket") - - // Now that we know the bucket has been created, we'll iterate over the - // entire node bucket so we can add the (updateTime || nodePub) key - // into the node update index. - err = nodes.ForEach(func(nodePub, nodeInfo []byte) error { - if len(nodePub) != 33 { - return nil - } - - log.Tracef("Adding %x to node update index", nodePub) - - // The first 8 bytes of a node's serialize data is the update - // time, so we can extract that without decoding the entire - // structure. - updateTime := nodeInfo[:8] - - // Now that we have the update time, we can construct the key - // to insert into the index. - var indexKey [8 + 33]byte - copy(indexKey[:8], updateTime) - copy(indexKey[8:], nodePub) - - return nodeUpdateIndex.Put(indexKey[:], nil) - }) - if err != nil { - return fmt.Errorf("unable to update node indexes: %v", err) - } - - log.Infof("Populating new edge update index bucket") - - // With the set of nodes updated, we'll now update all edges to have a - // corresponding entry in the edge update index. - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return fmt.Errorf("unable to create edge bucket: %v", err) - } - edgeUpdateIndex, err := edges.CreateBucketIfNotExists( - edgeUpdateIndexBucket, - ) - if err != nil { - return fmt.Errorf("unable to create edge update index: %v", err) - } - - // We'll now run through each edge policy in the database, and update - // the index to ensure each edge has the proper record. - err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { - if len(edgeKey) != 41 { - return nil - } - - // Now that we know this is the proper record, we'll grab the - // channel ID (last 8 bytes of the key), and then decode the - // edge policy so we can access the update time. - chanID := edgeKey[33:] - edgePolicyReader := bytes.NewReader(edgePolicyBytes) - - edgePolicy, err := deserializeChanEdgePolicy( - edgePolicyReader, nodes, - ) - if err != nil { - return err - } - - log.Tracef("Adding chan_id=%v to edge update index", - edgePolicy.ChannelID) - - // We'll now construct the index key using the channel ID, and - // the last time it was updated: (updateTime || chanID). - var indexKey [8 + 8]byte - byteOrder.PutUint64( - indexKey[:], uint64(edgePolicy.LastUpdate.Unix()), - ) - copy(indexKey[8:], chanID) - - return edgeUpdateIndex.Put(indexKey[:], nil) - }) - if err != nil { - return fmt.Errorf("unable to update edge indexes: %v", err) - } - - log.Infof("Migration to node and edge update indexes complete!") - - return nil -} - -// migrateInvoiceTimeSeries is a database migration that assigns all existing -// invoices an index in the add and/or the settle index. Additionally, all -// existing invoices will have their bytes padded out in order to encode the -// add+settle index as well as the amount paid. -func migrateInvoiceTimeSeries(tx *bbolt.Tx) error { - invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) - if err != nil { - return err - } - - addIndex, err := invoices.CreateBucketIfNotExists( - addIndexBucket, - ) - if err != nil { - return err - } - settleIndex, err := invoices.CreateBucketIfNotExists( - settleIndexBucket, - ) - if err != nil { - return err - } - - log.Infof("Migrating invoice database to new time series format") - - // Now that we have all the buckets we need, we'll run through each - // invoice in the database, and update it to reflect the new format - // expected post migration. - // NOTE: we store the converted invoices and put them back into the - // database after the loop, since modifying the bucket within the - // ForEach loop is not safe. - var invoicesKeys [][]byte - var invoicesValues [][]byte - err = invoices.ForEach(func(invoiceNum, invoiceBytes []byte) error { - // If this is a sub bucket, then we'll skip it. - if invoiceBytes == nil { - return nil - } - - // First, we'll make a copy of the encoded invoice bytes. - invoiceBytesCopy := make([]byte, len(invoiceBytes)) - copy(invoiceBytesCopy, invoiceBytes) - - // With the bytes copied over, we'll append 24 additional - // bytes. We do this so we can decode the invoice under the new - // serialization format. - padding := bytes.Repeat([]byte{0}, 24) - invoiceBytesCopy = append(invoiceBytesCopy, padding...) - - invoiceReader := bytes.NewReader(invoiceBytesCopy) - invoice, err := deserializeInvoiceLegacy(invoiceReader) - if err != nil { - return fmt.Errorf("unable to decode invoice: %v", err) - } - - // Now that we have the fully decoded invoice, we can update - // the various indexes that we're added, and finally the - // invoice itself before re-inserting it. - - // First, we'll get the new sequence in the addIndex in order - // to create the proper mapping. - nextAddSeqNo, err := addIndex.NextSequence() - if err != nil { - return err - } - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) - err = addIndex.Put(seqNoBytes[:], invoiceNum[:]) - if err != nil { - return err - } - - log.Tracef("Adding invoice (preimage=%x, add_index=%v) to add "+ - "time series", invoice.Terms.PaymentPreimage[:], - nextAddSeqNo) - - // Next, we'll check if the invoice has been settled or not. If - // so, then we'll also add it to the settle index. - var nextSettleSeqNo uint64 - if invoice.Terms.State == ContractSettled { - nextSettleSeqNo, err = settleIndex.NextSequence() - if err != nil { - return err - } - - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) - err := settleIndex.Put(seqNoBytes[:], invoiceNum) - if err != nil { - return err - } - - invoice.AmtPaid = invoice.Terms.Value - - log.Tracef("Adding invoice (preimage=%x, "+ - "settle_index=%v) to add time series", - invoice.Terms.PaymentPreimage[:], - nextSettleSeqNo) - } - - // Finally, we'll update the invoice itself with the new - // indexing information as well as the amount paid if it has - // been settled or not. - invoice.AddIndex = nextAddSeqNo - invoice.SettleIndex = nextSettleSeqNo - - // We've fully migrated an invoice, so we'll now update the - // invoice in-place. - var b bytes.Buffer - if err := serializeInvoiceLegacy(&b, &invoice); err != nil { - return err - } - - // Save the key and value pending update for after the ForEach - // is done. - invoicesKeys = append(invoicesKeys, invoiceNum) - invoicesValues = append(invoicesValues, b.Bytes()) - return nil - }) - if err != nil { - return err - } - - // Now put the converted invoices into the DB. - for i := range invoicesKeys { - key := invoicesKeys[i] - value := invoicesValues[i] - if err := invoices.Put(key, value); err != nil { - return err - } - } - - log.Infof("Migration to invoice time series index complete!") - - return nil -} - -// migrateInvoiceTimeSeriesOutgoingPayments is a follow up to the -// migrateInvoiceTimeSeries migration. As at the time of writing, the -// OutgoingPayment struct embeddeds an instance of the Invoice struct. As a -// result, we also need to migrate the internal invoice to the new format. -func migrateInvoiceTimeSeriesOutgoingPayments(tx *bbolt.Tx) error { - payBucket := tx.Bucket(paymentBucket) - if payBucket == nil { - return nil - } - - log.Infof("Migrating invoice database to new outgoing payment format") - - // We store the keys and values we want to modify since it is not safe - // to modify them directly within the ForEach loop. - var paymentKeys [][]byte - var paymentValues [][]byte - err := payBucket.ForEach(func(payID, paymentBytes []byte) error { - log.Tracef("Migrating payment %x", payID[:]) - - // The internal invoices for each payment only contain a - // populated contract term, and creation date, as a result, - // most of the bytes will be "empty". - - // We'll calculate the end of the invoice index assuming a - // "minimal" index that's embedded within the greater - // OutgoingPayment. The breakdown is: - // 3 bytes empty var bytes, 16 bytes creation date, 16 bytes - // settled date, 32 bytes payment pre-image, 8 bytes value, 1 - // byte settled. - endOfInvoiceIndex := 1 + 1 + 1 + 16 + 16 + 32 + 8 + 1 - - // We'll now extract the prefix of the pure invoice embedded - // within. - invoiceBytes := paymentBytes[:endOfInvoiceIndex] - - // With the prefix extracted, we'll copy over the invoice, and - // also add padding for the new 24 bytes of fields, and finally - // append the remainder of the outgoing payment. - paymentCopy := make([]byte, len(invoiceBytes)) - copy(paymentCopy[:], invoiceBytes) - - padding := bytes.Repeat([]byte{0}, 24) - paymentCopy = append(paymentCopy, padding...) - paymentCopy = append( - paymentCopy, paymentBytes[endOfInvoiceIndex:]..., - ) - - // At this point, we now have the new format of the outgoing - // payments, we'll attempt to deserialize it to ensure the - // bytes are properly formatted. - paymentReader := bytes.NewReader(paymentCopy) - _, err := deserializeOutgoingPayment(paymentReader) - if err != nil { - return fmt.Errorf("unable to deserialize payment: %v", err) - } - - // Now that we know the modifications was successful, we'll - // store it to our slice of keys and values, and write it back - // to disk in the new format after the ForEach loop is over. - paymentKeys = append(paymentKeys, payID) - paymentValues = append(paymentValues, paymentCopy) - return nil - }) - if err != nil { - return err - } - - // Finally store the updated payments to the bucket. - for i := range paymentKeys { - key := paymentKeys[i] - value := paymentValues[i] - if err := payBucket.Put(key, value); err != nil { - return err - } - } - - log.Infof("Migration to outgoing payment invoices complete!") - - return nil -} - -// migrateEdgePolicies is a migration function that will update the edges -// bucket. It ensure that edges with unknown policies will also have an entry -// in the bucket. After the migration, there will be two edge entries for -// every channel, regardless of whether the policies are known. -func migrateEdgePolicies(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return nil - } - - edges := tx.Bucket(edgeBucket) - if edges == nil { - return nil - } - - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return nil - } - - // checkKey gets the policy from the database with a low-level call - // so that it is still possible to distinguish between unknown and - // not present. - checkKey := func(channelId uint64, keyBytes []byte) error { - var channelID [8]byte - byteOrder.PutUint64(channelID[:], channelId) - - _, err := fetchChanEdgePolicy(edges, - channelID[:], keyBytes, nodes) - - if err == ErrEdgeNotFound { - log.Tracef("Adding unknown edge policy present for node %x, channel %v", - keyBytes, channelId) - - err := putChanEdgePolicyUnknown(edges, channelId, keyBytes) - if err != nil { - return err - } - - return nil - } - - return err - } - - // Iterate over all channels and check both edge policies. - err := edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - infoReader := bytes.NewReader(edgeInfoBytes) - edgeInfo, err := deserializeChanEdgeInfo(infoReader) - if err != nil { - return err - } - - for _, key := range [][]byte{edgeInfo.NodeKey1Bytes[:], - edgeInfo.NodeKey2Bytes[:]} { - - if err := checkKey(edgeInfo.ChannelID, key); err != nil { - return err - } - } - - return nil - }) - - if err != nil { - return fmt.Errorf("unable to update edge policies: %v", err) - } - - log.Infof("Migration of edge policies complete!") - - return nil -} - -// paymentStatusesMigration is a database migration intended for adding payment -// statuses for each existing payment entity in bucket to be able control -// transitions of statuses and prevent cases such as double payment -func paymentStatusesMigration(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing statuses of payments, - // where a key is payment hash, value is payment status. - paymentStatuses, err := tx.CreateBucketIfNotExists(paymentStatusBucket) - if err != nil { - return err - } - - log.Infof("Migrating database to support payment statuses") - - circuitAddKey := []byte("circuit-adds") - circuits := tx.Bucket(circuitAddKey) - if circuits != nil { - log.Infof("Marking all known circuits with status InFlight") - - err = circuits.ForEach(func(k, v []byte) error { - // Parse the first 8 bytes as the short chan ID for the - // circuit. We'll skip all short chan IDs are not - // locally initiated, which includes all non-zero short - // chan ids. - chanID := binary.BigEndian.Uint64(k[:8]) - if chanID != 0 { - return nil - } - - // The payment hash is the third item in the serialized - // payment circuit. The first two items are an AddRef - // (10 bytes) and the incoming circuit key (16 bytes). - const payHashOffset = 10 + 16 - - paymentHash := v[payHashOffset : payHashOffset+32] - - return paymentStatuses.Put( - paymentHash[:], StatusInFlight.Bytes(), - ) - }) - if err != nil { - return err - } - } - - log.Infof("Marking all existing payments with status Completed") - - // Get the bucket dedicated to storing payments - bucket := tx.Bucket(paymentBucket) - if bucket == nil { - return nil - } - - // For each payment in the bucket, deserialize the payment and mark it - // as completed. - err = bucket.ForEach(func(k, v []byte) error { - // Ignores if it is sub-bucket. - if v == nil { - return nil - } - - r := bytes.NewReader(v) - payment, err := deserializeOutgoingPayment(r) - if err != nil { - return err - } - - // Calculate payment hash for current payment. - paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) - - // Update status for current payment to completed. If it fails, - // the migration is aborted and the payment bucket is returned - // to its previous state. - return paymentStatuses.Put(paymentHash[:], StatusSucceeded.Bytes()) - }) - if err != nil { - return err - } - - log.Infof("Migration of payment statuses complete!") - - return nil -} - -// migratePruneEdgeUpdateIndex is a database migration that attempts to resolve -// some lingering bugs with regards to edge policies and their update index. -// Stale entries within the edge update index were not being properly pruned due -// to a miscalculation on the offset of an edge's policy last update. This -// migration also fixes the case where the public keys within edge policies were -// being serialized with an extra byte, causing an even greater error when -// attempting to perform the offset calculation described earlier. -func migratePruneEdgeUpdateIndex(tx *bbolt.Tx) error { - // To begin the migration, we'll retrieve the update index bucket. If it - // does not exist, we have nothing left to do so we can simply exit. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return nil - } - edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) - if edgeUpdateIndex == nil { - return nil - } - - // Retrieve some buckets that will be needed later on. These should - // already exist given the assumption that the buckets above do as - // well. - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return fmt.Errorf("error creating edge index bucket: %s", err) - } - if edgeIndex == nil { - return fmt.Errorf("unable to create/fetch edge index " + - "bucket") - } - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return fmt.Errorf("unable to make node bucket") - } - - log.Info("Migrating database to properly prune edge update index") - - // We'll need to properly prune all the outdated entries within the edge - // update index. To do so, we'll gather all of the existing policies - // within the graph to re-populate them later on. - var edgeKeys [][]byte - err = edges.ForEach(func(edgeKey, edgePolicyBytes []byte) error { - // All valid entries are indexed by a public key (33 bytes) - // followed by a channel ID (8 bytes), so we'll skip any entries - // with keys that do not match this. - if len(edgeKey) != 33+8 { - return nil - } - - edgeKeys = append(edgeKeys, edgeKey) - - return nil - }) - if err != nil { - return fmt.Errorf("unable to gather existing edge policies: %v", - err) - } - - log.Info("Constructing set of edge update entries to purge.") - - // Build the set of keys that we will remove from the edge update index. - // This will include all keys contained within the bucket. - var updateKeysToRemove [][]byte - err = edgeUpdateIndex.ForEach(func(updKey, _ []byte) error { - updateKeysToRemove = append(updateKeysToRemove, updKey) - return nil - }) - if err != nil { - return fmt.Errorf("unable to gather existing edge updates: %v", - err) - } - - log.Infof("Removing %d entries from edge update index.", - len(updateKeysToRemove)) - - // With the set of keys contained in the edge update index constructed, - // we'll proceed in purging all of them from the index. - for _, updKey := range updateKeysToRemove { - if err := edgeUpdateIndex.Delete(updKey); err != nil { - return err - } - } - - log.Infof("Repopulating edge update index with %d valid entries.", - len(edgeKeys)) - - // For each edge key, we'll retrieve the policy, deserialize it, and - // re-add it to the different buckets. By doing so, we'll ensure that - // all existing edge policies are serialized correctly within their - // respective buckets and that the correct entries are populated within - // the edge update index. - for _, edgeKey := range edgeKeys { - edgePolicyBytes := edges.Get(edgeKey) - - // Skip any entries with unknown policies as there will not be - // any entries for them in the edge update index. - if bytes.Equal(edgePolicyBytes[:], unknownPolicy) { - continue - } - - edgePolicy, err := deserializeChanEdgePolicy( - bytes.NewReader(edgePolicyBytes), nodes, - ) - if err != nil { - return err - } - - _, err = updateEdgePolicy(tx, edgePolicy) - if err != nil { - return err - } - } - - log.Info("Migration to properly prune edge update index complete!") - - return nil -} - -// migrateOptionalChannelCloseSummaryFields migrates the serialized format of -// ChannelCloseSummary to a format where optional fields' presence is indicated -// with boolean markers. -func migrateOptionalChannelCloseSummaryFields(tx *bbolt.Tx) error { - closedChanBucket := tx.Bucket(closedChannelBucket) - if closedChanBucket == nil { - return nil - } - - log.Info("Migrating to new closed channel format...") - - // We store the converted keys and values and put them back into the - // database after the loop, since modifying the bucket within the - // ForEach loop is not safe. - var closedChansKeys [][]byte - var closedChansValues [][]byte - err := closedChanBucket.ForEach(func(chanID, summary []byte) error { - r := bytes.NewReader(summary) - - // Read the old (v6) format from the database. - c, err := deserializeCloseChannelSummaryV6(r) - if err != nil { - return err - } - - // Serialize using the new format, and put back into the - // bucket. - var b bytes.Buffer - if err := serializeChannelCloseSummary(&b, c); err != nil { - return err - } - - // Now that we know the modifications was successful, we'll - // Store the key and value to our slices, and write it back to - // disk in the new format after the ForEach loop is over. - closedChansKeys = append(closedChansKeys, chanID) - closedChansValues = append(closedChansValues, b.Bytes()) - return nil - }) - if err != nil { - return fmt.Errorf("unable to update closed channels: %v", err) - } - - // Now put the new format back into the DB. - for i := range closedChansKeys { - key := closedChansKeys[i] - value := closedChansValues[i] - if err := closedChanBucket.Put(key, value); err != nil { - return err - } - } - - log.Info("Migration to new closed channel format complete!") - - return nil -} - -var messageStoreBucket = []byte("message-store") - -// migrateGossipMessageStoreKeys migrates the key format for gossip messages -// found in the message store to a new one that takes into consideration the of -// the message being stored. -func migrateGossipMessageStoreKeys(tx *bbolt.Tx) error { - // We'll start by retrieving the bucket in which these messages are - // stored within. If there isn't one, there's nothing left for us to do - // so we can avoid the migration. - messageStore := tx.Bucket(messageStoreBucket) - if messageStore == nil { - return nil - } - - log.Info("Migrating to the gossip message store new key format") - - // Otherwise we'll proceed with the migration. We'll start by coalescing - // all the current messages within the store, which are indexed by the - // public key of the peer which they should be sent to, followed by the - // short channel ID of the channel for which the message belongs to. We - // should only expect to find channel announcement signatures as that - // was the only support message type previously. - msgs := make(map[[33 + 8]byte]*lnwire.AnnounceSignatures) - err := messageStore.ForEach(func(k, v []byte) error { - var msgKey [33 + 8]byte - copy(msgKey[:], k) - - msg := &lnwire.AnnounceSignatures{} - if err := msg.Decode(bytes.NewReader(v), 0); err != nil { - return err - } - - msgs[msgKey] = msg - - return nil - - }) - if err != nil { - return err - } - - // Then, we'll go over all of our messages, remove their previous entry, - // and add another with the new key format. Once we've done this for - // every message, we can consider the migration complete. - for oldMsgKey, msg := range msgs { - if err := messageStore.Delete(oldMsgKey[:]); err != nil { - return err - } - - // Construct the new key for which we'll find this message with - // in the store. It'll be the same as the old, but we'll also - // include the message type. - var msgType [2]byte - binary.BigEndian.PutUint16(msgType[:], uint16(msg.MsgType())) - newMsgKey := append(oldMsgKey[:], msgType[:]...) - - // Serialize the message with its wire encoding. - var b bytes.Buffer - if _, err := lnwire.WriteMessage(&b, msg, 0); err != nil { - return err - } - - if err := messageStore.Put(newMsgKey, b.Bytes()); err != nil { - return err - } - } - - log.Info("Migration to the gossip message store new key format complete!") - - return nil -} - -// migrateOutgoingPayments moves the OutgoingPayments into a new bucket format -// where they all reside in a top-level bucket indexed by the payment hash. In -// this sub-bucket we store information relevant to this payment, such as the -// payment status. -// -// Since the router cannot handle resumed payments that have the status -// InFlight (we have no PaymentAttemptInfo available for pre-migration -// payments) we delete those statuses, so only Completed payments remain in the -// new bucket structure. -func migrateOutgoingPayments(tx *bbolt.Tx) error { - log.Infof("Migrating outgoing payments to new bucket structure") - - oldPayments := tx.Bucket(paymentBucket) - - // Return early if there are no payments to migrate. - if oldPayments == nil { - log.Infof("No outgoing payments found, nothing to migrate.") - return nil - } - - newPayments, err := tx.CreateBucket(paymentsRootBucket) - if err != nil { - return err - } - - // Helper method to get the source pubkey. We define it such that we - // only attempt to fetch it if needed. - sourcePub := func() ([33]byte, error) { - var pub [33]byte - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return pub, ErrGraphNotFound - } - - selfPub := nodes.Get(sourceKey) - if selfPub == nil { - return pub, ErrSourceNodeNotSet - } - copy(pub[:], selfPub[:]) - return pub, nil - } - - err = oldPayments.ForEach(func(k, v []byte) error { - // Ignores if it is sub-bucket. - if v == nil { - return nil - } - - // Read the old payment format. - r := bytes.NewReader(v) - payment, err := deserializeOutgoingPayment(r) - if err != nil { - return err - } - - // Calculate payment hash from the payment preimage. - paymentHash := sha256.Sum256(payment.PaymentPreimage[:]) - - // Now create and add a PaymentCreationInfo to the bucket. - c := &PaymentCreationInfo{ - PaymentHash: paymentHash, - Value: payment.Terms.Value, - CreationDate: payment.CreationDate, - PaymentRequest: payment.PaymentRequest, - } - - var infoBuf bytes.Buffer - if err := serializePaymentCreationInfo(&infoBuf, c); err != nil { - return err - } - - sourcePubKey, err := sourcePub() - if err != nil { - return err - } - - // Do the same for the PaymentAttemptInfo. - totalAmt := payment.Terms.Value + payment.Fee - rt := route.Route{ - TotalTimeLock: payment.TimeLockLength, - TotalAmount: totalAmt, - SourcePubKey: sourcePubKey, - Hops: []*route.Hop{}, - } - for _, hop := range payment.Path { - rt.Hops = append(rt.Hops, &route.Hop{ - PubKeyBytes: hop, - AmtToForward: totalAmt, - }) - } - - // Since the old format didn't store the fee for individual - // hops, we let the last hop eat the whole fee for the total to - // add up. - if len(rt.Hops) > 0 { - rt.Hops[len(rt.Hops)-1].AmtToForward = payment.Terms.Value - } - - // Since we don't have the session key for old payments, we - // create a random one to be able to serialize the attempt - // info. - priv, _ := btcec.NewPrivateKey(btcec.S256()) - s := &PaymentAttemptInfo{ - PaymentID: 0, // unknown. - SessionKey: priv, // unknown. - Route: rt, - } - - var attemptBuf bytes.Buffer - if err := serializePaymentAttemptInfoMigration9(&attemptBuf, s); err != nil { - return err - } - - // Reuse the existing payment sequence number. - var seqNum [8]byte - copy(seqNum[:], k) - - // Create a bucket indexed by the payment hash. - bucket, err := newPayments.CreateBucket(paymentHash[:]) - - // If the bucket already exists, it means that we are migrating - // from a database containing duplicate payments to a payment - // hash. To keep this information, we store such duplicate - // payments in a sub-bucket. - if err == bbolt.ErrBucketExists { - pHashBucket := newPayments.Bucket(paymentHash[:]) - - // Create a bucket for duplicate payments within this - // payment hash's bucket. - dup, err := pHashBucket.CreateBucketIfNotExists( - paymentDuplicateBucket, - ) - if err != nil { - return err - } - - // Each duplicate will get its own sub-bucket within - // this bucket, so use their sequence number to index - // them by. - bucket, err = dup.CreateBucket(seqNum[:]) - if err != nil { - return err - } - - } else if err != nil { - return err - } - - // Store the payment's information to the bucket. - err = bucket.Put(paymentSequenceKey, seqNum[:]) - if err != nil { - return err - } - - err = bucket.Put(paymentCreationInfoKey, infoBuf.Bytes()) - if err != nil { - return err - } - - err = bucket.Put(paymentAttemptInfoKey, attemptBuf.Bytes()) - if err != nil { - return err - } - - err = bucket.Put(paymentSettleInfoKey, payment.PaymentPreimage[:]) - if err != nil { - return err - } - - return nil - }) - if err != nil { - return err - } - - // To continue producing unique sequence numbers, we set the sequence - // of the new bucket to that of the old one. - seq := oldPayments.Sequence() - if err := newPayments.SetSequence(seq); err != nil { - return err - } - - // Now we delete the old buckets. Deleting the payment status buckets - // deletes all payment statuses other than Complete. - err = tx.DeleteBucket(paymentStatusBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - // Finally delete the old payment bucket. - err = tx.DeleteBucket(paymentBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - log.Infof("Migration of outgoing payment bucket structure completed!") - return nil -} diff --git a/channeldb/migrations_test.go b/channeldb/migrations_test.go deleted file mode 100644 index 93bf602f..00000000 --- a/channeldb/migrations_test.go +++ /dev/null @@ -1,952 +0,0 @@ -package channeldb - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "fmt" - "math/rand" - "reflect" - "testing" - "time" - - "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" -) - -// TestPaymentStatusesMigration checks that already completed payments will have -// their payment statuses set to Completed after the migration. -func TestPaymentStatusesMigration(t *testing.T) { - t.Parallel() - - fakePayment := makeFakePayment() - paymentHash := sha256.Sum256(fakePayment.PaymentPreimage[:]) - - // Add fake payment to test database, verifying that it was created, - // that we have only one payment, and its status is not "Completed". - beforeMigrationFunc := func(d *DB) { - if err := d.addPayment(fakePayment); err != nil { - t.Fatalf("unable to add payment: %v", err) - } - - payments, err := d.fetchAllPayments() - if err != nil { - t.Fatalf("unable to fetch payments: %v", err) - } - - if len(payments) != 1 { - t.Fatalf("wrong qty of paymets: expected 1, got %v", - len(payments)) - } - - paymentStatus, err := d.fetchPaymentStatus(paymentHash) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - // We should receive default status if we have any in database. - if paymentStatus != StatusUnknown { - t.Fatalf("wrong payment status: expected %v, got %v", - StatusUnknown.String(), paymentStatus.String()) - } - - // Lastly, we'll add a locally-sourced circuit and - // non-locally-sourced circuit to the circuit map. The - // locally-sourced payment should end up with an InFlight - // status, while the other should remain unchanged, which - // defaults to Grounded. - err = d.Update(func(tx *bbolt.Tx) error { - circuits, err := tx.CreateBucketIfNotExists( - []byte("circuit-adds"), - ) - if err != nil { - return err - } - - groundedKey := make([]byte, 16) - binary.BigEndian.PutUint64(groundedKey[:8], 1) - binary.BigEndian.PutUint64(groundedKey[8:], 1) - - // Generated using TestHalfCircuitSerialization with nil - // ErrorEncrypter, which is the case for locally-sourced - // payments. No payment status should end up being set - // for this circuit, since the short channel id of the - // key is non-zero (e.g., a forwarded circuit). This - // will default it to Grounded. - groundedCircuit := []byte{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, - // start payment hash - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - // end payment hash - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, - 0x42, 0x40, 0x00, - } - - err = circuits.Put(groundedKey, groundedCircuit) - if err != nil { - return err - } - - inFlightKey := make([]byte, 16) - binary.BigEndian.PutUint64(inFlightKey[:8], 0) - binary.BigEndian.PutUint64(inFlightKey[8:], 1) - - // Generated using TestHalfCircuitSerialization with nil - // ErrorEncrypter, which is not the case for forwarded - // payments, but should have no impact on the - // correctness of the test. The payment status for this - // circuit should be set to InFlight, since the short - // channel id in the key is 0 (sourceHop). - inFlightCircuit := []byte{ - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x01, - // start payment hash - 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - // end payment hash - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, - 0x42, 0x40, 0x00, - } - - return circuits.Put(inFlightKey, inFlightCircuit) - }) - if err != nil { - t.Fatalf("unable to add circuit map entry: %v", err) - } - } - - // Verify that the created payment status is "Completed" for our one - // fake payment. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'paymentStatusesMigration' wasn't applied") - } - - // Check that our completed payments were migrated. - paymentStatus, err := d.fetchPaymentStatus(paymentHash) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if paymentStatus != StatusSucceeded { - t.Fatalf("wrong payment status: expected %v, got %v", - StatusSucceeded.String(), paymentStatus.String()) - } - - inFlightHash := [32]byte{ - 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - } - - // Check that the locally sourced payment was transitioned to - // InFlight. - paymentStatus, err = d.fetchPaymentStatus(inFlightHash) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if paymentStatus != StatusInFlight { - t.Fatalf("wrong payment status: expected %v, got %v", - StatusInFlight.String(), paymentStatus.String()) - } - - groundedHash := [32]byte{ - 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - } - - // Check that non-locally sourced payments remain in the default - // Grounded state. - paymentStatus, err = d.fetchPaymentStatus(groundedHash) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if paymentStatus != StatusUnknown { - t.Fatalf("wrong payment status: expected %v, got %v", - StatusUnknown.String(), paymentStatus.String()) - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - paymentStatusesMigration, - false) -} - -// TestMigrateOptionalChannelCloseSummaryFields properly converts a -// ChannelCloseSummary to the v7 format, where optional fields have their -// presence indicated with boolean markers. -func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { - t.Parallel() - - chanState, err := createTestChannelState(nil) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - var chanPointBuf bytes.Buffer - err = writeOutpoint(&chanPointBuf, &chanState.FundingOutpoint) - if err != nil { - t.Fatalf("unable to write outpoint: %v", err) - } - - chanID := chanPointBuf.Bytes() - - testCases := []struct { - closeSummary *ChannelCloseSummary - oldSerialization func(c *ChannelCloseSummary) []byte - }{ - { - // A close summary where none of the new fields are - // set. - closeSummary: &ChannelCloseSummary{ - ChanPoint: chanState.FundingOutpoint, - ShortChanID: chanState.ShortChanID(), - ChainHash: chanState.ChainHash, - ClosingTXID: testTx.TxHash(), - CloseHeight: 100, - RemotePub: chanState.IdentityPub, - Capacity: chanState.Capacity, - SettledBalance: btcutil.Amount(50000), - CloseType: RemoteForceClose, - IsPending: true, - - // The last fields will be unset. - RemoteCurrentRevocation: nil, - LocalChanConfig: ChannelConfig{}, - RemoteNextRevocation: nil, - }, - - // In the old format the last field written is the - // IsPendingField. It should be converted by adding an - // extra boolean marker at the end to indicate that the - // remaining fields are not there. - oldSerialization: func(cs *ChannelCloseSummary) []byte { - var buf bytes.Buffer - err := WriteElements(&buf, cs.ChanPoint, - cs.ShortChanID, cs.ChainHash, - cs.ClosingTXID, cs.CloseHeight, - cs.RemotePub, cs.Capacity, - cs.SettledBalance, cs.TimeLockedBalance, - cs.CloseType, cs.IsPending, - ) - if err != nil { - t.Fatal(err) - } - - // For the old format, these are all the fields - // that are written. - return buf.Bytes() - }, - }, - { - // A close summary where the new fields are present, - // but the optional RemoteNextRevocation field is not - // set. - closeSummary: &ChannelCloseSummary{ - ChanPoint: chanState.FundingOutpoint, - ShortChanID: chanState.ShortChanID(), - ChainHash: chanState.ChainHash, - ClosingTXID: testTx.TxHash(), - CloseHeight: 100, - RemotePub: chanState.IdentityPub, - Capacity: chanState.Capacity, - SettledBalance: btcutil.Amount(50000), - CloseType: RemoteForceClose, - IsPending: true, - RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, - LocalChanConfig: chanState.LocalChanCfg, - - // RemoteNextRevocation is optional, and here - // it is not set. - RemoteNextRevocation: nil, - }, - - // In the old format the last field written is the - // LocalChanConfig. This indicates that the optional - // RemoteNextRevocation field is not present. It should - // be converted by adding boolean markers for all these - // fields. - oldSerialization: func(cs *ChannelCloseSummary) []byte { - var buf bytes.Buffer - err := WriteElements(&buf, cs.ChanPoint, - cs.ShortChanID, cs.ChainHash, - cs.ClosingTXID, cs.CloseHeight, - cs.RemotePub, cs.Capacity, - cs.SettledBalance, cs.TimeLockedBalance, - cs.CloseType, cs.IsPending, - ) - if err != nil { - t.Fatal(err) - } - - err = WriteElements(&buf, cs.RemoteCurrentRevocation) - if err != nil { - t.Fatal(err) - } - - err = writeChanConfig(&buf, &cs.LocalChanConfig) - if err != nil { - t.Fatal(err) - } - - // RemoteNextRevocation is not written. - return buf.Bytes() - }, - }, - { - // A close summary where all fields are present. - closeSummary: &ChannelCloseSummary{ - ChanPoint: chanState.FundingOutpoint, - ShortChanID: chanState.ShortChanID(), - ChainHash: chanState.ChainHash, - ClosingTXID: testTx.TxHash(), - CloseHeight: 100, - RemotePub: chanState.IdentityPub, - Capacity: chanState.Capacity, - SettledBalance: btcutil.Amount(50000), - CloseType: RemoteForceClose, - IsPending: true, - RemoteCurrentRevocation: chanState.RemoteCurrentRevocation, - LocalChanConfig: chanState.LocalChanCfg, - - // RemoteNextRevocation is optional, and in - // this case we set it. - RemoteNextRevocation: chanState.RemoteNextRevocation, - }, - - // In the old format all the fields are written. It - // should be converted by adding boolean markers for - // all these fields. - oldSerialization: func(cs *ChannelCloseSummary) []byte { - var buf bytes.Buffer - err := WriteElements(&buf, cs.ChanPoint, - cs.ShortChanID, cs.ChainHash, - cs.ClosingTXID, cs.CloseHeight, - cs.RemotePub, cs.Capacity, - cs.SettledBalance, cs.TimeLockedBalance, - cs.CloseType, cs.IsPending, - ) - if err != nil { - t.Fatal(err) - } - - err = WriteElements(&buf, cs.RemoteCurrentRevocation) - if err != nil { - t.Fatal(err) - } - - err = writeChanConfig(&buf, &cs.LocalChanConfig) - if err != nil { - t.Fatal(err) - } - - err = WriteElements(&buf, cs.RemoteNextRevocation) - if err != nil { - t.Fatal(err) - } - - return buf.Bytes() - }, - }, - } - - for _, test := range testCases { - - // Before the migration we must add the old format to the DB. - beforeMigrationFunc := func(d *DB) { - - // Get the old serialization format for this test's - // close summary, and it to the closed channel bucket. - old := test.oldSerialization(test.closeSummary) - err = d.Update(func(tx *bbolt.Tx) error { - closedChanBucket, err := tx.CreateBucketIfNotExists( - closedChannelBucket, - ) - if err != nil { - return err - } - return closedChanBucket.Put(chanID, old) - }) - if err != nil { - t.Fatalf("unable to add old serialization: %v", - err) - } - } - - // After the migration it should be found in the new format. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration wasn't applied") - } - - // We generate the new serialized version, to check - // against what is found in the DB. - var b bytes.Buffer - err = serializeChannelCloseSummary(&b, test.closeSummary) - if err != nil { - t.Fatalf("unable to serialize: %v", err) - } - newSerialization := b.Bytes() - - var dbSummary []byte - err = d.View(func(tx *bbolt.Tx) error { - closedChanBucket := tx.Bucket(closedChannelBucket) - if closedChanBucket == nil { - return errors.New("unable to find bucket") - } - - // Get the serialized verision from the DB and - // make sure it matches what we expected. - dbSummary = closedChanBucket.Get(chanID) - if !bytes.Equal(dbSummary, newSerialization) { - return fmt.Errorf("unexpected new " + - "serialization") - } - return nil - }) - if err != nil { - t.Fatalf("unable to view DB: %v", err) - } - - // Finally we fetch the deserialized summary from the - // DB and check that it is equal to our original one. - dbChannels, err := d.FetchClosedChannels(false) - if err != nil { - t.Fatalf("unable to fetch closed channels: %v", - err) - } - - if len(dbChannels) != 1 { - t.Fatalf("expected 1 closed channels, found %v", - len(dbChannels)) - } - - dbChan := dbChannels[0] - if !reflect.DeepEqual(dbChan, test.closeSummary) { - dbChan.RemotePub.Curve = nil - test.closeSummary.RemotePub.Curve = nil - t.Fatalf("not equal: %v vs %v", - spew.Sdump(dbChan), - spew.Sdump(test.closeSummary)) - } - - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrateOptionalChannelCloseSummaryFields, - false) - } -} - -// TestMigrateGossipMessageStoreKeys ensures that the migration to the new -// gossip message store key format is successful/unsuccessful under various -// scenarios. -func TestMigrateGossipMessageStoreKeys(t *testing.T) { - t.Parallel() - - // Construct the message which we'll use to test the migration, along - // with its old and new key formats. - shortChanID := lnwire.ShortChannelID{BlockHeight: 10} - msg := &lnwire.AnnounceSignatures{ShortChannelID: shortChanID} - - var oldMsgKey [33 + 8]byte - copy(oldMsgKey[:33], pubKey.SerializeCompressed()) - binary.BigEndian.PutUint64(oldMsgKey[33:41], shortChanID.ToUint64()) - - var newMsgKey [33 + 8 + 2]byte - copy(newMsgKey[:41], oldMsgKey[:]) - binary.BigEndian.PutUint16(newMsgKey[41:43], uint16(msg.MsgType())) - - // Before the migration, we'll create the bucket where the messages - // should live and insert them. - beforeMigration := func(db *DB) { - var b bytes.Buffer - if err := msg.Encode(&b, 0); err != nil { - t.Fatalf("unable to serialize message: %v", err) - } - - err := db.Update(func(tx *bbolt.Tx) error { - messageStore, err := tx.CreateBucketIfNotExists( - messageStoreBucket, - ) - if err != nil { - return err - } - - return messageStore.Put(oldMsgKey[:], b.Bytes()) - }) - if err != nil { - t.Fatal(err) - } - } - - // After the migration, we'll make sure that: - // 1. We cannot find the message under its old key. - // 2. We can find the message under its new key. - // 3. The message matches the original. - afterMigration := func(db *DB) { - meta, err := db.FetchMeta(nil) - if err != nil { - t.Fatalf("unable to fetch db version: %v", err) - } - if meta.DbVersionNumber != 1 { - t.Fatalf("migration should have succeeded but didn't") - } - - var rawMsg []byte - err = db.View(func(tx *bbolt.Tx) error { - messageStore := tx.Bucket(messageStoreBucket) - if messageStore == nil { - return errors.New("message store bucket not " + - "found") - } - rawMsg = messageStore.Get(oldMsgKey[:]) - if rawMsg != nil { - t.Fatal("expected to not find message under " + - "old key, but did") - } - rawMsg = messageStore.Get(newMsgKey[:]) - if rawMsg == nil { - return fmt.Errorf("expected to find message " + - "under new key, but didn't") - } - - return nil - }) - if err != nil { - t.Fatal(err) - } - - gotMsg, err := lnwire.ReadMessage(bytes.NewReader(rawMsg), 0) - if err != nil { - t.Fatalf("unable to deserialize raw message: %v", err) - } - if !reflect.DeepEqual(msg, gotMsg) { - t.Fatalf("expected message: %v\ngot message: %v", - spew.Sdump(msg), spew.Sdump(gotMsg)) - } - } - - applyMigration( - t, beforeMigration, afterMigration, - migrateGossipMessageStoreKeys, false, - ) -} - -// TestOutgoingPaymentsMigration checks that OutgoingPayments are migrated to a -// new bucket structure after the migration. -func TestOutgoingPaymentsMigration(t *testing.T) { - t.Parallel() - - const numPayments = 4 - var oldPayments []*outgoingPayment - - // Add fake payments to test database, verifying that it was created. - beforeMigrationFunc := func(d *DB) { - for i := 0; i < numPayments; i++ { - var p *outgoingPayment - var err error - - // We fill the database with random payments. For the - // very last one we'll use a duplicate of the first, to - // ensure we are able to handle migration from a - // database that has copies. - if i < numPayments-1 { - p, err = makeRandomFakePayment() - if err != nil { - t.Fatalf("unable to create payment: %v", - err) - } - } else { - p = oldPayments[0] - } - - if err := d.addPayment(p); err != nil { - t.Fatalf("unable to add payment: %v", err) - } - - oldPayments = append(oldPayments, p) - } - - payments, err := d.fetchAllPayments() - if err != nil { - t.Fatalf("unable to fetch payments: %v", err) - } - - if len(payments) != numPayments { - t.Fatalf("wrong qty of paymets: expected %d got %v", - numPayments, len(payments)) - } - } - - // Verify that all payments were migrated. - afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'paymentStatusesMigration' wasn't applied") - } - - sentPayments, err := d.fetchPaymentsMigration9() - if err != nil { - t.Fatalf("unable to fetch sent payments: %v", err) - } - - if len(sentPayments) != numPayments { - t.Fatalf("expected %d payments, got %d", numPayments, - len(sentPayments)) - } - - graph := d.ChannelGraph() - sourceNode, err := graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - - for i, p := range sentPayments { - // The payment status should be Completed. - if p.Status != StatusSucceeded { - t.Fatalf("expected Completed, got %v", p.Status) - } - - // Check that the sequence number is preserved. They - // start counting at 1. - if p.sequenceNum != uint64(i+1) { - t.Fatalf("expected seqnum %d, got %d", i, - p.sequenceNum) - } - - // Order of payments should be be preserved. - old := oldPayments[i] - - // Check the individial fields. - if p.Info.Value != old.Terms.Value { - t.Fatalf("value mismatch") - } - - if p.Info.CreationDate != old.CreationDate { - t.Fatalf("date mismatch") - } - - if !bytes.Equal(p.Info.PaymentRequest, old.PaymentRequest) { - t.Fatalf("payreq mismatch") - } - - if *p.PaymentPreimage != old.PaymentPreimage { - t.Fatalf("preimage mismatch") - } - - if p.Attempt.Route.TotalFees() != old.Fee { - t.Fatalf("Fee mismatch") - } - - if p.Attempt.Route.TotalAmount != old.Fee+old.Terms.Value { - t.Fatalf("Total amount mismatch") - } - - if p.Attempt.Route.TotalTimeLock != old.TimeLockLength { - t.Fatalf("timelock mismatch") - } - - if p.Attempt.Route.SourcePubKey != sourceNode.PubKeyBytes { - t.Fatalf("source mismatch: %x vs %x", - p.Attempt.Route.SourcePubKey[:], - sourceNode.PubKeyBytes[:]) - } - - for i, hop := range old.Path { - if hop != p.Attempt.Route.Hops[i].PubKeyBytes { - t.Fatalf("path mismatch") - } - } - } - - // Finally, check that the payment sequence number is updated - // to reflect the migrated payments. - err = d.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return fmt.Errorf("payments bucket not found") - } - - seq := payments.Sequence() - if seq != numPayments { - return fmt.Errorf("expected sequence to be "+ - "%d, got %d", numPayments, seq) - } - - return nil - }) - if err != nil { - t.Fatal(err) - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrateOutgoingPayments, - false) -} - -func makeRandPaymentCreationInfo() (*PaymentCreationInfo, error) { - var payHash lntypes.Hash - if _, err := rand.Read(payHash[:]); err != nil { - return nil, err - } - - return &PaymentCreationInfo{ - PaymentHash: payHash, - Value: lnwire.MilliSatoshi(rand.Int63()), - CreationDate: time.Now(), - PaymentRequest: []byte("test"), - }, nil -} - -// TestPaymentRouteSerialization tests that we're able to properly migrate -// existing payments on disk that contain the traversed routes to the new -// routing format which supports the TLV payloads. We also test that the -// migration is able to handle duplicate payment attempts. -func TestPaymentRouteSerialization(t *testing.T) { - t.Parallel() - - legacyHop1 := &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - LegacyPayload: true, - AmtToForward: 555, - } - legacyHop2 := &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - LegacyPayload: true, - AmtToForward: 555, - } - legacyRoute := route.Route{ - TotalTimeLock: 123, - TotalAmount: 1234567, - SourcePubKey: route.NewVertex(pub), - Hops: []*route.Hop{legacyHop1, legacyHop2}, - } - - const numPayments = 4 - var oldPayments []*Payment - - sharedPayAttempt := PaymentAttemptInfo{ - PaymentID: 1, - SessionKey: priv, - Route: legacyRoute, - } - - // We'll first add a series of fake payments, using the existing legacy - // serialization format. - beforeMigrationFunc := func(d *DB) { - err := d.Update(func(tx *bbolt.Tx) error { - paymentsBucket, err := tx.CreateBucket( - paymentsRootBucket, - ) - if err != nil { - t.Fatalf("unable to create new payments "+ - "bucket: %v", err) - } - - for i := 0; i < numPayments; i++ { - var seqNum [8]byte - byteOrder.PutUint64(seqNum[:], uint64(i)) - - // All payments will be randomly generated, - // other than the final payment. We'll force - // the final payment to re-use an existing - // payment hash so we can insert it into the - // duplicate payment hash bucket. - var payInfo *PaymentCreationInfo - if i < numPayments-1 { - payInfo, err = makeRandPaymentCreationInfo() - if err != nil { - t.Fatalf("unable to create "+ - "payment: %v", err) - } - } else { - payInfo = oldPayments[0].Info - } - - // Next, legacy encoded when needed, we'll - // serialize the info and the attempt. - var payInfoBytes bytes.Buffer - err = serializePaymentCreationInfo( - &payInfoBytes, payInfo, - ) - if err != nil { - t.Fatalf("unable to encode pay "+ - "info: %v", err) - } - var payAttemptBytes bytes.Buffer - err = serializePaymentAttemptInfoLegacy( - &payAttemptBytes, &sharedPayAttempt, - ) - if err != nil { - t.Fatalf("unable to encode payment attempt: "+ - "%v", err) - } - - // Before we write to disk, we'll need to fetch - // the proper bucket. If this is the duplicate - // payment, then we'll grab the dup bucket, - // otherwise, we'll use the top level bucket. - var payHashBucket *bbolt.Bucket - if i < numPayments-1 { - payHashBucket, err = paymentsBucket.CreateBucket( - payInfo.PaymentHash[:], - ) - if err != nil { - t.Fatalf("unable to create payments bucket: %v", err) - } - } else { - payHashBucket = paymentsBucket.Bucket( - payInfo.PaymentHash[:], - ) - dupPayBucket, err := payHashBucket.CreateBucket( - paymentDuplicateBucket, - ) - if err != nil { - t.Fatalf("unable to create "+ - "dup hash bucket: %v", err) - } - - payHashBucket, err = dupPayBucket.CreateBucket( - seqNum[:], - ) - if err != nil { - t.Fatalf("unable to make dup "+ - "bucket: %v", err) - } - } - - err = payHashBucket.Put(paymentSequenceKey, seqNum[:]) - if err != nil { - t.Fatalf("unable to write seqno: %v", err) - } - - err = payHashBucket.Put( - paymentCreationInfoKey, payInfoBytes.Bytes(), - ) - if err != nil { - t.Fatalf("unable to write creation "+ - "info: %v", err) - } - - err = payHashBucket.Put( - paymentAttemptInfoKey, payAttemptBytes.Bytes(), - ) - if err != nil { - t.Fatalf("unable to write attempt "+ - "info: %v", err) - } - - oldPayments = append(oldPayments, &Payment{ - Info: payInfo, - Attempt: &sharedPayAttempt, - }) - } - - return nil - }) - if err != nil { - t.Fatalf("unable to create test payments: %v", err) - } - } - - afterMigrationFunc := func(d *DB) { - newPayments, err := d.FetchPayments() - if err != nil { - t.Fatalf("unable to fetch new payments: %v", err) - } - - if len(newPayments) != numPayments { - t.Fatalf("expected %d payments, got %d", numPayments, - len(newPayments)) - } - - for i, p := range newPayments { - // Order of payments should be be preserved. - old := oldPayments[i] - - if p.Attempt.PaymentID != old.Attempt.PaymentID { - t.Fatalf("wrong pay ID: expected %v, got %v", - p.Attempt.PaymentID, - old.Attempt.PaymentID) - } - - if p.Attempt.Route.TotalFees() != old.Attempt.Route.TotalFees() { - t.Fatalf("Fee mismatch") - } - - if p.Attempt.Route.TotalAmount != old.Attempt.Route.TotalAmount { - t.Fatalf("Total amount mismatch") - } - - if p.Attempt.Route.TotalTimeLock != old.Attempt.Route.TotalTimeLock { - t.Fatalf("timelock mismatch") - } - - if p.Attempt.Route.SourcePubKey != old.Attempt.Route.SourcePubKey { - t.Fatalf("source mismatch: %x vs %x", - p.Attempt.Route.SourcePubKey[:], - old.Attempt.Route.SourcePubKey[:]) - } - - for i, hop := range p.Attempt.Route.Hops { - if !reflect.DeepEqual(hop, legacyRoute.Hops[i]) { - t.Fatalf("hop mismatch") - } - } - } - } - - applyMigration(t, - beforeMigrationFunc, - afterMigrationFunc, - migrateRouteSerialization, - false) -} diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 8cc036fc..a792f965 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -12,7 +12,6 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/tlv" ) @@ -53,34 +52,6 @@ var ( } ) -func makeFakePayment() *outgoingPayment { - fakeInvoice := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - Memo: []byte("fake memo"), - Receipt: []byte("fake receipt"), - PaymentRequest: []byte(""), - } - - copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) - fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) - - fakePath := make([][33]byte, 3) - for i := 0; i < 3; i++ { - copy(fakePath[i][:], bytes.Repeat([]byte{byte(i)}, 33)) - } - - fakePayment := &outgoingPayment{ - Invoice: *fakeInvoice, - Fee: 101, - Path: fakePath, - TimeLockLength: 1000, - } - copy(fakePayment.PaymentPreimage[:], rev[:]) - return fakePayment -} - func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { var preimg lntypes.Preimage copy(preimg[:], rev[:]) @@ -114,58 +85,6 @@ func randomBytes(minLen, maxLen int) ([]byte, error) { return randBuf, nil } -func makeRandomFakePayment() (*outgoingPayment, error) { - var err error - fakeInvoice := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - } - - fakeInvoice.Memo, err = randomBytes(1, 50) - if err != nil { - return nil, err - } - - fakeInvoice.Receipt, err = randomBytes(1, 50) - if err != nil { - return nil, err - } - - fakeInvoice.PaymentRequest, err = randomBytes(1, 50) - if err != nil { - return nil, err - } - - preImg, err := randomBytes(32, 33) - if err != nil { - return nil, err - } - copy(fakeInvoice.Terms.PaymentPreimage[:], preImg) - - fakeInvoice.Terms.Value = lnwire.MilliSatoshi(rand.Intn(10000)) - - fakePathLen := 1 + rand.Intn(5) - fakePath := make([][33]byte, fakePathLen) - for i := 0; i < fakePathLen; i++ { - b, err := randomBytes(33, 34) - if err != nil { - return nil, err - } - copy(fakePath[i][:], b) - } - - fakePayment := &outgoingPayment{ - Invoice: *fakeInvoice, - Fee: lnwire.MilliSatoshi(rand.Intn(1001)), - Path: fakePath, - TimeLockLength: uint32(rand.Intn(10000)), - } - copy(fakePayment.PaymentPreimage[:], fakeInvoice.Terms.PaymentPreimage[:]) - - return fakePayment, nil -} - func TestSentPaymentSerialization(t *testing.T) { t.Parallel() From 43449ca7a7da63539092ddb1cf25d81ed763c56e Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 29 Oct 2019 11:49:07 +0100 Subject: [PATCH 3/6] channeldb/migration_01_to_11: add references to untested migrations --- channeldb/migration_01_to_11/migrations_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index cdaef57f..598832c2 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -950,3 +950,14 @@ func TestPaymentRouteSerialization(t *testing.T) { MigrateRouteSerialization, false) } + +// TestNotCoveredMigrations only references migrations that are not referenced +// anywhere else in this package. This prevents false positives when linting +// with unused. +func TestNotCoveredMigrations(t *testing.T) { + _ = MigrateNodeAndEdgeUpdateIndex + _ = MigrateInvoiceTimeSeries + _ = MigrateInvoiceTimeSeriesOutgoingPayments + _ = MigrateEdgePolicies + _ = MigratePruneEdgeUpdateIndex +} From f5191440c5faf3078bdb77c78b9a4b17b551087f Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Tue, 29 Oct 2019 12:07:59 +0100 Subject: [PATCH 4/6] channeldb: initialize migrations logger --- channeldb/log.go | 2 ++ channeldb/migration_01_to_11/log.go | 20 +++----------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/channeldb/log.go b/channeldb/log.go index e0158d45..30ddff03 100644 --- a/channeldb/log.go +++ b/channeldb/log.go @@ -3,6 +3,7 @@ package channeldb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/channeldb/migration_01_to_11" ) // log is a logger that is initialized with no output filters. This @@ -25,4 +26,5 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + migration_01_to_11.UseLogger(logger) } diff --git a/channeldb/migration_01_to_11/log.go b/channeldb/migration_01_to_11/log.go index 17958b19..b169b5af 100644 --- a/channeldb/migration_01_to_11/log.go +++ b/channeldb/migration_01_to_11/log.go @@ -2,27 +2,13 @@ package migration_01_to_11 import ( "github.com/btcsuite/btclog" - "github.com/lightningnetwork/lnd/build" ) -// log is a logger that is initialized with no output filters. This -// means the package will not perform any logging by default until the caller -// requests it. -var log btclog.Logger - -func init() { - UseLogger(build.NewSubLogger("CHDB", nil)) -} - -// DisableLog disables all library log output. Logging output is disabled -// by default until UseLogger is called. -func DisableLog() { - UseLogger(btclog.Disabled) -} +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled // UseLogger uses a specified Logger to output package logging info. -// This should be used in preference to SetLogWriter if the caller is also -// using btclog. func UseLogger(logger btclog.Logger) { log = logger } From 60503d6c44567f1aed3c9a7b51b1a725652db911 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 24 Oct 2019 12:45:07 +0200 Subject: [PATCH 5/6] channeldb/migration_01_to_11: remove unused code --- channeldb/migration_01_to_11/README.md | 24 - channeldb/migration_01_to_11/addr_test.go | 149 - channeldb/migration_01_to_11/channel.go | 2066 ----------- channeldb/migration_01_to_11/channel_cache.go | 50 - .../migration_01_to_11/channel_cache_test.go | 105 - channeldb/migration_01_to_11/channel_test.go | 820 ----- channeldb/migration_01_to_11/codec.go | 6 - channeldb/migration_01_to_11/db.go | 786 ----- channeldb/migration_01_to_11/db_test.go | 471 --- channeldb/migration_01_to_11/doc.go | 1 - channeldb/migration_01_to_11/error.go | 63 - channeldb/migration_01_to_11/fees.go | 1 - .../migration_01_to_11/forwarding_log.go | 274 -- .../migration_01_to_11/forwarding_log_test.go | 265 -- .../migration_01_to_11/forwarding_package.go | 928 ----- .../forwarding_package_test.go | 815 ----- channeldb/migration_01_to_11/graph.go | 2883 +-------------- channeldb/migration_01_to_11/graph_test.go | 3140 ----------------- channeldb/migration_01_to_11/invoice_test.go | 694 ---- channeldb/migration_01_to_11/invoices.go | 770 ---- channeldb/migration_01_to_11/nodes.go | 316 -- channeldb/migration_01_to_11/nodes_test.go | 140 - channeldb/migration_01_to_11/options.go | 21 - .../migration_01_to_11/payment_control.go | 474 --- .../payment_control_test.go | 550 --- channeldb/migration_01_to_11/payments.go | 42 - channeldb/migration_01_to_11/payments_test.go | 216 -- channeldb/migration_01_to_11/reject_cache.go | 95 - .../migration_01_to_11/reject_cache_test.go | 107 - channeldb/migration_01_to_11/waitingproof.go | 251 -- .../migration_01_to_11/waitingproof_test.go | 59 - channeldb/migration_01_to_11/witness_cache.go | 229 -- .../migration_01_to_11/witness_cache_test.go | 238 -- 33 files changed, 1 insertion(+), 17048 deletions(-) delete mode 100644 channeldb/migration_01_to_11/README.md delete mode 100644 channeldb/migration_01_to_11/addr_test.go delete mode 100644 channeldb/migration_01_to_11/channel_cache.go delete mode 100644 channeldb/migration_01_to_11/channel_cache_test.go delete mode 100644 channeldb/migration_01_to_11/db_test.go delete mode 100644 channeldb/migration_01_to_11/doc.go delete mode 100644 channeldb/migration_01_to_11/fees.go delete mode 100644 channeldb/migration_01_to_11/forwarding_log.go delete mode 100644 channeldb/migration_01_to_11/forwarding_log_test.go delete mode 100644 channeldb/migration_01_to_11/forwarding_package.go delete mode 100644 channeldb/migration_01_to_11/forwarding_package_test.go delete mode 100644 channeldb/migration_01_to_11/invoice_test.go delete mode 100644 channeldb/migration_01_to_11/nodes.go delete mode 100644 channeldb/migration_01_to_11/nodes_test.go delete mode 100644 channeldb/migration_01_to_11/payment_control_test.go delete mode 100644 channeldb/migration_01_to_11/reject_cache.go delete mode 100644 channeldb/migration_01_to_11/reject_cache_test.go delete mode 100644 channeldb/migration_01_to_11/waitingproof.go delete mode 100644 channeldb/migration_01_to_11/waitingproof_test.go delete mode 100644 channeldb/migration_01_to_11/witness_cache.go delete mode 100644 channeldb/migration_01_to_11/witness_cache_test.go diff --git a/channeldb/migration_01_to_11/README.md b/channeldb/migration_01_to_11/README.md deleted file mode 100644 index 7e3a81ef..00000000 --- a/channeldb/migration_01_to_11/README.md +++ /dev/null @@ -1,24 +0,0 @@ -channeldb -========== - -[![Build Status](http://img.shields.io/travis/lightningnetwork/lnd.svg)](https://travis-ci.org/lightningnetwork/lnd) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/lightningnetwork/lnd/blob/master/LICENSE) -[![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/lightningnetwork/lnd/channeldb) - -The channeldb implements the persistent storage engine for `lnd` and -generically a data storage layer for the required state within the Lightning -Network. The backing storage engine is -[boltdb](https://github.com/coreos/bbolt), an embedded pure-go key-value store -based off of LMDB. - -The package implements an object-oriented storage model with queries and -mutations flowing through a particular object instance rather than the database -itself. The storage implemented by the objects includes: open channels, past -commitment revocation states, the channel graph which includes authenticated -node and channel announcements, outgoing payments, and invoices - -## Installation and Updating - -```bash -$ go get -u github.com/lightningnetwork/lnd/channeldb -``` diff --git a/channeldb/migration_01_to_11/addr_test.go b/channeldb/migration_01_to_11/addr_test.go deleted file mode 100644 index 8cdf99c3..00000000 --- a/channeldb/migration_01_to_11/addr_test.go +++ /dev/null @@ -1,149 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "net" - "strings" - "testing" - - "github.com/lightningnetwork/lnd/tor" -) - -type unknownAddrType struct{} - -func (t unknownAddrType) Network() string { return "unknown" } -func (t unknownAddrType) String() string { return "unknown" } - -var testIP4 = net.ParseIP("192.168.1.1") -var testIP6 = net.ParseIP("2001:0db8:0000:0000:0000:ff00:0042:8329") - -var addrTests = []struct { - expAddr net.Addr - serErr string -}{ - // Valid addresses. - { - expAddr: &net.TCPAddr{ - IP: testIP4, - Port: 12345, - }, - }, - { - expAddr: &net.TCPAddr{ - IP: testIP6, - Port: 65535, - }, - }, - { - expAddr: &tor.OnionAddr{ - OnionService: "3g2upl4pq6kufc4m.onion", - Port: 9735, - }, - }, - { - expAddr: &tor.OnionAddr{ - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.onion", - Port: 80, - }, - }, - - // Invalid addresses. - { - expAddr: unknownAddrType{}, - serErr: ErrUnknownAddressType.Error(), - }, - { - expAddr: &net.TCPAddr{ - // Remove last byte of IPv4 address. - IP: testIP4[:len(testIP4)-1], - Port: 12345, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Add an extra byte of IPv4 address. - IP: append(testIP4, 0xff), - Port: 12345, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Remove last byte of IPv6 address. - IP: testIP6[:len(testIP6)-1], - Port: 65535, - }, - serErr: "unable to encode", - }, - { - expAddr: &net.TCPAddr{ - // Add an extra byte to the IPv6 address. - IP: append(testIP6, 0xff), - Port: 65535, - }, - serErr: "unable to encode", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid suffix. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyd.inion", - Port: 80, - }, - serErr: "invalid suffix", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid length. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyy.onion", - Port: 80, - }, - serErr: "unknown onion service length", - }, - { - expAddr: &tor.OnionAddr{ - // Invalid encoding. - OnionService: "vww6ybal4bd7szmgncyruucpgfkqahzddi37ktceo3ah7ngmcopnpyyA.onion", - Port: 80, - }, - serErr: "illegal base32", - }, -} - -// TestAddrSerialization tests that the serialization method used by channeldb -// for net.Addr's works as intended. -func TestAddrSerialization(t *testing.T) { - t.Parallel() - - var b bytes.Buffer - for _, test := range addrTests { - err := serializeAddr(&b, test.expAddr) - switch { - case err == nil && test.serErr != "": - t.Fatalf("expected serialization err for addr %v", - test.expAddr) - - case err != nil && test.serErr == "": - t.Fatalf("unexpected serialization err for addr %v: %v", - test.expAddr, err) - - case err != nil && !strings.Contains(err.Error(), test.serErr): - t.Fatalf("unexpected serialization err for addr %v, "+ - "want: %v, got %v", test.expAddr, test.serErr, - err) - - case err != nil: - continue - } - - addr, err := deserializeAddr(&b) - if err != nil { - t.Fatalf("unable to deserialize address: %v", err) - } - - if addr.String() != test.expAddr.String() { - t.Fatalf("expected address %v after serialization, "+ - "got %v", addr, test.expAddr) - } - } -} diff --git a/channeldb/migration_01_to_11/channel.go b/channeldb/migration_01_to_11/channel.go index 23d66852..e67c0c69 100644 --- a/channeldb/migration_01_to_11/channel.go +++ b/channeldb/migration_01_to_11/channel.go @@ -1,12 +1,9 @@ package migration_01_to_11 import ( - "bytes" - "encoding/binary" "errors" "fmt" "io" - "net" "strconv" "strings" "sync" @@ -15,8 +12,6 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" @@ -36,90 +31,6 @@ var ( // // TODO(roasbeef): flesh out comment openChannelBucket = []byte("open-chan-bucket") - - // chanInfoKey can be accessed within the bucket for a channel - // (identified by its chanPoint). This key stores all the static - // information for a channel which is decided at the end of the - // funding flow. - chanInfoKey = []byte("chan-info-key") - - // chanCommitmentKey can be accessed within the sub-bucket for a - // particular channel. This key stores the up to date commitment state - // for a particular channel party. Appending a 0 to the end of this key - // indicates it's the commitment for the local party, and appending a 1 - // to the end of this key indicates it's the commitment for the remote - // party. - chanCommitmentKey = []byte("chan-commitment-key") - - // revocationStateKey stores their current revocation hash, our - // preimage producer and their preimage store. - revocationStateKey = []byte("revocation-state-key") - - // dataLossCommitPointKey stores the commitment point received from the - // remote peer during a channel sync in case we have lost channel state. - dataLossCommitPointKey = []byte("data-loss-commit-point-key") - - // closingTxKey points to a the closing tx that we broadcasted when - // moving the channel to state CommitBroadcasted. - closingTxKey = []byte("closing-tx-key") - - // commitDiffKey stores the current pending commitment state we've - // extended to the remote party (if any). Each time we propose a new - // state, we store the information necessary to reconstruct this state - // from the prior commitment. This allows us to resync the remote party - // to their expected state in the case of message loss. - // - // TODO(roasbeef): rename to commit chain? - commitDiffKey = []byte("commit-diff-key") - - // revocationLogBucket is dedicated for storing the necessary delta - // state between channel updates required to re-construct a past state - // in order to punish a counterparty attempting a non-cooperative - // channel closure. This key should be accessed from within the - // sub-bucket of a target channel, identified by its channel point. - revocationLogBucket = []byte("revocation-log-key") -) - -var ( - // ErrNoCommitmentsFound is returned when a channel has not set - // commitment states. - ErrNoCommitmentsFound = fmt.Errorf("no commitments found") - - // ErrNoChanInfoFound is returned when a particular channel does not - // have any channels state. - ErrNoChanInfoFound = fmt.Errorf("no chan info found") - - // ErrNoRevocationsFound is returned when revocation state for a - // particular channel cannot be found. - ErrNoRevocationsFound = fmt.Errorf("no revocations found") - - // ErrNoPendingCommit is returned when there is not a pending - // commitment for a remote party. A new commitment is written to disk - // each time we write a new state in order to be properly fault - // tolerant. - ErrNoPendingCommit = fmt.Errorf("no pending commits found") - - // ErrInvalidCircuitKeyLen signals that a circuit key could not be - // decoded because the byte slice is of an invalid length. - ErrInvalidCircuitKeyLen = fmt.Errorf( - "length of serialized circuit key must be 16 bytes") - - // ErrNoCommitPoint is returned when no data loss commit point is found - // in the database. - ErrNoCommitPoint = fmt.Errorf("no commit point found") - - // ErrNoCloseTx is returned when no closing tx is found for a channel - // in the state CommitBroadcasted. - ErrNoCloseTx = fmt.Errorf("no closing tx found") - - // ErrNoRestoredChannelMutation is returned when a caller attempts to - // mutate a channel that's been recovered. - ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + - "channel state") - - // ErrChanBorked is returned when a caller attempts to mutate a borked - // channel. - ErrChanBorked = fmt.Errorf("cannot mutate borked channel") ) // ChannelType is an enum-like type that describes one of several possible @@ -136,30 +47,8 @@ const ( // SingleFunder represents a channel wherein one party solely funds the // entire capacity of the channel. SingleFunder ChannelType = 0 - - // DualFunder represents a channel wherein both parties contribute - // funds towards the total capacity of the channel. The channel may be - // funded symmetrically or asymmetrically. - DualFunder ChannelType = 1 - - // SingleFunderTweakless is similar to the basic SingleFunder channel - // type, but it omits the tweak for one's key in the commitment - // transaction of the remote party. - SingleFunderTweakless ChannelType = 2 ) -// IsSingleFunder returns true if the channel type if one of the known single -// funder variants. -func (c ChannelType) IsSingleFunder() bool { - return c == SingleFunder || c == SingleFunderTweakless -} - -// IsTweakless returns true if the target channel uses a commitment that -// doesn't tweak the key for the remote party. -func (c ChannelType) IsTweakless() bool { - return c == SingleFunderTweakless -} - // ChannelConstraints represents a set of constraints meant to allow a node to // limit their exposure, enact flow control and ensure that all HTLCs are // economically relevant. This struct will be mirrored for both sides of the @@ -444,10 +333,6 @@ type OpenChannel struct { // negotiate fees, or close the channel. IsInitiator bool - // chanStatus is the current status of this channel. If it is not in - // the state Default, it should not be used for forwarding payments. - chanStatus ChannelStatus - // FundingBroadcastHeight is the height in which the funding // transaction was broadcast. This value can be used by higher level // sub-systems to determine if a channel is stale and/or should have @@ -519,11 +404,6 @@ type OpenChannel struct { // implementation of secret store is shachain store. RevocationStore shachain.Store - // Packager is used to create and update forwarding packages for this - // channel, which encodes all necessary information to recover from - // failures and reforward HTLCs that were not fully processed. - Packager FwdPackager - // FundingTxn is the transaction containing this channel's funding // outpoint. Upon restarts, this txn will be rebroadcast if the channel // is found to be pending. @@ -548,657 +428,6 @@ func (c *OpenChannel) ShortChanID() lnwire.ShortChannelID { return c.ShortChannelID } -// ChanStatus returns the current ChannelStatus of this channel. -func (c *OpenChannel) ChanStatus() ChannelStatus { - c.RLock() - defer c.RUnlock() - - return c.chanStatus -} - -// ApplyChanStatus allows the caller to modify the internal channel state in a -// thead-safe manner. -func (c *OpenChannel) ApplyChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() - - return c.putChanStatus(status) -} - -// ClearChanStatus allows the caller to clear a particular channel status from -// the primary channel status bit field. After this method returns, a call to -// HasChanStatus(status) should return false. -func (c *OpenChannel) ClearChanStatus(status ChannelStatus) error { - c.Lock() - defer c.Unlock() - - return c.clearChanStatus(status) -} - -// HasChanStatus returns true if the internal bitfield channel status of the -// target channel has the specified status bit set. -func (c *OpenChannel) HasChanStatus(status ChannelStatus) bool { - c.RLock() - defer c.RUnlock() - - return c.hasChanStatus(status) -} - -func (c *OpenChannel) hasChanStatus(status ChannelStatus) bool { - return c.chanStatus&status == status -} - -// RefreshShortChanID updates the in-memory short channel ID using the latest -// value observed on disk. -func (c *OpenChannel) RefreshShortChanID() error { - c.Lock() - defer c.Unlock() - - var sid lnwire.ShortChannelID - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - sid = channel.ShortChannelID - - return nil - }) - if err != nil { - return err - } - - c.ShortChannelID = sid - c.Packager = NewChannelPackager(sid) - - return nil -} - -// fetchChanBucket is a helper function that returns the bucket where a -// channel's data resides in given: the public key for the node, the outpoint, -// and the chainhash that the channel resides on. -func fetchChanBucket(tx *bbolt.Tx, nodeKey *btcec.PublicKey, - outPoint *wire.OutPoint, chainHash chainhash.Hash) (*bbolt.Bucket, error) { - - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return nil, ErrNoChanDBExists - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := nodeKey.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return nil, ErrNoActiveChannels - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket := nodeChanBucket.Bucket(chainHash[:]) - if chainBucket == nil { - return nil, ErrNoActiveChannels - } - - // With the bucket for the node and chain fetched, we can now go down - // another level, for this channel itself. - var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, outPoint); err != nil { - return nil, err - } - chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) - if chanBucket == nil { - return nil, ErrChannelNotFound - } - - return chanBucket, nil -} - -// fullSync syncs the contents of an OpenChannel while re-using an existing -// database transaction. -func (c *OpenChannel) fullSync(tx *bbolt.Tx) error { - // First fetch the top level bucket which stores all data related to - // current, active channels. - openChanBucket, err := tx.CreateBucketIfNotExists(openChannelBucket) - if err != nil { - return err - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - nodePub := c.IdentityPub.SerializeCompressed() - nodeChanBucket, err := openChanBucket.CreateBucketIfNotExists(nodePub) - if err != nil { - return err - } - - // We'll then recurse down an additional layer in order to fetch the - // bucket for this particular chain. - chainBucket, err := nodeChanBucket.CreateBucketIfNotExists(c.ChainHash[:]) - if err != nil { - return err - } - - // With the bucket for the node fetched, we can now go down another - // level, creating the bucket for this channel itself. - var chanPointBuf bytes.Buffer - if err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint); err != nil { - return err - } - chanBucket, err := chainBucket.CreateBucket( - chanPointBuf.Bytes(), - ) - switch { - case err == bbolt.ErrBucketExists: - // If this channel already exists, then in order to avoid - // overriding it, we'll return an error back up to the caller. - return ErrChanAlreadyExists - case err != nil: - return err - } - - return putOpenChannel(chanBucket, c) -} - -// MarkAsOpen marks a channel as fully open given a locator that uniquely -// describes its location within the chain. -func (c *OpenChannel) MarkAsOpen(openLoc lnwire.ShortChannelID) error { - c.Lock() - defer c.Unlock() - - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - channel.IsPending = false - channel.ShortChannelID = openLoc - - return putOpenChannel(chanBucket, channel) - }); err != nil { - return err - } - - c.IsPending = false - c.ShortChannelID = openLoc - c.Packager = NewChannelPackager(openLoc) - - return nil -} - -// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the -// passed commitPoint for use to retrieve funds in case the remote force closes -// the channel. -func (c *OpenChannel) MarkDataLoss(commitPoint *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - var b bytes.Buffer - if err := WriteElement(&b, commitPoint); err != nil { - return err - } - - putCommitPoint := func(chanBucket *bbolt.Bucket) error { - return chanBucket.Put(dataLossCommitPointKey, b.Bytes()) - } - - return c.putChanStatus(ChanStatusLocalDataLoss, putCommitPoint) -} - -// DataLossCommitPoint retrieves the stored commit point set during -// MarkDataLoss. If not found ErrNoCommitPoint is returned. -func (c *OpenChannel) DataLossCommitPoint() (*btcec.PublicKey, error) { - var commitPoint *btcec.PublicKey - - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCommitPoint - default: - return err - } - - bs := chanBucket.Get(dataLossCommitPointKey) - if bs == nil { - return ErrNoCommitPoint - } - r := bytes.NewReader(bs) - if err := ReadElements(r, &commitPoint); err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, err - } - - return commitPoint, nil -} - -// MarkBorked marks the event when the channel as reached an irreconcilable -// state, such as a channel breach or state desynchronization. Borked channels -// should never be added to the switch. -func (c *OpenChannel) MarkBorked() error { - c.Lock() - defer c.Unlock() - - return c.putChanStatus(ChanStatusBorked) -} - -// ChanSyncMsg returns the ChannelReestablish message that should be sent upon -// reconnection with the remote peer that we're maintaining this channel with. -// The information contained within this message is necessary to re-sync our -// commitment chains in the case of a last or only partially processed message. -// When the remote party receiver this message one of three things may happen: -// -// 1. We're fully synced and no messages need to be sent. -// 2. We didn't get the last CommitSig message they sent, to they'll re-send -// it. -// 3. We didn't get the last RevokeAndAck message they sent, so they'll -// re-send it. -// -// If this is a restored channel, having status ChanStatusRestored, then we'll -// modify our typical chan sync message to ensure they force close even if -// we're on the very first state. -func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { - c.Lock() - defer c.Unlock() - - // The remote commitment height that we'll send in the - // ChannelReestablish message is our current commitment height plus - // one. If the receiver thinks that our commitment height is actually - // *equal* to this value, then they'll re-send the last commitment that - // they sent but we never fully processed. - localHeight := c.LocalCommitment.CommitHeight - nextLocalCommitHeight := localHeight + 1 - - // The second value we'll send is the height of the remote commitment - // from our PoV. If the receiver thinks that their height is actually - // *one plus* this value, then they'll re-send their last revocation. - remoteChainTipHeight := c.RemoteCommitment.CommitHeight - - // If this channel has undergone a commitment update, then in order to - // prove to the remote party our knowledge of their prior commitment - // state, we'll also send over the last commitment secret that the - // remote party sent. - var lastCommitSecret [32]byte - if remoteChainTipHeight != 0 { - remoteSecret, err := c.RevocationStore.LookUp( - remoteChainTipHeight - 1, - ) - if err != nil { - return nil, err - } - lastCommitSecret = [32]byte(*remoteSecret) - } - - // Additionally, we'll send over the current unrevoked commitment on - // our local commitment transaction. - currentCommitSecret, err := c.RevocationProducer.AtIndex( - localHeight, - ) - if err != nil { - return nil, err - } - - // If we've restored this channel, then we'll purposefully give them an - // invalid LocalUnrevokedCommitPoint so they'll force close the channel - // allowing us to sweep our funds. - if c.hasChanStatus(ChanStatusRestored) { - currentCommitSecret[0] ^= 1 - - // If this is a tweakless channel, then we'll purposefully send - // a next local height taht's invalid to trigger a force close - // on their end. We do this as tweakless channels don't require - // that the commitment point is valid, only that it's present. - if c.ChanType.IsTweakless() { - nextLocalCommitHeight = 0 - } - } - - return &lnwire.ChannelReestablish{ - ChanID: lnwire.NewChanIDFromOutPoint( - &c.FundingOutpoint, - ), - NextLocalCommitHeight: nextLocalCommitHeight, - RemoteCommitTailHeight: remoteChainTipHeight, - LastRemoteCommitSecret: lastCommitSecret, - LocalUnrevokedCommitPoint: input.ComputeCommitmentPoint( - currentCommitSecret[:], - ), - }, nil -} - -// isBorked returns true if the channel has been marked as borked in the -// database. This requires an existing database transaction to already be -// active. -// -// NOTE: The primary mutex should already be held before this method is called. -func (c *OpenChannel) isBorked(chanBucket *bbolt.Bucket) (bool, error) { - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return false, err - } - - return channel.chanStatus != ChanStatusDefault, nil -} - -// MarkCommitmentBroadcasted marks the channel as a commitment transaction has -// been broadcast, either our own or the remote, and we should watch the chain -// for it to confirm before taking any further action. It takes as argument the -// closing tx _we believe_ will appear in the chain. This is only used to -// republish this tx at startup to ensure propagation, and we should still -// handle the case where a different tx actually hits the chain. -func (c *OpenChannel) MarkCommitmentBroadcasted(closeTx *wire.MsgTx) error { - c.Lock() - defer c.Unlock() - - var b bytes.Buffer - if err := WriteElement(&b, closeTx); err != nil { - return err - } - - putClosingTx := func(chanBucket *bbolt.Bucket) error { - return chanBucket.Put(closingTxKey, b.Bytes()) - } - - return c.putChanStatus(ChanStatusCommitBroadcasted, putClosingTx) -} - -// BroadcastedCommitment retrieves the stored closing tx set during -// MarkCommitmentBroadcasted. If not found ErrNoCloseTx is returned. -func (c *OpenChannel) BroadcastedCommitment() (*wire.MsgTx, error) { - var closeTx *wire.MsgTx - - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoCloseTx - default: - return err - } - - bs := chanBucket.Get(closingTxKey) - if bs == nil { - return ErrNoCloseTx - } - r := bytes.NewReader(bs) - return ReadElement(r, &closeTx) - }) - if err != nil { - return nil, err - } - - return closeTx, nil -} - -// putChanStatus appends the given status to the channel. fs is an optional -// list of closures that are given the chanBucket in order to atomically add -// extra information together with the new status. -func (c *OpenChannel) putChanStatus(status ChannelStatus, - fs ...func(*bbolt.Bucket) error) error { - - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Add this status to the existing bitvector found in the DB. - status = channel.chanStatus | status - channel.chanStatus = status - - if err := putOpenChannel(chanBucket, channel); err != nil { - return err - } - - for _, f := range fs { - if err := f(chanBucket); err != nil { - return err - } - } - - return nil - }); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -func (c *OpenChannel) clearChanStatus(status ChannelStatus) error { - if err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - channel, err := fetchOpenChannel(chanBucket, &c.FundingOutpoint) - if err != nil { - return err - } - - // Unset this bit in the bitvector on disk. - status = channel.chanStatus & ^status - channel.chanStatus = status - - return putOpenChannel(chanBucket, channel) - }); err != nil { - return err - } - - // Update the in-memory representation to keep it in sync with the DB. - c.chanStatus = status - - return nil -} - -// putChannel serializes, and stores the current state of the channel in its -// entirety. -func putOpenChannel(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - // First, we'll write out all the relatively static fields, that are - // decided upon initial channel creation. - if err := putChanInfo(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan info: %v", err) - } - - // With the static channel info written out, we'll now write out the - // current commitment state for both parties. - if err := putChanCommitments(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan commitments: %v", err) - } - - // Finally, we'll write out the revocation state for both parties - // within a distinct key space. - if err := putChanRevocationState(chanBucket, channel); err != nil { - return fmt.Errorf("unable to store chan revocations: %v", err) - } - - return nil -} - -// fetchOpenChannel retrieves, and deserializes (including decrypting -// sensitive) the complete channel currently active with the passed nodeID. -func fetchOpenChannel(chanBucket *bbolt.Bucket, - chanPoint *wire.OutPoint) (*OpenChannel, error) { - - channel := &OpenChannel{ - FundingOutpoint: *chanPoint, - } - - // First, we'll read all the static information that changes less - // frequently from disk. - if err := fetchChanInfo(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan info: %v", err) - } - - // With the static information read, we'll now read the current - // commitment state for both sides of the channel. - if err := fetchChanCommitments(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan commitments: %v", err) - } - - // Finally, we'll retrieve the current revocation state so we can - // properly - if err := fetchChanRevocationState(chanBucket, channel); err != nil { - return nil, fmt.Errorf("unable to fetch chan revocations: %v", err) - } - - channel.Packager = NewChannelPackager(channel.ShortChannelID) - - return channel, nil -} - -// SyncPending writes the contents of the channel to the database while it's in -// the pending (waiting for funding confirmation) state. The IsPending flag -// will be set to true. When the channel's funding transaction is confirmed, -// the channel should be marked as "open" and the IsPending flag set to false. -// Note that this function also creates a LinkNode relationship between this -// newly created channel and a new LinkNode instance. This allows listing all -// channels in the database globally, or according to the LinkNode they were -// created with. -// -// TODO(roasbeef): addr param should eventually be an lnwire.NetAddress type -// that includes service bits. -func (c *OpenChannel) SyncPending(addr net.Addr, pendingHeight uint32) error { - c.Lock() - defer c.Unlock() - - c.FundingBroadcastHeight = pendingHeight - - return c.Db.Update(func(tx *bbolt.Tx) error { - return syncNewChannel(tx, c, []net.Addr{addr}) - }) -} - -// syncNewChannel will write the passed channel to disk, and also create a -// LinkNode (if needed) for the channel peer. -func syncNewChannel(tx *bbolt.Tx, c *OpenChannel, addrs []net.Addr) error { - // First, sync all the persistent channel state to disk. - if err := c.fullSync(tx); err != nil { - return err - } - - nodeInfoBucket, err := tx.CreateBucketIfNotExists(nodeInfoBucket) - if err != nil { - return err - } - - // If a LinkNode for this identity public key already exists, - // then we can exit early. - nodePub := c.IdentityPub.SerializeCompressed() - if nodeInfoBucket.Get(nodePub) != nil { - return nil - } - - // Next, we need to establish a (possibly) new LinkNode relationship - // for this channel. The LinkNode metadata contains reachability, - // up-time, and service bits related information. - linkNode := c.Db.NewLinkNode(wire.MainNet, c.IdentityPub, addrs...) - - // TODO(roasbeef): do away with link node all together? - - return putLinkNode(nodeInfoBucket, linkNode) -} - -// UpdateCommitment updates the commitment state for the specified party -// (remote or local). The commitment stat completely describes the balance -// state at this point in the commitment chain. This method its to be called on -// two occasions: when we revoke our prior commitment state, and when the -// remote party revokes their prior commitment state. -func (c *OpenChannel) UpdateCommitment(newCommitment *ChannelCommitment) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state as all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - if err = putChanInfo(chanBucket, c); err != nil { - return fmt.Errorf("unable to store chan info: %v", err) - } - - // With the proper bucket fetched, we'll now write the latest - // commitment state to disk for the target party. - err = putChanCommitment( - chanBucket, newCommitment, true, - ) - if err != nil { - return fmt.Errorf("unable to store chan "+ - "revocations: %v", err) - } - - return nil - }) - if err != nil { - return err - } - - c.LocalCommitment = *newCommitment - - return nil -} - // HTLC is the on-disk representation of a hash time-locked contract. HTLCs are // contained within ChannelDeltas which encode the current state of the // commitment between state updates. @@ -1247,101 +476,6 @@ type HTLC struct { LogIndex uint64 } -// SerializeHtlcs writes out the passed set of HTLC's into the passed writer -// using the current default on-disk serialization format. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func SerializeHtlcs(b io.Writer, htlcs ...HTLC) error { - numHtlcs := uint16(len(htlcs)) - if err := WriteElement(b, numHtlcs); err != nil { - return err - } - - for _, htlc := range htlcs { - if err := WriteElements(b, - htlc.Signature, htlc.RHash, htlc.Amt, htlc.RefundTimeout, - htlc.OutputIndex, htlc.Incoming, htlc.OnionBlob[:], - htlc.HtlcIndex, htlc.LogIndex, - ); err != nil { - return err - } - } - - return nil -} - -// DeserializeHtlcs attempts to read out a slice of HTLC's from the passed -// io.Reader. The bytes within the passed reader MUST have been previously -// written to using the SerializeHtlcs function. -// -// NOTE: This API is NOT stable, the on-disk format will likely change in the -// future. -func DeserializeHtlcs(r io.Reader) ([]HTLC, error) { - var numHtlcs uint16 - if err := ReadElement(r, &numHtlcs); err != nil { - return nil, err - } - - var htlcs []HTLC - if numHtlcs == 0 { - return htlcs, nil - } - - htlcs = make([]HTLC, numHtlcs) - for i := uint16(0); i < numHtlcs; i++ { - if err := ReadElements(r, - &htlcs[i].Signature, &htlcs[i].RHash, &htlcs[i].Amt, - &htlcs[i].RefundTimeout, &htlcs[i].OutputIndex, - &htlcs[i].Incoming, &htlcs[i].OnionBlob, - &htlcs[i].HtlcIndex, &htlcs[i].LogIndex, - ); err != nil { - return htlcs, err - } - } - - return htlcs, nil -} - -// Copy returns a full copy of the target HTLC. -func (h *HTLC) Copy() HTLC { - clone := HTLC{ - Incoming: h.Incoming, - Amt: h.Amt, - RefundTimeout: h.RefundTimeout, - OutputIndex: h.OutputIndex, - } - copy(clone.Signature[:], h.Signature) - copy(clone.RHash[:], h.RHash[:]) - - return clone -} - -// LogUpdate represents a pending update to the remote commitment chain. The -// log update may be an add, fail, or settle entry. We maintain this data in -// order to be able to properly retransmit our proposed -// state if necessary. -type LogUpdate struct { - // LogIndex is the log index of this proposed commitment update entry. - LogIndex uint64 - - // UpdateMsg is the update message that was included within the our - // local update log. The LogIndex value denotes the log index of this - // update which will be used when restoring our local update log if - // we're left with a dangling update on restart. - UpdateMsg lnwire.Message -} - -// Encode writes a log update to the provided io.Writer. -func (l *LogUpdate) Encode(w io.Writer) error { - return WriteElements(w, l.LogIndex, l.UpdateMsg) -} - -// Decode reads a log update from the provided io.Reader. -func (l *LogUpdate) Decode(r io.Reader) error { - return ReadElements(r, &l.LogIndex, &l.UpdateMsg) -} - // CircuitKey is used by a channel to uniquely identify the HTLCs it receives // from the switch, and is used to purge our in-memory state of HTLCs that have // already been processed by a link. Two list of CircuitKeys are included in @@ -1360,723 +494,20 @@ type CircuitKey struct { HtlcID uint64 } -// SetBytes deserializes the given bytes into this CircuitKey. -func (k *CircuitKey) SetBytes(bs []byte) error { - if len(bs) != 16 { - return ErrInvalidCircuitKeyLen - } - - k.ChanID = lnwire.NewShortChanIDFromInt( - binary.BigEndian.Uint64(bs[:8])) - k.HtlcID = binary.BigEndian.Uint64(bs[8:]) - - return nil -} - -// Bytes returns the serialized bytes for this circuit key. -func (k CircuitKey) Bytes() []byte { - var bs = make([]byte, 16) - binary.BigEndian.PutUint64(bs[:8], k.ChanID.ToUint64()) - binary.BigEndian.PutUint64(bs[8:], k.HtlcID) - return bs -} - -// Encode writes a CircuitKey to the provided io.Writer. -func (k *CircuitKey) Encode(w io.Writer) error { - var scratch [16]byte - binary.BigEndian.PutUint64(scratch[:8], k.ChanID.ToUint64()) - binary.BigEndian.PutUint64(scratch[8:], k.HtlcID) - - _, err := w.Write(scratch[:]) - return err -} - -// Decode reads a CircuitKey from the provided io.Reader. -func (k *CircuitKey) Decode(r io.Reader) error { - var scratch [16]byte - - if _, err := io.ReadFull(r, scratch[:]); err != nil { - return err - } - k.ChanID = lnwire.NewShortChanIDFromInt( - binary.BigEndian.Uint64(scratch[:8])) - k.HtlcID = binary.BigEndian.Uint64(scratch[8:]) - - return nil -} - // String returns a string representation of the CircuitKey. func (k CircuitKey) String() string { return fmt.Sprintf("(Chan ID=%s, HTLC ID=%d)", k.ChanID, k.HtlcID) } -// CommitDiff represents the delta needed to apply the state transition between -// two subsequent commitment states. Given state N and state N+1, one is able -// to apply the set of messages contained within the CommitDiff to N to arrive -// at state N+1. Each time a new commitment is extended, we'll write a new -// commitment (along with the full commitment state) to disk so we can -// re-transmit the state in the case of a connection loss or message drop. -type CommitDiff struct { - // ChannelCommitment is the full commitment state that one would arrive - // at by applying the set of messages contained in the UpdateDiff to - // the prior accepted commitment. - Commitment ChannelCommitment - - // LogUpdates is the set of messages sent prior to the commitment state - // transition in question. Upon reconnection, if we detect that they - // don't have the commitment, then we re-send this along with the - // proper signature. - LogUpdates []LogUpdate - - // CommitSig is the exact CommitSig message that should be sent after - // the set of LogUpdates above has been retransmitted. The signatures - // within this message should properly cover the new commitment state - // and also the HTLC's within the new commitment state. - CommitSig *lnwire.CommitSig - - // OpenedCircuitKeys is a set of unique identifiers for any downstream - // Add packets included in this commitment txn. After a restart, this - // set of htlcs is acked from the link's incoming mailbox to ensure - // there isn't an attempt to re-add them to this commitment txn. - OpenedCircuitKeys []CircuitKey - - // ClosedCircuitKeys records the unique identifiers for any settle/fail - // packets that were resolved by this commitment txn. After a restart, - // this is used to ensure those circuits are removed from the circuit - // map, and the downstream packets in the link's mailbox are removed. - ClosedCircuitKeys []CircuitKey - - // AddAcks specifies the locations (commit height, pkg index) of any - // Adds that were failed/settled in this commit diff. This will ack - // entries in *this* channel's forwarding packages. - // - // NOTE: This value is not serialized, it is used to atomically mark the - // resolution of adds, such that they will not be reprocessed after a - // restart. - AddAcks []AddRef - - // SettleFailAcks specifies the locations (chan id, commit height, pkg - // index) of any Settles or Fails that were locked into this commit - // diff, and originate from *another* channel, i.e. the outgoing link. - // - // NOTE: This value is not serialized, it is used to atomically acks - // settles and fails from the forwarding packages of other channels, - // such that they will not be reforwarded internally after a restart. - SettleFailAcks []SettleFailRef -} - -func serializeCommitDiff(w io.Writer, diff *CommitDiff) error { - if err := serializeChanCommit(w, &diff.Commitment); err != nil { - return err - } - - if err := diff.CommitSig.Encode(w, 0); err != nil { - return err - } - - numUpdates := uint16(len(diff.LogUpdates)) - if err := binary.Write(w, byteOrder, numUpdates); err != nil { - return err - } - - for _, diff := range diff.LogUpdates { - err := WriteElements(w, diff.LogIndex, diff.UpdateMsg) - if err != nil { - return err - } - } - - numOpenRefs := uint16(len(diff.OpenedCircuitKeys)) - if err := binary.Write(w, byteOrder, numOpenRefs); err != nil { - return err - } - - for _, openRef := range diff.OpenedCircuitKeys { - err := WriteElements(w, openRef.ChanID, openRef.HtlcID) - if err != nil { - return err - } - } - - numClosedRefs := uint16(len(diff.ClosedCircuitKeys)) - if err := binary.Write(w, byteOrder, numClosedRefs); err != nil { - return err - } - - for _, closedRef := range diff.ClosedCircuitKeys { - err := WriteElements(w, closedRef.ChanID, closedRef.HtlcID) - if err != nil { - return err - } - } - - return nil -} - -func deserializeCommitDiff(r io.Reader) (*CommitDiff, error) { - var ( - d CommitDiff - err error - ) - - d.Commitment, err = deserializeChanCommit(r) - if err != nil { - return nil, err - } - - d.CommitSig = &lnwire.CommitSig{} - if err := d.CommitSig.Decode(r, 0); err != nil { - return nil, err - } - - var numUpdates uint16 - if err := binary.Read(r, byteOrder, &numUpdates); err != nil { - return nil, err - } - - d.LogUpdates = make([]LogUpdate, numUpdates) - for i := 0; i < int(numUpdates); i++ { - err := ReadElements(r, - &d.LogUpdates[i].LogIndex, &d.LogUpdates[i].UpdateMsg, - ) - if err != nil { - return nil, err - } - } - - var numOpenRefs uint16 - if err := binary.Read(r, byteOrder, &numOpenRefs); err != nil { - return nil, err - } - - d.OpenedCircuitKeys = make([]CircuitKey, numOpenRefs) - for i := 0; i < int(numOpenRefs); i++ { - err := ReadElements(r, - &d.OpenedCircuitKeys[i].ChanID, - &d.OpenedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - var numClosedRefs uint16 - if err := binary.Read(r, byteOrder, &numClosedRefs); err != nil { - return nil, err - } - - d.ClosedCircuitKeys = make([]CircuitKey, numClosedRefs) - for i := 0; i < int(numClosedRefs); i++ { - err := ReadElements(r, - &d.ClosedCircuitKeys[i].ChanID, - &d.ClosedCircuitKeys[i].HtlcID) - if err != nil { - return nil, err - } - } - - return &d, nil -} - -// AppendRemoteCommitChain appends a new CommitDiff to the end of the -// commitment chain for the remote party. This method is to be used once we -// have prepared a new commitment state for the remote party, but before we -// transmit it to the remote party. The contents of the argument should be -// sufficient to retransmit the updates and signature needed to reconstruct the -// state in full, in the case that we need to retransmit. -func (c *OpenChannel) AppendRemoteCommitChain(diff *CommitDiff) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - return c.Db.Update(func(tx *bbolt.Tx) error { - // First, we'll grab the writable bucket where this channel's - // data resides. - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Any outgoing settles and fails necessarily have a - // corresponding adds in this channel's forwarding packages. - // Mark all of these as being fully processed in our forwarding - // package, which prevents us from reprocessing them after - // startup. - err = c.Packager.AckAddHtlcs(tx, diff.AddAcks...) - if err != nil { - return err - } - - // Additionally, we ack from any fails or settles that are - // persisted in another channel's forwarding package. This - // prevents the same fails and settles from being retransmitted - // after restarts. The actual fail or settle we need to - // propagate to the remote party is now in the commit diff. - err = c.Packager.AckSettleFails(tx, diff.SettleFailAcks...) - if err != nil { - return err - } - - // TODO(roasbeef): use seqno to derive key for later LCP - - // With the bucket retrieved, we'll now serialize the commit - // diff itself, and write it to disk. - var b bytes.Buffer - if err := serializeCommitDiff(&b, diff); err != nil { - return err - } - return chanBucket.Put(commitDiffKey, b.Bytes()) - }) -} - -// RemoteCommitChainTip returns the "tip" of the current remote commitment -// chain. This value will be non-nil iff, we've created a new commitment for -// the remote party that they haven't yet ACK'd. In this case, their commitment -// chain will have a length of two: their current unrevoked commitment, and -// this new pending commitment. Once they revoked their prior state, we'll swap -// these pointers, causing the tip and the tail to point to the same entry. -func (c *OpenChannel) RemoteCommitChainTip() (*CommitDiff, error) { - var cd *CommitDiff - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - switch err { - case nil: - case ErrNoChanDBExists, ErrNoActiveChannels, ErrChannelNotFound: - return ErrNoPendingCommit - default: - return err - } - - tipBytes := chanBucket.Get(commitDiffKey) - if tipBytes == nil { - return ErrNoPendingCommit - } - - tipReader := bytes.NewReader(tipBytes) - dcd, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - - cd = dcd - return nil - }) - if err != nil { - return nil, err - } - - return cd, err -} - -// InsertNextRevocation inserts the _next_ commitment point (revocation) into -// the database, and also modifies the internal RemoteNextRevocation attribute -// to point to the passed key. This method is to be using during final channel -// set up, _after_ the channel has been fully confirmed. -// -// NOTE: If this method isn't called, then the target channel won't be able to -// propose new states for the commitment state of the remote party. -func (c *OpenChannel) InsertNextRevocation(revKey *btcec.PublicKey) error { - c.Lock() - defer c.Unlock() - - c.RemoteNextRevocation = revKey - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return putChanRevocationState(chanBucket, c) - }) - if err != nil { - return err - } - - return nil -} - -// AdvanceCommitChainTail records the new state transition within an on-disk -// append-only log which records all state transitions by the remote peer. In -// the case of an uncooperative broadcast of a prior state by the remote peer, -// this log can be consulted in order to reconstruct the state needed to -// rectify the situation. This method will add the current commitment for the -// remote party to the revocation log, and promote the current pending -// commitment to the current remote commitment. -func (c *OpenChannel) AdvanceCommitChainTail(fwdPkg *FwdPkg) error { - c.Lock() - defer c.Unlock() - - // If this is a restored channel, then we want to avoid mutating the - // state at all, as it's impossible to do so in a protocol compliant - // manner. - if c.hasChanStatus(ChanStatusRestored) { - return ErrNoRestoredChannelMutation - } - - var newRemoteCommit *ChannelCommitment - - err := c.Db.Update(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - // If the channel is marked as borked, then for safety reasons, - // we shouldn't attempt any further updates. - isBorked, err := c.isBorked(chanBucket) - if err != nil { - return err - } - if isBorked { - return ErrChanBorked - } - - // Persist the latest preimage state to disk as the remote peer - // has just added to our local preimage store, and given us a - // new pending revocation key. - if err := putChanRevocationState(chanBucket, c); err != nil { - return err - } - - // With the current preimage producer/store state updated, - // append a new log entry recording this the delta of this - // state transition. - // - // TODO(roasbeef): could make the deltas relative, would save - // space, but then tradeoff for more disk-seeks to recover the - // full state. - logKey := revocationLogBucket - logBucket, err := chanBucket.CreateBucketIfNotExists(logKey) - if err != nil { - return err - } - - // Before we append this revoked state to the revocation log, - // we'll swap out what's currently the tail of the commit tip, - // with the current locked-in commitment for the remote party. - tipBytes := chanBucket.Get(commitDiffKey) - tipReader := bytes.NewReader(tipBytes) - newCommit, err := deserializeCommitDiff(tipReader) - if err != nil { - return err - } - err = putChanCommitment( - chanBucket, &newCommit.Commitment, false, - ) - if err != nil { - return err - } - if err := chanBucket.Delete(commitDiffKey); err != nil { - return err - } - - // With the commitment pointer swapped, we can now add the - // revoked (prior) state to the revocation log. - // - // TODO(roasbeef): store less - err = appendChannelLogEntry(logBucket, &c.RemoteCommitment) - if err != nil { - return err - } - - // Lastly, we write the forwarding package to disk so that we - // can properly recover from failures and reforward HTLCs that - // have not received a corresponding settle/fail. - if err := c.Packager.AddFwdPkg(tx, fwdPkg); err != nil { - return err - } - - newRemoteCommit = &newCommit.Commitment - - return nil - }) - if err != nil { - return err - } - - // With the db transaction complete, we'll swap over the in-memory - // pointer of the new remote commitment, which was previously the tip - // of the commit chain. - c.RemoteCommitment = *newRemoteCommit - - return nil -} - -// NextLocalHtlcIndex returns the next unallocated local htlc index. To ensure -// this always returns the next index that has been not been allocated, this -// will first try to examine any pending commitments, before falling back to the -// last locked-in local commitment. -func (c *OpenChannel) NextLocalHtlcIndex() (uint64, error) { - // First, load the most recent commit diff that we initiated for the - // remote party. If no pending commit is found, this is not treated as - // a critical error, since we can always fall back. - pendingRemoteCommit, err := c.RemoteCommitChainTip() - if err != nil && err != ErrNoPendingCommit { - return 0, err - } - - // If a pending commit was found, its local htlc index will be at least - // as large as the one on our local commitment. - if pendingRemoteCommit != nil { - return pendingRemoteCommit.Commitment.LocalHtlcIndex, nil - } - - // Otherwise, fallback to using the local htlc index of our commitment. - return c.LocalCommitment.LocalHtlcIndex, nil -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in map indexed by the -// remote commitment height at which the updates were locked in. -func (c *OpenChannel) LoadFwdPkgs() ([]*FwdPkg, error) { - c.RLock() - defer c.RUnlock() - - var fwdPkgs []*FwdPkg - if err := c.Db.View(func(tx *bbolt.Tx) error { - var err error - fwdPkgs, err = c.Packager.LoadFwdPkgs(tx) - return err - }); err != nil { - return nil, err - } - - return fwdPkgs, nil -} - -// AckAddHtlcs updates the AckAddFilter containing any of the provided AddRefs -// indicating that a response to this Add has been committed to the remote party. -// Doing so will prevent these Add HTLCs from being reforwarded internally. -func (c *OpenChannel) AckAddHtlcs(addRefs ...AddRef) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.AckAddHtlcs(tx, addRefs...) - }) -} - -// AckSettleFails updates the SettleFailFilter containing any of the provided -// SettleFailRefs, indicating that the response has been delivered to the -// incoming link, corresponding to a particular AddRef. Doing so will prevent -// the responses from being retransmitted internally. -func (c *OpenChannel) AckSettleFails(settleFailRefs ...SettleFailRef) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.AckSettleFails(tx, settleFailRefs...) - }) -} - -// SetFwdFilter atomically sets the forwarding filter for the forwarding package -// identified by `height`. -func (c *OpenChannel) SetFwdFilter(height uint64, fwdFilter *PkgFilter) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.SetFwdFilter(tx, height, fwdFilter) - }) -} - -// RemoveFwdPkg atomically removes a forwarding package specified by the remote -// commitment height. -// -// NOTE: This method should only be called on packages marked FwdStateCompleted. -func (c *OpenChannel) RemoveFwdPkg(height uint64) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - return c.Packager.RemovePkg(tx, height) - }) -} - -// RevocationLogTail returns the "tail", or the end of the current revocation -// log. This entry represents the last previous state for the remote node's -// commitment chain. The ChannelDelta returned by this method will always lag -// one state behind the most current (unrevoked) state of the remote node's -// commitment chain. -func (c *OpenChannel) RevocationLogTail() (*ChannelCommitment, error) { - c.RLock() - defer c.RUnlock() - - // If we haven't created any state updates yet, then we'll exit early as - // there's nothing to be found on disk in the revocation bucket. - if c.RemoteCommitment.CommitHeight == 0 { - return nil, nil - } - - var commit ChannelCommitment - if err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket == nil { - return ErrNoPastDeltas - } - - // Once we have the bucket that stores the revocation log from - // this channel, we'll jump to the _last_ key in bucket. As we - // store the update number on disk in a big-endian format, - // this will retrieve the latest entry. - cursor := logBucket.Cursor() - _, tailLogEntry := cursor.Last() - logEntryReader := bytes.NewReader(tailLogEntry) - - // Once we have the entry, we'll decode it into the channel - // delta pointer we created above. - var dbErr error - commit, dbErr = deserializeChanCommit(logEntryReader) - if dbErr != nil { - return dbErr - } - - return nil - }); err != nil { - return nil, err - } - - return &commit, nil -} - -// CommitmentHeight returns the current commitment height. The commitment -// height represents the number of updates to the commitment state to date. -// This value is always monotonically increasing. This method is provided in -// order to allow multiple instances of a particular open channel to obtain a -// consistent view of the number of channel updates to date. -func (c *OpenChannel) CommitmentHeight() (uint64, error) { - c.RLock() - defer c.RUnlock() - - var height uint64 - err := c.Db.View(func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - commit, err := fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - - height = commit.CommitHeight - return nil - }) - if err != nil { - return 0, err - } - - return height, nil -} - -// FindPreviousState scans through the append-only log in an attempt to recover -// the previous channel state indicated by the update number. This method is -// intended to be used for obtaining the relevant data needed to claim all -// funds rightfully spendable in the case of an on-chain broadcast of the -// commitment transaction. -func (c *OpenChannel) FindPreviousState(updateNum uint64) (*ChannelCommitment, error) { - c.RLock() - defer c.RUnlock() - - var commit ChannelCommitment - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket == nil { - return ErrNoPastDeltas - } - - c, err := fetchChannelLogEntry(logBucket, updateNum) - if err != nil { - return err - } - - commit = c - return nil - }) - if err != nil { - return nil, err - } - - return &commit, nil -} - // ClosureType is an enum like structure that details exactly _how_ a channel // was closed. Three closure types are currently possible: none, cooperative, // local force close, remote force close, and (remote) breach. type ClosureType uint8 const ( - // CooperativeClose indicates that a channel has been closed - // cooperatively. This means that both channel peers were online and - // signed a new transaction paying out the settled balance of the - // contract. - CooperativeClose ClosureType = 0 - - // LocalForceClose indicates that we have unilaterally broadcast our - // current commitment state on-chain. - LocalForceClose ClosureType = 1 - // RemoteForceClose indicates that the remote peer has unilaterally // broadcast their current commitment state on-chain. RemoteForceClose ClosureType = 4 - - // BreachClose indicates that the remote peer attempted to broadcast a - // prior _revoked_ channel state. - BreachClose ClosureType = 2 - - // FundingCanceled indicates that the channel never was fully opened - // before it was marked as closed in the database. This can happen if - // we or the remote fail at some point during the opening workflow, or - // we timeout waiting for the funding transaction to be confirmed. - FundingCanceled ClosureType = 3 - - // Abandoned indicates that the channel state was removed without - // any further actions. This is intended to clean up unusable - // channels during development. - Abandoned ClosureType = 5 ) // ChannelCloseSummary contains the final state of a channel at the point it @@ -2160,214 +591,6 @@ type ChannelCloseSummary struct { LastChanSyncMsg *lnwire.ChannelReestablish } -// CloseChannel closes a previously active Lightning channel. Closing a channel -// entails deleting all saved state within the database concerning this -// channel. This method also takes a struct that summarizes the state of the -// channel at closing, this compact representation will be the only component -// of a channel left over after a full closing. -func (c *OpenChannel) CloseChannel(summary *ChannelCloseSummary) error { - c.Lock() - defer c.Unlock() - - return c.Db.Update(func(tx *bbolt.Tx) error { - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoChanDBExists - } - - nodePub := c.IdentityPub.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return ErrNoActiveChannels - } - - chainBucket := nodeChanBucket.Bucket(c.ChainHash[:]) - if chainBucket == nil { - return ErrNoActiveChannels - } - - var chanPointBuf bytes.Buffer - err := writeOutpoint(&chanPointBuf, &c.FundingOutpoint) - if err != nil { - return err - } - chanBucket := chainBucket.Bucket(chanPointBuf.Bytes()) - if chanBucket == nil { - return ErrNoActiveChannels - } - - // Before we delete the channel state, we'll read out the full - // details, as we'll also store portions of this information - // for record keeping. - chanState, err := fetchOpenChannel( - chanBucket, &c.FundingOutpoint, - ) - if err != nil { - return err - } - - // Now that the index to this channel has been deleted, purge - // the remaining channel metadata from the database. - err = deleteOpenChannel(chanBucket, chanPointBuf.Bytes()) - if err != nil { - return err - } - - // With the base channel data deleted, attempt to delete the - // information stored within the revocation log. - logBucket := chanBucket.Bucket(revocationLogBucket) - if logBucket != nil { - err = chanBucket.DeleteBucket(revocationLogBucket) - if err != nil { - return err - } - } - - err = chainBucket.DeleteBucket(chanPointBuf.Bytes()) - if err != nil { - return err - } - - // Finally, create a summary of this channel in the closed - // channel bucket for this node. - return putChannelCloseSummary( - tx, chanPointBuf.Bytes(), summary, chanState, - ) - }) -} - -// ChannelSnapshot is a frozen snapshot of the current channel state. A -// snapshot is detached from the original channel that generated it, providing -// read-only access to the current or prior state of an active channel. -// -// TODO(roasbeef): remove all together? pretty much just commitment -type ChannelSnapshot struct { - // RemoteIdentity is the identity public key of the remote node that we - // are maintaining the open channel with. - RemoteIdentity btcec.PublicKey - - // ChanPoint is the outpoint that created the channel. This output is - // found within the funding transaction and uniquely identified the - // channel on the resident chain. - ChannelPoint wire.OutPoint - - // ChainHash is the genesis hash of the chain that the channel resides - // within. - ChainHash chainhash.Hash - - // Capacity is the total capacity of the channel. - Capacity btcutil.Amount - - // TotalMSatSent is the total number of milli-satoshis we've sent - // within this channel. - TotalMSatSent lnwire.MilliSatoshi - - // TotalMSatReceived is the total number of milli-satoshis we've - // received within this channel. - TotalMSatReceived lnwire.MilliSatoshi - - // ChannelCommitment is the current up-to-date commitment for the - // target channel. - ChannelCommitment -} - -// Snapshot returns a read-only snapshot of the current channel state. This -// snapshot includes information concerning the current settled balance within -// the channel, metadata detailing total flows, and any outstanding HTLCs. -func (c *OpenChannel) Snapshot() *ChannelSnapshot { - c.RLock() - defer c.RUnlock() - - localCommit := c.LocalCommitment - snapshot := &ChannelSnapshot{ - RemoteIdentity: *c.IdentityPub, - ChannelPoint: c.FundingOutpoint, - Capacity: c.Capacity, - TotalMSatSent: c.TotalMSatSent, - TotalMSatReceived: c.TotalMSatReceived, - ChainHash: c.ChainHash, - ChannelCommitment: ChannelCommitment{ - LocalBalance: localCommit.LocalBalance, - RemoteBalance: localCommit.RemoteBalance, - CommitHeight: localCommit.CommitHeight, - CommitFee: localCommit.CommitFee, - }, - } - - // Copy over the current set of HTLCs to ensure the caller can't mutate - // our internal state. - snapshot.Htlcs = make([]HTLC, len(localCommit.Htlcs)) - for i, h := range localCommit.Htlcs { - snapshot.Htlcs[i] = h.Copy() - } - - return snapshot -} - -// LatestCommitments returns the two latest commitments for both the local and -// remote party. These commitments are read from disk to ensure that only the -// latest fully committed state is returned. The first commitment returned is -// the local commitment, and the second returned is the remote commitment. -func (c *OpenChannel) LatestCommitments() (*ChannelCommitment, *ChannelCommitment, error) { - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanCommitments(chanBucket, c) - }) - if err != nil { - return nil, nil, err - } - - return &c.LocalCommitment, &c.RemoteCommitment, nil -} - -// RemoteRevocationStore returns the most up to date commitment version of the -// revocation storage tree for the remote party. This method can be used when -// acting on a possible contract breach to ensure, that the caller has the most -// up to date information required to deliver justice. -func (c *OpenChannel) RemoteRevocationStore() (shachain.Store, error) { - err := c.Db.View(func(tx *bbolt.Tx) error { - chanBucket, err := fetchChanBucket( - tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash, - ) - if err != nil { - return err - } - - return fetchChanRevocationState(chanBucket, c) - }) - if err != nil { - return nil, err - } - - return c.RevocationStore, nil -} - -func putChannelCloseSummary(tx *bbolt.Tx, chanID []byte, - summary *ChannelCloseSummary, lastChanState *OpenChannel) error { - - closedChanBucket, err := tx.CreateBucketIfNotExists(closedChannelBucket) - if err != nil { - return err - } - - summary.RemoteCurrentRevocation = lastChanState.RemoteCurrentRevocation - summary.RemoteNextRevocation = lastChanState.RemoteNextRevocation - summary.LocalChanConfig = lastChanState.LocalChanCfg - - var b bytes.Buffer - if err := serializeChannelCloseSummary(&b, summary); err != nil { - return err - } - - return closedChanBucket.Put(chanID, b.Bytes()) -} - func serializeChannelCloseSummary(w io.Writer, cs *ChannelCloseSummary) error { err := WriteElements(w, cs.ChanPoint, cs.ShortChanID, cs.ChainHash, cs.ClosingTXID, @@ -2517,113 +740,6 @@ func writeChanConfig(b io.Writer, c *ChannelConfig) error { ) } -func putChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - var w bytes.Buffer - if err := WriteElements(&w, - channel.ChanType, channel.ChainHash, channel.FundingOutpoint, - channel.ShortChannelID, channel.IsPending, channel.IsInitiator, - channel.chanStatus, channel.FundingBroadcastHeight, - channel.NumConfsRequired, channel.ChannelFlags, - channel.IdentityPub, channel.Capacity, channel.TotalMSatSent, - channel.TotalMSatReceived, - ); err != nil { - return err - } - - // For single funder channels that we initiated, write the funding txn. - if channel.ChanType.IsSingleFunder() && channel.IsInitiator && - !channel.hasChanStatus(ChanStatusRestored) { - - if err := WriteElement(&w, channel.FundingTxn); err != nil { - return err - } - } - - if err := writeChanConfig(&w, &channel.LocalChanCfg); err != nil { - return err - } - if err := writeChanConfig(&w, &channel.RemoteChanCfg); err != nil { - return err - } - - return chanBucket.Put(chanInfoKey, w.Bytes()) -} - -func serializeChanCommit(w io.Writer, c *ChannelCommitment) error { - if err := WriteElements(w, - c.CommitHeight, c.LocalLogIndex, c.LocalHtlcIndex, - c.RemoteLogIndex, c.RemoteHtlcIndex, c.LocalBalance, - c.RemoteBalance, c.CommitFee, c.FeePerKw, c.CommitTx, - c.CommitSig, - ); err != nil { - return err - } - - return SerializeHtlcs(w, c.Htlcs...) -} - -func putChanCommitment(chanBucket *bbolt.Bucket, c *ChannelCommitment, - local bool) error { - - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - var b bytes.Buffer - if err := serializeChanCommit(&b, c); err != nil { - return err - } - - return chanBucket.Put(commitKey, b.Bytes()) -} - -func putChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - // If this is a restored channel, then we don't have any commitments to - // write. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - err := putChanCommitment( - chanBucket, &channel.LocalCommitment, true, - ) - if err != nil { - return err - } - - return putChanCommitment( - chanBucket, &channel.RemoteCommitment, false, - ) -} - -func putChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - - var b bytes.Buffer - err := WriteElements( - &b, channel.RemoteCurrentRevocation, channel.RevocationProducer, - channel.RevocationStore, - ) - if err != nil { - return err - } - - // TODO(roasbeef): don't keep producer on disk - - // If the next revocation is present, which is only the case after the - // FundingLocked message has been sent, then we'll write it to disk. - if channel.RemoteNextRevocation != nil { - err = WriteElements(&b, channel.RemoteNextRevocation) - if err != nil { - return err - } - } - - return chanBucket.Put(revocationStateKey, b.Bytes()) -} - func readChanConfig(b io.Reader, c *ChannelConfig) error { return ReadElements(b, &c.DustLimit, &c.MaxPendingAmount, &c.ChanReserve, @@ -2633,185 +749,3 @@ func readChanConfig(b io.Reader, c *ChannelConfig) error { &c.HtlcBasePoint, ) } - -func fetchChanInfo(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - infoBytes := chanBucket.Get(chanInfoKey) - if infoBytes == nil { - return ErrNoChanInfoFound - } - r := bytes.NewReader(infoBytes) - - if err := ReadElements(r, - &channel.ChanType, &channel.ChainHash, &channel.FundingOutpoint, - &channel.ShortChannelID, &channel.IsPending, &channel.IsInitiator, - &channel.chanStatus, &channel.FundingBroadcastHeight, - &channel.NumConfsRequired, &channel.ChannelFlags, - &channel.IdentityPub, &channel.Capacity, &channel.TotalMSatSent, - &channel.TotalMSatReceived, - ); err != nil { - return err - } - - // For single funder channels that we initiated, read the funding txn. - if channel.ChanType.IsSingleFunder() && channel.IsInitiator && - !channel.hasChanStatus(ChanStatusRestored) { - - if err := ReadElement(r, &channel.FundingTxn); err != nil { - return err - } - } - - if err := readChanConfig(r, &channel.LocalChanCfg); err != nil { - return err - } - if err := readChanConfig(r, &channel.RemoteChanCfg); err != nil { - return err - } - - channel.Packager = NewChannelPackager(channel.ShortChannelID) - - return nil -} - -func deserializeChanCommit(r io.Reader) (ChannelCommitment, error) { - var c ChannelCommitment - - err := ReadElements(r, - &c.CommitHeight, &c.LocalLogIndex, &c.LocalHtlcIndex, &c.RemoteLogIndex, - &c.RemoteHtlcIndex, &c.LocalBalance, &c.RemoteBalance, - &c.CommitFee, &c.FeePerKw, &c.CommitTx, &c.CommitSig, - ) - if err != nil { - return c, err - } - - c.Htlcs, err = DeserializeHtlcs(r) - if err != nil { - return c, err - } - - return c, nil -} - -func fetchChanCommitment(chanBucket *bbolt.Bucket, local bool) (ChannelCommitment, error) { - var commitKey []byte - if local { - commitKey = append(chanCommitmentKey, byte(0x00)) - } else { - commitKey = append(chanCommitmentKey, byte(0x01)) - } - - commitBytes := chanBucket.Get(commitKey) - if commitBytes == nil { - return ChannelCommitment{}, ErrNoCommitmentsFound - } - - r := bytes.NewReader(commitBytes) - return deserializeChanCommit(r) -} - -func fetchChanCommitments(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - var err error - - // If this is a restored channel, then we don't have any commitments to - // read. - if channel.hasChanStatus(ChanStatusRestored) { - return nil - } - - channel.LocalCommitment, err = fetchChanCommitment(chanBucket, true) - if err != nil { - return err - } - channel.RemoteCommitment, err = fetchChanCommitment(chanBucket, false) - if err != nil { - return err - } - - return nil -} - -func fetchChanRevocationState(chanBucket *bbolt.Bucket, channel *OpenChannel) error { - revBytes := chanBucket.Get(revocationStateKey) - if revBytes == nil { - return ErrNoRevocationsFound - } - r := bytes.NewReader(revBytes) - - err := ReadElements( - r, &channel.RemoteCurrentRevocation, &channel.RevocationProducer, - &channel.RevocationStore, - ) - if err != nil { - return err - } - - // If there aren't any bytes left in the buffer, then we don't yet have - // the next remote revocation, so we can exit early here. - if r.Len() == 0 { - return nil - } - - // Otherwise we'll read the next revocation for the remote party which - // is always the last item within the buffer. - return ReadElements(r, &channel.RemoteNextRevocation) -} - -func deleteOpenChannel(chanBucket *bbolt.Bucket, chanPointBytes []byte) error { - - if err := chanBucket.Delete(chanInfoKey); err != nil { - return err - } - - err := chanBucket.Delete(append(chanCommitmentKey, byte(0x00))) - if err != nil { - return err - } - err = chanBucket.Delete(append(chanCommitmentKey, byte(0x01))) - if err != nil { - return err - } - - if err := chanBucket.Delete(revocationStateKey); err != nil { - return err - } - - if diff := chanBucket.Get(commitDiffKey); diff != nil { - return chanBucket.Delete(commitDiffKey) - } - - return nil - -} - -// makeLogKey converts a uint64 into an 8 byte array. -func makeLogKey(updateNum uint64) [8]byte { - var key [8]byte - byteOrder.PutUint64(key[:], updateNum) - return key -} - -func appendChannelLogEntry(log *bbolt.Bucket, - commit *ChannelCommitment) error { - - var b bytes.Buffer - if err := serializeChanCommit(&b, commit); err != nil { - return err - } - - logEntrykey := makeLogKey(commit.CommitHeight) - return log.Put(logEntrykey[:], b.Bytes()) -} - -func fetchChannelLogEntry(log *bbolt.Bucket, - updateNum uint64) (ChannelCommitment, error) { - - logEntrykey := makeLogKey(updateNum) - commitBytes := log.Get(logEntrykey[:]) - if commitBytes == nil { - return ChannelCommitment{}, fmt.Errorf("log entry not found") - } - - commitReader := bytes.NewReader(commitBytes) - return deserializeChanCommit(commitReader) -} diff --git a/channeldb/migration_01_to_11/channel_cache.go b/channeldb/migration_01_to_11/channel_cache.go deleted file mode 100644 index 5d391e00..00000000 --- a/channeldb/migration_01_to_11/channel_cache.go +++ /dev/null @@ -1,50 +0,0 @@ -package migration_01_to_11 - -// channelCache is an in-memory cache used to improve the performance of -// ChanUpdatesInHorizon. It caches the chan info and edge policies for a -// particular channel. -type channelCache struct { - n int - channels map[uint64]ChannelEdge -} - -// newChannelCache creates a new channelCache with maximum capacity of n -// channels. -func newChannelCache(n int) *channelCache { - return &channelCache{ - n: n, - channels: make(map[uint64]ChannelEdge), - } -} - -// get returns the channel from the cache, if it exists. -func (c *channelCache) get(chanid uint64) (ChannelEdge, bool) { - channel, ok := c.channels[chanid] - return channel, ok -} - -// insert adds the entry to the channel cache. If an entry for chanid already -// exists, it will be replaced with the new entry. If the entry doesn't exist, -// it will be inserted to the cache, performing a random eviction if the cache -// is at capacity. -func (c *channelCache) insert(chanid uint64, channel ChannelEdge) { - // If entry exists, replace it. - if _, ok := c.channels[chanid]; ok { - c.channels[chanid] = channel - return - } - - // Otherwise, evict an entry at random and insert. - if len(c.channels) == c.n { - for id := range c.channels { - delete(c.channels, id) - break - } - } - c.channels[chanid] = channel -} - -// remove deletes an edge for chanid from the cache, if it exists. -func (c *channelCache) remove(chanid uint64) { - delete(c.channels, chanid) -} diff --git a/channeldb/migration_01_to_11/channel_cache_test.go b/channeldb/migration_01_to_11/channel_cache_test.go deleted file mode 100644 index b2929635..00000000 --- a/channeldb/migration_01_to_11/channel_cache_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package migration_01_to_11 - -import ( - "reflect" - "testing" -) - -// TestChannelCache checks the behavior of the channelCache with respect to -// insertion, eviction, and removal of cache entries. -func TestChannelCache(t *testing.T) { - const cacheSize = 100 - - // Create a new channel cache with the configured max size. - c := newChannelCache(cacheSize) - - // As a sanity check, assert that querying the empty cache does not - // return an entry. - _, ok := c.get(0) - if ok { - t.Fatalf("channel cache should be empty") - } - - // Now, fill up the cache entirely. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, channelForInt(i)) - } - - // Assert that the cache has all of the entries just inserted, since no - // eviction should occur until we try to surpass the max size. - assertHasChanEntries(t, c, 0, cacheSize) - - // Now, insert a new element that causes the cache to evict an element. - c.insert(cacheSize, channelForInt(cacheSize)) - - // Assert that the cache has this last entry, as the cache should evict - // some prior element and not the newly inserted one. - assertHasChanEntries(t, c, cacheSize, cacheSize) - - // Iterate over all inserted elements and construct a set of the evicted - // elements. - evicted := make(map[uint64]struct{}) - for i := uint64(0); i < cacheSize+1; i++ { - _, ok := c.get(i) - if !ok { - evicted[i] = struct{}{} - } - } - - // Assert that exactly one element has been evicted. - numEvicted := len(evicted) - if numEvicted != 1 { - t.Fatalf("expected one evicted entry, got: %d", numEvicted) - } - - // Remove the highest item which initially caused the eviction and - // reinsert the element that was evicted prior. - c.remove(cacheSize) - for i := range evicted { - c.insert(i, channelForInt(i)) - } - - // Since the removal created an extra slot, the last insertion should - // not have caused an eviction and the entries for all channels in the - // original set that filled the cache should be present. - assertHasChanEntries(t, c, 0, cacheSize) - - // Finally, reinsert the existing set back into the cache and test that - // the cache still has all the entries. If the randomized eviction were - // happening on inserts for existing cache items, we expect this to fail - // with high probability. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, channelForInt(i)) - } - assertHasChanEntries(t, c, 0, cacheSize) - -} - -// assertHasEntries queries the edge cache for all channels in the range [start, -// end), asserting that they exist and their value matches the entry produced by -// entryForInt. -func assertHasChanEntries(t *testing.T, c *channelCache, start, end uint64) { - t.Helper() - - for i := start; i < end; i++ { - entry, ok := c.get(i) - if !ok { - t.Fatalf("channel cache should contain chan %d", i) - } - - expEntry := channelForInt(i) - if !reflect.DeepEqual(entry, expEntry) { - t.Fatalf("entry mismatch, want: %v, got: %v", - expEntry, entry) - } - } -} - -// channelForInt generates a unique ChannelEdge given an integer. -func channelForInt(i uint64) ChannelEdge { - return ChannelEdge{ - Info: &ChannelEdgeInfo{ - ChannelID: i, - }, - } -} diff --git a/channeldb/migration_01_to_11/channel_test.go b/channeldb/migration_01_to_11/channel_test.go index 53fb39d7..1380828e 100644 --- a/channeldb/migration_01_to_11/channel_test.go +++ b/channeldb/migration_01_to_11/channel_test.go @@ -4,18 +4,13 @@ import ( "bytes" "io/ioutil" "math/rand" - "net" "os" - "reflect" - "runtime" - "testing" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" _ "github.com/btcsuite/btcwallet/walletdb/bdb" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/shachain" @@ -66,8 +61,6 @@ var ( LockTime: 5, } privKey, pubKey = btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - - wireSig, _ = lnwire.NewSigFromSignature(testSig) ) // makeTestDB creates a new instance of the ChannelDB for testing purposes. A @@ -223,819 +216,6 @@ func createTestChannelState(cdb *DB) (*OpenChannel, error) { RevocationProducer: producer, RevocationStore: store, Db: cdb, - Packager: NewChannelPackager(chanID), FundingTxn: testTx, }, nil } - -func TestOpenChannelPutGetDelete(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state, then add an additional fake HTLC - // before syncing to disk. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - state.LocalCommitment.Htlcs = []HTLC{ - { - Signature: testSig.Serialize(), - Incoming: true, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: []byte("onionblob"), - }, - } - state.RemoteCommitment.Htlcs = []HTLC{ - { - Signature: testSig.Serialize(), - Incoming: false, - Amt: 10, - RHash: key, - RefundTimeout: 1, - OnionBlob: []byte("onionblob"), - }, - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - openChannels, err := cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channel: %v", err) - } - - newState := openChannels[0] - - // The decoded channel state should be identical to what we stored - // above. - if !reflect.DeepEqual(state, newState) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(state), spew.Sdump(newState)) - } - - // We'll also test that the channel is properly able to hot swap the - // next revocation for the state machine. This tests the initial - // post-funding revocation exchange. - nextRevKey, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to create new private key: %v", err) - } - if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil { - t.Fatalf("unable to update revocation: %v", err) - } - - openChannels, err = cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channel: %v", err) - } - updatedChan := openChannels[0] - - // Ensure that the revocation was set properly. - if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) { - t.Fatalf("next revocation wasn't updated") - } - - // Finally to wrap up the test, delete the state of the channel within - // the database. This involves "closing" the channel which removes all - // written state, and creates a small "summary" elsewhere within the - // database. - closeSummary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - RemotePub: state.IdentityPub, - SettledBalance: btcutil.Amount(500), - TimeLockedBalance: btcutil.Amount(10000), - IsPending: false, - CloseType: CooperativeClose, - } - if err := state.CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - - // As the channel is now closed, attempting to fetch all open channels - // for our fake node ID should return an empty slice. - openChans, err := cdb.FetchOpenChannels(state.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch open channels: %v", err) - } - if len(openChans) != 0 { - t.Fatalf("all channels not deleted, found %v", len(openChans)) - } - - // Additionally, attempting to fetch all the open channels globally - // should yield no results. - openChans, err = cdb.FetchAllChannels() - if err != nil { - t.Fatal("unable to fetch all open chans") - } - if len(openChans) != 0 { - t.Fatalf("all channels not deleted, found %v", len(openChans)) - } -} - -func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) { - if !reflect.DeepEqual(a, b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: commitments don't match: %v vs %v", - line, spew.Sdump(a), spew.Sdump(b)) - } -} - -func TestChannelStateTransition(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a minimal channel, then perform a full sync in order to - // persist the data. - channel, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := channel.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Add some HTLCs which were added during this new state transition. - // Half of the HTLCs are incoming, while the other half are outgoing. - var ( - htlcs []HTLC - htlcAmt lnwire.MilliSatoshi - ) - for i := uint32(0); i < 10; i++ { - var incoming bool - if i > 5 { - incoming = true - } - htlc := HTLC{ - Signature: testSig.Serialize(), - Incoming: incoming, - Amt: 10, - RHash: key, - RefundTimeout: i, - OutputIndex: int32(i * 3), - LogIndex: uint64(i * 2), - HtlcIndex: uint64(i), - } - htlc.OnionBlob = make([]byte, 10) - copy(htlc.OnionBlob[:], bytes.Repeat([]byte{2}, 10)) - htlcs = append(htlcs, htlc) - htlcAmt += htlc.Amt - } - - // Create a new channel delta which includes the above HTLCs, some - // balance updates, and an increment of the current commitment height. - // Additionally, modify the signature and commitment transaction. - newSequence := uint32(129498) - newSig := bytes.Repeat([]byte{3}, 71) - newTx := channel.LocalCommitment.CommitTx.Copy() - newTx.TxIn[0].Sequence = newSequence - commitment := ChannelCommitment{ - CommitHeight: 1, - LocalLogIndex: 2, - LocalHtlcIndex: 1, - RemoteLogIndex: 2, - RemoteHtlcIndex: 1, - LocalBalance: lnwire.MilliSatoshi(1e8), - RemoteBalance: lnwire.MilliSatoshi(1e8), - CommitFee: 55, - FeePerKw: 99, - CommitTx: newTx, - CommitSig: newSig, - Htlcs: htlcs, - } - - // First update the local node's broadcastable state and also add a - // CommitDiff remote node's as well in order to simulate a proper state - // transition. - if err := channel.UpdateCommitment(&commitment); err != nil { - t.Fatalf("unable to update commitment: %v", err) - } - - // The balances, new update, the HTLCs and the changes to the fake - // commitment transaction along with the modified signature should all - // have been updated. - updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channel: %v", err) - } - assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment) - numDiskUpdates, err := updatedChannel[0].CommitmentHeight() - if err != nil { - t.Fatalf("unable to read commitment height from disk: %v", err) - } - if numDiskUpdates != uint64(commitment.CommitHeight) { - t.Fatalf("num disk updates doesn't match: %v vs %v", - numDiskUpdates, commitment.CommitHeight) - } - - // Attempting to query for a commitment diff should return - // ErrNoPendingCommit as we haven't yet created a new state for them. - _, err = channel.RemoteCommitChainTip() - if err != ErrNoPendingCommit { - t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) - } - - // To simulate us extending a new state to the remote party, we'll also - // create a new commit diff for them. - remoteCommit := commitment - remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8) - remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8) - remoteCommit.CommitHeight = 1 - commitDiff := &CommitDiff{ - Commitment: remoteCommit, - CommitSig: &lnwire.CommitSig{ - ChanID: lnwire.ChannelID(key), - CommitSig: wireSig, - HtlcSigs: []lnwire.Sig{ - wireSig, - wireSig, - }, - }, - LogUpdates: []LogUpdate{ - { - LogIndex: 1, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 1, - Amount: lnwire.NewMSatFromSatoshis(100), - Expiry: 25, - }, - }, - { - LogIndex: 2, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ID: 2, - Amount: lnwire.NewMSatFromSatoshis(200), - Expiry: 50, - }, - }, - }, - OpenedCircuitKeys: []CircuitKey{}, - ClosedCircuitKeys: []CircuitKey{}, - } - copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], - bytes.Repeat([]byte{1}, 32)) - copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:], - bytes.Repeat([]byte{2}, 32)) - if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { - t.Fatalf("unable to add to commit chain: %v", err) - } - - // The commitment tip should now match the commitment that we just - // inserted. - diskCommitDiff, err := channel.RemoteCommitChainTip() - if err != nil { - t.Fatalf("unable to fetch commit diff: %v", err) - } - if !reflect.DeepEqual(commitDiff, diskCommitDiff) { - t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit), - spew.Sdump(diskCommitDiff)) - } - - // We'll save the old remote commitment as this will be added to the - // revocation log shortly. - oldRemoteCommit := channel.RemoteCommitment - - // Next, write to the log which tracks the necessary revocation state - // needed to rectify any fishy behavior by the remote party. Modify the - // current uncollapsed revocation state to simulate a state transition - // by the remote party. - channel.RemoteCurrentRevocation = channel.RemoteNextRevocation - newPriv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to generate key: %v", err) - } - channel.RemoteNextRevocation = newPriv.PubKey() - - fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, - diskCommitDiff.LogUpdates, nil) - - err = channel.AdvanceCommitChainTail(fwdPkg) - if err != nil { - t.Fatalf("unable to append to revocation log: %v", err) - } - - // At this point, the remote commit chain should be nil, and the posted - // remote commitment should match the one we added as a diff above. - if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit { - t.Fatalf("expected ErrNoPendingCommit, instead got %v", err) - } - - // We should be able to fetch the channel delta created above by its - // update number with all the state properly reconstructed. - diskPrevCommit, err := channel.FindPreviousState( - oldRemoteCommit.CommitHeight, - ) - if err != nil { - t.Fatalf("unable to fetch past delta: %v", err) - } - - // The two deltas (the original vs the on-disk version) should - // identical, and all HTLC data should properly be retained. - assertCommitmentEqual(t, &oldRemoteCommit, diskPrevCommit) - - // The state number recovered from the tail of the revocation log - // should be identical to this current state. - logTail, err := channel.RevocationLogTail() - if err != nil { - t.Fatalf("unable to retrieve log: %v", err) - } - if logTail.CommitHeight != oldRemoteCommit.CommitHeight { - t.Fatal("update number doesn't match") - } - - oldRemoteCommit = channel.RemoteCommitment - - // Next modify the posted diff commitment slightly, then create a new - // commitment diff and advance the tail. - commitDiff.Commitment.CommitHeight = 2 - commitDiff.Commitment.LocalBalance -= htlcAmt - commitDiff.Commitment.RemoteBalance += htlcAmt - commitDiff.LogUpdates = []LogUpdate{} - if err := channel.AppendRemoteCommitChain(commitDiff); err != nil { - t.Fatalf("unable to add to commit chain: %v", err) - } - - fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil) - - err = channel.AdvanceCommitChainTail(fwdPkg) - if err != nil { - t.Fatalf("unable to append to revocation log: %v", err) - } - - // Once again, fetch the state and ensure it has been properly updated. - prevCommit, err := channel.FindPreviousState(oldRemoteCommit.CommitHeight) - if err != nil { - t.Fatalf("unable to fetch past delta: %v", err) - } - assertCommitmentEqual(t, &oldRemoteCommit, prevCommit) - - // Once again, state number recovered from the tail of the revocation - // log should be identical to this current state. - logTail, err = channel.RevocationLogTail() - if err != nil { - t.Fatalf("unable to retrieve log: %v", err) - } - if logTail.CommitHeight != oldRemoteCommit.CommitHeight { - t.Fatal("update number doesn't match") - } - - // The revocation state stored on-disk should now also be identical. - updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channel: %v", err) - } - if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) { - t.Fatalf("revocation state was not synced") - } - if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) { - t.Fatalf("revocation state was not synced") - } - - // Now attempt to delete the channel from the database. - closeSummary := &ChannelCloseSummary{ - ChanPoint: channel.FundingOutpoint, - RemotePub: channel.IdentityPub, - SettledBalance: btcutil.Amount(500), - TimeLockedBalance: btcutil.Amount(10000), - IsPending: false, - CloseType: RemoteForceClose, - } - if err := updatedChannel[0].CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to delete updated channel: %v", err) - } - - // If we attempt to fetch the target channel again, it shouldn't be - // found. - channels, err := cdb.FetchOpenChannels(channel.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch updated channels: %v", err) - } - if len(channels) != 0 { - t.Fatalf("%v channels, found, but none should be", - len(channels)) - } - - // Attempting to find previous states on the channel should fail as the - // revocation log has been deleted. - _, err = updatedChannel[0].FindPreviousState(oldRemoteCommit.CommitHeight) - if err == nil { - t.Fatal("revocation log search should have failed") - } -} - -func TestFetchPendingChannels(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create first test channel state - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - pendingChannels, err := cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to list pending channels: %v", err) - } - - if len(pendingChannels) != 1 { - t.Fatalf("incorrect number of pending channels: expecting %v,"+ - "got %v", 1, len(pendingChannels)) - } - - // The broadcast height of the pending channel should have been set - // properly. - if pendingChannels[0].FundingBroadcastHeight != broadcastHeight { - t.Fatalf("broadcast height mismatch: expected %v, got %v", - pendingChannels[0].FundingBroadcastHeight, - broadcastHeight) - } - - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 5, - TxIndex: 10, - TxPosition: 15, - } - err = pendingChannels[0].MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - if pendingChannels[0].IsPending { - t.Fatalf("channel marked open should no longer be pending") - } - - if pendingChannels[0].ShortChanID() != chanOpenLoc { - t.Fatalf("channel opening height not updated: expected %v, "+ - "got %v", spew.Sdump(pendingChannels[0].ShortChanID()), - chanOpenLoc) - } - - // Next, we'll re-fetch the channel to ensure that the open height was - // properly set. - openChans, err := cdb.FetchAllChannels() - if err != nil { - t.Fatalf("unable to fetch channels: %v", err) - } - if openChans[0].ShortChanID() != chanOpenLoc { - t.Fatalf("channel opening heights don't match: expected %v, "+ - "got %v", spew.Sdump(openChans[0].ShortChanID()), - chanOpenLoc) - } - if openChans[0].FundingBroadcastHeight != broadcastHeight { - t.Fatalf("broadcast height mismatch: expected %v, got %v", - openChans[0].FundingBroadcastHeight, - broadcastHeight) - } - - pendingChannels, err = cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to list pending channels: %v", err) - } - - if len(pendingChannels) != 0 { - t.Fatalf("incorrect number of pending channels: expecting %v,"+ - "got %v", 0, len(pendingChannels)) - } -} - -func TestFetchClosedChannels(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a test channel, that we'll be closing within this pull - // request. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Next sync the channel to disk, marking it as being in a pending open - // state. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Next, simulate the confirmation of the channel by marking it as - // pending within the database. - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 5, - TxIndex: 10, - TxPosition: 15, - } - err = state.MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - // Next, close the channel by including a close channel summary in the - // database. - summary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - ClosingTXID: rev, - RemotePub: state.IdentityPub, - Capacity: state.Capacity, - SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(), - TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000, - CloseType: RemoteForceClose, - IsPending: true, - LocalChanConfig: state.LocalChanCfg, - } - if err := state.CloseChannel(summary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - - // Query the database to ensure that the channel has now been properly - // closed. We should get the same result whether querying for pending - // channels only, or not. - pendingClosed, err := cdb.FetchClosedChannels(true) - if err != nil { - t.Fatalf("failed fetching closed channels: %v", err) - } - if len(pendingClosed) != 1 { - t.Fatalf("incorrect number of pending closed channels: expecting %v,"+ - "got %v", 1, len(pendingClosed)) - } - if !reflect.DeepEqual(summary, pendingClosed[0]) { - t.Fatalf("database summaries don't match: expected %v got %v", - spew.Sdump(summary), spew.Sdump(pendingClosed[0])) - } - closed, err := cdb.FetchClosedChannels(false) - if err != nil { - t.Fatalf("failed fetching all closed channels: %v", err) - } - if len(closed) != 1 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 1, len(closed)) - } - if !reflect.DeepEqual(summary, closed[0]) { - t.Fatalf("database summaries don't match: expected %v got %v", - spew.Sdump(summary), spew.Sdump(closed[0])) - } - - // Mark the channel as fully closed. - err = cdb.MarkChanFullyClosed(&state.FundingOutpoint) - if err != nil { - t.Fatalf("failed fully closing channel: %v", err) - } - - // The channel should no longer be considered pending, but should still - // be retrieved when fetching all the closed channels. - closed, err = cdb.FetchClosedChannels(false) - if err != nil { - t.Fatalf("failed fetching closed channels: %v", err) - } - if len(closed) != 1 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 1, len(closed)) - } - pendingClose, err := cdb.FetchClosedChannels(true) - if err != nil { - t.Fatalf("failed fetching channels pending close: %v", err) - } - if len(pendingClose) != 0 { - t.Fatalf("incorrect number of closed channels: expecting %v, "+ - "got %v", 0, len(closed)) - } -} - -// TestFetchWaitingCloseChannels ensures that the correct channels that are -// waiting to be closed are returned. -func TestFetchWaitingCloseChannels(t *testing.T) { - t.Parallel() - - const numChannels = 2 - const broadcastHeight = 99 - addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 18555} - - // We'll start by creating two channels within our test database. One of - // them will have their funding transaction confirmed on-chain, while - // the other one will remain unconfirmed. - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - channels := make([]*OpenChannel, numChannels) - for i := 0; i < numChannels; i++ { - channel, err := createTestChannelState(db) - if err != nil { - t.Fatalf("unable to create channel: %v", err) - } - err = channel.SyncPending(addr, broadcastHeight) - if err != nil { - t.Fatalf("unable to sync channel: %v", err) - } - channels[i] = channel - } - - // We'll only confirm the first one. - channelConf := lnwire.ShortChannelID{ - BlockHeight: broadcastHeight + 1, - TxIndex: 10, - TxPosition: 15, - } - if err := channels[0].MarkAsOpen(channelConf); err != nil { - t.Fatalf("unable to mark channel as open: %v", err) - } - - // Then, we'll mark the channels as if their commitments were broadcast. - // This would happen in the event of a force close and should make the - // channels enter a state of waiting close. - for _, channel := range channels { - closeTx := wire.NewMsgTx(2) - closeTx.AddTxIn( - &wire.TxIn{ - PreviousOutPoint: channel.FundingOutpoint, - }, - ) - if err := channel.MarkCommitmentBroadcasted(closeTx); err != nil { - t.Fatalf("unable to mark commitment broadcast: %v", err) - } - } - - // Now, we'll fetch all the channels waiting to be closed from the - // database. We should expect to see both channels above, even if any of - // them haven't had their funding transaction confirm on-chain. - waitingCloseChannels, err := db.FetchWaitingCloseChannels() - if err != nil { - t.Fatalf("unable to fetch all waiting close channels: %v", err) - } - if len(waitingCloseChannels) != 2 { - t.Fatalf("expected %d channels waiting to be closed, got %d", 2, - len(waitingCloseChannels)) - } - expectedChannels := make(map[wire.OutPoint]struct{}) - for _, channel := range channels { - expectedChannels[channel.FundingOutpoint] = struct{}{} - } - for _, channel := range waitingCloseChannels { - if _, ok := expectedChannels[channel.FundingOutpoint]; !ok { - t.Fatalf("expected channel %v to be waiting close", - channel.FundingOutpoint) - } - - // Finally, make sure we can retrieve the closing tx for the - // channel. - closeTx, err := channel.BroadcastedCommitment() - if err != nil { - t.Fatalf("Unable to retrieve commitment: %v", err) - } - - if closeTx.TxIn[0].PreviousOutPoint != channel.FundingOutpoint { - t.Fatalf("expected outpoint %v, got %v", - channel.FundingOutpoint, - closeTx.TxIn[0].PreviousOutPoint) - } - } -} - -// TestRefreshShortChanID asserts that RefreshShortChanID updates the in-memory -// short channel ID of another OpenChannel to reflect a preceding call to -// MarkOpen on a different OpenChannel. -func TestRefreshShortChanID(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First create a test channel. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - - // Mark the channel as pending within the channeldb. - const broadcastHeight = 99 - if err := state.SyncPending(addr, broadcastHeight); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - - // Next, locate the pending channel with the database. - pendingChannels, err := cdb.FetchPendingChannels() - if err != nil { - t.Fatalf("unable to load pending channels; %v", err) - } - - var pendingChannel *OpenChannel - for _, channel := range pendingChannels { - if channel.FundingOutpoint == state.FundingOutpoint { - pendingChannel = channel - break - } - } - if pendingChannel == nil { - t.Fatalf("unable to find pending channel with funding "+ - "outpoint=%v: %v", state.FundingOutpoint, err) - } - - // Next, simulate the confirmation of the channel by marking it as - // pending within the database. - chanOpenLoc := lnwire.ShortChannelID{ - BlockHeight: 105, - TxIndex: 10, - TxPosition: 15, - } - - err = state.MarkAsOpen(chanOpenLoc) - if err != nil { - t.Fatalf("unable to mark channel open: %v", err) - } - - // The short_chan_id of the receiver to MarkAsOpen should reflect the - // open location, but the other pending channel should remain unchanged. - if state.ShortChanID() == pendingChannel.ShortChanID() { - t.Fatalf("pending channel short_chan_ID should not have been " + - "updated before refreshing short_chan_id") - } - - // Now that the receiver's short channel id has been updated, check to - // ensure that the channel packager's source has been updated as well. - // This ensures that the packager will read and write to buckets - // corresponding to the new short chan id, instead of the prior. - if state.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - state.Packager.(*ChannelPackager).source) - } - - // Now, refresh the short channel ID of the pending channel. - err = pendingChannel.RefreshShortChanID() - if err != nil { - t.Fatalf("unable to refresh short_chan_id: %v", err) - } - - // This should result in both OpenChannel's now having the same - // ShortChanID. - if state.ShortChanID() != pendingChannel.ShortChanID() { - t.Fatalf("expected pending channel short_chan_id to be "+ - "refreshed: want %v, got %v", state.ShortChanID(), - pendingChannel.ShortChanID()) - } - - // Check to ensure that the _other_ OpenChannel channel packager's - // source has also been updated after the refresh. This ensures that the - // other packagers will read and write to buckets corresponding to the - // updated short chan id. - if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc { - t.Fatalf("channel packager source was not updated: want %v, "+ - "got %v", chanOpenLoc, - pendingChannel.Packager.(*ChannelPackager).source) - } -} diff --git a/channeldb/migration_01_to_11/codec.go b/channeldb/migration_01_to_11/codec.go index cfef35e0..1727c8c9 100644 --- a/channeldb/migration_01_to_11/codec.go +++ b/channeldb/migration_01_to_11/codec.go @@ -48,12 +48,6 @@ type UnknownElementType struct { element interface{} } -// NewUnknownElementType creates a new UnknownElementType error from the passed -// method name and element. -func NewUnknownElementType(method string, el interface{}) UnknownElementType { - return UnknownElementType{method: method, element: el} -} - // Error returns the name of the method that encountered the error, as well as // the type that was unsupported. func (e UnknownElementType) Error() string { diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index e1057d65..623b33bc 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -4,16 +4,11 @@ import ( "bytes" "encoding/binary" "fmt" - "net" "os" "path/filepath" "time" - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" "github.com/coreos/bbolt" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" ) const ( @@ -87,57 +82,6 @@ func Open(dbPath string, modifiers ...OptionModifier) (*DB, error) { return chanDB, nil } -// Path returns the file path to the channel database. -func (d *DB) Path() string { - return d.dbPath -} - -// Wipe completely deletes all saved state within all used buckets within the -// database. The deletion is done in a single transaction, therefore this -// operation is fully atomic. -func (d *DB) Wipe() error { - return d.Update(func(tx *bbolt.Tx) error { - err := tx.DeleteBucket(openChannelBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(closedChannelBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(invoiceBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(nodeInfoBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - err = tx.DeleteBucket(nodeBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(edgeBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(edgeIndexBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - err = tx.DeleteBucket(graphMetaBucket) - if err != nil && err != bbolt.ErrBucketNotFound { - return err - } - - return nil - }) -} - // createChannelDB creates and initializes a fresh version of channeldb. In // the case that the target path has not yet been created or doesn't yet exist, // then the path is created. Additionally, all required top-level buckets used @@ -163,14 +107,6 @@ func createChannelDB(dbPath string) error { return err } - if _, err := tx.CreateBucket(forwardingLogBucket); err != nil { - return err - } - - if _, err := tx.CreateBucket(fwdPackagesKey); err != nil { - return err - } - if _, err := tx.CreateBucket(invoiceBucket); err != nil { return err } @@ -179,10 +115,6 @@ func createChannelDB(dbPath string) error { return err } - if _, err := tx.CreateBucket(nodeInfoBucket); err != nil { - return err - } - nodes, err := tx.CreateBucket(nodeBucket) if err != nil { return err @@ -249,359 +181,6 @@ func fileExists(path string) bool { return true } -// FetchOpenChannels starts a new database transaction and returns all stored -// currently active/open channels associated with the target nodeID. In the case -// that no active channels are known to have been created with this node, then a -// zero-length slice is returned. -func (d *DB) FetchOpenChannels(nodeID *btcec.PublicKey) ([]*OpenChannel, error) { - var channels []*OpenChannel - err := d.View(func(tx *bbolt.Tx) error { - var err error - channels, err = d.fetchOpenChannels(tx, nodeID) - return err - }) - - return channels, err -} - -// fetchOpenChannels uses and existing database transaction and returns all -// stored currently active/open channels associated with the target nodeID. In -// the case that no active channels are known to have been created with this -// node, then a zero-length slice is returned. -func (d *DB) fetchOpenChannels(tx *bbolt.Tx, - nodeID *btcec.PublicKey) ([]*OpenChannel, error) { - - // Get the bucket dedicated to storing the metadata for open channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return nil, nil - } - - // Within this top level bucket, fetch the bucket dedicated to storing - // open channel data specific to the remote node. - pub := nodeID.SerializeCompressed() - nodeChanBucket := openChanBucket.Bucket(pub) - if nodeChanBucket == nil { - return nil, nil - } - - // Next, we'll need to go down an additional layer in order to retrieve - // the channels for each chain the node knows of. - var channels []*OpenChannel - err := nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, then we'll retrieve - // that so we can extract all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read bucket for chain=%x", - chainHash[:]) - } - - // Finally, we both of the necessary buckets retrieved, fetch - // all the active channels related to this node. - nodeChannels, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read channel for "+ - "chain_hash=%x, node_key=%x: %v", - chainHash[:], pub, err) - } - - channels = append(channels, nodeChannels...) - return nil - }) - - return channels, err -} - -// fetchNodeChannels retrieves all active channels from the target chainBucket -// which is under a node's dedicated channel bucket. This function is typically -// used to fetch all the active channels related to a particular node. -func (d *DB) fetchNodeChannels(chainBucket *bbolt.Bucket) ([]*OpenChannel, error) { - - var channels []*OpenChannel - - // A node may have channels on several chains, so for each known chain, - // we'll extract all the channels. - err := chainBucket.ForEach(func(chanPoint, v []byte) error { - // If there's a value, it's not a bucket so ignore it. - if v != nil { - return nil - } - - // Once we've found a valid channel bucket, we'll extract it - // from the node's chain bucket. - chanBucket := chainBucket.Bucket(chanPoint) - - var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(chanPoint), &outPoint) - if err != nil { - return err - } - oChannel, err := fetchOpenChannel(chanBucket, &outPoint) - if err != nil { - return fmt.Errorf("unable to read channel data for "+ - "chan_point=%v: %v", outPoint, err) - } - oChannel.Db = d - - channels = append(channels, oChannel) - - return nil - }) - if err != nil { - return nil, err - } - - return channels, nil -} - -// FetchChannel attempts to locate a channel specified by the passed channel -// point. If the channel cannot be found, then an error will be returned. -func (d *DB) FetchChannel(chanPoint wire.OutPoint) (*OpenChannel, error) { - var ( - targetChan *OpenChannel - targetChanPoint bytes.Buffer - ) - - if err := writeOutpoint(&targetChanPoint, &chanPoint); err != nil { - return nil, err - } - - // chanScan will traverse the following bucket structure: - // * nodePub => chainHash => chanPoint - // - // At each level we go one further, ensuring that we're traversing the - // proper key (that's actually a bucket). By only reading the bucket - // structure and skipping fully decoding each channel, we save a good - // bit of CPU as we don't need to do things like decompress public - // keys. - chanScan := func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Within the node channel bucket, are the set of node pubkeys - // we have channels with, we don't know the entire set, so - // we'll check them all. - return openChanBucket.ForEach(func(nodePub, v []byte) error { - // Ensure that this is a key the same size as a pubkey, - // and also that it leads directly to a bucket. - if len(nodePub) != 33 || v != nil { - return nil - } - - nodeChanBucket := openChanBucket.Bucket(nodePub) - if nodeChanBucket == nil { - return nil - } - - // The next layer down is all the chains that this node - // has channels on with us. - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) - } - - // Finally we reach the leaf bucket that stores - // all the chanPoints for this node. - chanBucket := chainBucket.Bucket( - targetChanPoint.Bytes(), - ) - if chanBucket == nil { - return nil - } - - channel, err := fetchOpenChannel( - chanBucket, &chanPoint, - ) - if err != nil { - return err - } - - targetChan = channel - targetChan.Db = d - - return nil - }) - }) - } - - err := d.View(chanScan) - if err != nil { - return nil, err - } - - if targetChan != nil { - return targetChan, nil - } - - // If we can't find the channel, then we return with an error, as we - // have nothing to backup. - return nil, ErrChannelNotFound -} - -// FetchAllChannels attempts to retrieve all open channels currently stored -// within the database, including pending open, fully open and channels waiting -// for a closing transaction to confirm. -func (d *DB) FetchAllChannels() ([]*OpenChannel, error) { - var channels []*OpenChannel - - // TODO(halseth): fetch all in one db tx. - openChannels, err := d.FetchAllOpenChannels() - if err != nil { - return nil, err - } - channels = append(channels, openChannels...) - - pendingChannels, err := d.FetchPendingChannels() - if err != nil { - return nil, err - } - channels = append(channels, pendingChannels...) - - waitingClose, err := d.FetchWaitingCloseChannels() - if err != nil { - return nil, err - } - channels = append(channels, waitingClose...) - - return channels, nil -} - -// FetchAllOpenChannels will return all channels that have the funding -// transaction confirmed, and is not waiting for a closing transaction to be -// confirmed. -func (d *DB) FetchAllOpenChannels() ([]*OpenChannel, error) { - return fetchChannels(d, false, false) -} - -// FetchPendingChannels will return channels that have completed the process of -// generating and broadcasting funding transactions, but whose funding -// transactions have yet to be confirmed on the blockchain. -func (d *DB) FetchPendingChannels() ([]*OpenChannel, error) { - return fetchChannels(d, true, false) -} - -// FetchWaitingCloseChannels will return all channels that have been opened, -// but are now waiting for a closing transaction to be confirmed. -// -// NOTE: This includes channels that are also pending to be opened. -func (d *DB) FetchWaitingCloseChannels() ([]*OpenChannel, error) { - waitingClose, err := fetchChannels(d, false, true) - if err != nil { - return nil, err - } - pendingWaitingClose, err := fetchChannels(d, true, true) - if err != nil { - return nil, err - } - - return append(waitingClose, pendingWaitingClose...), nil -} - -// fetchChannels attempts to retrieve channels currently stored in the -// database. The pending parameter determines whether only pending channels -// will be returned, or only open channels will be returned. The waitingClose -// parameter determines whether only channels waiting for a closing transaction -// to be confirmed should be returned. If no active channels exist within the -// network, then ErrNoActiveChannels is returned. -func fetchChannels(d *DB, pending, waitingClose bool) ([]*OpenChannel, error) { - var channels []*OpenChannel - - err := d.View(func(tx *bbolt.Tx) error { - // Get the bucket dedicated to storing the metadata for open - // channels. - openChanBucket := tx.Bucket(openChannelBucket) - if openChanBucket == nil { - return ErrNoActiveChannels - } - - // Next, fetch the bucket dedicated to storing metadata related - // to all nodes. All keys within this bucket are the serialized - // public keys of all our direct counterparties. - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return fmt.Errorf("node bucket not created") - } - - // Finally for each node public key in the bucket, fetch all - // the channels related to this particular node. - return nodeMetaBucket.ForEach(func(k, v []byte) error { - nodeChanBucket := openChanBucket.Bucket(k) - if nodeChanBucket == nil { - return nil - } - - return nodeChanBucket.ForEach(func(chainHash, v []byte) error { - // If there's a value, it's not a bucket so - // ignore it. - if v != nil { - return nil - } - - // If we've found a valid chainhash bucket, - // then we'll retrieve that so we can extract - // all the channels. - chainBucket := nodeChanBucket.Bucket(chainHash) - if chainBucket == nil { - return fmt.Errorf("unable to read "+ - "bucket for chain=%x", chainHash[:]) - } - - nodeChans, err := d.fetchNodeChannels(chainBucket) - if err != nil { - return fmt.Errorf("unable to read "+ - "channel for chain_hash=%x, "+ - "node_key=%x: %v", chainHash[:], k, err) - } - for _, channel := range nodeChans { - if channel.IsPending != pending { - continue - } - - // If the channel is in any other state - // than Default, then it means it is - // waiting to be closed. - channelWaitingClose := - channel.ChanStatus() != ChanStatusDefault - - // Only include it if we requested - // channels with the same waitingClose - // status. - if channelWaitingClose != waitingClose { - continue - } - - channels = append(channels, channel) - } - return nil - }) - - }) - }) - if err != nil { - return nil, err - } - - return channels, nil -} - // FetchClosedChannels attempts to fetch all closed channels from the database. // The pendingOnly bool toggles if channels that aren't yet fully closed should // be returned in the response or not. When a channel was cooperatively closed, @@ -641,371 +220,6 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro return chanSummaries, nil } -// ErrClosedChannelNotFound signals that a closed channel could not be found in -// the channeldb. -var ErrClosedChannelNotFound = errors.New("unable to find closed channel summary") - -// FetchClosedChannel queries for a channel close summary using the channel -// point of the channel in question. -func (d *DB) FetchClosedChannel(chanID *wire.OutPoint) (*ChannelCloseSummary, error) { - var chanSummary *ChannelCloseSummary - if err := d.View(func(tx *bbolt.Tx) error { - closeBucket := tx.Bucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - var b bytes.Buffer - var err error - if err = writeOutpoint(&b, chanID); err != nil { - return err - } - - summaryBytes := closeBucket.Get(b.Bytes()) - if summaryBytes == nil { - return ErrClosedChannelNotFound - } - - summaryReader := bytes.NewReader(summaryBytes) - chanSummary, err = deserializeCloseChannelSummary(summaryReader) - - return err - }); err != nil { - return nil, err - } - - return chanSummary, nil -} - -// FetchClosedChannelForID queries for a channel close summary using the -// channel ID of the channel in question. -func (d *DB) FetchClosedChannelForID(cid lnwire.ChannelID) ( - *ChannelCloseSummary, error) { - - var chanSummary *ChannelCloseSummary - if err := d.View(func(tx *bbolt.Tx) error { - closeBucket := tx.Bucket(closedChannelBucket) - if closeBucket == nil { - return ErrClosedChannelNotFound - } - - // The first 30 bytes of the channel ID and outpoint will be - // equal. - cursor := closeBucket.Cursor() - op, c := cursor.Seek(cid[:30]) - - // We scan over all possible candidates for this channel ID. - for ; op != nil && bytes.Compare(cid[:30], op[:30]) <= 0; op, c = cursor.Next() { - var outPoint wire.OutPoint - err := readOutpoint(bytes.NewReader(op), &outPoint) - if err != nil { - return err - } - - // If the found outpoint does not correspond to this - // channel ID, we continue. - if !cid.IsChanPoint(&outPoint) { - continue - } - - // Deserialize the close summary and return. - r := bytes.NewReader(c) - chanSummary, err = deserializeCloseChannelSummary(r) - if err != nil { - return err - } - - return nil - } - return ErrClosedChannelNotFound - }); err != nil { - return nil, err - } - - return chanSummary, nil -} - -// MarkChanFullyClosed marks a channel as fully closed within the database. A -// channel should be marked as fully closed if the channel was initially -// cooperatively closed and it's reached a single confirmation, or after all -// the pending funds in a channel that has been forcibly closed have been -// swept. -func (d *DB) MarkChanFullyClosed(chanPoint *wire.OutPoint) error { - return d.Update(func(tx *bbolt.Tx) error { - var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { - return err - } - - chanID := b.Bytes() - - closedChanBucket, err := tx.CreateBucketIfNotExists( - closedChannelBucket, - ) - if err != nil { - return err - } - - chanSummaryBytes := closedChanBucket.Get(chanID) - if chanSummaryBytes == nil { - return fmt.Errorf("no closed channel for "+ - "chan_point=%v found", chanPoint) - } - - chanSummaryReader := bytes.NewReader(chanSummaryBytes) - chanSummary, err := deserializeCloseChannelSummary( - chanSummaryReader, - ) - if err != nil { - return err - } - - chanSummary.IsPending = false - - var newSummary bytes.Buffer - err = serializeChannelCloseSummary(&newSummary, chanSummary) - if err != nil { - return err - } - - err = closedChanBucket.Put(chanID, newSummary.Bytes()) - if err != nil { - return err - } - - // Now that the channel is closed, we'll check if we have any - // other open channels with this peer. If we don't we'll - // garbage collect it to ensure we don't establish persistent - // connections to peers without open channels. - return d.pruneLinkNode(tx, chanSummary.RemotePub) - }) -} - -// pruneLinkNode determines whether we should garbage collect a link node from -// the database due to no longer having any open channels with it. If there are -// any left, then this acts as a no-op. -func (d *DB) pruneLinkNode(tx *bbolt.Tx, remotePub *btcec.PublicKey) error { - openChannels, err := d.fetchOpenChannels(tx, remotePub) - if err != nil { - return fmt.Errorf("unable to fetch open channels for peer %x: "+ - "%v", remotePub.SerializeCompressed(), err) - } - - if len(openChannels) > 0 { - return nil - } - - log.Infof("Pruning link node %x with zero open channels from database", - remotePub.SerializeCompressed()) - - return d.deleteLinkNode(tx, remotePub) -} - -// PruneLinkNodes attempts to prune all link nodes found within the databse with -// whom we no longer have any open channels with. -func (d *DB) PruneLinkNodes() error { - return d.Update(func(tx *bbolt.Tx) error { - linkNodes, err := d.fetchAllLinkNodes(tx) - if err != nil { - return err - } - - for _, linkNode := range linkNodes { - err := d.pruneLinkNode(tx, linkNode.IdentityPub) - if err != nil { - return err - } - } - - return nil - }) -} - -// ChannelShell is a shell of a channel that is meant to be used for channel -// recovery purposes. It contains a minimal OpenChannel instance along with -// addresses for that target node. -type ChannelShell struct { - // NodeAddrs the set of addresses that this node has known to be - // reachable at in the past. - NodeAddrs []net.Addr - - // Chan is a shell of an OpenChannel, it contains only the items - // required to restore the channel on disk. - Chan *OpenChannel -} - -// RestoreChannelShells is a method that allows the caller to reconstruct the -// state of an OpenChannel from the ChannelShell. We'll attempt to write the -// new channel to disk, create a LinkNode instance with the passed node -// addresses, and finally create an edge within the graph for the channel as -// well. This method is idempotent, so repeated calls with the same set of -// channel shells won't modify the database after the initial call. -func (d *DB) RestoreChannelShells(channelShells ...*ChannelShell) error { - chanGraph := d.ChannelGraph() - - // TODO(conner): find way to do this w/o accessing internal members? - chanGraph.cacheMu.Lock() - defer chanGraph.cacheMu.Unlock() - - var chansRestored []uint64 - err := d.Update(func(tx *bbolt.Tx) error { - for _, channelShell := range channelShells { - channel := channelShell.Chan - - // When we make a channel, we mark that the channel has - // been restored, this will signal to other sub-systems - // to not attempt to use the channel as if it was a - // regular one. - channel.chanStatus |= ChanStatusRestored - - // First, we'll attempt to create a new open channel - // and link node for this channel. If the channel - // already exists, then in order to ensure this method - // is idempotent, we'll continue to the next step. - channel.Db = d - err := syncNewChannel( - tx, channel, channelShell.NodeAddrs, - ) - if err != nil { - return err - } - - // Next, we'll create an active edge in the graph - // database for this channel in order to restore our - // partial view of the network. - // - // TODO(roasbeef): if we restore *after* the channel - // has been closed on chain, then need to inform the - // router that it should try and prune these values as - // we can detect them - edgeInfo := ChannelEdgeInfo{ - ChannelID: channel.ShortChannelID.ToUint64(), - ChainHash: channel.ChainHash, - ChannelPoint: channel.FundingOutpoint, - Capacity: channel.Capacity, - } - - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - selfNode, err := chanGraph.sourceNode(nodes) - if err != nil { - return err - } - - // Depending on which pub key is smaller, we'll assign - // our roles as "node1" and "node2". - chanPeer := channel.IdentityPub.SerializeCompressed() - selfIsSmaller := bytes.Compare( - selfNode.PubKeyBytes[:], chanPeer, - ) == -1 - if selfIsSmaller { - copy(edgeInfo.NodeKey1Bytes[:], selfNode.PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], chanPeer) - } else { - copy(edgeInfo.NodeKey1Bytes[:], chanPeer) - copy(edgeInfo.NodeKey2Bytes[:], selfNode.PubKeyBytes[:]) - } - - // With the edge info shell constructed, we'll now add - // it to the graph. - err = chanGraph.addChannelEdge(tx, &edgeInfo) - if err != nil && err != ErrEdgeAlreadyExist { - return err - } - - // Similarly, we'll construct a channel edge shell and - // add that itself to the graph. - chanEdge := ChannelEdgePolicy{ - ChannelID: edgeInfo.ChannelID, - LastUpdate: time.Now(), - } - - // If their pubkey is larger, then we'll flip the - // direction bit to indicate that us, the "second" node - // is updating their policy. - if !selfIsSmaller { - chanEdge.ChannelFlags |= lnwire.ChanUpdateDirection - } - - _, err = updateEdgePolicy(tx, &chanEdge) - if err != nil { - return err - } - - chansRestored = append(chansRestored, edgeInfo.ChannelID) - } - - return nil - }) - if err != nil { - return err - } - - for _, chanid := range chansRestored { - chanGraph.rejectCache.remove(chanid) - chanGraph.chanCache.remove(chanid) - } - - return nil -} - -// AddrsForNode consults the graph and channel database for all addresses known -// to the passed node public key. -func (d *DB) AddrsForNode(nodePub *btcec.PublicKey) ([]net.Addr, error) { - var ( - linkNode *LinkNode - graphNode LightningNode - ) - - dbErr := d.View(func(tx *bbolt.Tx) error { - var err error - - linkNode, err = fetchLinkNode(tx, nodePub) - if err != nil { - return err - } - - // We'll also query the graph for this peer to see if they have - // any addresses that we don't currently have stored within the - // link node database. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - compressedPubKey := nodePub.SerializeCompressed() - graphNode, err = fetchLightningNode(nodes, compressedPubKey) - if err != nil && err != ErrGraphNodeNotFound { - // If the node isn't found, then that's OK, as we still - // have the link node data. - return err - } - - return nil - }) - if dbErr != nil { - return nil, dbErr - } - - // Now that we have both sources of addrs for this node, we'll use a - // map to de-duplicate any addresses between the two sources, and - // produce a final list of the combined addrs. - addrs := make(map[string]net.Addr) - for _, addr := range linkNode.Addresses { - addrs[addr.String()] = addr - } - for _, addr := range graphNode.Addresses { - addrs[addr.String()] = addr - } - dedupedAddrs := make([]net.Addr, 0, len(addrs)) - for _, addr := range addrs { - dedupedAddrs = append(dedupedAddrs, addr) - } - - return dedupedAddrs, nil -} - // syncVersions function is used for safe db version synchronization. It // applies migration functions to the current database and recovers the // previous state of db if at least one error/panic appeared during migration. diff --git a/channeldb/migration_01_to_11/db_test.go b/channeldb/migration_01_to_11/db_test.go deleted file mode 100644 index 721546e7..00000000 --- a/channeldb/migration_01_to_11/db_test.go +++ /dev/null @@ -1,471 +0,0 @@ -package migration_01_to_11 - -import ( - "io/ioutil" - "math" - "math/rand" - "net" - "os" - "path/filepath" - "reflect" - "testing" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcutil" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/keychain" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/shachain" -) - -func TestOpenWithCreate(t *testing.T) { - t.Parallel() - - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - defer os.RemoveAll(tempDirName) - - // Next, open thereby creating channeldb for the first time. - dbPath := filepath.Join(tempDirName, "cdb") - cdb, err := Open(dbPath) - if err != nil { - t.Fatalf("unable to create channeldb: %v", err) - } - if err := cdb.Close(); err != nil { - t.Fatalf("unable to close channeldb: %v", err) - } - - // The path should have been successfully created. - if !fileExists(dbPath) { - t.Fatalf("channeldb failed to create data directory") - } -} - -// TestWipe tests that the database wipe operation completes successfully -// and that the buckets are deleted. It also checks that attempts to fetch -// information while the buckets are not set return the correct errors. -func TestWipe(t *testing.T) { - t.Parallel() - - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - t.Fatalf("unable to create temp dir: %v", err) - } - defer os.RemoveAll(tempDirName) - - // Next, open thereby creating channeldb for the first time. - dbPath := filepath.Join(tempDirName, "cdb") - cdb, err := Open(dbPath) - if err != nil { - t.Fatalf("unable to create channeldb: %v", err) - } - defer cdb.Close() - - if err := cdb.Wipe(); err != nil { - t.Fatalf("unable to wipe channeldb: %v", err) - } - // Check correct errors are returned - _, err = cdb.FetchAllOpenChannels() - if err != ErrNoActiveChannels { - t.Fatalf("fetching open channels: expected '%v' instead got '%v'", - ErrNoActiveChannels, err) - } - _, err = cdb.FetchClosedChannels(false) - if err != ErrNoClosedChannels { - t.Fatalf("fetching closed channels: expected '%v' instead got '%v'", - ErrNoClosedChannels, err) - } -} - -// TestFetchClosedChannelForID tests that we are able to properly retrieve a -// ChannelCloseSummary from the DB given a ChannelID. -func TestFetchClosedChannelForID(t *testing.T) { - t.Parallel() - - const numChans = 101 - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state, that we will mutate the index of the - // funding point. - state, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Now run through the number of channels, and modify the outpoint index - // to create new channel IDs. - for i := uint32(0); i < numChans; i++ { - // Save the open channel to disk. - state.FundingOutpoint.Index = i - - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18556, - } - if err := state.SyncPending(addr, 101); err != nil { - t.Fatalf("unable to save and serialize channel "+ - "state: %v", err) - } - - // Close the channel. To make sure we retrieve the correct - // summary later, we make them differ in the SettledBalance. - closeSummary := &ChannelCloseSummary{ - ChanPoint: state.FundingOutpoint, - RemotePub: state.IdentityPub, - SettledBalance: btcutil.Amount(500 + i), - } - if err := state.CloseChannel(closeSummary); err != nil { - t.Fatalf("unable to close channel: %v", err) - } - } - - // Now run though them all again and make sure we are able to retrieve - // summaries from the DB. - for i := uint32(0); i < numChans; i++ { - state.FundingOutpoint.Index = i - - // We calculate the ChannelID and use it to fetch the summary. - cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) - fetchedSummary, err := cdb.FetchClosedChannelForID(cid) - if err != nil { - t.Fatalf("unable to fetch close summary: %v", err) - } - - // Make sure we retrieved the correct one by checking the - // SettledBalance. - if fetchedSummary.SettledBalance != btcutil.Amount(500+i) { - t.Fatalf("summaries don't match: expected %v got %v", - btcutil.Amount(500+i), - fetchedSummary.SettledBalance) - } - } - - // As a final test we make sure that we get ErrClosedChannelNotFound - // for a ChannelID we didn't add to the DB. - state.FundingOutpoint.Index++ - cid := lnwire.NewChanIDFromOutPoint(&state.FundingOutpoint) - _, err = cdb.FetchClosedChannelForID(cid) - if err != ErrClosedChannelNotFound { - t.Fatalf("expected ErrClosedChannelNotFound, instead got: %v", err) - } -} - -// TestAddrsForNode tests the we're able to properly obtain all the addresses -// for a target node. -func TestAddrsForNode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - graph := cdb.ChannelGraph() - - // We'll make a test vertex to insert into the database, as the source - // node, but this node will only have half the number of addresses it - // usually does. - testNode, err := createTestVertex(cdb) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - testNode.Addresses = []net.Addr{testAddr} - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // Next, we'll make a link node with the same pubkey, but with an - // additional address. - nodePub, err := testNode.PubKey() - if err != nil { - t.Fatalf("unable to recv node pub: %v", err) - } - linkNode := cdb.NewLinkNode( - wire.MainNet, nodePub, anotherAddr, - ) - if err := linkNode.Sync(); err != nil { - t.Fatalf("unable to sync link node: %v", err) - } - - // Now that we've created a link node, as well as a vertex for the - // node, we'll query for all its addresses. - nodeAddrs, err := cdb.AddrsForNode(nodePub) - if err != nil { - t.Fatalf("unable to obtain node addrs: %v", err) - } - - expectedAddrs := make(map[string]struct{}) - expectedAddrs[testAddr.String()] = struct{}{} - expectedAddrs[anotherAddr.String()] = struct{}{} - - // Finally, ensure that all the expected addresses are found. - if len(nodeAddrs) != len(expectedAddrs) { - t.Fatalf("expected %v addrs, got %v", - len(expectedAddrs), len(nodeAddrs)) - } - for _, addr := range nodeAddrs { - if _, ok := expectedAddrs[addr.String()]; !ok { - t.Fatalf("unexpected addr: %v", addr) - } - } -} - -// TestFetchChannel tests that we're able to fetch an arbitrary channel from -// disk. -func TestFetchChannel(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // Create the test channel state that we'll sync to the database - // shortly. - channelState, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - - // Mark the channel as pending, then immediately mark it as open to it - // can be fully visible. - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - } - if err := channelState.SyncPending(addr, 9); err != nil { - t.Fatalf("unable to save and serialize channel state: %v", err) - } - err = channelState.MarkAsOpen(lnwire.NewShortChanIDFromInt(99)) - if err != nil { - t.Fatalf("unable to mark channel open: %v", err) - } - - // Next, attempt to fetch the channel by its chan point. - dbChannel, err := cdb.FetchChannel(channelState.FundingOutpoint) - if err != nil { - t.Fatalf("unable to fetch channel: %v", err) - } - - // The decoded channel state should be identical to what we stored - // above. - if !reflect.DeepEqual(channelState, dbChannel) { - t.Fatalf("channel state doesn't match:: %v vs %v", - spew.Sdump(channelState), spew.Sdump(dbChannel)) - } - - // If we attempt to query for a non-exist ante channel, then we should - // get an error. - channelState2, err := createTestChannelState(cdb) - if err != nil { - t.Fatalf("unable to create channel state: %v", err) - } - channelState2.FundingOutpoint.Index ^= 1 - - _, err = cdb.FetchChannel(channelState2.FundingOutpoint) - if err == nil { - t.Fatalf("expected query to fail") - } -} - -func genRandomChannelShell() (*ChannelShell, error) { - var testPriv [32]byte - if _, err := rand.Read(testPriv[:]); err != nil { - return nil, err - } - - _, pub := btcec.PrivKeyFromBytes(btcec.S256(), testPriv[:]) - - var chanPoint wire.OutPoint - if _, err := rand.Read(chanPoint.Hash[:]); err != nil { - return nil, err - } - - pub.Curve = nil - - chanPoint.Index = uint32(rand.Intn(math.MaxUint16)) - - chanStatus := ChanStatusDefault | ChanStatusRestored - - var shaChainPriv [32]byte - if _, err := rand.Read(testPriv[:]); err != nil { - return nil, err - } - revRoot, err := chainhash.NewHash(shaChainPriv[:]) - if err != nil { - return nil, err - } - shaChainProducer := shachain.NewRevocationProducer(*revRoot) - - return &ChannelShell{ - NodeAddrs: []net.Addr{&net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 18555, - }}, - Chan: &OpenChannel{ - chanStatus: chanStatus, - ChainHash: rev, - FundingOutpoint: chanPoint, - ShortChannelID: lnwire.NewShortChanIDFromInt( - uint64(rand.Int63()), - ), - IdentityPub: pub, - LocalChanCfg: ChannelConfig{ - ChannelConstraints: ChannelConstraints{ - CsvDelay: uint16(rand.Int63()), - }, - PaymentBasePoint: keychain.KeyDescriptor{ - KeyLocator: keychain.KeyLocator{ - Family: keychain.KeyFamily(rand.Int63()), - Index: uint32(rand.Int63()), - }, - }, - }, - RemoteCurrentRevocation: pub, - IsPending: false, - RevocationStore: shachain.NewRevocationStore(), - RevocationProducer: shaChainProducer, - }, - }, nil -} - -// TestRestoreChannelShells tests that we're able to insert a partially channel -// populated to disk. This is useful for channel recovery purposes. We should -// find the new channel shell on disk, and also the db should be populated with -// an edge for that channel. -func TestRestoreChannelShells(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First, we'll make our channel shell, it will only have the minimal - // amount of information required for us to initiate the data loss - // protection feature. - channelShell, err := genRandomChannelShell() - if err != nil { - t.Fatalf("unable to gen channel shell: %v", err) - } - - graph := cdb.ChannelGraph() - - // Before we can restore the channel, we'll need to make a source node - // in the graph as the channel edge we create will need to have a - // origin. - testNode, err := createTestVertex(cdb) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // With the channel shell constructed, we'll now insert it into the - // database with the restoration method. - if err := cdb.RestoreChannelShells(channelShell); err != nil { - t.Fatalf("unable to restore channel shell: %v", err) - } - - // Now that the channel has been inserted, we'll attempt to query for - // it to ensure we can properly locate it via various means. - // - // First, we'll attempt to query for all channels that we have with the - // node public key that was restored. - nodeChans, err := cdb.FetchOpenChannels(channelShell.Chan.IdentityPub) - if err != nil { - t.Fatalf("unable find channel: %v", err) - } - - // We should now find a single channel from the database. - if len(nodeChans) != 1 { - t.Fatalf("unable to find restored channel by node "+ - "pubkey: %v", err) - } - - // Ensure that it isn't possible to modify the commitment state machine - // of this restored channel. - channel := nodeChans[0] - err = channel.UpdateCommitment(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - err = channel.AppendRemoteCommitChain(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - err = channel.AdvanceCommitChainTail(nil) - if err != ErrNoRestoredChannelMutation { - t.Fatalf("able to mutate restored channel") - } - - // That single channel should have the proper channel point, and also - // the expected set of flags to indicate that it was a restored - // channel. - if nodeChans[0].FundingOutpoint != channelShell.Chan.FundingOutpoint { - t.Fatalf("wrong funding outpoint: expected %v, got %v", - nodeChans[0].FundingOutpoint, - channelShell.Chan.FundingOutpoint) - } - if !nodeChans[0].HasChanStatus(ChanStatusRestored) { - t.Fatalf("node has wrong status flags: %v", - nodeChans[0].chanStatus) - } - - // We should also be able to find the channel if we query for it - // directly. - _, err = cdb.FetchChannel(channelShell.Chan.FundingOutpoint) - if err != nil { - t.Fatalf("unable to fetch channel: %v", err) - } - - // We should also be able to find the link node that was inserted by - // its public key. - linkNode, err := cdb.FetchLinkNode(channelShell.Chan.IdentityPub) - if err != nil { - t.Fatalf("unable to fetch link node: %v", err) - } - - // The node should have the same address, as specified in the channel - // shell. - if reflect.DeepEqual(linkNode.Addresses, channelShell.NodeAddrs) { - t.Fatalf("addr mismach: expected %v, got %v", - linkNode.Addresses, channelShell.NodeAddrs) - } - - // Finally, we'll ensure that the edge for the channel was properly - // inserted. - chanInfos, err := graph.FetchChanInfos( - []uint64{channelShell.Chan.ShortChannelID.ToUint64()}, - ) - if err != nil { - t.Fatalf("unable to find edges: %v", err) - } - - if len(chanInfos) != 1 { - t.Fatalf("wrong amount of chan infos: expected %v got %v", - len(chanInfos), 1) - } - - // We should only find a single edge. - if chanInfos[0].Policy1 != nil && chanInfos[0].Policy2 != nil { - t.Fatalf("only a single edge should be inserted: %v", err) - } -} diff --git a/channeldb/migration_01_to_11/doc.go b/channeldb/migration_01_to_11/doc.go deleted file mode 100644 index c90412f2..00000000 --- a/channeldb/migration_01_to_11/doc.go +++ /dev/null @@ -1 +0,0 @@ -package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/error.go b/channeldb/migration_01_to_11/error.go index f264fb70..232aaa2b 100644 --- a/channeldb/migration_01_to_11/error.go +++ b/channeldb/migration_01_to_11/error.go @@ -1,55 +1,23 @@ package migration_01_to_11 import ( - "errors" "fmt" ) var ( - // ErrNoChanDBExists is returned when a channel bucket hasn't been - // created. - ErrNoChanDBExists = fmt.Errorf("channel db has not yet been created") // ErrDBReversion is returned when detecting an attempt to revert to a // prior database version. ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version") - // ErrLinkNodesNotFound is returned when node info bucket hasn't been - // created. - ErrLinkNodesNotFound = fmt.Errorf("no link nodes exist") - - // ErrNoActiveChannels is returned when there is no active (open) - // channels within the database. - ErrNoActiveChannels = fmt.Errorf("no active channels exist") - - // ErrNoPastDeltas is returned when the channel delta bucket hasn't been - // created. - ErrNoPastDeltas = fmt.Errorf("channel has no recorded deltas") - - // ErrInvoiceNotFound is returned when a targeted invoice can't be - // found. - ErrInvoiceNotFound = fmt.Errorf("unable to locate invoice") - // ErrNoInvoicesCreated is returned when we don't have invoices in // our database to return. ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") - // ErrDuplicateInvoice is returned when an invoice with the target - // payment hash already exists. - ErrDuplicateInvoice = fmt.Errorf("invoice with payment hash already exists") - // ErrNoPaymentsCreated is returned when bucket of payments hasn't been // created. ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") - // ErrNodeNotFound is returned when node bucket exists, but node with - // specific identity can't be found. - ErrNodeNotFound = fmt.Errorf("link node with target identity not found") - - // ErrChannelNotFound is returned when we attempt to locate a channel - // for a specific chain, but it is not found. - ErrChannelNotFound = fmt.Errorf("channel not found") - // ErrMetaNotFound is returned when meta bucket hasn't been // created. ErrMetaNotFound = fmt.Errorf("unable to locate meta information") @@ -58,22 +26,11 @@ var ( // graph doesn't exist. ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") - // ErrGraphNeverPruned is returned when graph was never pruned. - ErrGraphNeverPruned = fmt.Errorf("graph never pruned") - // ErrSourceNodeNotSet is returned if the source node of the graph // hasn't been added The source node is the center node within a // star-graph. ErrSourceNodeNotSet = fmt.Errorf("source node does not exist") - // ErrGraphNodesNotFound is returned in case none of the nodes has - // been added in graph node bucket. - ErrGraphNodesNotFound = fmt.Errorf("no graph nodes exist") - - // ErrGraphNoEdgesFound is returned in case of none of the channel/edges - // has been added in graph edge bucket. - ErrGraphNoEdgesFound = fmt.Errorf("no graph edges exist") - // ErrGraphNodeNotFound is returned when we're unable to find the target // node. ErrGraphNodeNotFound = fmt.Errorf("unable to find node") @@ -82,17 +39,6 @@ var ( // can't be found. ErrEdgeNotFound = fmt.Errorf("edge not found") - // ErrZombieEdge is an error returned when we attempt to look up an edge - // but it is marked as a zombie within the zombie index. - ErrZombieEdge = errors.New("edge marked as zombie") - - // ErrEdgeAlreadyExist is returned when edge with specific - // channel id can't be added because it already exist. - ErrEdgeAlreadyExist = fmt.Errorf("edge already exist") - - // ErrNodeAliasNotFound is returned when alias for node can't be found. - ErrNodeAliasNotFound = fmt.Errorf("alias for node not found") - // ErrUnknownAddressType is returned when a node's addressType is not // an expected value. ErrUnknownAddressType = fmt.Errorf("address type cannot be resolved") @@ -101,20 +47,11 @@ var ( // channels it has closed, but it hasn't yet closed any channels. ErrNoClosedChannels = fmt.Errorf("no channel have been closed yet") - // ErrNoForwardingEvents is returned in the case that a query fails due - // to the log not having any recorded events. - ErrNoForwardingEvents = fmt.Errorf("no recorded forwarding events") - // ErrEdgePolicyOptionalFieldNotFound is an error returned if a channel // policy field is not found in the db even though its message flags // indicate it should be. ErrEdgePolicyOptionalFieldNotFound = fmt.Errorf("optional field not " + "present") - - // ErrChanAlreadyExists is return when the caller attempts to create a - // channel with a channel point that is already present in the - // database. - ErrChanAlreadyExists = fmt.Errorf("channel already exists") ) // ErrTooManyExtraOpaqueBytes creates an error which should be returned if the diff --git a/channeldb/migration_01_to_11/fees.go b/channeldb/migration_01_to_11/fees.go deleted file mode 100644 index c90412f2..00000000 --- a/channeldb/migration_01_to_11/fees.go +++ /dev/null @@ -1 +0,0 @@ -package migration_01_to_11 diff --git a/channeldb/migration_01_to_11/forwarding_log.go b/channeldb/migration_01_to_11/forwarding_log.go deleted file mode 100644 index 6b9f8f5d..00000000 --- a/channeldb/migration_01_to_11/forwarding_log.go +++ /dev/null @@ -1,274 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "io" - "sort" - "time" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" -) - -var ( - // forwardingLogBucket is the bucket that we'll use to store the - // forwarding log. The forwarding log contains a time series database - // of the forwarding history of a lightning daemon. Each key within the - // bucket is a timestamp (in nano seconds since the unix epoch), and - // the value a slice of a forwarding event for that timestamp. - forwardingLogBucket = []byte("circuit-fwd-log") -) - -const ( - // forwardingEventSize is the size of a forwarding event. The breakdown - // is as follows: - // - // * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in - // || 8 byte value out - // - // From the value in and value out, callers can easily compute the - // total fee extract from a forwarding event. - forwardingEventSize = 32 - - // MaxResponseEvents is the max number of forwarding events that will - // be returned by a single query response. This size was selected to - // safely remain under gRPC's 4MiB message size response limit. As each - // full forwarding event (including the timestamp) is 40 bytes, we can - // safely return 50k entries in a single response. - MaxResponseEvents = 50000 -) - -// ForwardingLog returns an instance of the ForwardingLog object backed by the -// target database instance. -func (d *DB) ForwardingLog() *ForwardingLog { - return &ForwardingLog{ - db: d, - } -} - -// ForwardingLog is a time series database that logs the fulfilment of payment -// circuits by a lightning network daemon. The log contains a series of -// forwarding events which map a timestamp to a forwarding event. A forwarding -// event describes which channels were used to create+settle a circuit, and the -// amount involved. Subtracting the outgoing amount from the incoming amount -// reveals the fee charged for the forwarding service. -type ForwardingLog struct { - db *DB -} - -// ForwardingEvent is an event in the forwarding log's time series. Each -// forwarding event logs the creation and tear-down of a payment circuit. A -// circuit is created once an incoming HTLC has been fully forwarded, and -// destroyed once the payment has been settled. -type ForwardingEvent struct { - // Timestamp is the settlement time of this payment circuit. - Timestamp time.Time - - // IncomingChanID is the incoming channel ID of the payment circuit. - IncomingChanID lnwire.ShortChannelID - - // OutgoingChanID is the outgoing channel ID of the payment circuit. - OutgoingChanID lnwire.ShortChannelID - - // AmtIn is the amount of the incoming HTLC. Subtracting this from the - // outgoing amount gives the total fees of this payment circuit. - AmtIn lnwire.MilliSatoshi - - // AmtOut is the amount of the outgoing HTLC. Subtracting the incoming - // amount from this gives the total fees for this payment circuit. - AmtOut lnwire.MilliSatoshi -} - -// encodeForwardingEvent writes out the target forwarding event to the passed -// io.Writer, using the expected DB format. Note that the timestamp isn't -// serialized as this will be the key value within the bucket. -func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { - return WriteElements( - w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, - ) -} - -// decodeForwardingEvent attempts to decode the raw bytes of a serialized -// forwarding event into the target ForwardingEvent. Note that the timestamp -// won't be decoded, as the caller is expected to set this due to the bucket -// structure of the forwarding log. -func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { - return ReadElements( - r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, - ) -} - -// AddForwardingEvents adds a series of forwarding events to the database. -// Before inserting, the set of events will be sorted according to their -// timestamp. This ensures that all writes to disk are sequential. -func (f *ForwardingLog) AddForwardingEvents(events []ForwardingEvent) error { - // Before we create the database transaction, we'll ensure that the set - // of forwarding events are properly sorted according to their - // timestamp. - sort.Slice(events, func(i, j int) bool { - return events[i].Timestamp.Before(events[j].Timestamp) - }) - - var timestamp [8]byte - - return f.db.Batch(func(tx *bbolt.Tx) error { - // First, we'll fetch the bucket that stores our time series - // log. - logBucket, err := tx.CreateBucketIfNotExists( - forwardingLogBucket, - ) - if err != nil { - return err - } - - // With the bucket obtained, we can now begin to write out the - // series of events. - for _, event := range events { - var eventBytes [forwardingEventSize]byte - eventBuf := bytes.NewBuffer(eventBytes[0:0:forwardingEventSize]) - - // First, we'll serialize this timestamp into our - // timestamp buffer. - byteOrder.PutUint64( - timestamp[:], uint64(event.Timestamp.UnixNano()), - ) - - // With the key encoded, we'll then encode the event - // into our buffer, then write it out to disk. - err := encodeForwardingEvent(eventBuf, &event) - if err != nil { - return err - } - err = logBucket.Put(timestamp[:], eventBuf.Bytes()) - if err != nil { - return err - } - } - - return nil - }) -} - -// ForwardingEventQuery represents a query to the forwarding log payment -// circuit time series database. The query allows a caller to retrieve all -// records for a particular time slice, offset in that time slice, limiting the -// total number of responses returned. -type ForwardingEventQuery struct { - // StartTime is the start time of the time slice. - StartTime time.Time - - // EndTime is the end time of the time slice. - EndTime time.Time - - // IndexOffset is the offset within the time slice to start at. This - // can be used to start the response at a particular record. - IndexOffset uint32 - - // NumMaxEvents is the max number of events to return. - NumMaxEvents uint32 -} - -// ForwardingLogTimeSlice is the response to a forwarding query. It includes -// the original query, the set events that match the query, and an integer -// which represents the offset index of the last item in the set of retuned -// events. This integer allows callers to resume their query using this offset -// in the event that the query's response exceeds the max number of returnable -// events. -type ForwardingLogTimeSlice struct { - ForwardingEventQuery - - // ForwardingEvents is the set of events in our time series that answer - // the query embedded above. - ForwardingEvents []ForwardingEvent - - // LastIndexOffset is the index of the last element in the set of - // returned ForwardingEvents above. Callers can use this to resume - // their query in the event that the time slice has too many events to - // fit into a single response. - LastIndexOffset uint32 -} - -// Query allows a caller to query the forwarding event time series for a -// particular time slice. The caller can control the precise time as well as -// the number of events to be returned. -// -// TODO(roasbeef): rename? -func (f *ForwardingLog) Query(q ForwardingEventQuery) (ForwardingLogTimeSlice, error) { - resp := ForwardingLogTimeSlice{ - ForwardingEventQuery: q, - } - - // If the user provided an index offset, then we'll not know how many - // records we need to skip. We'll also keep track of the record offset - // as that's part of the final return value. - recordsToSkip := q.IndexOffset - recordOffset := q.IndexOffset - - err := f.db.View(func(tx *bbolt.Tx) error { - // If the bucket wasn't found, then there aren't any events to - // be returned. - logBucket := tx.Bucket(forwardingLogBucket) - if logBucket == nil { - return ErrNoForwardingEvents - } - - // We'll be using a cursor to seek into the database, so we'll - // populate byte slices that represent the start of the key - // space we're interested in, and the end. - var startTime, endTime [8]byte - byteOrder.PutUint64(startTime[:], uint64(q.StartTime.UnixNano())) - byteOrder.PutUint64(endTime[:], uint64(q.EndTime.UnixNano())) - - // If we know that a set of log events exists, then we'll begin - // our seek through the log in order to satisfy the query. - // We'll continue until either we reach the end of the range, - // or reach our max number of events. - logCursor := logBucket.Cursor() - timestamp, events := logCursor.Seek(startTime[:]) - for ; timestamp != nil && bytes.Compare(timestamp, endTime[:]) <= 0; timestamp, events = logCursor.Next() { - // If our current return payload exceeds the max number - // of events, then we'll exit now. - if uint32(len(resp.ForwardingEvents)) >= q.NumMaxEvents { - return nil - } - - // If we're not yet past the user defined offset, then - // we'll continue to seek forward. - if recordsToSkip > 0 { - recordsToSkip-- - continue - } - - currentTime := time.Unix( - 0, int64(byteOrder.Uint64(timestamp)), - ) - - // At this point, we've skipped enough records to start - // to collate our query. For each record, we'll - // increment the final record offset so the querier can - // utilize pagination to seek further. - readBuf := bytes.NewReader(events) - for readBuf.Len() != 0 { - var event ForwardingEvent - err := decodeForwardingEvent(readBuf, &event) - if err != nil { - return err - } - - event.Timestamp = currentTime - resp.ForwardingEvents = append(resp.ForwardingEvents, event) - - recordOffset++ - } - } - - return nil - }) - if err != nil && err != ErrNoForwardingEvents { - return ForwardingLogTimeSlice{}, err - } - - resp.LastIndexOffset = recordOffset - - return resp, nil -} diff --git a/channeldb/migration_01_to_11/forwarding_log_test.go b/channeldb/migration_01_to_11/forwarding_log_test.go deleted file mode 100644 index 9e0de7c4..00000000 --- a/channeldb/migration_01_to_11/forwarding_log_test.go +++ /dev/null @@ -1,265 +0,0 @@ -package migration_01_to_11 - -import ( - "math/rand" - "reflect" - "testing" - - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lnwire" - - "time" -) - -// TestForwardingLogBasicStorageAndQuery tests that we're able to store and -// then query for items that have previously been added to the event log. -func TestForwardingLogBasicStorageAndQuery(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - timestamp := time.Unix(1234, 0) - - // We'll create 100 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 100 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: timestamp, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - timestamp = timestamp.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // With our events added we'll now construct a basic query to retrieve - // all of the events. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: timestamp, - IndexOffset: 0, - NumMaxEvents: 1000, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // The set of returned events should match identically, as they should - // be returned in sorted order. - if !reflect.DeepEqual(events, timeSlice.ForwardingEvents) { - t.Fatalf("event mismatch: expected %v vs %v", - spew.Sdump(events), spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The offset index of the final entry should be numEvents, so the - // number of total events we've written. - if timeSlice.LastIndexOffset != uint32(numEvents) { - t.Fatalf("wrong final offset: expected %v, got %v", - timeSlice.LastIndexOffset, numEvents) - } -} - -// TestForwardingLogQueryOptions tests that the query offset works properly. So -// if we add a series of events, then we should be able to seek within the -// timeslice accordingly. This exercises the index offset and num max event -// field in the query, and also the last index offset field int he response. -func TestForwardingLogQueryOptions(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - endTime := time.Unix(1234, 0) - - // We'll create 20 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 20 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: endTime, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - endTime = endTime.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // With all of our events added, we should be able to query for the - // first 10 events using the max event query field. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: endTime, - IndexOffset: 0, - NumMaxEvents: 10, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 10 events back. - if len(timeSlice.ForwardingEvents) != 10 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The set of events returned should be the first 10 events that we - // added. - if !reflect.DeepEqual(events[:10], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[:10]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The final offset should be the exact number of events returned. - if timeSlice.LastIndexOffset != 10 { - t.Fatalf("wrong index offset: expected %v, got %v", 10, - timeSlice.LastIndexOffset) - } - - // If we use the final offset to query again, then we should get 10 - // more events, that are the last 10 events we wrote. - eventQuery.IndexOffset = 10 - timeSlice, err = log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 10 events back once again. - if len(timeSlice.ForwardingEvents) != 10 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The events that we got back should be the last 10 events that we - // wrote out. - if !reflect.DeepEqual(events[10:], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[10:]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // Finally, the last index offset should be 20, or the number of - // records we've written out. - if timeSlice.LastIndexOffset != 20 { - t.Fatalf("wrong index offset: expected %v, got %v", 20, - timeSlice.LastIndexOffset) - } -} - -// TestForwardingLogQueryLimit tests that we're able to properly limit the -// number of events that are returned as part of a query. -func TestForwardingLogQueryLimit(t *testing.T) { - t.Parallel() - - // First, we'll set up a test database, and use that to instantiate the - // forwarding event log that we'll be using for the duration of the - // test. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - log := ForwardingLog{ - db: db, - } - - initialTime := time.Unix(1234, 0) - endTime := time.Unix(1234, 0) - - // We'll create 200 random events, which each event being spaced 10 - // minutes after the prior event. - numEvents := 200 - events := make([]ForwardingEvent, numEvents) - for i := 0; i < numEvents; i++ { - events[i] = ForwardingEvent{ - Timestamp: endTime, - IncomingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), - AmtIn: lnwire.MilliSatoshi(rand.Int63()), - AmtOut: lnwire.MilliSatoshi(rand.Int63()), - } - - endTime = endTime.Add(time.Minute * 10) - } - - // Now that all of our set of events constructed, we'll add them to the - // database in a batch manner. - if err := log.AddForwardingEvents(events); err != nil { - t.Fatalf("unable to add events: %v", err) - } - - // Once the events have been written out, we'll issue a query over the - // entire range, but restrict the number of events to the first 100. - eventQuery := ForwardingEventQuery{ - StartTime: initialTime, - EndTime: endTime, - IndexOffset: 0, - NumMaxEvents: 100, - } - timeSlice, err := log.Query(eventQuery) - if err != nil { - t.Fatalf("unable to query for events: %v", err) - } - - // We should get exactly 100 events back. - if len(timeSlice.ForwardingEvents) != 100 { - t.Fatalf("wrong number of events: expected %v, got %v", 10, - len(timeSlice.ForwardingEvents)) - } - - // The set of events returned should be the first 100 events that we - // added. - if !reflect.DeepEqual(events[:100], timeSlice.ForwardingEvents) { - t.Fatalf("wrong response: expected %v, got %v", - spew.Sdump(events[:100]), - spew.Sdump(timeSlice.ForwardingEvents)) - } - - // The final offset should be the exact number of events returned. - if timeSlice.LastIndexOffset != 100 { - t.Fatalf("wrong index offset: expected %v, got %v", 100, - timeSlice.LastIndexOffset) - } -} diff --git a/channeldb/migration_01_to_11/forwarding_package.go b/channeldb/migration_01_to_11/forwarding_package.go deleted file mode 100644 index cbbf90cf..00000000 --- a/channeldb/migration_01_to_11/forwarding_package.go +++ /dev/null @@ -1,928 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "io" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lnwire" -) - -// ErrCorruptedFwdPkg signals that the on-disk structure of the forwarding -// package has potentially been mangled. -var ErrCorruptedFwdPkg = errors.New("fwding package db has been corrupted") - -// FwdState is an enum used to describe the lifecycle of a FwdPkg. -type FwdState byte - -const ( - // FwdStateLockedIn is the starting state for all forwarding packages. - // Packages in this state have not yet committed to the exact set of - // Adds to forward to the switch. - FwdStateLockedIn FwdState = iota - - // FwdStateProcessed marks the state in which all Adds have been - // locally processed and the forwarding decision to the switch has been - // persisted. - FwdStateProcessed - - // FwdStateCompleted signals that all Adds have been acked, and that all - // settles and fails have been delivered to their sources. Packages in - // this state can be removed permanently. - FwdStateCompleted -) - -var ( - // fwdPackagesKey is the root-level bucket that all forwarding packages - // are written. This bucket is further subdivided based on the short - // channel ID of each channel. - fwdPackagesKey = []byte("fwd-packages") - - // addBucketKey is the bucket to which all Add log updates are written. - addBucketKey = []byte("add-updates") - - // failSettleBucketKey is the bucket to which all Settle/Fail log - // updates are written. - failSettleBucketKey = []byte("fail-settle-updates") - - // fwdFilterKey is a key used to write the set of Adds that passed - // validation and are to be forwarded to the switch. - // NOTE: The presence of this key within a forwarding package indicates - // that the package has reached FwdStateProcessed. - fwdFilterKey = []byte("fwd-filter-key") - - // ackFilterKey is a key used to access the PkgFilter indicating which - // Adds have received a Settle/Fail. This response may come from a - // number of sources, including: exitHop settle/fails, switch failures, - // chain arbiter interjections, as well as settle/fails from the - // next hop in the route. - ackFilterKey = []byte("ack-filter-key") - - // settleFailFilterKey is a key used to access the PkgFilter indicating - // which Settles/Fails in have been received and processed by the link - // that originally received the Add. - settleFailFilterKey = []byte("settle-fail-filter-key") -) - -// PkgFilter is used to compactly represent a particular subset of the Adds in a -// forwarding package. Each filter is represented as a simple, statically-sized -// bitvector, where the elements are intended to be the indices of the Adds as -// they are written in the FwdPkg. -type PkgFilter struct { - count uint16 - filter []byte -} - -// NewPkgFilter initializes an empty PkgFilter supporting `count` elements. -func NewPkgFilter(count uint16) *PkgFilter { - // We add 7 to ensure that the integer division yields properly rounded - // values. - filterLen := (count + 7) / 8 - - return &PkgFilter{ - count: count, - filter: make([]byte, filterLen), - } -} - -// Count returns the number of elements represented by this PkgFilter. -func (f *PkgFilter) Count() uint16 { - return f.count -} - -// Set marks the `i`-th element as included by this filter. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Set(i uint16) { - byt := i / 8 - bit := i % 8 - - // Set the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - f.filter[byt] |= byte(1 << (7 - bit)) -} - -// Contains queries the filter for membership of index `i`. -// NOTE: It is assumed that i is always less than count. -func (f *PkgFilter) Contains(i uint16) bool { - byt := i / 8 - bit := i % 8 - - // Read the i-th bit in the filter. - // TODO(conner): ignore if > count to prevent panic? - return f.filter[byt]&(1<<(7-bit)) != 0 -} - -// Equal checks two PkgFilters for equality. -func (f *PkgFilter) Equal(f2 *PkgFilter) bool { - if f == f2 { - return true - } - if f.count != f2.count { - return false - } - - return bytes.Equal(f.filter, f2.filter) -} - -// IsFull returns true if every element in the filter has been Set, and false -// otherwise. -func (f *PkgFilter) IsFull() bool { - // Batch validate bytes that are fully used. - for i := uint16(0); i < f.count/8; i++ { - if f.filter[i] != 0xFF { - return false - } - } - - // If the count is not a multiple of 8, check that the filter contains - // all remaining bits. - rem := f.count % 8 - for idx := f.count - rem; idx < f.count; idx++ { - if !f.Contains(idx) { - return false - } - } - - return true -} - -// Size returns number of bytes produced when the PkgFilter is serialized. -func (f *PkgFilter) Size() uint16 { - // 2 bytes for uint16 `count`, then round up number of bytes required to - // represent `count` bits. - return 2 + (f.count+7)/8 -} - -// Encode writes the filter to the provided io.Writer. -func (f *PkgFilter) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, f.count); err != nil { - return err - } - - _, err := w.Write(f.filter) - - return err -} - -// Decode reads the filter from the provided io.Reader. -func (f *PkgFilter) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &f.count); err != nil { - return err - } - - f.filter = make([]byte, f.Size()-2) - _, err := io.ReadFull(r, f.filter) - - return err -} - -// FwdPkg records all adds, settles, and fails that were locked in as a result -// of the remote peer sending us a revocation. Each package is identified by -// the short chanid and remote commitment height corresponding to the revocation -// that locked in the HTLCs. For everything except a locally initiated payment, -// settles and fails in a forwarding package must have a corresponding Add in -// another package, and can be removed individually once the source link has -// received the fail/settle. -// -// Adds cannot be removed, as we need to present the same batch of Adds to -// properly handle replay protection. Instead, we use a PkgFilter to mark that -// we have finished processing a particular Add. A FwdPkg should only be deleted -// after the AckFilter is full and all settles and fails have been persistently -// removed. -type FwdPkg struct { - // Source identifies the channel that wrote this forwarding package. - Source lnwire.ShortChannelID - - // Height is the height of the remote commitment chain that locked in - // this forwarding package. - Height uint64 - - // State signals the persistent condition of the package and directs how - // to reprocess the package in the event of failures. - State FwdState - - // Adds contains all add messages which need to be processed and - // forwarded to the switch. Adds does not change over the life of a - // forwarding package. - Adds []LogUpdate - - // FwdFilter is a filter containing the indices of all Adds that were - // forwarded to the switch. - FwdFilter *PkgFilter - - // AckFilter is a filter containing the indices of all Adds for which - // the source has received a settle or fail and is reflected in the next - // commitment txn. A package should not be removed until IsFull() - // returns true. - AckFilter *PkgFilter - - // SettleFails contains all settle and fail messages that should be - // forwarded to the switch. - SettleFails []LogUpdate - - // SettleFailFilter is a filter containing the indices of all Settle or - // Fails originating in this package that have been received and locked - // into the incoming link's commitment state. - SettleFailFilter *PkgFilter -} - -// NewFwdPkg initializes a new forwarding package in FwdStateLockedIn. This -// should be used to create a package at the time we receive a revocation. -func NewFwdPkg(source lnwire.ShortChannelID, height uint64, - addUpdates, settleFailUpdates []LogUpdate) *FwdPkg { - - nAddUpdates := uint16(len(addUpdates)) - nSettleFailUpdates := uint16(len(settleFailUpdates)) - - return &FwdPkg{ - Source: source, - Height: height, - State: FwdStateLockedIn, - Adds: addUpdates, - FwdFilter: NewPkgFilter(nAddUpdates), - AckFilter: NewPkgFilter(nAddUpdates), - SettleFails: settleFailUpdates, - SettleFailFilter: NewPkgFilter(nSettleFailUpdates), - } -} - -// ID returns an unique identifier for this package, used to ensure that sphinx -// replay processing of this batch is idempotent. -func (f *FwdPkg) ID() []byte { - var id = make([]byte, 16) - byteOrder.PutUint64(id[:8], f.Source.ToUint64()) - byteOrder.PutUint64(id[8:], f.Height) - return id -} - -// String returns a human-readable description of the forwarding package. -func (f *FwdPkg) String() string { - return fmt.Sprintf("%T(src=%v, height=%v, nadds=%v, nfailsettles=%v)", - f, f.Source, f.Height, len(f.Adds), len(f.SettleFails)) -} - -// AddRef is used to identify a particular Add in a FwdPkg. The short channel ID -// is assumed to be that of the packager. -type AddRef struct { - // Height is the remote commitment height that locked in the Add. - Height uint64 - - // Index is the index of the Add within the fwd pkg's Adds. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// Encode serializes the AddRef to the given io.Writer. -func (a *AddRef) Encode(w io.Writer) error { - if err := binary.Write(w, binary.BigEndian, a.Height); err != nil { - return err - } - - return binary.Write(w, binary.BigEndian, a.Index) -} - -// Decode deserializes the AddRef from the given io.Reader. -func (a *AddRef) Decode(r io.Reader) error { - if err := binary.Read(r, binary.BigEndian, &a.Height); err != nil { - return err - } - - return binary.Read(r, binary.BigEndian, &a.Index) -} - -// SettleFailRef is used to locate a Settle/Fail in another channel's FwdPkg. A -// channel does not remove its own Settle/Fail htlcs, so the source is provided -// to locate a db bucket belonging to another channel. -type SettleFailRef struct { - // Source identifies the outgoing link that locked in the settle or - // fail. This is then used by the *incoming* link to find the settle - // fail in another link's forwarding packages. - Source lnwire.ShortChannelID - - // Height is the remote commitment height that locked in this - // Settle/Fail. - Height uint64 - - // Index is the index of the Add with the fwd pkg's SettleFails. - // - // NOTE: This index is static over the lifetime of a forwarding package. - Index uint16 -} - -// SettleFailAcker is a generic interface providing the ability to acknowledge -// settle/fail HTLCs stored in forwarding packages. -type SettleFailAcker interface { - // AckSettleFails atomically updates the settle-fail filters in *other* - // channels' forwarding packages. - AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error -} - -// GlobalFwdPkgReader is an interface used to retrieve the forwarding packages -// of any active channel. -type GlobalFwdPkgReader interface { - // LoadChannelFwdPkgs loads all known forwarding packages for the given - // channel. - LoadChannelFwdPkgs(tx *bbolt.Tx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) -} - -// FwdOperator defines the interfaces for managing forwarding packages that are -// external to a particular channel. This interface is used by the switch to -// read forwarding packages from arbitrary channels, and acknowledge settles and -// fails for locally-sourced payments. -type FwdOperator interface { - // GlobalFwdPkgReader provides read access to all known forwarding - // packages - GlobalFwdPkgReader - - // SettleFailAcker grants the ability to acknowledge settles or fails - // residing in arbitrary forwarding packages. - SettleFailAcker -} - -// SwitchPackager is a concrete implementation of the FwdOperator interface. -// A SwitchPackager offers the ability to read any forwarding package, and ack -// arbitrary settle and fail HTLCs. -type SwitchPackager struct{} - -// NewSwitchPackager instantiates a new SwitchPackager. -func NewSwitchPackager() *SwitchPackager { - return &SwitchPackager{} -} - -// AckSettleFails atomically updates the settle-fail filters in *other* -// channels' forwarding packages, to mark that the switch has received a settle -// or fail residing in the forwarding package of a link. -func (*SwitchPackager) AckSettleFails(tx *bbolt.Tx, - settleFailRefs ...SettleFailRef) error { - - return ackSettleFails(tx, settleFailRefs) -} - -// LoadChannelFwdPkgs loads all forwarding packages for a particular channel. -func (*SwitchPackager) LoadChannelFwdPkgs(tx *bbolt.Tx, - source lnwire.ShortChannelID) ([]*FwdPkg, error) { - - return loadChannelFwdPkgs(tx, source) -} - -// FwdPackager supports all operations required to modify fwd packages, such as -// creation, updates, reading, and removal. The interfaces are broken down in -// this way to support future delegation of the subinterfaces. -type FwdPackager interface { - // AddFwdPkg serializes and writes a FwdPkg for this channel at the - // remote commitment height included in the forwarding package. - AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error - - // SetFwdFilter looks up the forwarding package at the remote `height` - // and sets the `fwdFilter`, marking the Adds for which: - // 1) We are not the exit node - // 2) Passed all validation - // 3) Should be forwarded to the switch immediately after a failure - SetFwdFilter(tx *bbolt.Tx, height uint64, fwdFilter *PkgFilter) error - - // AckAddHtlcs atomically updates the add filters in this channel's - // forwarding packages to mark the resolution of an Add that was - // received from the remote party. - AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error - - // SettleFailAcker allows a link to acknowledge settle/fail HTLCs - // belonging to other channels. - SettleFailAcker - - // LoadFwdPkgs loads all known forwarding packages owned by this - // channel. - LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) - - // RemovePkg deletes a forwarding package owned by this channel at - // the provided remote `height`. - RemovePkg(tx *bbolt.Tx, height uint64) error -} - -// ChannelPackager is used by a channel to manage the lifecycle of its forwarding -// packages. The packager is tied to a particular source channel ID, allowing it -// to create and edit its own packages. Each packager also has the ability to -// remove fail/settle htlcs that correspond to an add contained in one of -// source's packages. -type ChannelPackager struct { - source lnwire.ShortChannelID -} - -// NewChannelPackager creates a new packager for a single channel. -func NewChannelPackager(source lnwire.ShortChannelID) *ChannelPackager { - return &ChannelPackager{ - source: source, - } -} - -// AddFwdPkg writes a newly locked in forwarding package to disk. -func (*ChannelPackager) AddFwdPkg(tx *bbolt.Tx, fwdPkg *FwdPkg) error { - fwdPkgBkt, err := tx.CreateBucketIfNotExists(fwdPackagesKey) - if err != nil { - return err - } - - source := makeLogKey(fwdPkg.Source.ToUint64()) - sourceBkt, err := fwdPkgBkt.CreateBucketIfNotExists(source[:]) - if err != nil { - return err - } - - heightKey := makeLogKey(fwdPkg.Height) - heightBkt, err := sourceBkt.CreateBucketIfNotExists(heightKey[:]) - if err != nil { - return err - } - - // Write ADD updates we received at this commit height. - addBkt, err := heightBkt.CreateBucketIfNotExists(addBucketKey) - if err != nil { - return err - } - - // Write SETTLE/FAIL updates we received at this commit height. - failSettleBkt, err := heightBkt.CreateBucketIfNotExists(failSettleBucketKey) - if err != nil { - return err - } - - for i := range fwdPkg.Adds { - err = putLogUpdate(addBkt, uint16(i), &fwdPkg.Adds[i]) - if err != nil { - return err - } - } - - // Persist the initialized pkg filter, which will be used to determine - // when we can remove this forwarding package from disk. - var ackFilterBuf bytes.Buffer - if err := fwdPkg.AckFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - if err := heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()); err != nil { - return err - } - - for i := range fwdPkg.SettleFails { - err = putLogUpdate(failSettleBkt, uint16(i), &fwdPkg.SettleFails[i]) - if err != nil { - return err - } - } - - var settleFailFilterBuf bytes.Buffer - err = fwdPkg.SettleFailFilter.Encode(&settleFailFilterBuf) - if err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// putLogUpdate writes an htlc to the provided `bkt`, using `index` as the key. -func putLogUpdate(bkt *bbolt.Bucket, idx uint16, htlc *LogUpdate) error { - var b bytes.Buffer - if err := htlc.Encode(&b); err != nil { - return err - } - - return bkt.Put(uint16Key(idx), b.Bytes()) -} - -// LoadFwdPkgs scans the forwarding log for any packages that haven't been -// processed, and returns their deserialized log updates in a map indexed by the -// remote commitment height at which the updates were locked in. -func (p *ChannelPackager) LoadFwdPkgs(tx *bbolt.Tx) ([]*FwdPkg, error) { - return loadChannelFwdPkgs(tx, p.source) -} - -// loadChannelFwdPkgs loads all forwarding packages owned by `source`. -func loadChannelFwdPkgs(tx *bbolt.Tx, source lnwire.ShortChannelID) ([]*FwdPkg, error) { - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil, nil - } - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return nil, nil - } - - var heights []uint64 - if err := sourceBkt.ForEach(func(k, _ []byte) error { - if len(k) != 8 { - return ErrCorruptedFwdPkg - } - - heights = append(heights, byteOrder.Uint64(k)) - - return nil - }); err != nil { - return nil, err - } - - // Load the forwarding package for each retrieved height. - fwdPkgs := make([]*FwdPkg, 0, len(heights)) - for _, height := range heights { - fwdPkg, err := loadFwdPkg(fwdPkgBkt, source, height) - if err != nil { - return nil, err - } - - fwdPkgs = append(fwdPkgs, fwdPkg) - } - - return fwdPkgs, nil -} - -// loadFwPkg reads the packager's fwd pkg at a given height, and determines the -// appropriate FwdState. -func loadFwdPkg(fwdPkgBkt *bbolt.Bucket, source lnwire.ShortChannelID, - height uint64) (*FwdPkg, error) { - - sourceKey := makeLogKey(source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - // Load ADDs from disk. - addBkt := heightBkt.Bucket(addBucketKey) - if addBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - adds, err := loadHtlcs(addBkt) - if err != nil { - return nil, err - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - ackFilterReader := bytes.NewReader(ackFilterBytes) - - ackFilter := &PkgFilter{} - if err := ackFilter.Decode(ackFilterReader); err != nil { - return nil, err - } - - // Load SETTLE/FAILs from disk. - failSettleBkt := heightBkt.Bucket(failSettleBucketKey) - if failSettleBkt == nil { - return nil, ErrCorruptedFwdPkg - } - - failSettles, err := loadHtlcs(failSettleBkt) - if err != nil { - return nil, err - } - - // Load settle fail filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return nil, ErrCorruptedFwdPkg - } - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - - settleFailFilter := &PkgFilter{} - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return nil, err - } - - // Initialize the fwding package, which always starts in the - // FwdStateLockedIn. We can determine what state the package was left in - // by examining constraints on the information loaded from disk. - fwdPkg := &FwdPkg{ - Source: source, - State: FwdStateLockedIn, - Height: height, - Adds: adds, - AckFilter: ackFilter, - SettleFails: failSettles, - SettleFailFilter: settleFailFilter, - } - - // Check to see if we have written the set exported filter adds to - // disk. If we haven't, processing of this package was never started, or - // failed during the last attempt. - fwdFilterBytes := heightBkt.Get(fwdFilterKey) - if fwdFilterBytes == nil { - nAdds := uint16(len(adds)) - fwdPkg.FwdFilter = NewPkgFilter(nAdds) - return fwdPkg, nil - } - - fwdFilterReader := bytes.NewReader(fwdFilterBytes) - fwdPkg.FwdFilter = &PkgFilter{} - if err := fwdPkg.FwdFilter.Decode(fwdFilterReader); err != nil { - return nil, err - } - - // Otherwise, a complete round of processing was completed, and we - // advance the package to FwdStateProcessed. - fwdPkg.State = FwdStateProcessed - - // If every add, settle, and fail has been fully acknowledged, we can - // safely set the package's state to FwdStateCompleted, signalling that - // it can be garbage collected. - if fwdPkg.AckFilter.IsFull() && fwdPkg.SettleFailFilter.IsFull() { - fwdPkg.State = FwdStateCompleted - } - - return fwdPkg, nil -} - -// loadHtlcs retrieves all serialized htlcs in a bucket, returning -// them in order of the indexes they were written under. -func loadHtlcs(bkt *bbolt.Bucket) ([]LogUpdate, error) { - var htlcs []LogUpdate - if err := bkt.ForEach(func(_, v []byte) error { - var htlc LogUpdate - if err := htlc.Decode(bytes.NewReader(v)); err != nil { - return err - } - - htlcs = append(htlcs, htlc) - - return nil - }); err != nil { - return nil, err - } - - return htlcs, nil -} - -// SetFwdFilter writes the set of indexes corresponding to Adds at the -// `height` that are to be forwarded to the switch. Calling this method causes -// the forwarding package at `height` to be in FwdStateProcessed. We write this -// forwarding decision so that we always arrive at the same behavior for HTLCs -// leaving this channel. After a restart, we skip validation of these Adds, -// since they are assumed to have already been validated, and make the switch or -// outgoing link responsible for handling replays. -func (p *ChannelPackager) SetFwdFilter(tx *bbolt.Tx, height uint64, - fwdFilter *PkgFilter) error { - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - source := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(source[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - return ErrCorruptedFwdPkg - } - - // If the fwd filter has already been written, we return early to avoid - // modifying the persistent state. - forwardedAddsBytes := heightBkt.Get(fwdFilterKey) - if forwardedAddsBytes != nil { - return nil - } - - // Otherwise we serialize and write the provided fwd filter. - var b bytes.Buffer - if err := fwdFilter.Encode(&b); err != nil { - return err - } - - return heightBkt.Put(fwdFilterKey, b.Bytes()) -} - -// AckAddHtlcs accepts a list of references to add htlcs, and updates the -// AckAddFilter of those forwarding packages to indicate that a settle or fail -// has been received in response to the add. -func (p *ChannelPackager) AckAddHtlcs(tx *bbolt.Tx, addRefs ...AddRef) error { - if len(addRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - sourceKey := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceKey[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique height. - heightDiffs := make(map[uint64][]uint16) - for _, addRef := range addRefs { - heightDiffs[addRef.Height] = append( - heightDiffs[addRef.Height], - addRef.Index, - ) - } - - // Load each height bucket once and remove all acked htlcs at that - // height. - for height, indexes := range heightDiffs { - err := ackAddHtlcsAtHeight(sourceBkt, height, indexes) - if err != nil { - return err - } - } - - return nil -} - -// ackAddHtlcsAtHeight updates the AddAckFilter of a single forwarding package -// with a list of indexes, writing the resulting filter back in its place. -func ackAddHtlcsAtHeight(sourceBkt *bbolt.Bucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := sourceBkt.Bucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is successful, as there is nothing - // to ack. - return nil - } - - // Load ack filter from disk. - ackFilterBytes := heightBkt.Get(ackFilterKey) - if ackFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - ackFilter := &PkgFilter{} - ackFilterReader := bytes.NewReader(ackFilterBytes) - if err := ackFilter.Decode(ackFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - ackFilter.Set(index) - } - - // Write the resulting filter to disk. - var ackFilterBuf bytes.Buffer - if err := ackFilter.Encode(&ackFilterBuf); err != nil { - return err - } - - return heightBkt.Put(ackFilterKey, ackFilterBuf.Bytes()) -} - -// AckSettleFails persistently acknowledges settles or fails from a remote forwarding -// package. This should only be called after the source of the Add has locked in -// the settle/fail, or it becomes otherwise safe to forgo retransmitting the -// settle/fail after a restart. -func (p *ChannelPackager) AckSettleFails(tx *bbolt.Tx, settleFailRefs ...SettleFailRef) error { - return ackSettleFails(tx, settleFailRefs) -} - -// ackSettleFails persistently acknowledges a batch of settle fail references. -func ackSettleFails(tx *bbolt.Tx, settleFailRefs []SettleFailRef) error { - if len(settleFailRefs) == 0 { - return nil - } - - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return ErrCorruptedFwdPkg - } - - // Organize the forward references such that we just get a single slice - // of indexes for each unique destination-height pair. - destHeightDiffs := make(map[lnwire.ShortChannelID]map[uint64][]uint16) - for _, settleFailRef := range settleFailRefs { - destHeights, ok := destHeightDiffs[settleFailRef.Source] - if !ok { - destHeights = make(map[uint64][]uint16) - destHeightDiffs[settleFailRef.Source] = destHeights - } - - destHeights[settleFailRef.Height] = append( - destHeights[settleFailRef.Height], - settleFailRef.Index, - ) - } - - // With the references organized by destination and height, we now load - // each remote bucket, and update the settle fail filter for any - // settle/fail htlcs. - for dest, destHeights := range destHeightDiffs { - destKey := makeLogKey(dest.ToUint64()) - destBkt := fwdPkgBkt.Bucket(destKey[:]) - if destBkt == nil { - // If the destination bucket is not found, this is - // likely the result of the destination channel being - // closed and having it's forwarding packages wiped. We - // won't treat this as an error, because the response - // will no longer be retransmitted internally. - continue - } - - for height, indexes := range destHeights { - err := ackSettleFailsAtHeight(destBkt, height, indexes) - if err != nil { - return err - } - } - } - - return nil -} - -// ackSettleFailsAtHeight given a destination bucket, acks the provided indexes -// at particular a height by updating the settle fail filter. -func ackSettleFailsAtHeight(destBkt *bbolt.Bucket, height uint64, - indexes []uint16) error { - - heightKey := makeLogKey(height) - heightBkt := destBkt.Bucket(heightKey[:]) - if heightBkt == nil { - // If the height bucket isn't found, this could be because the - // forwarding package was already removed. We'll return nil to - // signal that the operation is as there is nothing to ack. - return nil - } - - // Load ack filter from disk. - settleFailFilterBytes := heightBkt.Get(settleFailFilterKey) - if settleFailFilterBytes == nil { - return ErrCorruptedFwdPkg - } - - settleFailFilter := &PkgFilter{} - settleFailFilterReader := bytes.NewReader(settleFailFilterBytes) - if err := settleFailFilter.Decode(settleFailFilterReader); err != nil { - return err - } - - // Update the ack filter for this height. - for _, index := range indexes { - settleFailFilter.Set(index) - } - - // Write the resulting filter to disk. - var settleFailFilterBuf bytes.Buffer - if err := settleFailFilter.Encode(&settleFailFilterBuf); err != nil { - return err - } - - return heightBkt.Put(settleFailFilterKey, settleFailFilterBuf.Bytes()) -} - -// RemovePkg deletes the forwarding package at the given height from the -// packager's source bucket. -func (p *ChannelPackager) RemovePkg(tx *bbolt.Tx, height uint64) error { - fwdPkgBkt := tx.Bucket(fwdPackagesKey) - if fwdPkgBkt == nil { - return nil - } - - sourceBytes := makeLogKey(p.source.ToUint64()) - sourceBkt := fwdPkgBkt.Bucket(sourceBytes[:]) - if sourceBkt == nil { - return ErrCorruptedFwdPkg - } - - heightKey := makeLogKey(height) - - return sourceBkt.DeleteBucket(heightKey[:]) -} - -// uint16Key writes the provided 16-bit unsigned integer to a 2-byte slice. -func uint16Key(i uint16) []byte { - key := make([]byte, 2) - byteOrder.PutUint16(key, i) - return key -} - -// Compile-time constraint to ensure that ChannelPackager implements the public -// FwdPackager interface. -var _ FwdPackager = (*ChannelPackager)(nil) - -// Compile-time constraint to ensure that SwitchPackager implements the public -// FwdOperator interface. -var _ FwdOperator = (*SwitchPackager)(nil) diff --git a/channeldb/migration_01_to_11/forwarding_package_test.go b/channeldb/migration_01_to_11/forwarding_package_test.go deleted file mode 100644 index 1128aad3..00000000 --- a/channeldb/migration_01_to_11/forwarding_package_test.go +++ /dev/null @@ -1,815 +0,0 @@ -package migration_01_to_11_test - -import ( - "bytes" - "io/ioutil" - "path/filepath" - "runtime" - "testing" - - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/channeldb" - "github.com/lightningnetwork/lnd/lnwire" -) - -// TestPkgFilterBruteForce tests the behavior of a pkg filter up to size 1000, -// which is greater than the number of HTLCs we permit on a commitment txn. -// This should encapsulate every potential filter used in practice. -func TestPkgFilterBruteForce(t *testing.T) { - t.Parallel() - - checkPkgFilterRange(t, 1000) -} - -// checkPkgFilterRange verifies the behavior of a pkg filter when doing a linear -// insertion of `high` elements. This is primarily to test that IsFull functions -// properly for all relevant sizes of `high`. -func checkPkgFilterRange(t *testing.T, high int) { - for i := uint16(0); i < uint16(high); i++ { - f := channeldb.NewPkgFilter(i) - - if f.Count() != i { - t.Fatalf("pkg filter count=%d is actually %d", - i, f.Count()) - } - checkPkgFilterEncodeDecode(t, i, f) - - for j := uint16(0); j < i; j++ { - if f.Contains(j) { - t.Fatalf("pkg filter count=%d contains %d "+ - "before being added", i, j) - } - - f.Set(j) - checkPkgFilterEncodeDecode(t, i, f) - - if !f.Contains(j) { - t.Fatalf("pkg filter count=%d missing %d "+ - "after being added", i, j) - } - - if j < i-1 && f.IsFull() { - t.Fatalf("pkg filter count=%d already full", i) - } - } - - if !f.IsFull() { - t.Fatalf("pkg filter count=%d not full", i) - } - checkPkgFilterEncodeDecode(t, i, f) - } -} - -// TestPkgFilterRand uses a random permutation to verify the proper behavior of -// the pkg filter if the entries are not inserted in-order. -func TestPkgFilterRand(t *testing.T) { - t.Parallel() - - checkPkgFilterRand(t, 3, 17) -} - -// checkPkgFilterRand checks the behavior of a pkg filter by randomly inserting -// indices and asserting the invariants. The order in which indices are inserted -// is parameterized by a base `b` coprime to `p`, and using modular -// exponentiation to generate all elements in [1,p). -func checkPkgFilterRand(t *testing.T, b, p uint16) { - f := channeldb.NewPkgFilter(p) - var j = b - for i := uint16(1); i < p; i++ { - if f.Contains(j) { - t.Fatalf("pkg filter contains %d-%d "+ - "before being added", i, j) - } - - f.Set(j) - checkPkgFilterEncodeDecode(t, i, f) - - if !f.Contains(j) { - t.Fatalf("pkg filter missing %d-%d "+ - "after being added", i, j) - } - - if i < p-1 && f.IsFull() { - t.Fatalf("pkg filter %d already full", i) - } - checkPkgFilterEncodeDecode(t, i, f) - - j = (b * j) % p - } - - // Set 0 independently, since it will never be emitted by the generator. - f.Set(0) - checkPkgFilterEncodeDecode(t, p, f) - - if !f.IsFull() { - t.Fatalf("pkg filter count=%d not full", p) - } - checkPkgFilterEncodeDecode(t, p, f) -} - -// checkPkgFilterEncodeDecode tests the serialization of a pkg filter by: -// 1) writing it to a buffer -// 2) verifying the number of bytes written matches the filter's Size() -// 3) reconstructing the filter decoding the bytes -// 4) checking that the two filters are the same according to Equal -func checkPkgFilterEncodeDecode(t *testing.T, i uint16, f *channeldb.PkgFilter) { - var b bytes.Buffer - if err := f.Encode(&b); err != nil { - t.Fatalf("unable to serialize pkg filter: %v", err) - } - - // +2 for uint16 length - size := uint16(len(b.Bytes())) - if size != f.Size() { - t.Fatalf("pkg filter count=%d serialized size differs, "+ - "Size(): %d, len(bytes): %v", i, f.Size(), size) - } - - reader := bytes.NewReader(b.Bytes()) - - f2 := &channeldb.PkgFilter{} - if err := f2.Decode(reader); err != nil { - t.Fatalf("unable to deserialize pkg filter: %v", err) - } - - if !f.Equal(f2) { - t.Fatalf("pkg filter count=%v does is not equal "+ - "after deserialization, want: %v, got %v", - i, f, f2) - } -} - -var ( - chanID = lnwire.NewChanIDFromOutPoint(&wire.OutPoint{}) - - adds = []channeldb.LogUpdate{ - { - LogIndex: 0, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: 1, - Amount: 100, - Expiry: 1000, - PaymentHash: [32]byte{0}, - }, - }, - { - LogIndex: 1, - UpdateMsg: &lnwire.UpdateAddHTLC{ - ChanID: chanID, - ID: 1, - Amount: 101, - Expiry: 1001, - PaymentHash: [32]byte{1}, - }, - }, - } - - settleFails = []channeldb.LogUpdate{ - { - LogIndex: 2, - UpdateMsg: &lnwire.UpdateFulfillHTLC{ - ChanID: chanID, - ID: 0, - PaymentPreimage: [32]byte{0}, - }, - }, - { - LogIndex: 3, - UpdateMsg: &lnwire.UpdateFailHTLC{ - ChanID: chanID, - ID: 1, - Reason: []byte{}, - }, - }, - } -) - -// TestPackagerEmptyFwdPkg checks that the state transitions exhibited by a -// forwarding package that contains no adds, fails or settles. We expect that -// the fwdpkg reaches FwdStateCompleted immediately after writing the forwarding -// decision via SetFwdFilter. -func TestPackagerEmptyFwdPkg(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package with no htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, nil) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. With no HTLCs, - // the ack filter will have no elements, and should always return true. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Now, write the forwarding decision. In this case, its just an empty - // fwd filter. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However with no htlcs, it should leap frog to FwdStateCompleted. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerOnlyAdds checks that the fwdpkg does not reach FwdStateCompleted -// as soon as all the adds in the package have been acked using AckAddHtlcs. -func TestPackagerOnlyAdds(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, nil) - - nAdds := len(adds) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range adds { - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However not allf of the HTLCs have been acked, so should not - // have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all adds have been - // acked, the ack filter should return true and the package should be - // FwdStateCompleted since there are no other settle/fail packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, 0) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerOnlySettleFails asserts that the fwdpkg remains in -// FwdStateProcessed after writing the forwarding decision when there are no -// adds in the fwdpkg. We expect this because an empty FwdFilter will always -// return true, but we are still waiting for the remaining fails and settles to -// be deleted. -func TestPackagerOnlySettleFails(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, nil, settleFails) - - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However, not all of the HTLCs have been - // acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], 0, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerAddsThenSettleFails writes a fwdpkg containing both adds and -// settle/fails, then checks the behavior when the adds are acked before any of -// the settle fails. Here we expect pkg to remain in FwdStateProcessed while the -// remainder of the fail/settles are being deleted. -func TestPackagerAddsThenSettleFails(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that only has add - // htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) - - nAdds := len(adds) - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - for i := range adds { - // We should still have one package on disk. Since the forwarding - // decision has been written, it will minimally be in FwdStateProcessed. - // However not allf of the HTLCs have been acked, so should not - // have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However not allf of the HTLCs have been - // acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to remove settle/fail htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// TestPackagerSettleFailsThenAdds writes a fwdpkg with both adds and -// settle/fails, then checks the behavior when the settle/fails are removed -// before any of the adds have been acked. This should cause the fwdpkg to -// remain in FwdStateProcessed until the final ack is recorded, at which point -// it should be promoted directly to FwdStateCompleted.since all adds have been -// removed. -func TestPackagerSettleFailsThenAdds(t *testing.T) { - t.Parallel() - - db := makeFwdPkgDB(t, "") - - shortChanID := lnwire.NewShortChanIDFromInt(1) - packager := channeldb.NewChannelPackager(shortChanID) - - // To begin, there should be no forwarding packages on disk. - fwdPkgs := loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } - - // Next, create and write a new forwarding package that has both add - // and settle/fail htlcs. - fwdPkg := channeldb.NewFwdPkg(shortChanID, 0, adds, settleFails) - - nAdds := len(adds) - nSettleFails := len(settleFails) - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AddFwdPkg(tx, fwdPkg) - }); err != nil { - t.Fatalf("unable to add fwd pkg: %v", err) - } - - // There should now be one fwdpkg on disk. Since no forwarding decision - // has been written, we expect it to be FwdStateLockedIn. The package - // has unacked add HTLCs, so the ack filter should not be full. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateLockedIn) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - // Now, write the forwarding decision. Since we have not explicitly - // added any adds to the fwdfilter, this would indicate that all of the - // adds were 1) settled locally by this link (exit hop), or 2) the htlc - // was failed locally. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.SetFwdFilter(tx, fwdPkg.Height, fwdPkg.FwdFilter) - }); err != nil { - t.Fatalf("unable to set fwdfiter: %v", err) - } - - // Simulate another channel deleting the settle/fails it received from - // the original fwd pkg. - // TODO(conner): use different packager/s? - for i := range settleFails { - // We should still have one package on disk. Since the - // forwarding decision has been written, it will minimally be in - // FwdStateProcessed. However none all of the add HTLCs have - // been acked, so should not have advanced further. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], false) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - failSettleRef := channeldb.SettleFailRef{ - Source: shortChanID, - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckSettleFails(tx, failSettleRef) - }); err != nil { - t.Fatalf("unable to remove settle/fail htlc: %v", err) - } - } - - // Now simulate this channel receiving a fail/settle for the adds in the - // fwdpkg. - for i := range adds { - // Again, we should still have one package on disk and be in - // FwdStateProcessed. This should not change until all of the - // add htlcs have been acked. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateProcessed) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], false) - - addRef := channeldb.AddRef{ - Height: fwdPkg.Height, - Index: uint16(i), - } - - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.AckAddHtlcs(tx, addRef) - }); err != nil { - t.Fatalf("unable to ack add htlc: %v", err) - } - } - - // We should still have one package on disk. Now that all settles and - // fails have been removed, package should be FwdStateCompleted since - // there are no other add packets. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 1 { - t.Fatalf("expected 1 fwdpkg, instead found %d", len(fwdPkgs)) - } - assertFwdPkgState(t, fwdPkgs[0], channeldb.FwdStateCompleted) - assertFwdPkgNumAddsSettleFails(t, fwdPkgs[0], nAdds, nSettleFails) - assertSettleFailFilterIsFull(t, fwdPkgs[0], true) - assertAckFilterIsFull(t, fwdPkgs[0], true) - - // Lastly, remove the completed forwarding package from disk. - if err := db.Update(func(tx *bbolt.Tx) error { - return packager.RemovePkg(tx, fwdPkg.Height) - }); err != nil { - t.Fatalf("unable to remove fwdpkg: %v", err) - } - - // Check that the fwd package was actually removed. - fwdPkgs = loadFwdPkgs(t, db, packager) - if len(fwdPkgs) != 0 { - t.Fatalf("no forwarding packages should exist, found %d", len(fwdPkgs)) - } -} - -// assertFwdPkgState checks the current state of a fwdpkg meets our -// expectations. -func assertFwdPkgState(t *testing.T, fwdPkg *channeldb.FwdPkg, - state channeldb.FwdState) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.State != state { - t.Fatalf("line %d: expected fwdpkg in state %v, found %v", - line, state, fwdPkg.State) - } -} - -// assertFwdPkgNumAddsSettleFails checks that the number of adds and -// settle/fail log updates are correct. -func assertFwdPkgNumAddsSettleFails(t *testing.T, fwdPkg *channeldb.FwdPkg, - expectedNumAdds, expectedNumSettleFails int) { - _, _, line, _ := runtime.Caller(1) - if len(fwdPkg.Adds) != expectedNumAdds { - t.Fatalf("line %d: expected fwdpkg to have %d adds, found %d", - line, expectedNumAdds, len(fwdPkg.Adds)) - } - - if len(fwdPkg.SettleFails) != expectedNumSettleFails { - t.Fatalf("line %d: expected fwdpkg to have %d settle/fails, found %d", - line, expectedNumSettleFails, len(fwdPkg.SettleFails)) - } -} - -// assertAckFilterIsFull checks whether or not a fwdpkg's ack filter matches our -// expected full-ness. -func assertAckFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.AckFilter.IsFull() != expected { - t.Fatalf("line %d: expected fwdpkg ack filter IsFull to be %v, "+ - "found %v", line, expected, fwdPkg.AckFilter.IsFull()) - } -} - -// assertSettleFailFilterIsFull checks whether or not a fwdpkg's settle fail -// filter matches our expected full-ness. -func assertSettleFailFilterIsFull(t *testing.T, fwdPkg *channeldb.FwdPkg, expected bool) { - _, _, line, _ := runtime.Caller(1) - if fwdPkg.SettleFailFilter.IsFull() != expected { - t.Fatalf("line %d: expected fwdpkg settle/fail filter IsFull to be %v, "+ - "found %v", line, expected, fwdPkg.SettleFailFilter.IsFull()) - } -} - -// loadFwdPkgs is a helper method that reads all forwarding packages for a -// particular packager. -func loadFwdPkgs(t *testing.T, db *bbolt.DB, - packager channeldb.FwdPackager) []*channeldb.FwdPkg { - - var fwdPkgs []*channeldb.FwdPkg - if err := db.View(func(tx *bbolt.Tx) error { - var err error - fwdPkgs, err = packager.LoadFwdPkgs(tx) - return err - }); err != nil { - t.Fatalf("unable to load fwd pkgs: %v", err) - } - - return fwdPkgs -} - -// makeFwdPkgDB initializes a test database for forwarding packages. If the -// provided path is an empty, it will create a temp dir/file to use. -func makeFwdPkgDB(t *testing.T, path string) *bbolt.DB { - if path == "" { - var err error - path, err = ioutil.TempDir("", "fwdpkgdb") - if err != nil { - t.Fatalf("unable to create temp path: %v", err) - } - - path = filepath.Join(path, "fwdpkg.db") - } - - db, err := bbolt.Open(path, 0600, nil) - if err != nil { - t.Fatalf("unable to open boltdb: %v", err) - } - - return db -} diff --git a/channeldb/migration_01_to_11/graph.go b/channeldb/migration_01_to_11/graph.go index d90863c6..8e8f4a4a 100644 --- a/channeldb/migration_01_to_11/graph.go +++ b/channeldb/migration_01_to_11/graph.go @@ -2,20 +2,15 @@ package migration_01_to_11 import ( "bytes" - "crypto/sha256" "encoding/binary" - "errors" "fmt" "image/color" "io" - "math" "net" - "sync" "time" "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/coreos/bbolt" @@ -74,11 +69,6 @@ var ( // lookup of incoming channel edges. unknownPolicy = []byte{} - // chanStart is an array of all zero bytes which is used to perform - // range scans within the edgeBucket to obtain all of the outgoing - // edges for a particular node. - chanStart [8]byte - // edgeIndexBucket is an index which can be used to iterate all edges // in the bucket, grouping them according to their in/out nodes. // Additionally, the items in this bucket also contain the complete @@ -155,9 +145,6 @@ const ( // would be possible for a node to create a ton of updates and slowly // fill our disk, and also waste bandwidth due to relaying. MaxAllowedExtraOpaqueBytes = 10000 - - // feeRateParts is the total number of parts used to express fee rates. - feeRateParts = 1e6 ) // ChannelGraph is a persistent, on-disk graph representation of the Lightning @@ -172,200 +159,16 @@ const ( // for that edge. type ChannelGraph struct { db *DB - - cacheMu sync.RWMutex - rejectCache *rejectCache - chanCache *channelCache } // newChannelGraph allocates a new ChannelGraph backed by a DB instance. The // returned instance has its own unique reject cache and channel cache. func newChannelGraph(db *DB, rejectCacheSize, chanCacheSize int) *ChannelGraph { return &ChannelGraph{ - db: db, - rejectCache: newRejectCache(rejectCacheSize), - chanCache: newChannelCache(chanCacheSize), + db: db, } } -// Database returns a pointer to the underlying database. -func (c *ChannelGraph) Database() *DB { - return c.db -} - -// ForEachChannel iterates through all the channel edges stored within the -// graph and invokes the passed callback for each edge. The callback takes two -// edges as since this is a directed graph, both the in/out edges are visited. -// If the callback returns an error, then the transaction is aborted and the -// iteration stops early. -// -// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer -// for that particular channel edge routing policy will be passed into the -// callback. -func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates - - return c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // For each edge pair within the edge index, we fetch each edge - // itself and also the node information in order to fully - // populated the object. - return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - infoReader := bytes.NewReader(edgeInfoBytes) - edgeInfo, err := deserializeChanEdgeInfo(infoReader) - if err != nil { - return err - } - edgeInfo.db = c.db - - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - return err - } - - // With both edges read, execute the call back. IF this - // function returns an error then the transaction will - // be aborted. - return cb(&edgeInfo, edge1, edge2) - }) - }) -} - -// ForEachNodeChannel iterates through all channels of a given node, executing the -// passed callback with an edge info structure and the policies of each end -// of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the -// connecting node. If the callback returns an error, then the iteration is -// halted with the error propagated back up to the caller. -// -// Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (c *ChannelGraph) ForEachNodeChannel(tx *bbolt.Tx, nodePub []byte, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error) error { - - db := c.db - - return nodeTraversal(tx, nodePub, db, cb) -} - -// DisabledChannelIDs returns the channel ids of disabled channels. -// A channel is disabled when two of the associated ChanelEdgePolicies -// have their disabled bit on. -func (c *ChannelGraph) DisabledChannelIDs() ([]uint64, error) { - var disabledChanIDs []uint64 - chanEdgeFound := make(map[uint64]struct{}) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - - disabledEdgePolicyIndex := edges.Bucket(disabledEdgePolicyBucket) - if disabledEdgePolicyIndex == nil { - return nil - } - - // We iterate over all disabled policies and we add each channel that - // has more than one disabled policy to disabledChanIDs array. - return disabledEdgePolicyIndex.ForEach(func(k, v []byte) error { - chanID := byteOrder.Uint64(k[:8]) - _, edgeFound := chanEdgeFound[chanID] - if edgeFound { - delete(chanEdgeFound, chanID) - disabledChanIDs = append(disabledChanIDs, chanID) - return nil - } - - chanEdgeFound[chanID] = struct{}{} - return nil - }) - }) - if err != nil { - return nil, err - } - - return disabledChanIDs, nil -} - -// ForEachNode iterates through all the stored vertices/nodes in the graph, -// executing the passed callback with each node encountered. If the callback -// returns an error, then the transaction is aborted and the iteration stops -// early. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal -// -// TODO(roasbeef): add iterator interface to allow for memory efficient graph -// traversal when graph gets mega -func (c *ChannelGraph) ForEachNode(tx *bbolt.Tx, cb func(*bbolt.Tx, *LightningNode) error) error { - traversal := func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - return nodes.ForEach(func(pubKey, nodeBytes []byte) error { - // If this is the source key, then we skip this - // iteration as the value for this key is a pubKey - // rather than raw node information. - if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { - return nil - } - - nodeReader := bytes.NewReader(nodeBytes) - node, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - node.db = c.db - - // Execute the callback, the transaction will abort if - // this returns an error. - return cb(tx, &node) - }) - } - - // If no transaction was provided, then we'll create a new transaction - // to execute the transaction within. - if tx == nil { - return c.db.View(traversal) - } - - // Otherwise, we re-use the existing transaction to execute the graph - // traversal. - return traversal(tx) -} - // SourceNode returns the source node of the graph. The source node is treated // as the center node within a star-graph. This method may be used to kick off // a path finding algorithm in order to explore the reachability of another @@ -442,20 +245,6 @@ func (c *ChannelGraph) SetSourceNode(node *LightningNode) error { }) } -// AddLightningNode adds a vertex/node to the graph database. If the node is not -// in the database from before, this will add a new, unconnected one to the -// graph. If it is present from before, this will update that node's -// information. Note that this method is expected to only be called to update -// an already present node from a node announcement, or to insert a node found -// in a channel update. -// -// TODO(roasbeef): also need sig of announcement -func (c *ChannelGraph) AddLightningNode(node *LightningNode) error { - return c.db.Update(func(tx *bbolt.Tx) error { - return addLightningNode(tx, node) - }) -} - func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { nodes, err := tx.CreateBucketIfNotExists(nodeBucket) if err != nil { @@ -477,1487 +266,6 @@ func addLightningNode(tx *bbolt.Tx, node *LightningNode) error { return putLightningNode(nodes, aliases, updateIndex, node) } -// LookupAlias attempts to return the alias as advertised by the target node. -// TODO(roasbeef): currently assumes that aliases are unique... -func (c *ChannelGraph) LookupAlias(pub *btcec.PublicKey) (string, error) { - var alias string - - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - aliases := nodes.Bucket(aliasIndexBucket) - if aliases == nil { - return ErrGraphNodesNotFound - } - - nodePub := pub.SerializeCompressed() - a := aliases.Get(nodePub) - if a == nil { - return ErrNodeAliasNotFound - } - - // TODO(roasbeef): should actually be using the utf-8 - // package... - alias = string(a) - return nil - }) - if err != nil { - return "", err - } - - return alias, nil -} - -// DeleteLightningNode starts a new database transaction to remove a vertex/node -// from the database according to the node's public key. -func (c *ChannelGraph) DeleteLightningNode(nodePub *btcec.PublicKey) error { - // TODO(roasbeef): ensure dangling edges are removed... - return c.db.Update(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - - return c.deleteLightningNode( - nodes, nodePub.SerializeCompressed(), - ) - }) -} - -// deleteLightningNode uses an existing database transaction to remove a -// vertex/node from the database according to the node's public key. -func (c *ChannelGraph) deleteLightningNode(nodes *bbolt.Bucket, - compressedPubKey []byte) error { - - aliases := nodes.Bucket(aliasIndexBucket) - if aliases == nil { - return ErrGraphNodesNotFound - } - - if err := aliases.Delete(compressedPubKey); err != nil { - return err - } - - // Before we delete the node, we'll fetch its current state so we can - // determine when its last update was to clear out the node update - // index. - node, err := fetchLightningNode(nodes, compressedPubKey) - if err != nil { - return err - } - - if err := nodes.Delete(compressedPubKey); err != nil { - - return err - } - - // Finally, we'll delete the index entry for the node within the - // nodeUpdateIndexBucket as this node is no longer active, so we don't - // need to track its last update. - nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) - if nodeUpdateIndex == nil { - return ErrGraphNodesNotFound - } - - // In order to delete the entry, we'll need to reconstruct the key for - // its last update. - updateUnix := uint64(node.LastUpdate.Unix()) - var indexKey [8 + 33]byte - byteOrder.PutUint64(indexKey[:8], updateUnix) - copy(indexKey[8:], compressedPubKey) - - return nodeUpdateIndex.Delete(indexKey[:]) -} - -// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An -// undirected edge from the two target nodes are created. The information -// stored denotes the static attributes of the channel, such as the channelID, -// the keys involved in creation of the channel, and the set of features that -// the channel supports. The chanPoint and chanID are used to uniquely identify -// the edge globally within the database. -func (c *ChannelGraph) AddChannelEdge(edge *ChannelEdgeInfo) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - return c.addChannelEdge(tx, edge) - }) - if err != nil { - return err - } - - c.rejectCache.remove(edge.ChannelID) - c.chanCache.remove(edge.ChannelID) - - return nil -} - -// addChannelEdge is the private form of AddChannelEdge that allows callers to -// utilize an existing db transaction. -func (c *ChannelGraph) addChannelEdge(tx *bbolt.Tx, edge *ChannelEdgeInfo) error { - // Construct the channel's primary key which is the 8-byte channel ID. - var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) - - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - - // First, attempt to check if this edge has already been created. If - // so, then we can exit early as this method is meant to be idempotent. - if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo != nil { - return ErrEdgeAlreadyExist - } - - // Before we insert the channel into the database, we'll ensure that - // both nodes already exist in the channel graph. If either node - // doesn't, then we'll insert a "shell" node that just includes its - // public key, so subsequent validation and queries can work properly. - _, node1Err := fetchLightningNode(nodes, edge.NodeKey1Bytes[:]) - switch { - case node1Err == ErrGraphNodeNotFound: - node1Shell := LightningNode{ - PubKeyBytes: edge.NodeKey1Bytes, - HaveNodeAnnouncement: false, - } - err := addLightningNode(tx, &node1Shell) - if err != nil { - return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey1Bytes) - - } - case node1Err != nil: - return err - } - - _, node2Err := fetchLightningNode(nodes, edge.NodeKey2Bytes[:]) - switch { - case node2Err == ErrGraphNodeNotFound: - node2Shell := LightningNode{ - PubKeyBytes: edge.NodeKey2Bytes, - HaveNodeAnnouncement: false, - } - err := addLightningNode(tx, &node2Shell) - if err != nil { - return fmt.Errorf("unable to create shell node "+ - "for: %x", edge.NodeKey2Bytes) - - } - case node2Err != nil: - return err - } - - // If the edge hasn't been created yet, then we'll first add it to the - // edge index in order to associate the edge between two nodes and also - // store the static components of the channel. - if err := putChanEdgeInfo(edgeIndex, edge, chanKey); err != nil { - return err - } - - // Mark edge policies for both sides as unknown. This is to enable - // efficient incoming channel lookup for a node. - for _, key := range []*[33]byte{&edge.NodeKey1Bytes, - &edge.NodeKey2Bytes} { - - err := putChanEdgePolicyUnknown(edges, edge.ChannelID, - key[:]) - if err != nil { - return err - } - } - - // Finally we add it to the channel index which maps channel points - // (outpoints) to the shorter channel ID's. - var b bytes.Buffer - if err := writeOutpoint(&b, &edge.ChannelPoint); err != nil { - return err - } - return chanIndex.Put(b.Bytes(), chanKey[:]) -} - -// HasChannelEdge returns true if the database knows of a channel edge with the -// passed channel ID, and false otherwise. If an edge with that ID is found -// within the graph, then two time stamps representing the last time the edge -// was updated for both directed edges are returned along with the boolean. If -// it is not found, then the zombie index is checked and its result is returned -// as the second boolean. -func (c *ChannelGraph) HasChannelEdge( - chanID uint64) (time.Time, time.Time, bool, bool, error) { - - var ( - upd1Time time.Time - upd2Time time.Time - exists bool - isZombie bool - ) - - // We'll query the cache with the shared lock held to allow multiple - // readers to access values in the cache concurrently if they exist. - c.cacheMu.RLock() - if entry, ok := c.rejectCache.get(chanID); ok { - c.cacheMu.RUnlock() - upd1Time = time.Unix(entry.upd1Time, 0) - upd2Time = time.Unix(entry.upd2Time, 0) - exists, isZombie = entry.flags.unpack() - return upd1Time, upd2Time, exists, isZombie, nil - } - c.cacheMu.RUnlock() - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - // The item was not found with the shared lock, so we'll acquire the - // exclusive lock and check the cache again in case another method added - // the entry to the cache while no lock was held. - if entry, ok := c.rejectCache.get(chanID); ok { - upd1Time = time.Unix(entry.upd1Time, 0) - upd2Time = time.Unix(entry.upd2Time, 0) - exists, isZombie = entry.flags.unpack() - return upd1Time, upd2Time, exists, isZombie, nil - } - - if err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - var channelID [8]byte - byteOrder.PutUint64(channelID[:], chanID) - - // If the edge doesn't exist, then we'll also check our zombie - // index. - if edgeIndex.Get(channelID[:]) == nil { - exists = false - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex != nil { - isZombie, _, _ = isZombieEdge( - zombieIndex, chanID, - ) - } - - return nil - } - - exists = true - isZombie = false - - // If the channel has been found in the graph, then retrieve - // the edges itself so we can return the last updated - // timestamps. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - - e1, e2, err := fetchChanEdgePolicies(edgeIndex, edges, nodes, - channelID[:], c.db) - if err != nil { - return err - } - - // As we may have only one of the edges populated, only set the - // update time if the edge was found in the database. - if e1 != nil { - upd1Time = e1.LastUpdate - } - if e2 != nil { - upd2Time = e2.LastUpdate - } - - return nil - }); err != nil { - return time.Time{}, time.Time{}, exists, isZombie, err - } - - c.rejectCache.insert(chanID, rejectCacheEntry{ - upd1Time: upd1Time.Unix(), - upd2Time: upd2Time.Unix(), - flags: packRejectFlags(exists, isZombie), - }) - - return upd1Time, upd2Time, exists, isZombie, nil -} - -// UpdateChannelEdge retrieves and update edge of the graph database. Method -// only reserved for updating an edge info after its already been created. -// In order to maintain this constraints, we return an error in the scenario -// that an edge info hasn't yet been created yet, but someone attempts to update -// it. -func (c *ChannelGraph) UpdateChannelEdge(edge *ChannelEdgeInfo) error { - // Construct the channel's primary key which is the 8-byte channel ID. - var chanKey [8]byte - binary.BigEndian.PutUint64(chanKey[:], edge.ChannelID) - - return c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edge == nil { - return ErrEdgeNotFound - } - - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - - if edgeInfo := edgeIndex.Get(chanKey[:]); edgeInfo == nil { - return ErrEdgeNotFound - } - - return putChanEdgeInfo(edgeIndex, edge, chanKey) - }) -} - -const ( - // pruneTipBytes is the total size of the value which stores a prune - // entry of the graph in the prune log. The "prune tip" is the last - // entry in the prune log, and indicates if the channel graph is in - // sync with the current UTXO state. The structure of the value - // is: blockHash, taking 32 bytes total. - pruneTipBytes = 32 -) - -// PruneGraph prunes newly closed channels from the channel graph in response -// to a new block being solved on the network. Any transactions which spend the -// funding output of any known channels within he graph will be deleted. -// Additionally, the "prune tip", or the last block which has been used to -// prune the graph is stored so callers can ensure the graph is fully in sync -// with the current UTXO state. A slice of channels that have been closed by -// the target block are returned if the function succeeds without error. -func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, - blockHash *chainhash.Hash, blockHeight uint32) ([]*ChannelEdgeInfo, error) { - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var chansClosed []*ChannelEdgeInfo - - err := c.db.Update(func(tx *bbolt.Tx) error { - // First grab the edges bucket which houses the information - // we'd like to delete - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - - // Next grab the two edge indexes which will also need to be updated. - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrSourceNodeNotSet - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - - // For each of the outpoints that have been spent within the - // block, we attempt to delete them from the graph as if that - // outpoint was a channel, then it has now been closed. - for _, chanPoint := range spentOutputs { - // TODO(roasbeef): load channel bloom filter, continue - // if NOT if filter - - var opBytes bytes.Buffer - if err := writeOutpoint(&opBytes, chanPoint); err != nil { - return err - } - - // First attempt to see if the channel exists within - // the database, if not, then we can exit early. - chanID := chanIndex.Get(opBytes.Bytes()) - if chanID == nil { - continue - } - - // However, if it does, then we'll read out the full - // version so we can add it to the set of deleted - // channels. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - - // Attempt to delete the channel, an ErrEdgeNotFound - // will be returned if that outpoint isn't known to be - // a channel. If no error is returned, then a channel - // was successfully pruned. - err = delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - chanID, false, - ) - if err != nil && err != ErrEdgeNotFound { - return err - } - - chansClosed = append(chansClosed, &edgeInfo) - } - - metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) - if err != nil { - return err - } - - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) - if err != nil { - return err - } - - // With the graph pruned, add a new entry to the prune log, - // which can be used to check if the graph is fully synced with - // the current UTXO state. - var blockHeightBytes [4]byte - byteOrder.PutUint32(blockHeightBytes[:], blockHeight) - - var newTip [pruneTipBytes]byte - copy(newTip[:], blockHash[:]) - - err = pruneBucket.Put(blockHeightBytes[:], newTip[:]) - if err != nil { - return err - } - - // Now that the graph has been pruned, we'll also attempt to - // prune any nodes that have had a channel closed within the - // latest block. - return c.pruneGraphNodes(nodes, edgeIndex) - }) - if err != nil { - return nil, err - } - - for _, channel := range chansClosed { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) - } - - return chansClosed, nil -} - -// PruneGraphNodes is a garbage collection method which attempts to prune out -// any nodes from the channel graph that are currently unconnected. This ensure -// that we only maintain a graph of reachable nodes. In the event that a pruned -// node gains more channels, it will be re-added back to the graph. -func (c *ChannelGraph) PruneGraphNodes() error { - return c.db.Update(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - return c.pruneGraphNodes(nodes, edgeIndex) - }) -} - -// pruneGraphNodes attempts to remove any nodes from the graph who have had a -// channel closed within the current block. If the node still has existing -// channels in the graph, this will act as a no-op. -func (c *ChannelGraph) pruneGraphNodes(nodes *bbolt.Bucket, - edgeIndex *bbolt.Bucket) error { - - log.Trace("Pruning nodes from graph with no open channels") - - // We'll retrieve the graph's source node to ensure we don't remove it - // even if it no longer has any open channels. - sourceNode, err := c.sourceNode(nodes) - if err != nil { - return err - } - - // We'll use this map to keep count the number of references to a node - // in the graph. A node should only be removed once it has no more - // references in the graph. - nodeRefCounts := make(map[[33]byte]int) - err = nodes.ForEach(func(pubKey, nodeBytes []byte) error { - // If this is the source key, then we skip this - // iteration as the value for this key is a pubKey - // rather than raw node information. - if bytes.Equal(pubKey, sourceKey) || len(pubKey) != 33 { - return nil - } - - var nodePub [33]byte - copy(nodePub[:], pubKey) - nodeRefCounts[nodePub] = 0 - - return nil - }) - if err != nil { - return err - } - - // To ensure we never delete the source node, we'll start off by - // bumping its ref count to 1. - nodeRefCounts[sourceNode.PubKeyBytes] = 1 - - // Next, we'll run through the edgeIndex which maps a channel ID to the - // edge info. We'll use this scan to populate our reference count map - // above. - err = edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - // The first 66 bytes of the edge info contain the pubkeys of - // the nodes that this edge attaches. We'll extract them, and - // add them to the ref count map. - var node1, node2 [33]byte - copy(node1[:], edgeInfoBytes[:33]) - copy(node2[:], edgeInfoBytes[33:]) - - // With the nodes extracted, we'll increase the ref count of - // each of the nodes. - nodeRefCounts[node1]++ - nodeRefCounts[node2]++ - - return nil - }) - if err != nil { - return err - } - - // Finally, we'll make a second pass over the set of nodes, and delete - // any nodes that have a ref count of zero. - var numNodesPruned int - for nodePubKey, refCount := range nodeRefCounts { - // If the ref count of the node isn't zero, then we can safely - // skip it as it still has edges to or from it within the - // graph. - if refCount != 0 { - continue - } - - // If we reach this point, then there are no longer any edges - // that connect this node, so we can delete it. - if err := c.deleteLightningNode(nodes, nodePubKey[:]); err != nil { - log.Warnf("Unable to prune node %x from the "+ - "graph: %v", nodePubKey, err) - continue - } - - log.Infof("Pruned unconnected node %x from channel graph", - nodePubKey[:]) - - numNodesPruned++ - } - - if numNodesPruned > 0 { - log.Infof("Pruned %v unconnected nodes from the channel graph", - numNodesPruned) - } - - return nil -} - -// DisconnectBlockAtHeight is used to indicate that the block specified -// by the passed height has been disconnected from the main chain. This -// will "rewind" the graph back to the height below, deleting channels -// that are no longer confirmed from the graph. The prune log will be -// set to the last prune height valid for the remaining chain. -// Channels that were removed from the graph resulting from the -// disconnected block are returned. -func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ([]*ChannelEdgeInfo, - error) { - - // Every channel having a ShortChannelID starting at 'height' - // will no longer be confirmed. - startShortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - } - - // Delete everything after this height from the db. - endShortChanID := lnwire.ShortChannelID{ - BlockHeight: math.MaxUint32 & 0x00ffffff, - TxIndex: math.MaxUint32 & 0x00ffffff, - TxPosition: math.MaxUint16, - } - // The block height will be the 3 first bytes of the channel IDs. - var chanIDStart [8]byte - byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) - var chanIDEnd [8]byte - byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - // Keep track of the channels that are removed from the graph. - var removedChans []*ChannelEdgeInfo - - if err := c.db.Update(func(tx *bbolt.Tx) error { - edges, err := tx.CreateBucketIfNotExists(edgeBucket) - if err != nil { - return err - } - edgeIndex, err := edges.CreateBucketIfNotExists(edgeIndexBucket) - if err != nil { - return err - } - chanIndex, err := edges.CreateBucketIfNotExists(channelPointBucket) - if err != nil { - return err - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - nodes, err := tx.CreateBucketIfNotExists(nodeBucket) - if err != nil { - return err - } - - // Scan from chanIDStart to chanIDEnd, deleting every - // found edge. - // NOTE: we must delete the edges after the cursor loop, since - // modifying the bucket while traversing is not safe. - var keys [][]byte - cursor := edgeIndex.Cursor() - for k, v := cursor.Seek(chanIDStart[:]); k != nil && - bytes.Compare(k, chanIDEnd[:]) <= 0; k, v = cursor.Next() { - - edgeInfoReader := bytes.NewReader(v) - edgeInfo, err := deserializeChanEdgeInfo(edgeInfoReader) - if err != nil { - return err - } - - keys = append(keys, k) - removedChans = append(removedChans, &edgeInfo) - } - - for _, k := range keys { - err = delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - k, false, - ) - if err != nil && err != ErrEdgeNotFound { - return err - } - } - - // Delete all the entries in the prune log having a height - // greater or equal to the block disconnected. - metaBucket, err := tx.CreateBucketIfNotExists(graphMetaBucket) - if err != nil { - return err - } - - pruneBucket, err := metaBucket.CreateBucketIfNotExists(pruneLogBucket) - if err != nil { - return err - } - - var pruneKeyStart [4]byte - byteOrder.PutUint32(pruneKeyStart[:], height) - - var pruneKeyEnd [4]byte - byteOrder.PutUint32(pruneKeyEnd[:], math.MaxUint32) - - // To avoid modifying the bucket while traversing, we delete - // the keys in a second loop. - var pruneKeys [][]byte - pruneCursor := pruneBucket.Cursor() - for k, _ := pruneCursor.Seek(pruneKeyStart[:]); k != nil && - bytes.Compare(k, pruneKeyEnd[:]) <= 0; k, _ = pruneCursor.Next() { - - pruneKeys = append(pruneKeys, k) - } - - for _, k := range pruneKeys { - if err := pruneBucket.Delete(k); err != nil { - return err - } - } - - return nil - }); err != nil { - return nil, err - } - - for _, channel := range removedChans { - c.rejectCache.remove(channel.ChannelID) - c.chanCache.remove(channel.ChannelID) - } - - return removedChans, nil -} - -// PruneTip returns the block height and hash of the latest block that has been -// used to prune channels in the graph. Knowing the "prune tip" allows callers -// to tell if the graph is currently in sync with the current best known UTXO -// state. -func (c *ChannelGraph) PruneTip() (*chainhash.Hash, uint32, error) { - var ( - tipHash chainhash.Hash - tipHeight uint32 - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - graphMeta := tx.Bucket(graphMetaBucket) - if graphMeta == nil { - return ErrGraphNotFound - } - pruneBucket := graphMeta.Bucket(pruneLogBucket) - if pruneBucket == nil { - return ErrGraphNeverPruned - } - - pruneCursor := pruneBucket.Cursor() - - // The prune key with the largest block height will be our - // prune tip. - k, v := pruneCursor.Last() - if k == nil { - return ErrGraphNeverPruned - } - - // Once we have the prune tip, the value will be the block hash, - // and the key the block height. - copy(tipHash[:], v[:]) - tipHeight = byteOrder.Uint32(k[:]) - - return nil - }) - if err != nil { - return nil, 0, err - } - - return &tipHash, tipHeight, nil -} - -// DeleteChannelEdges removes edges with the given channel IDs from the database -// and marks them as zombies. This ensures that we're unable to re-add it to our -// database once again. If an edge does not exist within the database, then -// ErrEdgeNotFound will be returned. -func (c *ChannelGraph) DeleteChannelEdges(chanIDs ...uint64) error { - // TODO(roasbeef): possibly delete from node bucket if node has no more - // channels - // TODO(roasbeef): don't delete both edges? - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrEdgeNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrEdgeNotFound - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodeNotFound - } - zombieIndex, err := edges.CreateBucketIfNotExists(zombieBucket) - if err != nil { - return err - } - - var rawChanID [8]byte - for _, chanID := range chanIDs { - byteOrder.PutUint64(rawChanID[:], chanID) - err := delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, - rawChanID[:], true, - ) - if err != nil { - return err - } - } - - return nil - }) - if err != nil { - return err - } - - for _, chanID := range chanIDs { - c.rejectCache.remove(chanID) - c.chanCache.remove(chanID) - } - - return nil -} - -// ChannelID attempt to lookup the 8-byte compact channel ID which maps to the -// passed channel point (outpoint). If the passed channel doesn't exist within -// the database, then ErrEdgeNotFound is returned. -func (c *ChannelGraph) ChannelID(chanPoint *wire.OutPoint) (uint64, error) { - var chanID uint64 - if err := c.db.View(func(tx *bbolt.Tx) error { - var err error - chanID, err = getChanID(tx, chanPoint) - return err - }); err != nil { - return 0, err - } - - return chanID, nil -} - -// getChanID returns the assigned channel ID for a given channel point. -func getChanID(tx *bbolt.Tx, chanPoint *wire.OutPoint) (uint64, error) { - var b bytes.Buffer - if err := writeOutpoint(&b, chanPoint); err != nil { - return 0, err - } - - edges := tx.Bucket(edgeBucket) - if edges == nil { - return 0, ErrGraphNoEdgesFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return 0, ErrGraphNoEdgesFound - } - - chanIDBytes := chanIndex.Get(b.Bytes()) - if chanIDBytes == nil { - return 0, ErrEdgeNotFound - } - - chanID := byteOrder.Uint64(chanIDBytes) - - return chanID, nil -} - -// TODO(roasbeef): allow updates to use Batch? - -// HighestChanID returns the "highest" known channel ID in the channel graph. -// This represents the "newest" channel from the PoV of the chain. This method -// can be used by peers to quickly determine if they're graphs are in sync. -func (c *ChannelGraph) HighestChanID() (uint64, error) { - var cid uint64 - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // In order to find the highest chan ID, we'll fetch a cursor - // and use that to seek to the "end" of our known rage. - cidCursor := edgeIndex.Cursor() - - lastChanID, _ := cidCursor.Last() - - // If there's no key, then this means that we don't actually - // know of any channels, so we'll return a predicable error. - if lastChanID == nil { - return ErrGraphNoEdgesFound - } - - // Otherwise, we'll de serialize the channel ID and return it - // to the caller. - cid = byteOrder.Uint64(lastChanID) - return nil - }) - if err != nil && err != ErrGraphNoEdgesFound { - return 0, err - } - - return cid, nil -} - -// ChannelEdge represents the complete set of information for a channel edge in -// the known channel graph. This struct couples the core information of the -// edge as well as each of the known advertised edge policies. -type ChannelEdge struct { - // Info contains all the static information describing the channel. - Info *ChannelEdgeInfo - - // Policy1 points to the "first" edge policy of the channel containing - // the dynamic information required to properly route through the edge. - Policy1 *ChannelEdgePolicy - - // Policy2 points to the "second" edge policy of the channel containing - // the dynamic information required to properly route through the edge. - Policy2 *ChannelEdgePolicy -} - -// ChanUpdatesInHorizon returns all the known channel edges which have at least -// one edge that has an update timestamp within the specified horizon. -func (c *ChannelGraph) ChanUpdatesInHorizon(startTime, endTime time.Time) ([]ChannelEdge, error) { - // To ensure we don't return duplicate ChannelEdges, we'll use an - // additional map to keep track of the edges already seen to prevent - // re-adding it. - edgesSeen := make(map[uint64]struct{}) - edgesToCache := make(map[uint64]ChannelEdge) - var edgesInHorizon []ChannelEdge - - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var hits int - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) - if edgeUpdateIndex == nil { - return ErrGraphNoEdgesFound - } - - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - // We'll now obtain a cursor to perform a range query within - // the index to find all channels within the horizon. - updateCursor := edgeUpdateIndex.Cursor() - - var startTimeBytes, endTimeBytes [8 + 8]byte - byteOrder.PutUint64( - startTimeBytes[:8], uint64(startTime.Unix()), - ) - byteOrder.PutUint64( - endTimeBytes[:8], uint64(endTime.Unix()), - ) - - // With our start and end times constructed, we'll step through - // the index collecting the info and policy of each update of - // each channel that has a last update within the time range. - for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && - bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { - - // We have a new eligible entry, so we'll slice of the - // chan ID so we can query it in the DB. - chanID := indexKey[8:] - - // If we've already retrieved the info and policies for - // this edge, then we can skip it as we don't need to do - // so again. - chanIDInt := byteOrder.Uint64(chanID) - if _, ok := edgesSeen[chanIDInt]; ok { - continue - } - - if channel, ok := c.chanCache.get(chanIDInt); ok { - hits++ - edgesSeen[chanIDInt] = struct{}{} - edgesInHorizon = append(edgesInHorizon, channel) - continue - } - - // First, we'll fetch the static edge information. - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - chanID := byteOrder.Uint64(chanID) - return fmt.Errorf("unable to fetch info for "+ - "edge with chan_id=%v: %v", chanID, err) - } - edgeInfo.db = c.db - - // With the static information obtained, we'll now - // fetch the dynamic policy info. - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - chanID := byteOrder.Uint64(chanID) - return fmt.Errorf("unable to fetch policies "+ - "for edge with chan_id=%v: %v", chanID, - err) - } - - // Finally, we'll collate this edge with the rest of - // edges to be returned. - edgesSeen[chanIDInt] = struct{}{} - channel := ChannelEdge{ - Info: &edgeInfo, - Policy1: edge1, - Policy2: edge2, - } - edgesInHorizon = append(edgesInHorizon, channel) - edgesToCache[chanIDInt] = channel - } - - return nil - }) - switch { - case err == ErrGraphNoEdgesFound: - fallthrough - case err == ErrGraphNodesNotFound: - break - - case err != nil: - return nil, err - } - - // Insert any edges loaded from disk into the cache. - for chanid, channel := range edgesToCache { - c.chanCache.insert(chanid, channel) - } - - log.Debugf("ChanUpdatesInHorizon hit percentage: %f (%d/%d)", - float64(hits)/float64(len(edgesInHorizon)), hits, - len(edgesInHorizon)) - - return edgesInHorizon, nil -} - -// NodeUpdatesInHorizon returns all the known lightning node which have an -// update timestamp within the passed range. This method can be used by two -// nodes to quickly determine if they have the same set of up to date node -// announcements. -func (c *ChannelGraph) NodeUpdatesInHorizon(startTime, endTime time.Time) ([]LightningNode, error) { - var nodesInHorizon []LightningNode - - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - - nodeUpdateIndex := nodes.Bucket(nodeUpdateIndexBucket) - if nodeUpdateIndex == nil { - return ErrGraphNodesNotFound - } - - // We'll now obtain a cursor to perform a range query within - // the index to find all node announcements within the horizon. - updateCursor := nodeUpdateIndex.Cursor() - - var startTimeBytes, endTimeBytes [8 + 33]byte - byteOrder.PutUint64( - startTimeBytes[:8], uint64(startTime.Unix()), - ) - byteOrder.PutUint64( - endTimeBytes[:8], uint64(endTime.Unix()), - ) - - // With our start and end times constructed, we'll step through - // the index collecting info for each node within the time - // range. - for indexKey, _ := updateCursor.Seek(startTimeBytes[:]); indexKey != nil && - bytes.Compare(indexKey, endTimeBytes[:]) <= 0; indexKey, _ = updateCursor.Next() { - - nodePub := indexKey[8:] - node, err := fetchLightningNode(nodes, nodePub) - if err != nil { - return err - } - node.db = c.db - - nodesInHorizon = append(nodesInHorizon, node) - } - - return nil - }) - switch { - case err == ErrGraphNoEdgesFound: - fallthrough - case err == ErrGraphNodesNotFound: - break - - case err != nil: - return nil, err - } - - return nodesInHorizon, nil -} - -// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan -// ID's that we don't know and are not known zombies of the passed set. In other -// words, we perform a set difference of our set of chan ID's and the ones -// passed in. This method can be used by callers to determine the set of -// channels another peer knows of that we don't. -func (c *ChannelGraph) FilterKnownChanIDs(chanIDs []uint64) ([]uint64, error) { - var newChanIDs []uint64 - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // Fetch the zombie index, it may not exist if no edges have - // ever been marked as zombies. If the index has been - // initialized, we will use it later to skip known zombie edges. - zombieIndex := edges.Bucket(zombieBucket) - - // We'll run through the set of chanIDs and collate only the - // set of channel that are unable to be found within our db. - var cidBytes [8]byte - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) - - // If the edge is already known, skip it. - if v := edgeIndex.Get(cidBytes[:]); v != nil { - continue - } - - // If the edge is a known zombie, skip it. - if zombieIndex != nil { - isZombie, _, _ := isZombieEdge(zombieIndex, cid) - if isZombie { - continue - } - } - - newChanIDs = append(newChanIDs, cid) - } - - return nil - }) - switch { - // If we don't know of any edges yet, then we'll return the entire set - // of chan IDs specified. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil - - case err != nil: - return nil, err - } - - return newChanIDs, nil -} - -// FilterChannelRange returns the channel ID's of all known channels which were -// mined in a block height within the passed range. This method can be used to -// quickly share with a peer the set of channels we know of within a particular -// range to catch them up after a period of time offline. -func (c *ChannelGraph) FilterChannelRange(startHeight, endHeight uint32) ([]uint64, error) { - var chanIDs []uint64 - - startChanID := &lnwire.ShortChannelID{ - BlockHeight: startHeight, - } - - endChanID := lnwire.ShortChannelID{ - BlockHeight: endHeight, - TxIndex: math.MaxUint32 & 0x00ffffff, - TxPosition: math.MaxUint16, - } - - // As we need to perform a range scan, we'll convert the starting and - // ending height to their corresponding values when encoded using short - // channel ID's. - var chanIDStart, chanIDEnd [8]byte - byteOrder.PutUint64(chanIDStart[:], startChanID.ToUint64()) - byteOrder.PutUint64(chanIDEnd[:], endChanID.ToUint64()) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - cursor := edgeIndex.Cursor() - - // We'll now iterate through the database, and find each - // channel ID that resides within the specified range. - var cid uint64 - for k, _ := cursor.Seek(chanIDStart[:]); k != nil && - bytes.Compare(k, chanIDEnd[:]) <= 0; k, _ = cursor.Next() { - - // This channel ID rests within the target range, so - // we'll convert it into an integer and add it to our - // returned set. - cid = byteOrder.Uint64(k) - chanIDs = append(chanIDs, cid) - } - - return nil - }) - switch { - // If we don't know of any channels yet, then there's nothing to - // filter, so we'll return an empty slice. - case err == ErrGraphNoEdgesFound: - return chanIDs, nil - - case err != nil: - return nil, err - } - - return chanIDs, nil -} - -// FetchChanInfos returns the set of channel edges that correspond to the passed -// channel ID's. If an edge is the query is unknown to the database, it will -// skipped and the result will contain only those edges that exist at the time -// of the query. This can be used to respond to peer queries that are seeking to -// fill in gaps in their view of the channel graph. -func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { - // TODO(roasbeef): sort cids? - - var ( - chanEdges []ChannelEdge - cidBytes [8]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - for _, cid := range chanIDs { - byteOrder.PutUint64(cidBytes[:], cid) - - // First, we'll fetch the static edge information. If - // the edge is unknown, we will skip the edge and - // continue gathering all known edges. - edgeInfo, err := fetchChanEdgeInfo( - edgeIndex, cidBytes[:], - ) - switch { - case err == ErrEdgeNotFound: - continue - case err != nil: - return err - } - edgeInfo.db = c.db - - // With the static information obtained, we'll now - // fetch the dynamic policy info. - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, cidBytes[:], c.db, - ) - if err != nil { - return err - } - - chanEdges = append(chanEdges, ChannelEdge{ - Info: &edgeInfo, - Policy1: edge1, - Policy2: edge2, - }) - } - return nil - }) - if err != nil { - return nil, err - } - - return chanEdges, nil -} - -func delEdgeUpdateIndexEntry(edgesBucket *bbolt.Bucket, chanID uint64, - edge1, edge2 *ChannelEdgePolicy) error { - - // First, we'll fetch the edge update index bucket which currently - // stores an entry for the channel we're about to delete. - updateIndex := edgesBucket.Bucket(edgeUpdateIndexBucket) - if updateIndex == nil { - // No edges in bucket, return early. - return nil - } - - // Now that we have the bucket, we'll attempt to construct a template - // for the index key: updateTime || chanid. - var indexKey [8 + 8]byte - byteOrder.PutUint64(indexKey[8:], chanID) - - // With the template constructed, we'll attempt to delete an entry that - // would have been created by both edges: we'll alternate the update - // times, as one may had overridden the other. - if edge1 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge1.LastUpdate.Unix())) - if err := updateIndex.Delete(indexKey[:]); err != nil { - return err - } - } - - // We'll also attempt to delete the entry that may have been created by - // the second edge. - if edge2 != nil { - byteOrder.PutUint64(indexKey[:8], uint64(edge2.LastUpdate.Unix())) - if err := updateIndex.Delete(indexKey[:]); err != nil { - return err - } - } - - return nil -} - -func delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, - nodes *bbolt.Bucket, chanID []byte, isZombie bool) error { - - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - - // We'll also remove the entry in the edge update index bucket before - // we delete the edges themselves so we can access their last update - // times. - cid := byteOrder.Uint64(chanID) - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, nil, - ) - if err != nil { - return err - } - err = delEdgeUpdateIndexEntry(edges, cid, edge1, edge2) - if err != nil { - return err - } - - // The edge key is of the format pubKey || chanID. First we construct - // the latter half, populating the channel ID. - var edgeKey [33 + 8]byte - copy(edgeKey[33:], chanID) - - // With the latter half constructed, copy over the first public key to - // delete the edge in this direction, then the second to delete the - // edge in the opposite direction. - copy(edgeKey[:33], edgeInfo.NodeKey1Bytes[:]) - if edges.Get(edgeKey[:]) != nil { - if err := edges.Delete(edgeKey[:]); err != nil { - return err - } - } - copy(edgeKey[:33], edgeInfo.NodeKey2Bytes[:]) - if edges.Get(edgeKey[:]) != nil { - if err := edges.Delete(edgeKey[:]); err != nil { - return err - } - } - - // As part of deleting the edge we also remove all disabled entries - // from the edgePolicyDisabledIndex bucket. We do that for both directions. - updateEdgePolicyDisabledIndex(edges, cid, false, false) - updateEdgePolicyDisabledIndex(edges, cid, true, false) - - // With the edge data deleted, we can purge the information from the two - // edge indexes. - if err := edgeIndex.Delete(chanID); err != nil { - return err - } - var b bytes.Buffer - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { - return err - } - if err := chanIndex.Delete(b.Bytes()); err != nil { - return err - } - - // Finally, we'll mark the edge as a zombie within our index if it's - // being removed due to the channel becoming a zombie. We do this to - // ensure we don't store unnecessary data for spent channels. - if !isZombie { - return nil - } - - return markEdgeZombie( - zombieIndex, byteOrder.Uint64(chanID), edgeInfo.NodeKey1Bytes, - edgeInfo.NodeKey2Bytes, - ) -} - -// UpdateEdgePolicy updates the edge routing policy for a single directed edge -// within the database for the referenced channel. The `flags` attribute within -// the ChannelEdgePolicy determines which of the directed edges are being -// updated. If the flag is 1, then the first node's information is being -// updated, otherwise it's the second node's information. The node ordering is -// determined by the lexicographical ordering of the identity public keys of -// the nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *ChannelEdgePolicy) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - var isUpdate1 bool - err := c.db.Update(func(tx *bbolt.Tx) error { - var err error - isUpdate1, err = updateEdgePolicy(tx, edge) - return err - }) - if err != nil { - return err - } - - // If an entry for this channel is found in reject cache, we'll modify - // the entry with the updated timestamp for the direction that was just - // written. If the edge doesn't exist, we'll load the cache entry lazily - // during the next query for this edge. - if entry, ok := c.rejectCache.get(edge.ChannelID); ok { - if isUpdate1 { - entry.upd1Time = edge.LastUpdate.Unix() - } else { - entry.upd2Time = edge.LastUpdate.Unix() - } - c.rejectCache.insert(edge.ChannelID, entry) - } - - // If an entry for this channel is found in channel cache, we'll modify - // the entry with the updated policy for the direction that was just - // written. If the edge doesn't exist, we'll defer loading the info and - // policies and lazily read from disk during the next query. - if channel, ok := c.chanCache.get(edge.ChannelID); ok { - if isUpdate1 { - channel.Policy1 = edge - } else { - channel.Policy2 = edge - } - c.chanCache.insert(edge.ChannelID, channel) - } - - return nil -} - // updateEdgePolicy attempts to update an edge's policy within the relevant // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged @@ -2083,297 +391,6 @@ func (l *LightningNode) PubKey() (*btcec.PublicKey, error) { return key, nil } -// AuthSig is a signature under the advertised public key which serves to -// authenticate the attributes announced by this node. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (l *LightningNode) AuthSig() (*btcec.Signature, error) { - return btcec.ParseSignature(l.AuthSigBytes, btcec.S256()) -} - -// AddPubKey is a setter-link method that can be used to swap out the public -// key for a node. -func (l *LightningNode) AddPubKey(key *btcec.PublicKey) { - l.pubKey = key - copy(l.PubKeyBytes[:], key.SerializeCompressed()) -} - -// NodeAnnouncement retrieves the latest node announcement of the node. -func (l *LightningNode) NodeAnnouncement(signed bool) (*lnwire.NodeAnnouncement, - error) { - - if !l.HaveNodeAnnouncement { - return nil, fmt.Errorf("node does not have node announcement") - } - - alias, err := lnwire.NewNodeAlias(l.Alias) - if err != nil { - return nil, err - } - - nodeAnn := &lnwire.NodeAnnouncement{ - Features: l.Features.RawFeatureVector, - NodeID: l.PubKeyBytes, - RGBColor: l.Color, - Alias: alias, - Addresses: l.Addresses, - Timestamp: uint32(l.LastUpdate.Unix()), - ExtraOpaqueData: l.ExtraOpaqueData, - } - - if !signed { - return nodeAnn, nil - } - - sig, err := lnwire.NewSigFromRawSignature(l.AuthSigBytes) - if err != nil { - return nil, err - } - - nodeAnn.Signature = sig - - return nodeAnn, nil -} - -// isPublic determines whether the node is seen as public within the graph from -// the source node's point of view. An existing database transaction can also be -// specified. -func (l *LightningNode) isPublic(tx *bbolt.Tx, sourcePubKey []byte) (bool, error) { - // In order to determine whether this node is publicly advertised within - // the graph, we'll need to look at all of its edges and check whether - // they extend to any other node than the source node. errDone will be - // used to terminate the check early. - nodeIsPublic := false - errDone := errors.New("done") - err := l.ForEachChannel(tx, func(_ *bbolt.Tx, info *ChannelEdgeInfo, - _, _ *ChannelEdgePolicy) error { - - // If this edge doesn't extend to the source node, we'll - // terminate our search as we can now conclude that the node is - // publicly advertised within the graph due to the local node - // knowing of the current edge. - if !bytes.Equal(info.NodeKey1Bytes[:], sourcePubKey) && - !bytes.Equal(info.NodeKey2Bytes[:], sourcePubKey) { - - nodeIsPublic = true - return errDone - } - - // Since the edge _does_ extend to the source node, we'll also - // need to ensure that this is a public edge. - if info.AuthProof != nil { - nodeIsPublic = true - return errDone - } - - // Otherwise, we'll continue our search. - return nil - }) - if err != nil && err != errDone { - return false, err - } - - return nodeIsPublic, nil -} - -// FetchLightningNode attempts to look up a target node by its identity public -// key. If the node isn't found in the database, then ErrGraphNodeNotFound is -// returned. -func (c *ChannelGraph) FetchLightningNode(pub *btcec.PublicKey) (*LightningNode, error) { - var node *LightningNode - nodePub := pub.SerializeCompressed() - err := c.db.View(func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // If a key for this serialized public key isn't found, then - // the target node doesn't exist within the database. - nodeBytes := nodes.Get(nodePub) - if nodeBytes == nil { - return ErrGraphNodeNotFound - } - - // If the node is found, then we can de deserialize the node - // information to return to the user. - nodeReader := bytes.NewReader(nodeBytes) - n, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - n.db = c.db - - node = &n - - return nil - }) - if err != nil { - return nil, err - } - - return node, nil -} - -// HasLightningNode determines if the graph has a vertex identified by the -// target node identity public key. If the node exists in the database, a -// timestamp of when the data for the node was lasted updated is returned along -// with a true boolean. Otherwise, an empty time.Time is returned with a false -// boolean. -func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, error) { - var ( - updateTime time.Time - exists bool - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // If a key for this serialized public key isn't found, we can - // exit early. - nodeBytes := nodes.Get(nodePub[:]) - if nodeBytes == nil { - exists = false - return nil - } - - // Otherwise we continue on to obtain the time stamp - // representing the last time the data for this node was - // updated. - nodeReader := bytes.NewReader(nodeBytes) - node, err := deserializeLightningNode(nodeReader) - if err != nil { - return err - } - - exists = true - updateTime = node.LastUpdate - return nil - }) - if err != nil { - return time.Time{}, exists, err - } - - return updateTime, exists, nil -} - -// nodeTraversal is used to traverse all channels of a node given by its -// public key and passes channel information into the specified callback. -func nodeTraversal(tx *bbolt.Tx, nodePub []byte, db *DB, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - - traversal := func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNotFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // In order to reach all the edges for this node, we take - // advantage of the construction of the key-space within the - // edge bucket. The keys are stored in the form: pubKey || - // chanID. Therefore, starting from a chanID of zero, we can - // scan forward in the bucket, grabbing all the edges for the - // node. Once the prefix no longer matches, then we know we're - // done. - var nodeStart [33 + 8]byte - copy(nodeStart[:], nodePub) - copy(nodeStart[33:], chanStart[:]) - - // Starting from the key pubKey || 0, we seek forward in the - // bucket until the retrieved key no longer has the public key - // as its prefix. This indicates that we've stepped over into - // another node's edges, so we can terminate our scan. - edgeCursor := edges.Cursor() - for nodeEdge, _ := edgeCursor.Seek(nodeStart[:]); bytes.HasPrefix(nodeEdge, nodePub); nodeEdge, _ = edgeCursor.Next() { - // If the prefix still matches, the channel id is - // returned in nodeEdge. Channel id is used to lookup - // the node at the other end of the channel and both - // edge policies. - chanID := nodeEdge[33:] - edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - edgeInfo.db = db - - outgoingPolicy, err := fetchChanEdgePolicy( - edges, chanID, nodePub, nodes, - ) - if err != nil { - return err - } - - otherNode, err := edgeInfo.OtherNodeKeyBytes(nodePub) - if err != nil { - return err - } - - incomingPolicy, err := fetchChanEdgePolicy( - edges, chanID, otherNode[:], nodes, - ) - if err != nil { - return err - } - - // Finally, we execute the callback. - err = cb(tx, &edgeInfo, outgoingPolicy, incomingPolicy) - if err != nil { - return err - } - } - - return nil - } - - // If no transaction was provided, then we'll create a new transaction - // to execute the transaction within. - if tx == nil { - return db.View(traversal) - } - - // Otherwise, we re-use the existing transaction to execute the graph - // traversal. - return traversal(tx) -} - -// ForEachChannel iterates through all channels of this node, executing the -// passed callback with an edge info structure and the policies of each end -// of the channel. The first edge policy is the outgoing edge *to* the -// the connecting node, while the second is the incoming edge *from* the -// connecting node. If the callback returns an error, then the iteration is -// halted with the error propagated back up to the caller. -// -// Unknown policies are passed into the callback as nil values. -// -// If the caller wishes to re-use an existing boltdb transaction, then it -// should be passed as the first argument. Otherwise the first argument should -// be nil and a fresh transaction will be created to execute the graph -// traversal. -func (l *LightningNode) ForEachChannel(tx *bbolt.Tx, - cb func(*bbolt.Tx, *ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - - nodePub := l.PubKeyBytes[:] - db := l.db - - return nodeTraversal(tx, nodePub, db, cb) -} - // ChannelEdgeInfo represents a fully authenticated channel along with all its // unique attributes. Once an authenticated channel announcement has been // processed on the network, then an instance of ChannelEdgeInfo encapsulating @@ -2395,19 +412,15 @@ type ChannelEdgeInfo struct { // NodeKey1Bytes is the raw public key of the first node. NodeKey1Bytes [33]byte - nodeKey1 *btcec.PublicKey // NodeKey2Bytes is the raw public key of the first node. NodeKey2Bytes [33]byte - nodeKey2 *btcec.PublicKey // BitcoinKey1Bytes is the raw public key of the first node. BitcoinKey1Bytes [33]byte - bitcoinKey1 *btcec.PublicKey // BitcoinKey2Bytes is the raw public key of the first node. BitcoinKey2Bytes [33]byte - bitcoinKey2 *btcec.PublicKey // Features is an opaque byte slice that encodes the set of channel // specific features that this channel edge supports. @@ -2433,173 +446,6 @@ type ChannelEdgeInfo struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db *DB -} - -// AddNodeKeys is a setter-like method that can be used to replace the set of -// keys for the target ChannelEdgeInfo. -func (c *ChannelEdgeInfo) AddNodeKeys(nodeKey1, nodeKey2, bitcoinKey1, - bitcoinKey2 *btcec.PublicKey) { - - c.nodeKey1 = nodeKey1 - copy(c.NodeKey1Bytes[:], c.nodeKey1.SerializeCompressed()) - - c.nodeKey2 = nodeKey2 - copy(c.NodeKey2Bytes[:], nodeKey2.SerializeCompressed()) - - c.bitcoinKey1 = bitcoinKey1 - copy(c.BitcoinKey1Bytes[:], c.bitcoinKey1.SerializeCompressed()) - - c.bitcoinKey2 = bitcoinKey2 - copy(c.BitcoinKey2Bytes[:], bitcoinKey2.SerializeCompressed()) -} - -// NodeKey1 is the identity public key of the "first" node that was involved in -// the creation of this channel. A node is considered "first" if the -// lexicographical ordering the its serialized public key is "smaller" than -// that of the other node involved in channel creation. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey1() (*btcec.PublicKey, error) { - if c.nodeKey1 != nil { - return c.nodeKey1, nil - } - - key, err := btcec.ParsePubKey(c.NodeKey1Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.nodeKey1 = key - - return key, nil -} - -// NodeKey2 is the identity public key of the "second" node that was -// involved in the creation of this channel. A node is considered -// "second" if the lexicographical ordering the its serialized public -// key is "larger" than that of the other node involved in channel -// creation. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) NodeKey2() (*btcec.PublicKey, error) { - if c.nodeKey2 != nil { - return c.nodeKey2, nil - } - - key, err := btcec.ParsePubKey(c.NodeKey2Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.nodeKey2 = key - - return key, nil -} - -// BitcoinKey1 is the Bitcoin multi-sig key belonging to the first -// node, that was involved in the funding transaction that originally -// created the channel that this struct represents. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey1() (*btcec.PublicKey, error) { - if c.bitcoinKey1 != nil { - return c.bitcoinKey1, nil - } - - key, err := btcec.ParsePubKey(c.BitcoinKey1Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.bitcoinKey1 = key - - return key, nil -} - -// BitcoinKey2 is the Bitcoin multi-sig key belonging to the second -// node, that was involved in the funding transaction that originally -// created the channel that this struct represents. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the pubkey if absolutely necessary. -func (c *ChannelEdgeInfo) BitcoinKey2() (*btcec.PublicKey, error) { - if c.bitcoinKey2 != nil { - return c.bitcoinKey2, nil - } - - key, err := btcec.ParsePubKey(c.BitcoinKey2Bytes[:], btcec.S256()) - if err != nil { - return nil, err - } - c.bitcoinKey2 = key - - return key, nil -} - -// OtherNodeKeyBytes returns the node key bytes of the other end of -// the channel. -func (c *ChannelEdgeInfo) OtherNodeKeyBytes(thisNodeKey []byte) ( - [33]byte, error) { - - switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - return c.NodeKey2Bytes, nil - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - return c.NodeKey1Bytes, nil - default: - return [33]byte{}, fmt.Errorf("node not participating in this channel") - } -} - -// FetchOtherNode attempts to fetch the full LightningNode that's opposite of -// the target node in the channel. This is useful when one knows the pubkey of -// one of the nodes, and wishes to obtain the full LightningNode for the other -// end of the channel. -func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*LightningNode, error) { - - // Ensure that the node passed in is actually a member of the channel. - var targetNodeBytes [33]byte - switch { - case bytes.Equal(c.NodeKey1Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey2Bytes - case bytes.Equal(c.NodeKey2Bytes[:], thisNodeKey): - targetNodeBytes = c.NodeKey1Bytes - default: - return nil, fmt.Errorf("node not participating in this channel") - } - - var targetNode *LightningNode - fetchNodeFunc := func(tx *bbolt.Tx) error { - // First grab the nodes bucket which stores the mapping from - // pubKey to node information. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - node, err := fetchLightningNode(nodes, targetNodeBytes[:]) - if err != nil { - return err - } - node.db = c.db - - targetNode = &node - - return nil - } - - // If the transaction is nil, then we'll need to create a new one, - // otherwise we can use the existing db transaction. - var err error - if tx == nil { - err = c.db.View(fetchNodeFunc) - } else { - err = fetchNodeFunc(tx) - } - - return targetNode, err } // ChannelAuthProof is the authentication proof (the signature portion) for a @@ -2610,117 +456,23 @@ func (c *ChannelEdgeInfo) FetchOtherNode(tx *bbolt.Tx, thisNodeKey []byte) (*Lig // nodeID1 || nodeID2 || bitcoinKey1|| bitcoinKey2 || 2-byte-feature-len || // features. type ChannelAuthProof struct { - // nodeSig1 is a cached instance of the first node signature. - nodeSig1 *btcec.Signature - // NodeSig1Bytes are the raw bytes of the first node signature encoded // in DER format. NodeSig1Bytes []byte - // nodeSig2 is a cached instance of the second node signature. - nodeSig2 *btcec.Signature - // NodeSig2Bytes are the raw bytes of the second node signature // encoded in DER format. NodeSig2Bytes []byte - // bitcoinSig1 is a cached instance of the first bitcoin signature. - bitcoinSig1 *btcec.Signature - // BitcoinSig1Bytes are the raw bytes of the first bitcoin signature // encoded in DER format. BitcoinSig1Bytes []byte - // bitcoinSig2 is a cached instance of the second bitcoin signature. - bitcoinSig2 *btcec.Signature - // BitcoinSig2Bytes are the raw bytes of the second bitcoin signature // encoded in DER format. BitcoinSig2Bytes []byte } -// Node1Sig is the signature using the identity key of the node that is first -// in a lexicographical ordering of the serialized public keys of the two nodes -// that created the channel. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node1Sig() (*btcec.Signature, error) { - if c.nodeSig1 != nil { - return c.nodeSig1, nil - } - - sig, err := btcec.ParseSignature(c.NodeSig1Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.nodeSig1 = sig - - return sig, nil -} - -// Node2Sig is the signature using the identity key of the node that is second -// in a lexicographical ordering of the serialized public keys of the two nodes -// that created the channel. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) Node2Sig() (*btcec.Signature, error) { - if c.nodeSig2 != nil { - return c.nodeSig2, nil - } - - sig, err := btcec.ParseSignature(c.NodeSig2Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.nodeSig2 = sig - - return sig, nil -} - -// BitcoinSig1 is the signature using the public key of the first node that was -// used in the channel's multi-sig output. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig1() (*btcec.Signature, error) { - if c.bitcoinSig1 != nil { - return c.bitcoinSig1, nil - } - - sig, err := btcec.ParseSignature(c.BitcoinSig1Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.bitcoinSig1 = sig - - return sig, nil -} - -// BitcoinSig2 is the signature using the public key of the second node that -// was used in the channel's multi-sig output. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelAuthProof) BitcoinSig2() (*btcec.Signature, error) { - if c.bitcoinSig2 != nil { - return c.bitcoinSig2, nil - } - - sig, err := btcec.ParseSignature(c.BitcoinSig2Bytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.bitcoinSig2 = sig - - return sig, nil -} - // IsEmpty check is the authentication proof is empty Proof is empty if at // least one of the signatures are equal to nil. func (c *ChannelAuthProof) IsEmpty() bool { @@ -2742,9 +494,6 @@ type ChannelEdgePolicy struct { // use SetSigBytes instead to make sure that the cache is invalidated. SigBytes []byte - // sig is a cached fully parsed signature. - sig *btcec.Signature - // ChannelID is the unique channel ID for the channel. The first 3 // bytes are the block height, the next 3 the index within the block, // and the last 2 bytes are the output index for the channel. @@ -2794,35 +543,6 @@ type ChannelEdgePolicy struct { // and ensure we're able to make upgrades to the network in a forwards // compatible manner. ExtraOpaqueData []byte - - db *DB -} - -// Signature is a channel announcement signature, which is needed for proper -// edge policy announcement. -// -// NOTE: By having this method to access an attribute, we ensure we only need -// to fully deserialize the signature if absolutely necessary. -func (c *ChannelEdgePolicy) Signature() (*btcec.Signature, error) { - if c.sig != nil { - return c.sig, nil - } - - sig, err := btcec.ParseSignature(c.SigBytes, btcec.S256()) - if err != nil { - return nil, err - } - - c.sig = sig - - return sig, nil -} - -// SetSigBytes updates the signature and invalidates the cached parsed -// signature. -func (c *ChannelEdgePolicy) SetSigBytes(sig []byte) { - c.SigBytes = sig - c.sig = nil } // IsDisabled determines whether the edge has the disabled bit set. @@ -2831,488 +551,6 @@ func (c *ChannelEdgePolicy) IsDisabled() bool { lnwire.ChanUpdateDisabled } -// ComputeFee computes the fee to forward an HTLC of `amt` milli-satoshis over -// the passed active payment channel. This value is currently computed as -// specified in BOLT07, but will likely change in the near future. -func (c *ChannelEdgePolicy) ComputeFee( - amt lnwire.MilliSatoshi) lnwire.MilliSatoshi { - - return c.FeeBaseMSat + (amt*c.FeeProportionalMillionths)/feeRateParts -} - -// divideCeil divides dividend by factor and rounds the result up. -func divideCeil(dividend, factor lnwire.MilliSatoshi) lnwire.MilliSatoshi { - return (dividend + factor - 1) / factor -} - -// ComputeFeeFromIncoming computes the fee to forward an HTLC given the incoming -// amount. -func (c *ChannelEdgePolicy) ComputeFeeFromIncoming( - incomingAmt lnwire.MilliSatoshi) lnwire.MilliSatoshi { - - return incomingAmt - divideCeil( - feeRateParts*(incomingAmt-c.FeeBaseMSat), - feeRateParts+c.FeeProportionalMillionths, - ) -} - -// FetchChannelEdgesByOutpoint attempts to lookup the two directed edges for -// the channel identified by the funding outpoint. If the channel can't be -// found, then ErrEdgeNotFound is returned. A struct which houses the general -// information for the channel itself is returned as well as two structs that -// contain the routing policies for the channel in either direction. -func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - var ( - edgeInfo *ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // If the channel's outpoint doesn't exist within the outpoint - // index, then the edge does not exist. - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrGraphNoEdgesFound - } - var b bytes.Buffer - if err := writeOutpoint(&b, op); err != nil { - return err - } - chanID := chanIndex.Get(b.Bytes()) - if chanID == nil { - return ErrEdgeNotFound - } - - // If the channel is found to exists, then we'll first retrieve - // the general information for the channel. - edge, err := fetchChanEdgeInfo(edgeIndex, chanID) - if err != nil { - return err - } - edgeInfo = &edge - edgeInfo.db = c.db - - // Once we have the information about the channels' parameters, - // we'll fetch the routing policies for each for the directed - // edges. - e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) - if err != nil { - return err - } - - policy1 = e1 - policy2 = e2 - return nil - }) - if err != nil { - return nil, nil, nil, err - } - - return edgeInfo, policy1, policy2, nil -} - -// FetchChannelEdgesByID attempts to lookup the two directed edges for the -// channel identified by the channel ID. If the channel can't be found, then -// ErrEdgeNotFound is returned. A struct which houses the general information -// for the channel itself is returned as well as two structs that contain the -// routing policies for the channel in either direction. -// -// ErrZombieEdge an be returned if the edge is currently marked as a zombie -// within the database. In this case, the ChannelEdgePolicy's will be nil, and -// the ChannelEdgeInfo will only include the public keys of each node. -func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64, -) (*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - var ( - edgeInfo *ChannelEdgeInfo - policy1 *ChannelEdgePolicy - policy2 *ChannelEdgePolicy - channelID [8]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - byteOrder.PutUint64(channelID[:], chanID) - - // Now, attempt to fetch edge. - edge, err := fetchChanEdgeInfo(edgeIndex, channelID[:]) - - // If it doesn't exist, we'll quickly check our zombie index to - // see if we've previously marked it as so. - if err == ErrEdgeNotFound { - // If the zombie index doesn't exist, or the edge is not - // marked as a zombie within it, then we'll return the - // original ErrEdgeNotFound error. - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return ErrEdgeNotFound - } - - isZombie, pubKey1, pubKey2 := isZombieEdge( - zombieIndex, chanID, - ) - if !isZombie { - return ErrEdgeNotFound - } - - // Otherwise, the edge is marked as a zombie, so we'll - // populate the edge info with the public keys of each - // party as this is the only information we have about - // it and return an error signaling so. - edgeInfo = &ChannelEdgeInfo{ - NodeKey1Bytes: pubKey1, - NodeKey2Bytes: pubKey2, - } - return ErrZombieEdge - } - - // Otherwise, we'll just return the error if any. - if err != nil { - return err - } - - edgeInfo = &edge - edgeInfo.db = c.db - - // Then we'll attempt to fetch the accompanying policies of this - // edge. - e1, e2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, channelID[:], c.db, - ) - if err != nil { - return err - } - - policy1 = e1 - policy2 = e2 - return nil - }) - if err == ErrZombieEdge { - return edgeInfo, nil, nil, err - } - if err != nil { - return nil, nil, nil, err - } - - return edgeInfo, policy1, policy2, nil -} - -// IsPublicNode is a helper method that determines whether the node with the -// given public key is seen as a public node in the graph from the graph's -// source node's point of view. -func (c *ChannelGraph) IsPublicNode(pubKey [33]byte) (bool, error) { - var nodeIsPublic bool - err := c.db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNodesNotFound - } - ourPubKey := nodes.Get(sourceKey) - if ourPubKey == nil { - return ErrSourceNodeNotSet - } - node, err := fetchLightningNode(nodes, pubKey[:]) - if err != nil { - return err - } - - nodeIsPublic, err = node.isPublic(tx, ourPubKey) - return err - }) - if err != nil { - return false, err - } - - return nodeIsPublic, nil -} - -// genMultiSigP2WSH generates the p2wsh'd multisig script for 2 of 2 pubkeys. -func genMultiSigP2WSH(aPub, bPub []byte) ([]byte, error) { - if len(aPub) != 33 || len(bPub) != 33 { - return nil, fmt.Errorf("Pubkey size error. Compressed " + - "pubkeys only") - } - - // Swap to sort pubkeys if needed. Keys are sorted in lexicographical - // order. The signatures within the scriptSig must also adhere to the - // order, ensuring that the signatures for each public key appears in - // the proper order on the stack. - if bytes.Compare(aPub, bPub) == 1 { - aPub, bPub = bPub, aPub - } - - // First, we'll generate the witness script for the multi-sig. - bldr := txscript.NewScriptBuilder() - bldr.AddOp(txscript.OP_2) - bldr.AddData(aPub) // Add both pubkeys (sorted). - bldr.AddData(bPub) - bldr.AddOp(txscript.OP_2) - bldr.AddOp(txscript.OP_CHECKMULTISIG) - witnessScript, err := bldr.Script() - if err != nil { - return nil, err - } - - // With the witness script generated, we'll now turn it into a p2sh - // script: - // * OP_0 - bldr = txscript.NewScriptBuilder() - bldr.AddOp(txscript.OP_0) - scriptHash := sha256.Sum256(witnessScript) - bldr.AddData(scriptHash[:]) - - return bldr.Script() -} - -// EdgePoint couples the outpoint of a channel with the funding script that it -// creates. The FilteredChainView will use this to watch for spends of this -// edge point on chain. We require both of these values as depending on the -// concrete implementation, either the pkScript, or the out point will be used. -type EdgePoint struct { - // FundingPkScript is the p2wsh multi-sig script of the target channel. - FundingPkScript []byte - - // OutPoint is the outpoint of the target channel. - OutPoint wire.OutPoint -} - -// String returns a human readable version of the target EdgePoint. We return -// the outpoint directly as it is enough to uniquely identify the edge point. -func (e *EdgePoint) String() string { - return e.OutPoint.String() -} - -// ChannelView returns the verifiable edge information for each active channel -// within the known channel graph. The set of UTXO's (along with their scripts) -// returned are the ones that need to be watched on chain to detect channel -// closes on the resident blockchain. -func (c *ChannelGraph) ChannelView() ([]EdgePoint, error) { - var edgePoints []EdgePoint - if err := c.db.View(func(tx *bbolt.Tx) error { - // We're going to iterate over the entire channel index, so - // we'll need to fetch the edgeBucket to get to the index as - // it's a sub-bucket. - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - chanIndex := edges.Bucket(channelPointBucket) - if chanIndex == nil { - return ErrGraphNoEdgesFound - } - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrGraphNoEdgesFound - } - - // Once we have the proper bucket, we'll range over each key - // (which is the channel point for the channel) and decode it, - // accumulating each entry. - return chanIndex.ForEach(func(chanPointBytes, chanID []byte) error { - chanPointReader := bytes.NewReader(chanPointBytes) - - var chanPoint wire.OutPoint - err := readOutpoint(chanPointReader, &chanPoint) - if err != nil { - return err - } - - edgeInfo, err := fetchChanEdgeInfo( - edgeIndex, chanID, - ) - if err != nil { - return err - } - - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], - edgeInfo.BitcoinKey2Bytes[:], - ) - if err != nil { - return err - } - - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: chanPoint, - }) - - return nil - }) - }); err != nil { - return nil, err - } - - return edgePoints, nil -} - -// NewChannelEdgePolicy returns a new blank ChannelEdgePolicy. -func (c *ChannelGraph) NewChannelEdgePolicy() *ChannelEdgePolicy { - return &ChannelEdgePolicy{db: c.db} -} - -// markEdgeZombie marks an edge as a zombie within our zombie index. The public -// keys should represent the node public keys of the two parties involved in the -// edge. -func markEdgeZombie(zombieIndex *bbolt.Bucket, chanID uint64, pubKey1, - pubKey2 [33]byte) error { - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - - var v [66]byte - copy(v[:33], pubKey1[:]) - copy(v[33:], pubKey2[:]) - - return zombieIndex.Put(k[:], v[:]) -} - -// MarkEdgeLive clears an edge from our zombie index, deeming it as live. -func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { - c.cacheMu.Lock() - defer c.cacheMu.Unlock() - - err := c.db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - return zombieIndex.Delete(k[:]) - }) - if err != nil { - return err - } - - c.rejectCache.remove(chanID) - c.chanCache.remove(chanID) - - return nil -} - -// IsZombieEdge returns whether the edge is considered zombie. If it is a -// zombie, then the two node public keys corresponding to this edge are also -// returned. -func (c *ChannelGraph) IsZombieEdge(chanID uint64) (bool, [33]byte, [33]byte) { - var ( - isZombie bool - pubKey1, pubKey2 [33]byte - ) - - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - isZombie, pubKey1, pubKey2 = isZombieEdge(zombieIndex, chanID) - return nil - }) - if err != nil { - return false, [33]byte{}, [33]byte{} - } - - return isZombie, pubKey1, pubKey2 -} - -// isZombieEdge returns whether an entry exists for the given channel in the -// zombie index. If an entry exists, then the two node public keys corresponding -// to this edge are also returned. -func isZombieEdge(zombieIndex *bbolt.Bucket, - chanID uint64) (bool, [33]byte, [33]byte) { - - var k [8]byte - byteOrder.PutUint64(k[:], chanID) - - v := zombieIndex.Get(k[:]) - if v == nil { - return false, [33]byte{}, [33]byte{} - } - - var pubKey1, pubKey2 [33]byte - copy(pubKey1[:], v[:33]) - copy(pubKey2[:], v[33:]) - - return true, pubKey1, pubKey2 -} - -// NumZombies returns the current number of zombie channels in the graph. -func (c *ChannelGraph) NumZombies() (uint64, error) { - var numZombies uint64 - err := c.db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return nil - } - zombieIndex := edges.Bucket(zombieBucket) - if zombieIndex == nil { - return nil - } - - return zombieIndex.ForEach(func(_, _ []byte) error { - numZombies++ - return nil - }) - }) - if err != nil { - return 0, err - } - - return numZombies, nil -} - func putLightningNode(nodeBucket *bbolt.Bucket, aliasBucket *bbolt.Bucket, updateIndex *bbolt.Bucket, node *LightningNode) error { @@ -3548,84 +786,6 @@ func deserializeLightningNode(r io.Reader) (LightningNode, error) { return node, nil } -func putChanEdgeInfo(edgeIndex *bbolt.Bucket, edgeInfo *ChannelEdgeInfo, chanID [8]byte) error { - var b bytes.Buffer - - if _, err := b.Write(edgeInfo.NodeKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.NodeKey2Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey1Bytes[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.BitcoinKey2Bytes[:]); err != nil { - return err - } - - if err := wire.WriteVarBytes(&b, 0, edgeInfo.Features); err != nil { - return err - } - - authProof := edgeInfo.AuthProof - var nodeSig1, nodeSig2, bitcoinSig1, bitcoinSig2 []byte - if authProof != nil { - nodeSig1 = authProof.NodeSig1Bytes - nodeSig2 = authProof.NodeSig2Bytes - bitcoinSig1 = authProof.BitcoinSig1Bytes - bitcoinSig2 = authProof.BitcoinSig2Bytes - } - - if err := wire.WriteVarBytes(&b, 0, nodeSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, nodeSig2); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig1); err != nil { - return err - } - if err := wire.WriteVarBytes(&b, 0, bitcoinSig2); err != nil { - return err - } - - if err := writeOutpoint(&b, &edgeInfo.ChannelPoint); err != nil { - return err - } - if err := binary.Write(&b, byteOrder, uint64(edgeInfo.Capacity)); err != nil { - return err - } - if _, err := b.Write(chanID[:]); err != nil { - return err - } - if _, err := b.Write(edgeInfo.ChainHash[:]); err != nil { - return err - } - - if len(edgeInfo.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { - return ErrTooManyExtraOpaqueBytes(len(edgeInfo.ExtraOpaqueData)) - } - err := wire.WriteVarBytes(&b, 0, edgeInfo.ExtraOpaqueData) - if err != nil { - return err - } - - return edgeIndex.Put(chanID[:], b.Bytes()) -} - -func fetchChanEdgeInfo(edgeIndex *bbolt.Bucket, - chanID []byte) (ChannelEdgeInfo, error) { - - edgeInfoBytes := edgeIndex.Get(chanID) - if edgeInfoBytes == nil { - return ChannelEdgeInfo{}, ErrEdgeNotFound - } - - edgeInfoReader := bytes.NewReader(edgeInfoBytes) - return deserializeChanEdgeInfo(edgeInfoReader) -} - func deserializeChanEdgeInfo(r io.Reader) (ChannelEdgeInfo, error) { var ( err error @@ -3856,47 +1016,6 @@ func fetchChanEdgePolicy(edges *bbolt.Bucket, chanID []byte, return ep, nil } -func fetchChanEdgePolicies(edgeIndex *bbolt.Bucket, edges *bbolt.Bucket, - nodes *bbolt.Bucket, chanID []byte, - db *DB) (*ChannelEdgePolicy, *ChannelEdgePolicy, error) { - - edgeInfo := edgeIndex.Get(chanID) - if edgeInfo == nil { - return nil, nil, ErrEdgeNotFound - } - - // The first node is contained within the first half of the edge - // information. We only propagate the error here and below if it's - // something other than edge non-existence. - node1Pub := edgeInfo[:33] - edge1, err := fetchChanEdgePolicy(edges, chanID, node1Pub, nodes) - if err != nil { - return nil, nil, err - } - - // As we may have a single direction of the edge but not the other, - // only fill in the database pointers if the edge is found. - if edge1 != nil { - edge1.db = db - edge1.Node.db = db - } - - // Similarly, the second node is contained within the latter - // half of the edge information. - node2Pub := edgeInfo[33:66] - edge2, err := fetchChanEdgePolicy(edges, chanID, node2Pub, nodes) - if err != nil { - return nil, nil, err - } - - if edge2 != nil { - edge2.db = db - edge2.Node.db = db - } - - return edge1, edge2, nil -} - func serializeChanEdgePolicy(w io.Writer, edge *ChannelEdgePolicy, to []byte) error { diff --git a/channeldb/migration_01_to_11/graph_test.go b/channeldb/migration_01_to_11/graph_test.go index 00a8a000..a65f0046 100644 --- a/channeldb/migration_01_to_11/graph_test.go +++ b/channeldb/migration_01_to_11/graph_test.go @@ -1,24 +1,13 @@ package migration_01_to_11 import ( - "bytes" - "crypto/sha256" - "fmt" "image/color" - "math" "math/big" prand "math/rand" "net" - "reflect" - "runtime" - "testing" "time" "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/chaincfg/chainhash" - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lnwire" ) @@ -66,3132 +55,3 @@ func createTestVertex(db *DB) (*LightningNode, error) { return createLightningNode(db, priv) } - -func TestNodeInsertionAndDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test basic insertion/deletion for vertexes from the - // graph, so we'll create a test vertex to start with. - _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - node := &LightningNode{ - HaveNodeAnnouncement: true, - AuthSigBytes: testSig.Serialize(), - LastUpdate: time.Unix(1232342, 0), - Color: color.RGBA{1, 2, 3, 0}, - Alias: "kek", - Features: testFeatures, - Addresses: testAddrs, - ExtraOpaqueData: []byte("extra new data"), - db: db, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - // First, insert the node into the graph DB. This should succeed - // without any errors. - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, fetch the node from the database to ensure everything was - // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) - if err != nil { - t.Fatalf("unable to locate node: %v", err) - } - - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { - t.Fatalf("unable to query for node: %v", err) - } else if !exists { - t.Fatalf("node should be found but wasn't") - } - - // The two nodes should match exactly! - if err := compareNodes(node, dbNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // Next, delete the node from the graph, this should purge all data - // related to the node. - if err := graph.DeleteLightningNode(testPub); err != nil { - t.Fatalf("unable to delete node; %v", err) - } - - // Finally, attempt to fetch the node again. This should fail as the - // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) - if err != ErrGraphNodeNotFound { - t.Fatalf("fetch after delete should fail!") - } -} - -// TestPartialNode checks that we can add and retrieve a LightningNode where -// where only the pubkey is known to the database. -func TestPartialNode(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We want to be able to insert nodes into the graph that only has the - // PubKey set. - _, testPub := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - node := &LightningNode{ - HaveNodeAnnouncement: false, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, fetch the node from the database to ensure everything was - // serialized properly. - dbNode, err := graph.FetchLightningNode(testPub) - if err != nil { - t.Fatalf("unable to locate node: %v", err) - } - - if _, exists, err := graph.HasLightningNode(dbNode.PubKeyBytes); err != nil { - t.Fatalf("unable to query for node: %v", err) - } else if !exists { - t.Fatalf("node should be found but wasn't") - } - - // The two nodes should match exactly! (with default values for - // LastUpdate and db set to satisfy compareNodes()) - node = &LightningNode{ - HaveNodeAnnouncement: false, - LastUpdate: time.Unix(0, 0), - db: db, - } - copy(node.PubKeyBytes[:], testPub.SerializeCompressed()) - - if err := compareNodes(node, dbNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // Next, delete the node from the graph, this should purge all data - // related to the node. - if err := graph.DeleteLightningNode(testPub); err != nil { - t.Fatalf("unable to delete node: %v", err) - } - - // Finally, attempt to fetch the node again. This should fail as the - // node should have been deleted from the database. - _, err = graph.FetchLightningNode(testPub) - if err != ErrGraphNodeNotFound { - t.Fatalf("fetch after delete should fail!") - } -} - -func TestAliasLookup(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the alias index within the database, so first - // create a new test node. - testNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // Add the node to the graph's database, this should also insert an - // entry into the alias index for this node. - if err := graph.AddLightningNode(testNode); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, attempt to lookup the alias. The alias should exactly match - // the one which the test node was assigned. - nodePub, err := testNode.PubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - } - dbAlias, err := graph.LookupAlias(nodePub) - if err != nil { - t.Fatalf("unable to find alias: %v", err) - } - if dbAlias != testNode.Alias { - t.Fatalf("aliases don't match, expected %v got %v", - testNode.Alias, dbAlias) - } - - // Ensure that looking up a non-existent alias results in an error. - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - nodePub, err = node.PubKey() - if err != nil { - t.Fatalf("unable to generate pubkey: %v", err) - } - _, err = graph.LookupAlias(nodePub) - if err != ErrNodeAliasNotFound { - t.Fatalf("alias lookup should fail for non-existent pubkey") - } -} - -func TestSourceNode(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the setting/getting of the source node, so we - // first create a fake node to use within the test. - testNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // Attempt to fetch the source node, this should return an error as the - // source node hasn't yet been set. - if _, err := graph.SourceNode(); err != ErrSourceNodeNotSet { - t.Fatalf("source node shouldn't be set in new graph") - } - - // Set the source the source node, this should insert the node into the - // database in a special way indicating it's the source node. - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // Retrieve the source node from the database, it should exactly match - // the one we set above. - sourceNode, err := graph.SourceNode() - if err != nil { - t.Fatalf("unable to fetch source node: %v", err) - } - if err := compareNodes(testNode, sourceNode); err != nil { - t.Fatalf("nodes don't match: %v", err) - } -} - -func TestEdgeInsertionDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the insertion/deletion of edges, so we create two - // vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - chanID := uint64(prand.Int63()) - outpoint := wire.OutPoint{ - Hash: rev, - Index: 9, - } - - // Add the new edge to the database, this should proceed without any - // errors. - node1Pub, err := node1.PubKey() - if err != nil { - t.Fatalf("unable to generate node key: %v", err) - } - node2Pub, err := node2.PubKey() - if err != nil { - t.Fatalf("unable to generate node key: %v", err) - } - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure that both policies are returned as unknown (nil). - _, e1, e2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel edge") - } - if e1 != nil || e2 != nil { - t.Fatalf("channel edges not unknown") - } - - // Next, attempt to delete the edge from the database, again this - // should proceed without any issues. - if err := graph.DeleteChannelEdges(chanID); err != nil { - t.Fatalf("unable to delete edge: %v", err) - } - - // Ensure that any query attempts to lookup the delete channel edge are - // properly deleted. - if _, _, _, err := graph.FetchChannelEdgesByOutpoint(&outpoint); err == nil { - t.Fatalf("channel edge not deleted") - } - if _, _, _, err := graph.FetchChannelEdgesByID(chanID); err == nil { - t.Fatalf("channel edge not deleted") - } - isZombie, _, _ := graph.IsZombieEdge(chanID) - if !isZombie { - t.Fatal("channel edge not marked as zombie") - } - - // Finally, attempt to delete a (now) non-existent edge within the - // database, this should result in an error. - err = graph.DeleteChannelEdges(chanID) - if err != ErrEdgeNotFound { - t.Fatalf("deleting a non-existent edge should fail!") - } -} - -func createEdge(height, txIndex uint32, txPosition uint16, outPointIndex uint32, - node1, node2 *LightningNode) (ChannelEdgeInfo, lnwire.ShortChannelID) { - - shortChanID := lnwire.ShortChannelID{ - BlockHeight: height, - TxIndex: txIndex, - TxPosition: txPosition, - } - outpoint := wire.OutPoint{ - Hash: rev, - Index: outPointIndex, - } - - node1Pub, _ := node1.PubKey() - node2Pub, _ := node2.PubKey() - edgeInfo := ChannelEdgeInfo{ - ChannelID: shortChanID.ToUint64(), - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 9000, - } - - copy(edgeInfo.NodeKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.NodeKey2Bytes[:], node2Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey1Bytes[:], node1Pub.SerializeCompressed()) - copy(edgeInfo.BitcoinKey2Bytes[:], node2Pub.SerializeCompressed()) - - return edgeInfo, shortChanID -} - -// TestDisconnectBlockAtHeight checks that the pruned state of the channel -// database is what we expect after calling DisconnectBlockAtHeight. -func TestDisconnectBlockAtHeight(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // We'd like to test the insertion/deletion of edges, so we create two - // vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - var spendOutputs []*wire.OutPoint - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) - - // Prune the graph a few times to make sure we have entries in the - // prune log. - _, err = graph.PruneGraph(spendOutputs, &blockHash, 155) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - var blockHash2 chainhash.Hash - copy(blockHash2[:], bytes.Repeat([]byte{2}, 32)) - - _, err = graph.PruneGraph(spendOutputs, &blockHash2, 156) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // We'll create 3 almost identical edges, so first create a helper - // method containing all logic for doing so. - - // Create an edge which has its block height at 156. - height := uint32(156) - edgeInfo, _ := createEdge(height, 0, 0, 0, node1, node2) - - // Create an edge with block height 157. We give it - // maximum values for tx index and position, to make - // sure our database range scan get edges from the - // entire range. - edgeInfo2, _ := createEdge( - height+1, math.MaxUint32&0x00ffffff, math.MaxUint16, 1, - node1, node2, - ) - - // Create a third edge, this with a block height of 155. - edgeInfo3, _ := createEdge(height-1, 0, 0, 2, node1, node2) - - // Now add all these new edges to the database. - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - if err := graph.AddChannelEdge(&edgeInfo2); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - if err := graph.AddChannelEdge(&edgeInfo3); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Call DisconnectBlockAtHeight, which should prune every channel - // that has a funding height of 'height' or greater. - removed, err := graph.DisconnectBlockAtHeight(uint32(height)) - if err != nil { - t.Fatalf("unable to prune %v", err) - } - - // The two edges should have been removed. - if len(removed) != 2 { - t.Fatalf("expected two edges to be removed from graph, "+ - "only %d were", len(removed)) - } - if removed[0].ChannelID != edgeInfo.ChannelID { - t.Fatalf("expected edge to be removed from graph") - } - if removed[1].ChannelID != edgeInfo2.ChannelID { - t.Fatalf("expected edge to be removed from graph") - } - - // The two first edges should be removed from the db. - _, _, has, isZombie, err := graph.HasChannelEdge(edgeInfo.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if has { - t.Fatalf("edge1 was not pruned from the graph") - } - if isZombie { - t.Fatal("reorged edge1 should not be marked as zombie") - } - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo2.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if has { - t.Fatalf("edge2 was not pruned from the graph") - } - if isZombie { - t.Fatal("reorged edge2 should not be marked as zombie") - } - - // Edge 3 should not be removed. - _, _, has, isZombie, err = graph.HasChannelEdge(edgeInfo3.ChannelID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if !has { - t.Fatalf("edge3 was pruned from the graph") - } - if isZombie { - t.Fatal("edge3 was marked as zombie") - } - - // PruneTip should be set to the blockHash we specified for the block - // at height 155. - hash, h, err := graph.PruneTip() - if err != nil { - t.Fatalf("unable to get prune tip: %v", err) - } - if !blockHash.IsEqual(hash) { - t.Fatalf("expected best block to be %x, was %x", blockHash, hash) - } - if h != height-1 { - t.Fatalf("expected best block height to be %d, was %d", height-1, h) - } -} - -func assertEdgeInfoEqual(t *testing.T, e1 *ChannelEdgeInfo, - e2 *ChannelEdgeInfo) { - - if e1.ChannelID != e2.ChannelID { - t.Fatalf("chan id's don't match: %v vs %v", e1.ChannelID, - e2.ChannelID) - } - - if e1.ChainHash != e2.ChainHash { - t.Fatalf("chain hashes don't match: %v vs %v", e1.ChainHash, - e2.ChainHash) - } - - if !bytes.Equal(e1.NodeKey1Bytes[:], e2.NodeKey1Bytes[:]) { - t.Fatalf("nodekey1 doesn't match") - } - if !bytes.Equal(e1.NodeKey2Bytes[:], e2.NodeKey2Bytes[:]) { - t.Fatalf("nodekey2 doesn't match") - } - if !bytes.Equal(e1.BitcoinKey1Bytes[:], e2.BitcoinKey1Bytes[:]) { - t.Fatalf("bitcoinkey1 doesn't match") - } - if !bytes.Equal(e1.BitcoinKey2Bytes[:], e2.BitcoinKey2Bytes[:]) { - t.Fatalf("bitcoinkey2 doesn't match") - } - - if !bytes.Equal(e1.Features, e2.Features) { - t.Fatalf("features doesn't match: %x vs %x", e1.Features, - e2.Features) - } - - if !bytes.Equal(e1.AuthProof.NodeSig1Bytes, e2.AuthProof.NodeSig1Bytes) { - t.Fatalf("nodesig1 doesn't match: %v vs %v", - spew.Sdump(e1.AuthProof.NodeSig1Bytes), - spew.Sdump(e2.AuthProof.NodeSig1Bytes)) - } - if !bytes.Equal(e1.AuthProof.NodeSig2Bytes, e2.AuthProof.NodeSig2Bytes) { - t.Fatalf("nodesig2 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig1Bytes, e2.AuthProof.BitcoinSig1Bytes) { - t.Fatalf("bitcoinsig1 doesn't match") - } - if !bytes.Equal(e1.AuthProof.BitcoinSig2Bytes, e2.AuthProof.BitcoinSig2Bytes) { - t.Fatalf("bitcoinsig2 doesn't match") - } - - if e1.ChannelPoint != e2.ChannelPoint { - t.Fatalf("channel point match: %v vs %v", e1.ChannelPoint, - e2.ChannelPoint) - } - - if e1.Capacity != e2.Capacity { - t.Fatalf("capacity doesn't match: %v vs %v", e1.Capacity, - e2.Capacity) - } - - if !bytes.Equal(e1.ExtraOpaqueData, e2.ExtraOpaqueData) { - t.Fatalf("extra data doesn't match: %v vs %v", - e2.ExtraOpaqueData, e2.ExtraOpaqueData) - } -} - -func createChannelEdge(db *DB, node1, node2 *LightningNode) (*ChannelEdgeInfo, - *ChannelEdgePolicy, *ChannelEdgePolicy) { - - var ( - firstNode *LightningNode - secondNode *LightningNode - ) - if bytes.Compare(node1.PubKeyBytes[:], node2.PubKeyBytes[:]) == -1 { - firstNode = node1 - secondNode = node2 - } else { - firstNode = node2 - secondNode = node1 - } - - // In addition to the fake vertexes we create some fake channel - // identifiers. - chanID := uint64(prand.Int63()) - outpoint := wire.OutPoint{ - Hash: rev, - Index: 9, - } - - // Add the new edge to the database, this should proceed without any - // errors. - edgeInfo := &ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: outpoint, - Capacity: 1000, - ExtraOpaqueData: []byte("new unknown feature"), - } - copy(edgeInfo.NodeKey1Bytes[:], firstNode.PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], secondNode.PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], firstNode.PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], secondNode.PubKeyBytes[:]) - - edge1 := &ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(433453, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 3452352, - Node: secondNode, - ExtraOpaqueData: []byte("new unknown feature2"), - db: db, - } - edge2 := &ChannelEdgePolicy{ - SigBytes: testSig.Serialize(), - ChannelID: chanID, - LastUpdate: time.Unix(124234, 0), - MessageFlags: 1, - ChannelFlags: 1, - TimeLockDelta: 99, - MinHTLC: 2342135, - MaxHTLC: 13928598, - FeeBaseMSat: 4352345, - FeeProportionalMillionths: 90392423, - Node: firstNode, - ExtraOpaqueData: []byte("new unknown feature1"), - db: db, - } - - return edgeInfo, edge1, edge2 -} - -func TestEdgeInfoUpdates(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the update of edges inserted into the database, so - // we create two vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create an edge and add it to the db. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - - // Make sure inserting the policy at this point, before the edge info - // is added, will fail. - if err := graph.UpdateEdgePolicy(edge1); err != ErrEdgeNotFound { - t.Fatalf("expected ErrEdgeNotFound, got: %v", err) - } - - // Add the edge info. - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanID := edgeInfo.ChannelID - outpoint := edgeInfo.ChannelPoint - - // Next, insert both edge policies into the database, they should both - // be inserted without any issues. - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Check for existence of the edge within the database, it should be - // found. - _, _, found, isZombie, err := graph.HasChannelEdge(chanID) - if err != nil { - t.Fatalf("unable to query for edge: %v", err) - } - if !found { - t.Fatalf("graph should have of inserted edge") - } - if isZombie { - t.Fatal("live edge should not be marked as zombie") - } - - // We should also be able to retrieve the channelID only knowing the - // channel point of the channel. - dbChanID, err := graph.ChannelID(&outpoint) - if err != nil { - t.Fatalf("unable to retrieve channel ID: %v", err) - } - if dbChanID != chanID { - t.Fatalf("chan ID's mismatch, expected %v got %v", dbChanID, - chanID) - } - - // With the edges inserted, perform some queries to ensure that they've - // been inserted properly. - dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) - - // Next, attempt to query the channel edges according to the outpoint - // of the channel. - dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByOutpoint(&outpoint) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) -} - -func randEdgePolicy(chanID uint64, op wire.OutPoint, db *DB) *ChannelEdgePolicy { - update := prand.Int63() - - return newEdgePolicy(chanID, op, db, update) -} - -func newEdgePolicy(chanID uint64, op wire.OutPoint, db *DB, - updateTime int64) *ChannelEdgePolicy { - - return &ChannelEdgePolicy{ - ChannelID: chanID, - LastUpdate: time.Unix(updateTime, 0), - MessageFlags: 1, - ChannelFlags: 0, - TimeLockDelta: uint16(prand.Int63()), - MinHTLC: lnwire.MilliSatoshi(prand.Int63()), - MaxHTLC: lnwire.MilliSatoshi(prand.Int63()), - FeeBaseMSat: lnwire.MilliSatoshi(prand.Int63()), - FeeProportionalMillionths: lnwire.MilliSatoshi(prand.Int63()), - db: db, - } -} - -func TestGraphTraversal(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test some of the graph traversal capabilities within - // the DB, so we'll create a series of fake nodes to insert into the - // graph. - const numNodes = 20 - nodes := make([]*LightningNode, numNodes) - nodeIndex := map[string]struct{}{} - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - nodes[i] = node - nodeIndex[node.Alias] = struct{}{} - } - - // Add each of the nodes into the graph, they should be inserted - // without error. - for _, node := range nodes { - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - - // Iterate over each node as returned by the graph, if all nodes are - // reached, then the map created above should be empty. - err = graph.ForEachNode(nil, func(_ *bbolt.Tx, node *LightningNode) error { - delete(nodeIndex, node.Alias) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(nodeIndex) != 0 { - t.Fatalf("all nodes not reached within ForEach") - } - - // Determine which node is "smaller", we'll need this in order to - // properly create the edges for the graph. - var firstNode, secondNode *LightningNode - if bytes.Compare(nodes[0].PubKeyBytes[:], nodes[1].PubKeyBytes[:]) == -1 { - firstNode = nodes[0] - secondNode = nodes[1] - } else { - firstNode = nodes[0] - secondNode = nodes[1] - } - - // Create 5 channels between the first two nodes we generated above. - const numChannels = 5 - chanIndex := map[uint64]struct{}{} - for i := 0; i < numChannels; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], nodes[1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], nodes[0].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], nodes[1].PubKeyBytes[:]) - err := graph.AddChannelEdge(&edgeInfo) - if err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create and add an edge with random data that points from - // node1 -> node2. - edge := randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 0 - edge.Node = secondNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node2 -> node1 - // this time. - edge = randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 1 - edge.Node = firstNode - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - chanIndex[chanID] = struct{}{} - } - - // Iterate through all the known channels within the graph DB, once - // again if the map is empty that indicates that all edges have - // properly been reached. - err = graph.ForEachChannel(func(ei *ChannelEdgeInfo, _ *ChannelEdgePolicy, - _ *ChannelEdgePolicy) error { - - delete(chanIndex, ei.ChannelID) - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if len(chanIndex) != 0 { - t.Fatalf("all edges not reached within ForEach") - } - - // Finally, we want to test the ability to iterate over all the - // outgoing channels for a particular node. - numNodeChans := 0 - err = firstNode.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy) error { - - // All channels between first and second node should have fully - // (both sides) specified policies. - if inEdge == nil || outEdge == nil { - return fmt.Errorf("channel policy not present") - } - - // Each should indicate that it's outgoing (pointed - // towards the second node). - if !bytes.Equal(outEdge.Node.PubKeyBytes[:], secondNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } - - // The incoming edge should also indicate that it's pointing to - // the origin node. - if !bytes.Equal(inEdge.Node.PubKeyBytes[:], firstNode.PubKeyBytes[:]) { - return fmt.Errorf("wrong outgoing edge") - } - - numNodeChans++ - return nil - }) - if err != nil { - t.Fatalf("for each failure: %v", err) - } - if numNodeChans != numChannels { - t.Fatalf("all edges for node not reached within ForEach: "+ - "expected %v, got %v", numChannels, numNodeChans) - } -} - -func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash, - blockHeight uint32) { - - pruneHash, pruneHeight, err := graph.PruneTip() - if err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to fetch prune tip: %v", line, err) - } - if !bytes.Equal(blockHash[:], pruneHash[:]) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line: %v, prune tips don't match, expected %x got %x", - line, blockHash, pruneHash) - } - if pruneHeight != blockHeight { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: prune heights don't match, expected %v "+ - "got %v", line, blockHeight, pruneHeight) - } -} - -func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { - numChans := 0 - if err := graph.ForEachChannel(func(*ChannelEdgeInfo, *ChannelEdgePolicy, - *ChannelEdgePolicy) error { - - numChans++ - return nil - }); err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to scan channels: %v", line, err) - } - if numChans != n { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v chans instead have %v", line, - n, numChans) - } -} - -func assertNumNodes(t *testing.T, graph *ChannelGraph, n int) { - numNodes := 0 - err := graph.ForEachNode(nil, func(_ *bbolt.Tx, _ *LightningNode) error { - numNodes++ - return nil - }) - if err != nil { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: unable to scan nodes: %v", line, err) - } - - if numNodes != n { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: expected %v nodes, got %v", line, n, numNodes) - } -} - -func assertChanViewEqual(t *testing.T, a []EdgePoint, b []EdgePoint) { - if len(a) != len(b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chan views don't match", line) - } - - chanViewSet := make(map[wire.OutPoint]struct{}) - for _, op := range a { - chanViewSet[op.OutPoint] = struct{}{} - } - - for _, op := range b { - if _, ok := chanViewSet[op.OutPoint]; !ok { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chanPoint(%v) not found in first "+ - "view", line, op) - } - } -} - -func assertChanViewEqualChanPoints(t *testing.T, a []EdgePoint, b []*wire.OutPoint) { - if len(a) != len(b) { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chan views don't match", line) - } - - chanViewSet := make(map[wire.OutPoint]struct{}) - for _, op := range a { - chanViewSet[op.OutPoint] = struct{}{} - } - - for _, op := range b { - if _, ok := chanViewSet[*op]; !ok { - _, _, line, _ := runtime.Caller(1) - t.Fatalf("line %v: chanPoint(%v) not found in first "+ - "view", line, op) - } - } -} - -func TestGraphPruning(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // As initial set up for the test, we'll create a graph with 5 vertexes - // and enough edges to create a fully connected graph. The graph will - // be rather simple, representing a straight line. - const numNodes = 5 - graphNodes := make([]*LightningNode, numNodes) - for i := 0; i < numNodes; i++ { - node, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - graphNodes[i] = node - } - - // With the vertexes created, we'll next create a series of channels - // between them. - channelPoints := make([]*wire.OutPoint, 0, numNodes-1) - edgePoints := make([]EdgePoint, 0, numNodes-1) - for i := 0; i < numNodes-1; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - chanID := uint64(i + 1) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channelPoints = append(channelPoints, &op) - - edgeInfo := ChannelEdgeInfo{ - ChannelID: chanID, - ChainHash: key, - AuthProof: &ChannelAuthProof{ - NodeSig1Bytes: testSig.Serialize(), - NodeSig2Bytes: testSig.Serialize(), - BitcoinSig1Bytes: testSig.Serialize(), - BitcoinSig2Bytes: testSig.Serialize(), - }, - ChannelPoint: op, - Capacity: 1000, - } - copy(edgeInfo.NodeKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) - copy(edgeInfo.NodeKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey1Bytes[:], graphNodes[i].PubKeyBytes[:]) - copy(edgeInfo.BitcoinKey2Bytes[:], graphNodes[i+1].PubKeyBytes[:]) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - pkScript, err := genMultiSigP2WSH( - edgeInfo.BitcoinKey1Bytes[:], edgeInfo.BitcoinKey2Bytes[:], - ) - if err != nil { - t.Fatalf("unable to gen multi-sig p2wsh: %v", err) - } - edgePoints = append(edgePoints, EdgePoint{ - FundingPkScript: pkScript, - OutPoint: op, - }) - - // Create and add an edge with random data that points from - // node_i -> node_i+1 - edge := randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 0 - edge.Node = graphNodes[i] - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Create another random edge that points from node_i+1 -> - // node_i this time. - edge = randEdgePolicy(chanID, op, db) - edge.ChannelFlags = 1 - edge.Node = graphNodes[i] - edge.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - } - - // With all the channel points added, we'll consult the graph to ensure - // it has the same channel view as the one we just constructed. - channelView, err := graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - assertChanViewEqual(t, channelView, edgePoints) - - // Now with our test graph created, we can test the pruning - // capabilities of the channel graph. - - // First we create a mock block that ends up closing the first two - // channels. - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{1}, 32)) - blockHeight := uint32(1) - block := channelPoints[:2] - prunedChans, err := graph.PruneGraph(block, &blockHash, blockHeight) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - if len(prunedChans) != 2 { - t.Fatalf("incorrect number of channels pruned: "+ - "expected %v, got %v", 2, prunedChans) - } - - // Now ensure that the prune tip has been updated. - assertPruneTip(t, graph, &blockHash, blockHeight) - - // Count up the number of channels known within the graph, only 2 - // should be remaining. - assertNumChans(t, graph, 2) - - // Those channels should also be missing from the channel view. - channelView, err = graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - assertChanViewEqualChanPoints(t, channelView, channelPoints[2:]) - - // Next we'll create a block that doesn't close any channels within the - // graph to test the negative error case. - fakeHash := sha256.Sum256([]byte("test prune")) - nonChannel := &wire.OutPoint{ - Hash: fakeHash, - Index: 9, - } - blockHash = sha256.Sum256(blockHash[:]) - blockHeight = 2 - prunedChans, err = graph.PruneGraph( - []*wire.OutPoint{nonChannel}, &blockHash, blockHeight, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // No channels should have been detected as pruned. - if len(prunedChans) != 0 { - t.Fatalf("channels were pruned but shouldn't have been") - } - - // Once again, the prune tip should have been updated. We should still - // see both channels and their participants, along with the source node. - assertPruneTip(t, graph, &blockHash, blockHeight) - assertNumChans(t, graph, 2) - assertNumNodes(t, graph, 4) - - // Finally, create a block that prunes the remainder of the channels - // from the graph. - blockHash = sha256.Sum256(blockHash[:]) - blockHeight = 3 - prunedChans, err = graph.PruneGraph( - channelPoints[2:], &blockHash, blockHeight, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // The remainder of the channels should have been pruned from the - // graph. - if len(prunedChans) != 2 { - t.Fatalf("incorrect number of channels pruned: "+ - "expected %v, got %v", 2, len(prunedChans)) - } - - // The prune tip should be updated, no channels should be found, and - // only the source node should remain within the current graph. - assertPruneTip(t, graph, &blockHash, blockHeight) - assertNumChans(t, graph, 0) - assertNumNodes(t, graph, 1) - - // Finally, the channel view at this point in the graph should now be - // completely empty. Those channels should also be missing from the - // channel view. - channelView, err = graph.ChannelView() - if err != nil { - t.Fatalf("unable to get graph channel view: %v", err) - } - if len(channelView) != 0 { - t.Fatalf("channel view should be empty, instead have: %v", - channelView) - } -} - -// TestHighestChanID tests that we're able to properly retrieve the highest -// known channel ID in the database. -func TestHighestChanID(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we don't yet have any channels in the database, then we should - // get a channel ID of zero if we ask for the highest channel ID. - bestID, err := graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - if bestID != 0 { - t.Fatalf("best ID w/ no chan should be zero, is instead: %v", - bestID) - } - - // Next, we'll insert two channels into the database, with each channel - // connecting the same two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // The first channel with be at height 10, while the other will be at - // height 100. - edge1, _ := createEdge(10, 0, 0, 0, node1, node2) - edge2, chanID2 := createEdge(100, 0, 0, 0, node1, node2) - - if err := graph.AddChannelEdge(&edge1); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - if err := graph.AddChannelEdge(&edge2); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Now that the edges has been inserted, we'll query for the highest - // known channel ID in the database. - bestID, err = graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - - if bestID != chanID2.ToUint64() { - t.Fatalf("expected %v got %v for best chan ID: ", - chanID2.ToUint64(), bestID) - } - - // If we add another edge, then the current best chan ID should be - // updated as well. - edge3, chanID3 := createEdge(1000, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edge3); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - bestID, err = graph.HighestChanID() - if err != nil { - t.Fatalf("unable to get highest ID: %v", err) - } - - if bestID != chanID3.ToUint64() { - t.Fatalf("expected %v got %v for best chan ID: ", - chanID3.ToUint64(), bestID) - } -} - -// TestChanUpdatesInHorizon tests the we're able to properly retrieve all known -// channel updates within a specific time horizon. It also tests that upon -// insertion of a new edge, the edge update index is updated properly. -func TestChanUpdatesInHorizon(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we issue an arbitrary query before any channel updates are - // inserted in the database, we should get zero results. - chanUpdates, err := graph.ChanUpdatesInHorizon( - time.Unix(999, 0), time.Unix(9999, 0), - ) - if err != nil { - t.Fatalf("unable to updates for updates: %v", err) - } - if len(chanUpdates) != 0 { - t.Fatalf("expected 0 chan updates, instead got %v", - len(chanUpdates)) - } - - // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll now create 10 channels between the two nodes, with update - // times 10 seconds after each other. - const numChans = 10 - startTime := time.Unix(1234, 0) - endTime := startTime - edges := make([]ChannelEdge, 0, numChans) - for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - edge1UpdateTime := endTime - edge2UpdateTime := edge1UpdateTime.Add(time.Second) - endTime = endTime.Add(time.Second * 10) - - edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, edge1UpdateTime.Unix(), - ) - edge1.ChannelFlags = 0 - edge1.Node = node2 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, edge2UpdateTime.Unix(), - ) - edge2.ChannelFlags = 1 - edge2.Node = node1 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edges = append(edges, ChannelEdge{ - Info: &channel, - Policy1: edge1, - Policy2: edge2, - }) - } - - // With our channels loaded, we'll now start our series of queries. - queryCases := []struct { - start time.Time - end time.Time - - resp []ChannelEdge - }{ - // If we query for a time range that's strictly below our set - // of updates, then we'll get an empty result back. - { - start: time.Unix(100, 0), - end: time.Unix(200, 0), - }, - - // If we query for a time range that's well beyond our set of - // updates, we should get an empty set of results back. - { - start: time.Unix(99999, 0), - end: time.Unix(999999, 0), - }, - - // If we query for the start time, and 10 seconds directly - // after it, we should only get a single update, that first - // one. - { - start: time.Unix(1234, 0), - end: startTime.Add(time.Second * 10), - - resp: []ChannelEdge{edges[0]}, - }, - - // If we add 10 seconds past the first update, and then - // subtract 10 from the last update, then we should only get - // the 8 edges in the middle. - { - start: startTime.Add(time.Second * 10), - end: endTime.Add(-time.Second * 10), - - resp: edges[1:9], - }, - - // If we use the start and end time as is, we should get the - // entire range. - { - start: startTime, - end: endTime, - - resp: edges, - }, - } - for _, queryCase := range queryCases { - resp, err := graph.ChanUpdatesInHorizon( - queryCase.start, queryCase.end, - ) - if err != nil { - t.Fatalf("unable to query for updates: %v", err) - } - - if len(resp) != len(queryCase.resp) { - t.Fatalf("expected %v chans, got %v chans", - len(queryCase.resp), len(resp)) - - } - - for i := 0; i < len(resp); i++ { - chanExp := queryCase.resp[i] - chanRet := resp[i] - - assertEdgeInfoEqual(t, chanExp.Info, chanRet.Info) - - err := compareEdgePolicies(chanExp.Policy1, chanRet.Policy1) - if err != nil { - t.Fatal(err) - } - compareEdgePolicies(chanExp.Policy2, chanRet.Policy2) - if err != nil { - t.Fatal(err) - } - } - } -} - -// TestNodeUpdatesInHorizon tests that we're able to properly scan and retrieve -// the most recent node updates within a particular time horizon. -func TestNodeUpdatesInHorizon(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - startTime := time.Unix(1234, 0) - endTime := startTime - - // If we issue an arbitrary query before we insert any nodes into the - // database, then we shouldn't get any results back. - nodeUpdates, err := graph.NodeUpdatesInHorizon( - time.Unix(999, 0), time.Unix(9999, 0), - ) - if err != nil { - t.Fatalf("unable to query for node updates: %v", err) - } - if len(nodeUpdates) != 0 { - t.Fatalf("expected 0 node updates, instead got %v", - len(nodeUpdates)) - } - - // We'll create 10 node announcements, each with an update timestamp 10 - // seconds after the other. - const numNodes = 10 - nodeAnns := make([]LightningNode, 0, numNodes) - for i := 0; i < numNodes; i++ { - nodeAnn, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - - // The node ann will use the current end time as its last - // update them, then we'll add 10 seconds in order to create - // the proper update time for the next node announcement. - updateTime := endTime - endTime = updateTime.Add(time.Second * 10) - - nodeAnn.LastUpdate = updateTime - - nodeAnns = append(nodeAnns, *nodeAnn) - - if err := graph.AddLightningNode(nodeAnn); err != nil { - t.Fatalf("unable to add lightning node: %v", err) - } - } - - queryCases := []struct { - start time.Time - end time.Time - - resp []LightningNode - }{ - // If we query for a time range that's strictly below our set - // of updates, then we'll get an empty result back. - { - start: time.Unix(100, 0), - end: time.Unix(200, 0), - }, - - // If we query for a time range that's well beyond our set of - // updates, we should get an empty set of results back. - { - start: time.Unix(99999, 0), - end: time.Unix(999999, 0), - }, - - // If we skip he first time epoch with out start time, then we - // should get back every now but the first. - { - start: startTime.Add(time.Second * 10), - end: endTime, - - resp: nodeAnns[1:], - }, - - // If we query for the range as is, we should get all 10 - // announcements back. - { - start: startTime, - end: endTime, - - resp: nodeAnns, - }, - - // If we reduce the ending time by 10 seconds, then we should - // get all but the last node we inserted. - { - start: startTime, - end: endTime.Add(-time.Second * 10), - - resp: nodeAnns[:9], - }, - } - for _, queryCase := range queryCases { - resp, err := graph.NodeUpdatesInHorizon(queryCase.start, queryCase.end) - if err != nil { - t.Fatalf("unable to query for nodes: %v", err) - } - - if len(resp) != len(queryCase.resp) { - t.Fatalf("expected %v nodes, got %v nodes", - len(queryCase.resp), len(resp)) - - } - - for i := 0; i < len(resp); i++ { - err := compareNodes(&queryCase.resp[i], &resp[i]) - if err != nil { - t.Fatal(err) - } - } - } -} - -// TestFilterKnownChanIDs tests that we're able to properly perform the set -// differences of an incoming set of channel ID's, and those that we already -// know of on disk. -func TestFilterKnownChanIDs(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // If we try to filter out a set of channel ID's before we even know of - // any channels, then we should get the entire set back. - preChanIDs := []uint64{1, 2, 3, 4} - filteredIDs, err := graph.FilterKnownChanIDs(preChanIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) - } - if !reflect.DeepEqual(preChanIDs, filteredIDs) { - t.Fatalf("chan IDs shouldn't have been filtered!") - } - - // We'll start by creating two nodes which will seed our test graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Next, we'll add 5 channel ID's to the graph, each of them having a - // block height 10 blocks after the previous. - const numChans = 5 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanIDs = append(chanIDs, chanID.ToUint64()) - } - - const numZombies = 5 - zombieIDs := make([]uint64, 0, numZombies) - for i := 0; i < numZombies; i++ { - channel, chanID := createEdge( - uint32(i*10+1), 0, 0, 0, node1, node2, - ) - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - err := graph.DeleteChannelEdges(channel.ChannelID) - if err != nil { - t.Fatalf("unable to mark edge zombie: %v", err) - } - - zombieIDs = append(zombieIDs, chanID.ToUint64()) - } - - queryCases := []struct { - queryIDs []uint64 - - resp []uint64 - }{ - // If we attempt to filter out all chanIDs we know of, the - // response should be the empty set. - { - queryIDs: chanIDs, - }, - // If we attempt to filter out all zombies that we know of, the - // response should be the empty set. - { - queryIDs: zombieIDs, - }, - - // If we query for a set of ID's that we didn't insert, we - // should get the same set back. - { - queryIDs: []uint64{99, 100}, - resp: []uint64{99, 100}, - }, - - // If we query for a super-set of our the chan ID's inserted, - // we should only get those new chanIDs back. - { - queryIDs: append(chanIDs, []uint64{99, 101}...), - resp: []uint64{99, 101}, - }, - } - - for _, queryCase := range queryCases { - resp, err := graph.FilterKnownChanIDs(queryCase.queryIDs) - if err != nil { - t.Fatalf("unable to filter chan IDs: %v", err) - } - - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("expected %v, got %v", spew.Sdump(queryCase.resp), - spew.Sdump(resp)) - } - } -} - -// TestFilterChannelRange tests that we're able to properly retrieve the full -// set of short channel ID's for a given block range. -func TestFilterChannelRange(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // If we try to filter a channel range before we have any channels - // inserted, we should get an empty slice of results. - resp, err := graph.FilterChannelRange(10, 100) - if err != nil { - t.Fatalf("unable to filter channels: %v", err) - } - if len(resp) != 0 { - t.Fatalf("expected zero chans, instead got %v", len(resp)) - } - - // To start, we'll create a set of channels, each mined in a block 10 - // blocks after the prior one. - startHeight := uint32(100) - endHeight := startHeight - const numChans = 10 - chanIDs := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - chanHeight := endHeight - channel, chanID := createEdge( - uint32(chanHeight), uint32(i+1), 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanIDs = append(chanIDs, chanID.ToUint64()) - - endHeight += 10 - } - - // With our channels inserted, we'll construct a series of queries that - // we'll execute below in order to exercise the features of the - // FilterKnownChanIDs method. - queryCases := []struct { - startHeight uint32 - endHeight uint32 - - resp []uint64 - }{ - // If we query for the entire range, then we should get the same - // set of short channel IDs back. - { - startHeight: startHeight, - endHeight: endHeight, - - resp: chanIDs, - }, - - // If we query for a range of channels right before our range, we - // shouldn't get any results back. - { - startHeight: 0, - endHeight: 10, - }, - - // If we only query for the last height (range wise), we should - // only get that last channel. - { - startHeight: endHeight - 10, - endHeight: endHeight - 10, - - resp: chanIDs[9:], - }, - - // If we query for just the first height, we should only get a - // single channel back (the first one). - { - startHeight: startHeight, - endHeight: startHeight, - - resp: chanIDs[:1], - }, - } - for i, queryCase := range queryCases { - resp, err := graph.FilterChannelRange( - queryCase.startHeight, queryCase.endHeight, - ) - if err != nil { - t.Fatalf("unable to issue range query: %v", err) - } - - if !reflect.DeepEqual(resp, queryCase.resp) { - t.Fatalf("case #%v: expected %v, got %v", i, - queryCase.resp, resp) - } - } -} - -// TestFetchChanInfos tests that we're able to properly retrieve the full set -// of ChannelEdge structs for a given set of short channel ID's. -func TestFetchChanInfos(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll make 5 test channels, ensuring we keep track of which channel - // ID corresponds to a particular ChannelEdge. - const numChans = 5 - startTime := time.Unix(1234, 0) - endTime := startTime - edges := make([]ChannelEdge, 0, numChans) - edgeQuery := make([]uint64, 0, numChans) - for i := 0; i < numChans; i++ { - txHash := sha256.Sum256([]byte{byte(i)}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(i*10), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - updateTime := endTime - endTime = updateTime.Add(time.Second * 10) - - edge1 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edge1.ChannelFlags = 0 - edge1.Node = node2 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edge2.ChannelFlags = 1 - edge2.Node = node1 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edges = append(edges, ChannelEdge{ - Info: &channel, - Policy1: edge1, - Policy2: edge2, - }) - - edgeQuery = append(edgeQuery, chanID.ToUint64()) - } - - // Add an additional edge that does not exist. The query should skip - // this channel and return only infos for the edges that exist. - edgeQuery = append(edgeQuery, 500) - - // Add an another edge to the query that has been marked as a zombie - // edge. The query should also skip this channel. - zombieChan, zombieChanID := createEdge( - 666, 0, 0, 0, node1, node2, - ) - if err := graph.AddChannelEdge(&zombieChan); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - err = graph.DeleteChannelEdges(zombieChan.ChannelID) - if err != nil { - t.Fatalf("unable to delete and mark edge zombie: %v", err) - } - edgeQuery = append(edgeQuery, zombieChanID.ToUint64()) - - // We'll now attempt to query for the range of channel ID's we just - // inserted into the database. We should get the exact same set of - // edges back. - resp, err := graph.FetchChanInfos(edgeQuery) - if err != nil { - t.Fatalf("unable to fetch chan edges: %v", err) - } - if len(resp) != len(edges) { - t.Fatalf("expected %v edges, instead got %v", len(edges), - len(resp)) - } - - for i := 0; i < len(resp); i++ { - err := compareEdgePolicies(resp[i].Policy1, edges[i].Policy1) - if err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - err = compareEdgePolicies(resp[i].Policy2, edges[i].Policy2) - if err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, resp[i].Info, edges[i].Info) - } -} - -// TestIncompleteChannelPolicies tests that a channel that only has a policy -// specified on one end is properly returned in ForEachChannel calls from -// both sides. -func TestIncompleteChannelPolicies(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // Create two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create channel between nodes. - txHash := sha256.Sum256([]byte{0}) - op := wire.OutPoint{ - Hash: txHash, - Index: 0, - } - - channel, chanID := createEdge( - uint32(0), 0, 0, 0, node1, node2, - ) - - if err := graph.AddChannelEdge(&channel); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure that channel is reported with unknown policies. - - checkPolicies := func(node *LightningNode, expectedIn, expectedOut bool) { - calls := 0 - node.ForEachChannel(nil, func(_ *bbolt.Tx, _ *ChannelEdgeInfo, - outEdge, inEdge *ChannelEdgePolicy) error { - - if !expectedOut && outEdge != nil { - t.Fatalf("Expected no outgoing policy") - } - - if expectedOut && outEdge == nil { - t.Fatalf("Expected an outgoing policy") - } - - if !expectedIn && inEdge != nil { - t.Fatalf("Expected no incoming policy") - } - - if expectedIn && inEdge == nil { - t.Fatalf("Expected an incoming policy") - } - - calls++ - - return nil - }) - - if calls != 1 { - t.Fatalf("Expected only one callback call") - } - } - - checkPolicies(node2, false, false) - - // Only create an edge policy for node1 and leave the policy for node2 - // unknown. - updateTime := time.Unix(1234, 0) - - edgePolicy := newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edgePolicy.ChannelFlags = 0 - edgePolicy.Node = node2 - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - checkPolicies(node1, false, true) - checkPolicies(node2, true, false) - - // Create second policy and assert that both policies are reported - // as present. - edgePolicy = newEdgePolicy( - chanID.ToUint64(), op, db, updateTime.Unix(), - ) - edgePolicy.ChannelFlags = 1 - edgePolicy.Node = node1 - edgePolicy.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edgePolicy); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - checkPolicies(node1, true, true) - checkPolicies(node2, true, true) -} - -// TestChannelEdgePruningUpdateIndexDeletion tests that once edges are deleted -// from the graph, then their entries within the update index are also cleaned -// up. -func TestChannelEdgePruningUpdateIndexDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // We'll first populate our graph with two nodes. All channels created - // below will be made between these two nodes. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // With the two nodes created, we'll now create a random channel, as - // well as two edges in the database with distinct update times. - edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge1.ChannelFlags = 0 - edge1.Node = node1 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - edge2 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge2.ChannelFlags = 1 - edge2.Node = node2 - edge2.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // checkIndexTimestamps is a helper function that checks the edge update - // index only includes the given timestamps. - checkIndexTimestamps := func(timestamps ...uint64) { - timestampSet := make(map[uint64]struct{}) - for _, t := range timestamps { - timestampSet[t] = struct{}{} - } - - err := db.View(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrGraphNoEdgesFound - } - edgeUpdateIndex := edges.Bucket(edgeUpdateIndexBucket) - if edgeUpdateIndex == nil { - return ErrGraphNoEdgesFound - } - - numEntries := edgeUpdateIndex.Stats().KeyN - expectedEntries := len(timestampSet) - if numEntries != expectedEntries { - return fmt.Errorf("expected %v entries in the "+ - "update index, got %v", expectedEntries, - numEntries) - } - - return edgeUpdateIndex.ForEach(func(k, _ []byte) error { - t := byteOrder.Uint64(k[:8]) - if _, ok := timestampSet[t]; !ok { - return fmt.Errorf("found unexpected "+ - "timestamp "+"%d", t) - } - - return nil - }) - }) - if err != nil { - t.Fatal(err) - } - } - - // With both edges policies added, we'll make sure to check they exist - // within the edge update index. - checkIndexTimestamps( - uint64(edge1.LastUpdate.Unix()), - uint64(edge2.LastUpdate.Unix()), - ) - - // Now, we'll update the edge policies to ensure the old timestamps are - // removed from the update index. - edge1.ChannelFlags = 2 - edge1.LastUpdate = time.Now() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - edge2.ChannelFlags = 3 - edge2.LastUpdate = edge1.LastUpdate.Add(time.Hour) - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // With the policies updated, we should now be able to find their - // updated entries within the update index. - checkIndexTimestamps( - uint64(edge1.LastUpdate.Unix()), - uint64(edge2.LastUpdate.Unix()), - ) - - // Now we'll prune the graph, removing the edges, and also the update - // index entries from the database all together. - var blockHash chainhash.Hash - copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) - _, err = graph.PruneGraph( - []*wire.OutPoint{&edgeInfo.ChannelPoint}, &blockHash, 101, - ) - if err != nil { - t.Fatalf("unable to prune graph: %v", err) - } - - // Finally, we'll check the database state one last time to conclude - // that we should no longer be able to locate _any_ entries within the - // edge update index. - checkIndexTimestamps() -} - -// TestPruneGraphNodes tests that unconnected vertexes are pruned via the -// PruneSyncState method. -func TestPruneGraphNodes(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - // We'll start off by inserting our source node, to ensure that it's - // the only node left after we prune the graph. - graph := db.ChannelGraph() - sourceNode, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create source node: %v", err) - } - if err := graph.SetSourceNode(sourceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - // With the source node inserted, we'll now add three nodes to the - // channel graph, at the end of the scenario, only two of these nodes - // should still be in the graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node3, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node3); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll now add a new edge to the graph, but only actually advertise - // the edge of *one* of the nodes. - edgeInfo, chanID := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - // We'll now insert an advertised edge, but it'll only be the edge that - // points from the first to the second node. - edge1 := randEdgePolicy(chanID.ToUint64(), edgeInfo.ChannelPoint, db) - edge1.ChannelFlags = 0 - edge1.Node = node1 - edge1.SigBytes = testSig.Serialize() - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // We'll now initiate a around of graph pruning. - if err := graph.PruneGraphNodes(); err != nil { - t.Fatalf("unable to prune graph nodes: %v", err) - } - - // At this point, there should be 3 nodes left in the graph still: the - // source node (which can't be pruned), and node 1+2. Nodes 1 and two - // should still be left in the graph as there's half of an advertised - // edge between them. - assertNumNodes(t, graph, 3) - - // Finally, we'll ensure that node3, the only fully unconnected node as - // properly deleted from the graph and not another node in its place. - node3Pub, err := node3.PubKey() - if err != nil { - t.Fatalf("unable to fetch the pubkey of node3: %v", err) - } - if _, err := graph.FetchLightningNode(node3Pub); err == nil { - t.Fatalf("node 3 should have been deleted!") - } -} - -// TestAddChannelEdgeShellNodes tests that when we attempt to add a ChannelEdge -// to the graph, one or both of the nodes the edge involves aren't found in the -// database, then shell edges are created for each node if needed. -func TestAddChannelEdgeShellNodes(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // To start, we'll create two nodes, and only add one of them to the - // channel graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - // We'll now create an edge between the two nodes, as a result, node2 - // should be inserted into the database as a shell node. - edgeInfo, _ := createEdge(100, 0, 0, 0, node1, node2) - if err := graph.AddChannelEdge(&edgeInfo); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - - node1Pub, err := node1.PubKey() - if err != nil { - t.Fatalf("unable to parse node 1 pub: %v", err) - } - node2Pub, err := node2.PubKey() - if err != nil { - t.Fatalf("unable to parse node 2 pub: %v", err) - } - - // Ensure that node1 was inserted as a full node, while node2 only has - // a shell node present. - node1, err = graph.FetchLightningNode(node1Pub) - if err != nil { - t.Fatalf("unable to fetch node1: %v", err) - } - if !node1.HaveNodeAnnouncement { - t.Fatalf("have shell announcement for node1, shouldn't") - } - - node2, err = graph.FetchLightningNode(node2Pub) - if err != nil { - t.Fatalf("unable to fetch node2: %v", err) - } - if node2.HaveNodeAnnouncement { - t.Fatalf("should have shell announcement for node2, but is full") - } -} - -// TestNodePruningUpdateIndexDeletion tests that once a node has been removed -// from the channel graph, we also remove the entry from the update index as -// well. -func TestNodePruningUpdateIndexDeletion(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'll first populate our graph with a single node that will be - // removed shortly. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // We'll confirm that we can retrieve the node using - // NodeUpdatesInHorizon, using a time that's slightly beyond the last - // update time of our test node. - startTime := time.Unix(9, 0) - endTime := node1.LastUpdate.Add(time.Minute) - nodesInHorizon, err := graph.NodeUpdatesInHorizon(startTime, endTime) - if err != nil { - t.Fatalf("unable to fetch nodes in horizon: %v", err) - } - - // We should only have a single node, and that node should exactly - // match the node we just inserted. - if len(nodesInHorizon) != 1 { - t.Fatalf("should have 1 nodes instead have: %v", - len(nodesInHorizon)) - } - if err := compareNodes(node1, &nodesInHorizon[0]); err != nil { - t.Fatalf("nodes don't match: %v", err) - } - - // We'll now delete the node from the graph, this should result in it - // being removed from the update index as well. - nodePub, _ := node1.PubKey() - if err := graph.DeleteLightningNode(nodePub); err != nil { - t.Fatalf("unable to delete node: %v", err) - } - - // Now that the node has been deleted, we'll again query the nodes in - // the horizon. This time we should have no nodes at all. - nodesInHorizon, err = graph.NodeUpdatesInHorizon(startTime, endTime) - if err != nil { - t.Fatalf("unable to fetch nodes in horizon: %v", err) - } - - if len(nodesInHorizon) != 0 { - t.Fatalf("should have zero nodes instead have: %v", - len(nodesInHorizon)) - } -} - -// TestNodeIsPublic ensures that we properly detect nodes that are seen as -// public within the network graph. -func TestNodeIsPublic(t *testing.T) { - t.Parallel() - - // We'll start off the test by creating a small network of 3 - // participants with the following graph: - // - // Alice <-> Bob <-> Carol - // - // We'll need to create a separate database and channel graph for each - // participant to replicate real-world scenarios (private edges being in - // some graphs but not others, etc.). - aliceDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - aliceNode, err := createTestVertex(aliceDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - aliceGraph := aliceDB.ChannelGraph() - if err := aliceGraph.SetSourceNode(aliceNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - bobDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - bobNode, err := createTestVertex(bobDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - bobGraph := bobDB.ChannelGraph() - if err := bobGraph.SetSourceNode(bobNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - carolDB, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - carolNode, err := createTestVertex(carolDB) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - carolGraph := carolDB.ChannelGraph() - if err := carolGraph.SetSourceNode(carolNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } - - aliceBobEdge, _ := createEdge(10, 0, 0, 0, aliceNode, bobNode) - bobCarolEdge, _ := createEdge(10, 1, 0, 1, bobNode, carolNode) - - // After creating all of our nodes and edges, we'll add them to each - // participant's graph. - nodes := []*LightningNode{aliceNode, bobNode, carolNode} - edges := []*ChannelEdgeInfo{&aliceBobEdge, &bobCarolEdge} - dbs := []*DB{aliceDB, bobDB, carolDB} - graphs := []*ChannelGraph{aliceGraph, bobGraph, carolGraph} - for i, graph := range graphs { - for _, node := range nodes { - node.db = dbs[i] - if err := graph.AddLightningNode(node); err != nil { - t.Fatalf("unable to add node: %v", err) - } - } - for _, edge := range edges { - edge.db = dbs[i] - if err := graph.AddChannelEdge(edge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - } - } - - // checkNodes is a helper closure that will be used to assert that the - // given nodes are seen as public/private within the given graphs. - checkNodes := func(nodes []*LightningNode, graphs []*ChannelGraph, - public bool) { - - t.Helper() - - for _, node := range nodes { - for _, graph := range graphs { - isPublic, err := graph.IsPublicNode(node.PubKeyBytes) - if err != nil { - t.Fatalf("unable to determine if pivot "+ - "is public: %v", err) - } - - switch { - case isPublic && !public: - t.Fatalf("expected %x to be private", - node.PubKeyBytes) - case !isPublic && public: - t.Fatalf("expected %x to be public", - node.PubKeyBytes) - } - } - } - } - - // Due to the way the edges were set up above, we'll make sure each node - // can correctly determine that every other node is public. - checkNodes(nodes, graphs, true) - - // Now, we'll remove the edge between Alice and Bob from everyone's - // graph. This will make Alice be seen as a private node as it no longer - // has any advertised edges. - for _, graph := range graphs { - err := graph.DeleteChannelEdges(aliceBobEdge.ChannelID) - if err != nil { - t.Fatalf("unable to remove edge: %v", err) - } - } - checkNodes( - []*LightningNode{aliceNode}, - []*ChannelGraph{bobGraph, carolGraph}, - false, - ) - - // We'll also make the edge between Bob and Carol private. Within Bob's - // and Carol's graph, the edge will exist, but it will not have a proof - // that allows it to be advertised. Within Alice's graph, we'll - // completely remove the edge as it is not possible for her to know of - // it without it being advertised. - for i, graph := range graphs { - err := graph.DeleteChannelEdges(bobCarolEdge.ChannelID) - if err != nil { - t.Fatalf("unable to remove edge: %v", err) - } - - if graph == aliceGraph { - continue - } - - bobCarolEdge.AuthProof = nil - bobCarolEdge.db = dbs[i] - if err := graph.AddChannelEdge(&bobCarolEdge); err != nil { - t.Fatalf("unable to add edge: %v", err) - } - } - - // With the modifications above, Bob should now be seen as a private - // node from both Alice's and Carol's perspective. - checkNodes( - []*LightningNode{bobNode}, - []*ChannelGraph{aliceGraph, carolGraph}, - false, - ) -} - -// TestDisabledChannelIDs ensures that the disabled channels within the -// disabledEdgePolicyBucket are managed properly and the list returned from -// DisabledChannelIDs is correct. -func TestDisabledChannelIDs(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - graph := db.ChannelGraph() - - // Create first node and add it to the graph. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Create second node and add it to the graph. - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - // Adding a new channel edge to the graph. - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Ensure no disabled channels exist in the bucket on start. - disabledChanIds, err := graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } - - // Add one disabled policy and ensure the channel is still not in the - // disabled list. - edge1.ChannelFlags |= lnwire.ChanUpdateDisabled - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } - - // Add second disabled policy and ensure the channel is now in the - // disabled list. - edge2.ChannelFlags |= lnwire.ChanUpdateDisabled - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) != 1 || disabledChanIds[0] != edgeInfo.ChannelID { - t.Fatalf("expected disabled channel with id %v, "+ - "got %v", edgeInfo.ChannelID, disabledChanIds) - } - - // Delete the channel edge and ensure it is removed from the disabled list. - if err = graph.DeleteChannelEdges(edgeInfo.ChannelID); err != nil { - t.Fatalf("unable to delete channel edge: %v", err) - } - disabledChanIds, err = graph.DisabledChannelIDs() - if err != nil { - t.Fatalf("unable to get disabled channel ids: %v", err) - } - if len(disabledChanIds) > 0 { - t.Fatalf("expected empty disabled channels, got %v disabled channels", - len(disabledChanIds)) - } -} - -// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in -// the DB that indicates that it should support the htlc_maximum_value_msat -// field, but it is not part of the opaque data, then we'll handle it as it is -// unknown. It also checks that we are correctly able to overwrite it when we -// receive the proper update. -func TestEdgePolicyMissingMaxHtcl(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - - graph := db.ChannelGraph() - - // We'd like to test the update of edges inserted into the database, so - // we create two vertexes to connect. - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - if err := graph.AddLightningNode(node1); err != nil { - t.Fatalf("unable to add node: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test node: %v", err) - } - - edgeInfo, edge1, edge2 := createChannelEdge(db, node1, node2) - if err := graph.AddLightningNode(node2); err != nil { - t.Fatalf("unable to add node: %v", err) - } - if err := graph.AddChannelEdge(edgeInfo); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - chanID := edgeInfo.ChannelID - from := edge2.Node.PubKeyBytes[:] - to := edge1.Node.PubKeyBytes[:] - - // We'll remove the no max_htlc field from the first edge policy, and - // all other opaque data, and serialize it. - edge1.MessageFlags = 0 - edge1.ExtraOpaqueData = nil - - var b bytes.Buffer - err = serializeChanEdgePolicy(&b, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } - - // Set the max_htlc field. The extra bytes added to the serialization - // will be the opaque data containing the serialized field. - edge1.MessageFlags = lnwire.ChanUpdateOptionMaxHtlc - edge1.MaxHTLC = 13928598 - var b2 bytes.Buffer - err = serializeChanEdgePolicy(&b2, edge1, to) - if err != nil { - t.Fatalf("unable to serialize policy") - } - - withMaxHtlc := b2.Bytes() - - // Remove the opaque data from the serialization. - stripped := withMaxHtlc[:len(b.Bytes())] - - // Attempting to deserialize these bytes should return an error. - r := bytes.NewReader(stripped) - err = db.View(func(tx *bbolt.Tx) error { - nodes := tx.Bucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - _, err = deserializeChanEdgePolicy(r, nodes) - if err != ErrEdgePolicyOptionalFieldNotFound { - t.Fatalf("expected "+ - "ErrEdgePolicyOptionalFieldNotFound, got %v", - err) - } - - return nil - }) - if err != nil { - t.Fatalf("error reading db: %v", err) - } - - // Put the stripped bytes in the DB. - err = db.Update(func(tx *bbolt.Tx) error { - edges := tx.Bucket(edgeBucket) - if edges == nil { - return ErrEdgeNotFound - } - - edgeIndex := edges.Bucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - - var edgeKey [33 + 8]byte - copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID) - - var scratch [8]byte - var indexKey [8 + 8]byte - copy(indexKey[:], scratch[:]) - byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) - - updateIndex, err := edges.CreateBucketIfNotExists(edgeUpdateIndexBucket) - if err != nil { - return err - } - - if err := updateIndex.Put(indexKey[:], nil); err != nil { - return err - } - - return edges.Put(edgeKey[:], stripped) - }) - if err != nil { - t.Fatalf("error writing db: %v", err) - } - - // And add the second, unmodified edge. - if err := graph.UpdateEdgePolicy(edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - // Attempt to fetch the edge and policies from the DB. Since the policy - // we added is invalid according to the new format, it should be as we - // are not aware of the policy (indicated by the policy returned being - // nil) - dbEdgeInfo, dbEdge1, dbEdge2, err := graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - - // The first edge should have a nil-policy returned - if dbEdge1 != nil { - t.Fatalf("expected db edge to be nil") - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) - - // Now add the original, unmodified edge policy, and make sure the edge - // policies then become fully populated. - if err := graph.UpdateEdgePolicy(edge1); err != nil { - t.Fatalf("unable to update edge: %v", err) - } - - dbEdgeInfo, dbEdge1, dbEdge2, err = graph.FetchChannelEdgesByID(chanID) - if err != nil { - t.Fatalf("unable to fetch channel by ID: %v", err) - } - if err := compareEdgePolicies(dbEdge1, edge1); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - if err := compareEdgePolicies(dbEdge2, edge2); err != nil { - t.Fatalf("edge doesn't match: %v", err) - } - assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) -} - -// assertNumZombies queries the provided ChannelGraph for NumZombies, and -// asserts that the returned number is equal to expZombies. -func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { - t.Helper() - - numZombies, err := graph.NumZombies() - if err != nil { - t.Fatalf("unable to query number of zombies: %v", err) - } - - if numZombies != expZombies { - t.Fatalf("expected %d zombies, found %d", - expZombies, numZombies) - } -} - -// TestGraphZombieIndex ensures that we can mark edges correctly as zombie/live. -func TestGraphZombieIndex(t *testing.T) { - t.Parallel() - - // We'll start by creating our test graph along with a test edge. - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to create test database: %v", err) - } - graph := db.ChannelGraph() - - node1, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - node2, err := createTestVertex(db) - if err != nil { - t.Fatalf("unable to create test vertex: %v", err) - } - - // Swap the nodes if the second's pubkey is smaller than the first. - // Without this, the comparisons at the end will fail probabilistically. - if bytes.Compare(node2.PubKeyBytes[:], node1.PubKeyBytes[:]) < 0 { - node1, node2 = node2, node1 - } - - edge, _, _ := createChannelEdge(db, node1, node2) - if err := graph.AddChannelEdge(edge); err != nil { - t.Fatalf("unable to create channel edge: %v", err) - } - - // Since the edge is known the graph and it isn't a zombie, IsZombieEdge - // should not report the channel as a zombie. - isZombie, _, _ := graph.IsZombieEdge(edge.ChannelID) - if isZombie { - t.Fatal("expected edge to not be marked as zombie") - } - assertNumZombies(t, graph, 0) - - // If we delete the edge and mark it as a zombie, then we should expect - // to see it within the index. - err = graph.DeleteChannelEdges(edge.ChannelID) - if err != nil { - t.Fatalf("unable to mark edge as zombie: %v", err) - } - isZombie, pubKey1, pubKey2 := graph.IsZombieEdge(edge.ChannelID) - if !isZombie { - t.Fatal("expected edge to be marked as zombie") - } - if pubKey1 != node1.PubKeyBytes { - t.Fatalf("expected pubKey1 %x, got %x", node1.PubKeyBytes, - pubKey1) - } - if pubKey2 != node2.PubKeyBytes { - t.Fatalf("expected pubKey2 %x, got %x", node2.PubKeyBytes, - pubKey2) - } - assertNumZombies(t, graph, 1) - - // Similarly, if we mark the same edge as live, we should no longer see - // it within the index. - if err := graph.MarkEdgeLive(edge.ChannelID); err != nil { - t.Fatalf("unable to mark edge as live: %v", err) - } - isZombie, _, _ = graph.IsZombieEdge(edge.ChannelID) - if isZombie { - t.Fatal("expected edge to not be marked as zombie") - } - assertNumZombies(t, graph, 0) -} - -// compareNodes is used to compare two LightningNodes while excluding the -// Features struct, which cannot be compared as the semantics for reserializing -// the featuresMap have not been defined. -func compareNodes(a, b *LightningNode) error { - if a.LastUpdate != b.LastUpdate { - return fmt.Errorf("node LastUpdate doesn't match: expected %v, \n"+ - "got %v", a.LastUpdate, b.LastUpdate) - } - if !reflect.DeepEqual(a.Addresses, b.Addresses) { - return fmt.Errorf("Addresses doesn't match: expected %#v, \n "+ - "got %#v", a.Addresses, b.Addresses) - } - if !reflect.DeepEqual(a.PubKeyBytes, b.PubKeyBytes) { - return fmt.Errorf("PubKey doesn't match: expected %#v, \n "+ - "got %#v", a.PubKeyBytes, b.PubKeyBytes) - } - if !reflect.DeepEqual(a.Color, b.Color) { - return fmt.Errorf("Color doesn't match: expected %#v, \n "+ - "got %#v", a.Color, b.Color) - } - if !reflect.DeepEqual(a.Alias, b.Alias) { - return fmt.Errorf("Alias doesn't match: expected %#v, \n "+ - "got %#v", a.Alias, b.Alias) - } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } - if !reflect.DeepEqual(a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) { - return fmt.Errorf("HaveNodeAnnouncement doesn't match: expected %#v, \n "+ - "got %#v", a.HaveNodeAnnouncement, b.HaveNodeAnnouncement) - } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } - - return nil -} - -// compareEdgePolicies is used to compare two ChannelEdgePolices using -// compareNodes, so as to exclude comparisons of the Nodes' Features struct. -func compareEdgePolicies(a, b *ChannelEdgePolicy) error { - if a.ChannelID != b.ChannelID { - return fmt.Errorf("ChannelID doesn't match: expected %v, "+ - "got %v", a.ChannelID, b.ChannelID) - } - if !reflect.DeepEqual(a.LastUpdate, b.LastUpdate) { - return fmt.Errorf("edge LastUpdate doesn't match: expected %#v, \n "+ - "got %#v", a.LastUpdate, b.LastUpdate) - } - if a.MessageFlags != b.MessageFlags { - return fmt.Errorf("MessageFlags doesn't match: expected %v, "+ - "got %v", a.MessageFlags, b.MessageFlags) - } - if a.ChannelFlags != b.ChannelFlags { - return fmt.Errorf("ChannelFlags doesn't match: expected %v, "+ - "got %v", a.ChannelFlags, b.ChannelFlags) - } - if a.TimeLockDelta != b.TimeLockDelta { - return fmt.Errorf("TimeLockDelta doesn't match: expected %v, "+ - "got %v", a.TimeLockDelta, b.TimeLockDelta) - } - if a.MinHTLC != b.MinHTLC { - return fmt.Errorf("MinHTLC doesn't match: expected %v, "+ - "got %v", a.MinHTLC, b.MinHTLC) - } - if a.MaxHTLC != b.MaxHTLC { - return fmt.Errorf("MaxHTLC doesn't match: expected %v, "+ - "got %v", a.MaxHTLC, b.MaxHTLC) - } - if a.FeeBaseMSat != b.FeeBaseMSat { - return fmt.Errorf("FeeBaseMSat doesn't match: expected %v, "+ - "got %v", a.FeeBaseMSat, b.FeeBaseMSat) - } - if a.FeeProportionalMillionths != b.FeeProportionalMillionths { - return fmt.Errorf("FeeProportionalMillionths doesn't match: "+ - "expected %v, got %v", a.FeeProportionalMillionths, - b.FeeProportionalMillionths) - } - if !bytes.Equal(a.ExtraOpaqueData, b.ExtraOpaqueData) { - return fmt.Errorf("extra data doesn't match: %v vs %v", - a.ExtraOpaqueData, b.ExtraOpaqueData) - } - if err := compareNodes(a.Node, b.Node); err != nil { - return err - } - if !reflect.DeepEqual(a.db, b.db) { - return fmt.Errorf("db doesn't match: expected %#v, \n "+ - "got %#v", a.db, b.db) - } - return nil -} - -// TestLightningNodeSigVerifcation checks that we can use the LightningNode's -// pubkey to verify signatures. -func TestLightningNodeSigVerification(t *testing.T) { - t.Parallel() - - // Create some dummy data to sign. - var data [32]byte - if _, err := prand.Read(data[:]); err != nil { - t.Fatalf("unable to read prand: %v", err) - } - - // Create private key and sign the data with it. - priv, err := btcec.NewPrivateKey(btcec.S256()) - if err != nil { - t.Fatalf("unable to crete priv key: %v", err) - } - - sign, err := priv.Sign(data[:]) - if err != nil { - t.Fatalf("unable to sign: %v", err) - } - - // Sanity check that the signature checks out. - if !sign.Verify(data[:], priv.PubKey()) { - t.Fatalf("signature doesn't check out") - } - - // Create a LightningNode from the same private key. - db, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - node, err := createLightningNode(db, priv) - if err != nil { - t.Fatalf("unable to create node: %v", err) - } - - // And finally check that we can verify the same signature from the - // pubkey returned from the lightning node. - nodePub, err := node.PubKey() - if err != nil { - t.Fatalf("unable to get pubkey: %v", err) - } - - if !sign.Verify(data[:], nodePub) { - t.Fatalf("unable to verify sig") - } -} - -// TestComputeFee tests fee calculation based on both in- and outgoing amt. -func TestComputeFee(t *testing.T) { - var ( - policy = ChannelEdgePolicy{ - FeeBaseMSat: 10000, - FeeProportionalMillionths: 30000, - } - outgoingAmt = lnwire.MilliSatoshi(1000000) - expectedFee = lnwire.MilliSatoshi(40000) - ) - - fee := policy.ComputeFee(outgoingAmt) - if fee != expectedFee { - t.Fatalf("expected fee %v, got %v", expectedFee, fee) - } - - fwdFee := policy.ComputeFeeFromIncoming(outgoingAmt + fee) - if fwdFee != expectedFee { - t.Fatalf("expected fee %v, but got %v", fee, fwdFee) - } -} diff --git a/channeldb/migration_01_to_11/invoice_test.go b/channeldb/migration_01_to_11/invoice_test.go deleted file mode 100644 index 795fe493..00000000 --- a/channeldb/migration_01_to_11/invoice_test.go +++ /dev/null @@ -1,694 +0,0 @@ -package migration_01_to_11 - -import ( - "crypto/rand" - "reflect" - "testing" - "time" - - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lnwire" -) - -func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { - var pre [32]byte - if _, err := rand.Read(pre[:]); err != nil { - return nil, err - } - - i := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - Terms: ContractTerm{ - PaymentPreimage: pre, - Value: value, - }, - Htlcs: map[CircuitKey]*InvoiceHTLC{}, - Expiry: 4000, - } - i.Memo = []byte("memo") - i.Receipt = []byte("receipt") - - // Create a random byte slice of MaxPaymentRequestSize bytes to be used - // as a dummy paymentrequest, and determine if it should be set based - // on one of the random bytes. - var r [MaxPaymentRequestSize]byte - if _, err := rand.Read(r[:]); err != nil { - return nil, err - } - if r[0]&1 == 0 { - i.PaymentRequest = r[:] - } else { - i.PaymentRequest = []byte("") - } - - return i, nil -} - -func TestInvoiceWorkflow(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // Create a fake invoice which we'll use several times in the tests - // below. - fakeInvoice := &Invoice{ - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - Htlcs: map[CircuitKey]*InvoiceHTLC{}, - } - fakeInvoice.Memo = []byte("memo") - fakeInvoice.Receipt = []byte("receipt") - fakeInvoice.PaymentRequest = []byte("") - copy(fakeInvoice.Terms.PaymentPreimage[:], rev[:]) - fakeInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000) - - paymentHash := fakeInvoice.Terms.PaymentPreimage.Hash() - - // Add the invoice to the database, this should succeed as there aren't - // any existing invoices within the database with the same payment - // hash. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != nil { - t.Fatalf("unable to find invoice: %v", err) - } - - // Attempt to retrieve the invoice which was just added to the - // database. It should be found, and the invoice returned should be - // identical to the one created above. - dbInvoice, err := db.LookupInvoice(paymentHash) - if err != nil { - t.Fatalf("unable to find invoice: %v", err) - } - if !reflect.DeepEqual(*fakeInvoice, dbInvoice) { - t.Fatalf("invoice fetched from db doesn't match original %v vs %v", - spew.Sdump(fakeInvoice), spew.Sdump(dbInvoice)) - } - - // The add index of the invoice retrieved from the database should now - // be fully populated. As this is the first index written to the DB, - // the addIndex should be 1. - if dbInvoice.AddIndex != 1 { - t.Fatalf("wrong add index: expected %v, got %v", 1, - dbInvoice.AddIndex) - } - - // Settle the invoice, the version retrieved from the database should - // now have the settled bit toggle to true and a non-default - // SettledDate - payAmt := fakeInvoice.Terms.Value * 2 - _, err = db.UpdateInvoice(paymentHash, getUpdateInvoice(payAmt)) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - dbInvoice2, err := db.LookupInvoice(paymentHash) - if err != nil { - t.Fatalf("unable to fetch invoice: %v", err) - } - if dbInvoice2.Terms.State != ContractSettled { - t.Fatalf("invoice should now be settled but isn't") - } - if dbInvoice2.SettleDate.IsZero() { - t.Fatalf("invoice should have non-zero SettledDate but isn't") - } - - // Our 2x payment should be reflected, and also the settle index of 1 - // should also have been committed for this index. - if dbInvoice2.AmtPaid != payAmt { - t.Fatalf("wrong amt paid: expected %v, got %v", payAmt, - dbInvoice2.AmtPaid) - } - if dbInvoice2.SettleIndex != 1 { - t.Fatalf("wrong settle index: expected %v, got %v", 1, - dbInvoice2.SettleIndex) - } - - // Attempt to insert generated above again, this should fail as - // duplicates are rejected by the processing logic. - if _, err := db.AddInvoice(fakeInvoice, paymentHash); err != ErrDuplicateInvoice { - t.Fatalf("invoice insertion should fail due to duplication, "+ - "instead %v", err) - } - - // Attempt to look up a non-existent invoice, this should also fail but - // with a "not found" error. - var fakeHash [32]byte - if _, err := db.LookupInvoice(fakeHash); err != ErrInvoiceNotFound { - t.Fatalf("lookup should have failed, instead %v", err) - } - - // Add 10 random invoices. - const numInvoices = 10 - amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]*Invoice, numInvoices+1) - invoices[0] = &dbInvoice2 - for i := 1; i < len(invoices)-1; i++ { - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - hash := invoice.Terms.PaymentPreimage.Hash() - if _, err := db.AddInvoice(invoice, hash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - invoices[i] = invoice - } - - // Perform a scan to collect all the active invoices. - dbInvoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch all invoices: %v", err) - } - - // The retrieve list of invoices should be identical as since we're - // using big endian, the invoices should be retrieved in ascending - // order (and the primary key should be incremented with each - // insertion). - for i := 0; i < len(invoices)-1; i++ { - if !reflect.DeepEqual(*invoices[i], dbInvoices[i]) { - t.Fatalf("retrieved invoices don't match %v vs %v", - spew.Sdump(invoices[i]), - spew.Sdump(dbInvoices[i])) - } - } -} - -// TestInvoiceTimeSeries tests that newly added invoices invoices, as well as -// settled invoices are added to the database are properly placed in the add -// add or settle index which serves as an event time series. -func TestInvoiceAddTimeSeries(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // We'll start off by creating 20 random invoices, and inserting them - // into the database. - const numInvoices = 20 - amt := lnwire.NewMSatFromSatoshis(1000) - invoices := make([]Invoice, numInvoices) - for i := 0; i < len(invoices); i++ { - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - invoices[i] = *invoice - } - - // With the invoices constructed, we'll now create a series of queries - // that we'll use to assert expected return values of - // InvoicesAddedSince. - addQueries := []struct { - sinceAddIndex uint64 - - resp []Invoice - }{ - // If we specify a value of zero, we shouldn't get any invoices - // back. - { - sinceAddIndex: 0, - }, - - // If we specify a value well beyond the number of inserted - // invoices, we shouldn't get any invoices back. - { - sinceAddIndex: 99999999, - }, - - // Using an index of 1 should result in all values, but the - // first one being returned. - { - sinceAddIndex: 1, - resp: invoices[1:], - }, - - // If we use an index of 10, then we should retrieve the - // reaming 10 invoices. - { - sinceAddIndex: 10, - resp: invoices[10:], - }, - } - - for i, query := range addQueries { - resp, err := db.InvoicesAddedSince(query.sinceAddIndex) - if err != nil { - t.Fatalf("unable to query: %v", err) - } - - if !reflect.DeepEqual(query.resp, resp) { - t.Fatalf("test #%v: expected %v, got %v", i, - spew.Sdump(query.resp), spew.Sdump(resp)) - } - } - - // We'll now only settle the latter half of each of those invoices. - for i := 10; i < len(invoices); i++ { - invoice := &invoices[i] - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(0), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - } - - invoices, err = db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to fetch invoices: %v", err) - } - - // We'll slice off the first 10 invoices, as we only settled the last - // 10. - invoices = invoices[10:] - - // We'll now prepare an additional set of queries to ensure the settle - // time series has properly been maintained in the database. - settleQueries := []struct { - sinceSettleIndex uint64 - - resp []Invoice - }{ - // If we specify a value of zero, we shouldn't get any settled - // invoices back. - { - sinceSettleIndex: 0, - }, - - // If we specify a value well beyond the number of settled - // invoices, we shouldn't get any invoices back. - { - sinceSettleIndex: 99999999, - }, - - // Using an index of 1 should result in the final 10 invoices - // being returned, as we only settled those. - { - sinceSettleIndex: 1, - resp: invoices[1:], - }, - } - - for i, query := range settleQueries { - resp, err := db.InvoicesSettledSince(query.sinceSettleIndex) - if err != nil { - t.Fatalf("unable to query: %v", err) - } - - if !reflect.DeepEqual(query.resp, resp) { - t.Fatalf("test #%v: expected %v, got %v", i, - spew.Sdump(query.resp), spew.Sdump(resp)) - } - } -} - -// TestDuplicateSettleInvoice tests that if we add a new invoice and settle it -// twice, then the second time we also receive the invoice that we settled as a -// return argument. -func TestDuplicateSettleInvoice(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - db.now = func() time.Time { return time.Unix(1, 0) } - - // We'll start out by creating an invoice and writing it to the DB. - amt := lnwire.NewMSatFromSatoshis(1000) - invoice, err := randInvoice(amt) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - payHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, payHash); err != nil { - t.Fatalf("unable to add invoice %v", err) - } - - // With the invoice in the DB, we'll now attempt to settle the invoice. - dbInvoice, err := db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - - // We'll update what we expect the settle invoice to be so that our - // comparison below has the correct assumption. - invoice.SettleIndex = 1 - invoice.Terms.State = ContractSettled - invoice.AmtPaid = amt - invoice.SettleDate = dbInvoice.SettleDate - invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ - {}: { - Amt: amt, - AcceptTime: time.Unix(1, 0), - ResolveTime: time.Unix(1, 0), - State: HtlcStateSettled, - }, - } - - // We should get back the exact same invoice that we just inserted. - if !reflect.DeepEqual(dbInvoice, invoice) { - t.Fatalf("wrong invoice after settle, expected %v got %v", - spew.Sdump(invoice), spew.Sdump(dbInvoice)) - } - - // If we try to settle the invoice again, then we should get the very - // same invoice back, but with an error this time. - dbInvoice, err = db.UpdateInvoice( - payHash, getUpdateInvoice(amt), - ) - if err != ErrInvoiceAlreadySettled { - t.Fatalf("expected ErrInvoiceAlreadySettled") - } - - if dbInvoice == nil { - t.Fatalf("invoice from db is nil after settle!") - } - - invoice.SettleDate = dbInvoice.SettleDate - if !reflect.DeepEqual(dbInvoice, invoice) { - t.Fatalf("wrong invoice after second settle, expected %v got %v", - spew.Sdump(invoice), spew.Sdump(dbInvoice)) - } -} - -// TestQueryInvoices ensures that we can properly query the invoice database for -// invoices using different types of queries. -func TestQueryInvoices(t *testing.T) { - t.Parallel() - - db, cleanUp, err := makeTestDB() - defer cleanUp() - if err != nil { - t.Fatalf("unable to make test db: %v", err) - } - - // To begin the test, we'll add 50 invoices to the database. We'll - // assume that the index of the invoice within the database is the same - // as the amount of the invoice itself. - const numInvoices = 50 - for i := lnwire.MilliSatoshi(1); i <= numInvoices; i++ { - invoice, err := randInvoice(i) - if err != nil { - t.Fatalf("unable to create invoice: %v", err) - } - - paymentHash := invoice.Terms.PaymentPreimage.Hash() - - if _, err := db.AddInvoice(invoice, paymentHash); err != nil { - t.Fatalf("unable to add invoice: %v", err) - } - - // We'll only settle half of all invoices created. - if i%2 == 0 { - _, err := db.UpdateInvoice( - paymentHash, getUpdateInvoice(i), - ) - if err != nil { - t.Fatalf("unable to settle invoice: %v", err) - } - } - } - - // We'll then retrieve the set of all invoices and pending invoices. - // This will serve useful when comparing the expected responses of the - // query with the actual ones. - invoices, err := db.FetchAllInvoices(false) - if err != nil { - t.Fatalf("unable to retrieve invoices: %v", err) - } - pendingInvoices, err := db.FetchAllInvoices(true) - if err != nil { - t.Fatalf("unable to retrieve pending invoices: %v", err) - } - - // The test will consist of several queries along with their respective - // expected response. Each query response should match its expected one. - testCases := []struct { - query InvoiceQuery - expected []Invoice - }{ - // Fetch all invoices with a single query. - { - query: InvoiceQuery{ - NumMaxInvoices: numInvoices, - }, - expected: invoices, - }, - // Fetch all invoices with a single query, reversed. - { - query: InvoiceQuery{ - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: invoices, - }, - // Fetch the first 25 invoices. - { - query: InvoiceQuery{ - NumMaxInvoices: numInvoices / 2, - }, - expected: invoices[:numInvoices/2], - }, - // Fetch the first 10 invoices, but this time iterating - // backwards. - { - query: InvoiceQuery{ - IndexOffset: 11, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: invoices[:10], - }, - // Fetch the last 40 invoices. - { - query: InvoiceQuery{ - IndexOffset: 10, - NumMaxInvoices: numInvoices, - }, - expected: invoices[10:], - }, - // Fetch all but the first invoice. - { - query: InvoiceQuery{ - IndexOffset: 1, - NumMaxInvoices: numInvoices, - }, - expected: invoices[1:], - }, - // Fetch one invoice, reversed, with index offset 3. This - // should give us the second invoice in the array. - { - query: InvoiceQuery{ - IndexOffset: 3, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[1:2], - }, - // Same as above, at index 2. - { - query: InvoiceQuery{ - IndexOffset: 2, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[0:1], - }, - // Fetch one invoice, at index 1, reversed. Since invoice#1 is - // the very first, there won't be any left in a reverse search, - // so we expect no invoices to be returned. - { - query: InvoiceQuery{ - IndexOffset: 1, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: nil, - }, - // Same as above, but don't restrict the number of invoices to - // 1. - { - query: InvoiceQuery{ - IndexOffset: 1, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - expected: nil, - }, - // Fetch one invoice, reversed, with no offset set. We expect - // the last invoice in the response. - { - query: InvoiceQuery{ - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Fetch one invoice, reversed, the offset set at numInvoices+1. - // We expect this to return the last invoice. - { - query: InvoiceQuery{ - IndexOffset: numInvoices + 1, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Same as above, at offset numInvoices. - { - query: InvoiceQuery{ - IndexOffset: numInvoices, - Reversed: true, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-2 : numInvoices-1], - }, - // Fetch one invoice, at no offset (same as offset 0). We - // expect the first invoice only in the response. - { - query: InvoiceQuery{ - NumMaxInvoices: 1, - }, - expected: invoices[:1], - }, - // Same as above, at offset 1. - { - query: InvoiceQuery{ - IndexOffset: 1, - NumMaxInvoices: 1, - }, - expected: invoices[1:2], - }, - // Same as above, at offset 2. - { - query: InvoiceQuery{ - IndexOffset: 2, - NumMaxInvoices: 1, - }, - expected: invoices[2:3], - }, - // Same as above, at offset numInvoices-1. Expect the last - // invoice to be returned. - { - query: InvoiceQuery{ - IndexOffset: numInvoices - 1, - NumMaxInvoices: 1, - }, - expected: invoices[numInvoices-1:], - }, - // Same as above, at offset numInvoices. No invoices should be - // returned, as there are no invoices after this offset. - { - query: InvoiceQuery{ - IndexOffset: numInvoices, - NumMaxInvoices: 1, - }, - expected: nil, - }, - // Fetch all pending invoices with a single query. - { - query: InvoiceQuery{ - PendingOnly: true, - NumMaxInvoices: numInvoices, - }, - expected: pendingInvoices, - }, - // Fetch the first 12 pending invoices. - { - query: InvoiceQuery{ - PendingOnly: true, - NumMaxInvoices: numInvoices / 4, - }, - expected: pendingInvoices[:len(pendingInvoices)/2], - }, - // Fetch the first 5 pending invoices, but this time iterating - // backwards. - { - query: InvoiceQuery{ - IndexOffset: 10, - PendingOnly: true, - Reversed: true, - NumMaxInvoices: numInvoices, - }, - // Since we seek to the invoice with index 10 and - // iterate backwards, there should only be 5 pending - // invoices before it as every other invoice within the - // index is settled. - expected: pendingInvoices[:5], - }, - // Fetch the last 15 invoices. - { - query: InvoiceQuery{ - IndexOffset: 20, - PendingOnly: true, - NumMaxInvoices: numInvoices, - }, - // Since we seek to the invoice with index 20, there are - // 30 invoices left. From these 30, only 15 of them are - // still pending. - expected: pendingInvoices[len(pendingInvoices)-15:], - }, - } - - for i, testCase := range testCases { - response, err := db.QueryInvoices(testCase.query) - if err != nil { - t.Fatalf("unable to query invoice database: %v", err) - } - - if !reflect.DeepEqual(response.Invoices, testCase.expected) { - t.Fatalf("test #%d: query returned incorrect set of "+ - "invoices: expcted %v, got %v", i, - spew.Sdump(response.Invoices), - spew.Sdump(testCase.expected)) - } - } -} - -// getUpdateInvoice returns an invoice update callback that, when called, -// settles the invoice with the given amount. -func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback { - return func(invoice *Invoice) (*InvoiceUpdateDesc, error) { - if invoice.Terms.State == ContractSettled { - return nil, ErrInvoiceAlreadySettled - } - - update := &InvoiceUpdateDesc{ - Preimage: invoice.Terms.PaymentPreimage, - State: ContractSettled, - Htlcs: map[CircuitKey]*HtlcAcceptDesc{ - {}: { - Amt: amt, - }, - }, - } - - return update, nil - } -} diff --git a/channeldb/migration_01_to_11/invoices.go b/channeldb/migration_01_to_11/invoices.go index 5f40454a..f60457ff 100644 --- a/channeldb/migration_01_to_11/invoices.go +++ b/channeldb/migration_01_to_11/invoices.go @@ -3,7 +3,6 @@ package migration_01_to_11 import ( "bytes" "encoding/binary" - "errors" "fmt" "io" "time" @@ -16,9 +15,6 @@ import ( ) var ( - // UnknownPreimage is an all-zeroes preimage that indicates that the - // preimage for this invoice is not yet known. - UnknownPreimage lntypes.Preimage // invoiceBucket is the name of the bucket within the database that // stores all data related to invoices no matter their final state. @@ -26,23 +22,6 @@ var ( // which is a monotonically increasing uint32. invoiceBucket = []byte("invoices") - // paymentHashIndexBucket is the name of the sub-bucket within the - // invoiceBucket which indexes all invoices by their payment hash. The - // payment hash is the sha256 of the invoice's payment preimage. This - // index is used to detect duplicates, and also to provide a fast path - // for looking up incoming HTLCs to determine if we're able to settle - // them fully. - // - // maps: payHash => invoiceKey - invoiceIndexBucket = []byte("paymenthashes") - - // numInvoicesKey is the name of key which houses the auto-incrementing - // invoice ID which is essentially used as a primary key. With each - // invoice inserted, the primary key is incremented by one. This key is - // stored within the invoiceIndexBucket. Within the invoiceBucket - // invoices are uniquely identified by the invoice ID. - numInvoicesKey = []byte("nik") - // addIndexBucket is an index bucket that we'll use to create a // monotonically increasing set of add indexes. Each time we add a new // invoice, this sequence number will be incremented and then populated @@ -62,21 +41,6 @@ var ( // // settleIndexNo => invoiceKey settleIndexBucket = []byte("invoice-settle-index") - - // ErrInvoiceAlreadySettled is returned when the invoice is already - // settled. - ErrInvoiceAlreadySettled = errors.New("invoice already settled") - - // ErrInvoiceAlreadyCanceled is returned when the invoice is already - // canceled. - ErrInvoiceAlreadyCanceled = errors.New("invoice already canceled") - - // ErrInvoiceAlreadyAccepted is returned when the invoice is already - // accepted. - ErrInvoiceAlreadyAccepted = errors.New("invoice already accepted") - - // ErrInvoiceStillOpen is returned when the invoice is still open. - ErrInvoiceStillOpen = errors.New("invoice still open") ) const ( @@ -237,18 +201,6 @@ type Invoice struct { // HtlcState defines the states an htlc paying to an invoice can be in. type HtlcState uint8 -const ( - // HtlcStateAccepted indicates the htlc is locked-in, but not resolved. - HtlcStateAccepted HtlcState = iota - - // HtlcStateCanceled indicates the htlc is canceled back to the - // sender. - HtlcStateCanceled - - // HtlcStateSettled indicates the htlc is settled. - HtlcStateSettled -) - // InvoiceHTLC contains details about an htlc paying to this invoice. type InvoiceHTLC struct { // Amt is the amount that is carried by this htlc. @@ -276,37 +228,6 @@ type InvoiceHTLC struct { State HtlcState } -// HtlcAcceptDesc describes the details of a newly accepted htlc. -type HtlcAcceptDesc struct { - // AcceptHeight is the block height at which this htlc was accepted. - AcceptHeight int32 - - // Amt is the amount that is carried by this htlc. - Amt lnwire.MilliSatoshi - - // Expiry is the expiry height of this htlc. - Expiry uint32 -} - -// InvoiceUpdateDesc describes the changes that should be applied to the -// invoice. -type InvoiceUpdateDesc struct { - // State is the new state that this invoice should progress to. - State ContractState - - // Htlcs describes the changes that need to be made to the invoice htlcs - // in the database. Htlc map entries with their value set should be - // added. If the map value is nil, the htlc should be canceled. - Htlcs map[CircuitKey]*HtlcAcceptDesc - - // Preimage must be set to the preimage when state is settled. - Preimage lntypes.Preimage -} - -// InvoiceUpdateCallback is a callback used in the db transaction to update the -// invoice. -type InvoiceUpdateCallback = func(invoice *Invoice) (*InvoiceUpdateDesc, error) - func validateInvoice(i *Invoice) error { if len(i.Memo) > MaxMemoSize { return fmt.Errorf("max length a memo is %v, and invoice "+ @@ -325,186 +246,6 @@ func validateInvoice(i *Invoice) error { return nil } -// AddInvoice inserts the targeted invoice into the database. If the invoice has -// *any* payment hashes which already exists within the database, then the -// insertion will be aborted and rejected due to the strict policy banning any -// duplicate payment hashes. A side effect of this function is that it sets -// AddIndex on newInvoice. -func (d *DB) AddInvoice(newInvoice *Invoice, paymentHash lntypes.Hash) ( - uint64, error) { - - if err := validateInvoice(newInvoice); err != nil { - return 0, err - } - - var invoiceAddIndex uint64 - err := d.Update(func(tx *bbolt.Tx) error { - invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) - if err != nil { - return err - } - - invoiceIndex, err := invoices.CreateBucketIfNotExists( - invoiceIndexBucket, - ) - if err != nil { - return err - } - addIndex, err := invoices.CreateBucketIfNotExists( - addIndexBucket, - ) - if err != nil { - return err - } - - // Ensure that an invoice an identical payment hash doesn't - // already exist within the index. - if invoiceIndex.Get(paymentHash[:]) != nil { - return ErrDuplicateInvoice - } - - // If the current running payment ID counter hasn't yet been - // created, then create it now. - var invoiceNum uint32 - invoiceCounter := invoiceIndex.Get(numInvoicesKey) - if invoiceCounter == nil { - var scratch [4]byte - byteOrder.PutUint32(scratch[:], invoiceNum) - err := invoiceIndex.Put(numInvoicesKey, scratch[:]) - if err != nil { - return err - } - } else { - invoiceNum = byteOrder.Uint32(invoiceCounter) - } - - newIndex, err := putInvoice( - invoices, invoiceIndex, addIndex, newInvoice, invoiceNum, - paymentHash, - ) - if err != nil { - return err - } - - invoiceAddIndex = newIndex - return nil - }) - if err != nil { - return 0, err - } - - return invoiceAddIndex, err -} - -// InvoicesAddedSince can be used by callers to seek into the event time series -// of all the invoices added in the database. The specified sinceAddIndex -// should be the highest add index that the caller knows of. This method will -// return all invoices with an add index greater than the specified -// sinceAddIndex. -// -// NOTE: The index starts from 1, as a result. We enforce that specifying a -// value below the starting index value is a noop. -func (d *DB) InvoicesAddedSince(sinceAddIndex uint64) ([]Invoice, error) { - var newInvoices []Invoice - - // If an index of zero was specified, then in order to maintain - // backwards compat, we won't send out any new invoices. - if sinceAddIndex == 0 { - return newInvoices, nil - } - - var startIndex [8]byte - byteOrder.PutUint64(startIndex[:], sinceAddIndex) - - err := d.DB.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - - addIndex := invoices.Bucket(addIndexBucket) - if addIndex == nil { - return ErrNoInvoicesCreated - } - - // We'll now run through each entry in the add index starting - // at our starting index. We'll continue until we reach the - // very end of the current key space. - invoiceCursor := addIndex.Cursor() - - // We'll seek to the starting index, then manually advance the - // cursor in order to skip the entry with the since add index. - invoiceCursor.Seek(startIndex[:]) - addSeqNo, invoiceKey := invoiceCursor.Next() - - for ; addSeqNo != nil && bytes.Compare(addSeqNo, startIndex[:]) > 0; addSeqNo, invoiceKey = invoiceCursor.Next() { - - // For each key found, we'll look up the actual - // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - newInvoices = append(newInvoices, invoice) - } - - return nil - }) - switch { - // If no invoices have been created, then we'll return the empty set of - // invoices. - case err == ErrNoInvoicesCreated: - - case err != nil: - return nil, err - } - - return newInvoices, nil -} - -// LookupInvoice attempts to look up an invoice according to its 32 byte -// payment hash. If an invoice which can settle the HTLC identified by the -// passed payment hash isn't found, then an error is returned. Otherwise, the -// full invoice is returned. Before setting the incoming HTLC, the values -// SHOULD be checked to ensure the payer meets the agreed upon contractual -// terms of the payment. -func (d *DB) LookupInvoice(paymentHash [32]byte) (Invoice, error) { - var invoice Invoice - err := d.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - invoiceIndex := invoices.Bucket(invoiceIndexBucket) - if invoiceIndex == nil { - return ErrNoInvoicesCreated - } - - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } - - // An invoice matching the payment hash has been found, so - // retrieve the record of the invoice itself. - i, err := fetchInvoice(invoiceNum, invoices) - if err != nil { - return err - } - invoice = i - - return nil - }) - if err != nil { - return invoice, err - } - - return invoice, nil -} - // FetchAllInvoices returns all invoices currently stored within the database. // If the pendingOnly param is true, then only unsettled invoices will be // returned, skipping all invoices that are fully settled. @@ -549,343 +290,6 @@ func (d *DB) FetchAllInvoices(pendingOnly bool) ([]Invoice, error) { return invoices, nil } -// InvoiceQuery represents a query to the invoice database. The query allows a -// caller to retrieve all invoices starting from a particular add index and -// limit the number of results returned. -type InvoiceQuery struct { - // IndexOffset is the offset within the add indices to start at. This - // can be used to start the response at a particular invoice. - IndexOffset uint64 - - // NumMaxInvoices is the maximum number of invoices that should be - // starting from the add index. - NumMaxInvoices uint64 - - // PendingOnly, if set, returns unsettled invoices starting from the - // add index. - PendingOnly bool - - // Reversed, if set, indicates that the invoices returned should start - // from the IndexOffset and go backwards. - Reversed bool -} - -// InvoiceSlice is the response to a invoice query. It includes the original -// query, the set of invoices that match the query, and an integer which -// represents the offset index of the last item in the set of returned invoices. -// This integer allows callers to resume their query using this offset in the -// event that the query's response exceeds the maximum number of returnable -// invoices. -type InvoiceSlice struct { - InvoiceQuery - - // Invoices is the set of invoices that matched the query above. - Invoices []Invoice - - // FirstIndexOffset is the index of the first element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - FirstIndexOffset uint64 - - // LastIndexOffset is the index of the last element in the set of - // returned Invoices above. Callers can use this to resume their query - // in the event that the slice has too many events to fit into a single - // response. - LastIndexOffset uint64 -} - -// QueryInvoices allows a caller to query the invoice database for invoices -// within the specified add index range. -func (d *DB) QueryInvoices(q InvoiceQuery) (InvoiceSlice, error) { - resp := InvoiceSlice{ - InvoiceQuery: q, - } - - err := d.View(func(tx *bbolt.Tx) error { - // If the bucket wasn't found, then there aren't any invoices - // within the database yet, so we can simply exit. - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - invoiceAddIndex := invoices.Bucket(addIndexBucket) - if invoiceAddIndex == nil { - return ErrNoInvoicesCreated - } - - // keyForIndex is a helper closure that retrieves the invoice - // key for the given add index of an invoice. - keyForIndex := func(c *bbolt.Cursor, index uint64) []byte { - var keyIndex [8]byte - byteOrder.PutUint64(keyIndex[:], index) - _, invoiceKey := c.Seek(keyIndex[:]) - return invoiceKey - } - - // nextKey is a helper closure to determine what the next - // invoice key is when iterating over the invoice add index. - nextKey := func(c *bbolt.Cursor) ([]byte, []byte) { - if q.Reversed { - return c.Prev() - } - return c.Next() - } - - // We'll be using a cursor to seek into the database and return - // a slice of invoices. We'll need to determine where to start - // our cursor depending on the parameters set within the query. - c := invoiceAddIndex.Cursor() - invoiceKey := keyForIndex(c, q.IndexOffset+1) - - // If the query is specifying reverse iteration, then we must - // handle a few offset cases. - if q.Reversed { - switch q.IndexOffset { - - // This indicates the default case, where no offset was - // specified. In that case we just start from the last - // invoice. - case 0: - _, invoiceKey = c.Last() - - // This indicates the offset being set to the very - // first invoice. Since there are no invoices before - // this offset, and the direction is reversed, we can - // return without adding any invoices to the response. - case 1: - return nil - - // Otherwise we start iteration at the invoice prior to - // the offset. - default: - invoiceKey = keyForIndex(c, q.IndexOffset-1) - } - } - - // If we know that a set of invoices exists, then we'll begin - // our seek through the bucket in order to satisfy the query. - // We'll continue until either we reach the end of the range, or - // reach our max number of invoices. - for ; invoiceKey != nil; _, invoiceKey = nextKey(c) { - // If our current return payload exceeds the max number - // of invoices, then we'll exit now. - if uint64(len(resp.Invoices)) >= q.NumMaxInvoices { - break - } - - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - // Skip any settled invoices if the caller is only - // interested in unsettled. - if q.PendingOnly && - invoice.Terms.State == ContractSettled { - - continue - } - - // At this point, we've exhausted the offset, so we'll - // begin collecting invoices found within the range. - resp.Invoices = append(resp.Invoices, invoice) - } - - // If we iterated through the add index in reverse order, then - // we'll need to reverse the slice of invoices to return them in - // forward order. - if q.Reversed { - numInvoices := len(resp.Invoices) - for i := 0; i < numInvoices/2; i++ { - opposite := numInvoices - i - 1 - resp.Invoices[i], resp.Invoices[opposite] = - resp.Invoices[opposite], resp.Invoices[i] - } - } - - return nil - }) - if err != nil && err != ErrNoInvoicesCreated { - return resp, err - } - - // Finally, record the indexes of the first and last invoices returned - // so that the caller can resume from this point later on. - if len(resp.Invoices) > 0 { - resp.FirstIndexOffset = resp.Invoices[0].AddIndex - resp.LastIndexOffset = resp.Invoices[len(resp.Invoices)-1].AddIndex - } - - return resp, nil -} - -// UpdateInvoice attempts to update an invoice corresponding to the passed -// payment hash. If an invoice matching the passed payment hash doesn't exist -// within the database, then the action will fail with a "not found" error. -// -// The update is performed inside the same database transaction that fetches the -// invoice and is therefore atomic. The fields to update are controlled by the -// supplied callback. -func (d *DB) UpdateInvoice(paymentHash lntypes.Hash, - callback InvoiceUpdateCallback) (*Invoice, error) { - - var updatedInvoice *Invoice - err := d.Update(func(tx *bbolt.Tx) error { - invoices, err := tx.CreateBucketIfNotExists(invoiceBucket) - if err != nil { - return err - } - invoiceIndex, err := invoices.CreateBucketIfNotExists( - invoiceIndexBucket, - ) - if err != nil { - return err - } - settleIndex, err := invoices.CreateBucketIfNotExists( - settleIndexBucket, - ) - if err != nil { - return err - } - - // Check the invoice index to see if an invoice paying to this - // hash exists within the DB. - invoiceNum := invoiceIndex.Get(paymentHash[:]) - if invoiceNum == nil { - return ErrInvoiceNotFound - } - - updatedInvoice, err = d.updateInvoice( - paymentHash, invoices, settleIndex, invoiceNum, - callback, - ) - - return err - }) - - return updatedInvoice, err -} - -// InvoicesSettledSince can be used by callers to catch up any settled invoices -// they missed within the settled invoice time series. We'll return all known -// settled invoice that have a settle index higher than the passed -// sinceSettleIndex. -// -// NOTE: The index starts from 1, as a result. We enforce that specifying a -// value below the starting index value is a noop. -func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) { - var settledInvoices []Invoice - - // If an index of zero was specified, then in order to maintain - // backwards compat, we won't send out any new invoices. - if sinceSettleIndex == 0 { - return settledInvoices, nil - } - - var startIndex [8]byte - byteOrder.PutUint64(startIndex[:], sinceSettleIndex) - - err := d.DB.View(func(tx *bbolt.Tx) error { - invoices := tx.Bucket(invoiceBucket) - if invoices == nil { - return ErrNoInvoicesCreated - } - - settleIndex := invoices.Bucket(settleIndexBucket) - if settleIndex == nil { - return ErrNoInvoicesCreated - } - - // We'll now run through each entry in the add index starting - // at our starting index. We'll continue until we reach the - // very end of the current key space. - invoiceCursor := settleIndex.Cursor() - - // We'll seek to the starting index, then manually advance the - // cursor in order to skip the entry with the since add index. - invoiceCursor.Seek(startIndex[:]) - seqNo, invoiceKey := invoiceCursor.Next() - - for ; seqNo != nil && bytes.Compare(seqNo, startIndex[:]) > 0; seqNo, invoiceKey = invoiceCursor.Next() { - - // For each key found, we'll look up the actual - // invoice, then accumulate it into our return value. - invoice, err := fetchInvoice(invoiceKey, invoices) - if err != nil { - return err - } - - settledInvoices = append(settledInvoices, invoice) - } - - return nil - }) - if err != nil { - return nil, err - } - - return settledInvoices, nil -} - -func putInvoice(invoices, invoiceIndex, addIndex *bbolt.Bucket, - i *Invoice, invoiceNum uint32, paymentHash lntypes.Hash) ( - uint64, error) { - - // Create the invoice key which is just the big-endian representation - // of the invoice number. - var invoiceKey [4]byte - byteOrder.PutUint32(invoiceKey[:], invoiceNum) - - // Increment the num invoice counter index so the next invoice bares - // the proper ID. - var scratch [4]byte - invoiceCounter := invoiceNum + 1 - byteOrder.PutUint32(scratch[:], invoiceCounter) - if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil { - return 0, err - } - - // Add the payment hash to the invoice index. This will let us quickly - // identify if we can settle an incoming payment, and also to possibly - // allow a single invoice to have multiple payment installations. - err := invoiceIndex.Put(paymentHash[:], invoiceKey[:]) - if err != nil { - return 0, err - } - - // Next, we'll obtain the next add invoice index (sequence - // number), so we can properly place this invoice within this - // event stream. - nextAddSeqNo, err := addIndex.NextSequence() - if err != nil { - return 0, err - } - - // With the next sequence obtained, we'll updating the event series in - // the add index bucket to map this current add counter to the index of - // this new invoice. - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo) - if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil { - return 0, err - } - - i.AddIndex = nextAddSeqNo - - // Finally, serialize the invoice itself to be written to the disk. - var buf bytes.Buffer - if err := serializeInvoice(&buf, i); err != nil { - return 0, err - } - - if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil { - return 0, err - } - - return nextAddSeqNo, nil -} - // serializeInvoice serializes an invoice to a writer. // // Note: this function is in use for a migration. Before making changes that @@ -1006,17 +410,6 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error { return nil } -func fetchInvoice(invoiceNum []byte, invoices *bbolt.Bucket) (Invoice, error) { - invoiceBytes := invoices.Get(invoiceNum) - if invoiceBytes == nil { - return Invoice{}, ErrInvoiceNotFound - } - - invoiceReader := bytes.NewReader(invoiceBytes) - - return deserializeInvoice(invoiceReader) -} - func deserializeInvoice(r io.Reader) (Invoice, error) { var err error invoice := Invoice{} @@ -1155,166 +548,3 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { return htlcs, nil } - -// copySlice allocates a new slice and copies the source into it. -func copySlice(src []byte) []byte { - dest := make([]byte, len(src)) - copy(dest, src) - return dest -} - -// copyInvoice makes a deep copy of the supplied invoice. -func copyInvoice(src *Invoice) *Invoice { - dest := Invoice{ - Memo: copySlice(src.Memo), - Receipt: copySlice(src.Receipt), - PaymentRequest: copySlice(src.PaymentRequest), - FinalCltvDelta: src.FinalCltvDelta, - CreationDate: src.CreationDate, - SettleDate: src.SettleDate, - Terms: src.Terms, - AddIndex: src.AddIndex, - SettleIndex: src.SettleIndex, - AmtPaid: src.AmtPaid, - Htlcs: make( - map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), - ), - } - - for k, v := range src.Htlcs { - dest.Htlcs[k] = v - } - - return &dest -} - -// updateInvoice fetches the invoice, obtains the update descriptor from the -// callback and applies the updates in a single db transaction. -func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucket, - invoiceNum []byte, callback InvoiceUpdateCallback) (*Invoice, error) { - - invoice, err := fetchInvoice(invoiceNum, invoices) - if err != nil { - return nil, err - } - - preUpdateState := invoice.Terms.State - - // Create deep copy to prevent any accidental modification in the - // callback. - copy := copyInvoice(&invoice) - - // Call the callback and obtain the update descriptor. - update, err := callback(copy) - if err != nil { - return &invoice, err - } - - // Update invoice state. - invoice.Terms.State = update.State - - now := d.now() - - // Update htlc set. - for key, htlcUpdate := range update.Htlcs { - htlc, ok := invoice.Htlcs[key] - - // No update means the htlc needs to be canceled. - if htlcUpdate == nil { - if !ok { - return nil, fmt.Errorf("unknown htlc %v", key) - } - if htlc.State != HtlcStateAccepted { - return nil, fmt.Errorf("can only cancel " + - "accepted htlcs") - } - - htlc.State = HtlcStateCanceled - htlc.ResolveTime = now - invoice.AmtPaid -= htlc.Amt - - continue - } - - // Add new htlc paying to the invoice. - if ok { - return nil, fmt.Errorf("htlc %v already exists", key) - } - htlc = &InvoiceHTLC{ - Amt: htlcUpdate.Amt, - Expiry: htlcUpdate.Expiry, - AcceptHeight: uint32(htlcUpdate.AcceptHeight), - AcceptTime: now, - } - if preUpdateState == ContractSettled { - htlc.State = HtlcStateSettled - htlc.ResolveTime = now - } else { - htlc.State = HtlcStateAccepted - } - - invoice.Htlcs[key] = htlc - invoice.AmtPaid += htlc.Amt - } - - // If invoice moved to the settled state, update settle index and settle - // time. - if preUpdateState != invoice.Terms.State && - invoice.Terms.State == ContractSettled { - - if update.Preimage.Hash() != hash { - return nil, fmt.Errorf("preimage does not match") - } - invoice.Terms.PaymentPreimage = update.Preimage - - // Settle all accepted htlcs. - for _, htlc := range invoice.Htlcs { - if htlc.State != HtlcStateAccepted { - continue - } - - htlc.State = HtlcStateSettled - htlc.ResolveTime = now - } - - err := setSettleFields(settleIndex, invoiceNum, &invoice, now) - if err != nil { - return nil, err - } - } - - var buf bytes.Buffer - if err := serializeInvoice(&buf, &invoice); err != nil { - return nil, err - } - - if err := invoices.Put(invoiceNum[:], buf.Bytes()); err != nil { - return nil, err - } - - return &invoice, nil -} - -func setSettleFields(settleIndex *bbolt.Bucket, invoiceNum []byte, - invoice *Invoice, now time.Time) error { - - // Now that we know the invoice hasn't already been settled, we'll - // update the settle index so we can place this settle event in the - // proper location within our time series. - nextSettleSeqNo, err := settleIndex.NextSequence() - if err != nil { - return err - } - - var seqNoBytes [8]byte - byteOrder.PutUint64(seqNoBytes[:], nextSettleSeqNo) - if err := settleIndex.Put(seqNoBytes[:], invoiceNum); err != nil { - return err - } - - invoice.Terms.State = ContractSettled - invoice.SettleDate = now - invoice.SettleIndex = nextSettleSeqNo - - return nil -} diff --git a/channeldb/migration_01_to_11/nodes.go b/channeldb/migration_01_to_11/nodes.go deleted file mode 100644 index f40359e8..00000000 --- a/channeldb/migration_01_to_11/nodes.go +++ /dev/null @@ -1,316 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "io" - "net" - "time" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" - "github.com/coreos/bbolt" -) - -var ( - // nodeInfoBucket stores metadata pertaining to nodes that we've had - // direct channel-based correspondence with. This bucket allows one to - // query for all open channels pertaining to the node by exploring each - // node's sub-bucket within the openChanBucket. - nodeInfoBucket = []byte("nib") -) - -// LinkNode stores metadata related to node's that we have/had a direct -// channel open with. Information such as the Bitcoin network the node -// advertised, and its identity public key are also stored. Additionally, this -// struct and the bucket its stored within have store data similar to that of -// Bitcoin's addrmanager. The TCP address information stored within the struct -// can be used to establish persistent connections will all channel -// counterparties on daemon startup. -// -// TODO(roasbeef): also add current OnionKey plus rotation schedule? -// TODO(roasbeef): add bitfield for supported services -// * possibly add a wire.NetAddress type, type -type LinkNode struct { - // Network indicates the Bitcoin network that the LinkNode advertises - // for incoming channel creation. - Network wire.BitcoinNet - - // IdentityPub is the node's current identity public key. Any - // channel/topology related information received by this node MUST be - // signed by this public key. - IdentityPub *btcec.PublicKey - - // LastSeen tracks the last time this node was seen within the network. - // A node should be marked as seen if the daemon either is able to - // establish an outgoing connection to the node or receives a new - // incoming connection from the node. This timestamp (stored in unix - // epoch) may be used within a heuristic which aims to determine when a - // channel should be unilaterally closed due to inactivity. - // - // TODO(roasbeef): replace with block hash/height? - // * possibly add a time-value metric into the heuristic? - LastSeen time.Time - - // Addresses is a list of IP address in which either we were able to - // reach the node over in the past, OR we received an incoming - // authenticated connection for the stored identity public key. - Addresses []net.Addr - - db *DB -} - -// NewLinkNode creates a new LinkNode from the provided parameters, which is -// backed by an instance of channeldb. -func (db *DB) NewLinkNode(bitNet wire.BitcoinNet, pub *btcec.PublicKey, - addrs ...net.Addr) *LinkNode { - - return &LinkNode{ - Network: bitNet, - IdentityPub: pub, - LastSeen: time.Now(), - Addresses: addrs, - db: db, - } -} - -// UpdateLastSeen updates the last time this node was directly encountered on -// the Lightning Network. -func (l *LinkNode) UpdateLastSeen(lastSeen time.Time) error { - l.LastSeen = lastSeen - - return l.Sync() -} - -// AddAddress appends the specified TCP address to the list of known addresses -// this node is/was known to be reachable at. -func (l *LinkNode) AddAddress(addr net.Addr) error { - for _, a := range l.Addresses { - if a.String() == addr.String() { - return nil - } - } - - l.Addresses = append(l.Addresses, addr) - - return l.Sync() -} - -// Sync performs a full database sync which writes the current up-to-date data -// within the struct to the database. -func (l *LinkNode) Sync() error { - - // Finally update the database by storing the link node and updating - // any relevant indexes. - return l.db.Update(func(tx *bbolt.Tx) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound - } - - return putLinkNode(nodeMetaBucket, l) - }) -} - -// putLinkNode serializes then writes the encoded version of the passed link -// node into the nodeMetaBucket. This function is provided in order to allow -// the ability to re-use a database transaction across many operations. -func putLinkNode(nodeMetaBucket *bbolt.Bucket, l *LinkNode) error { - // First serialize the LinkNode into its raw-bytes encoding. - var b bytes.Buffer - if err := serializeLinkNode(&b, l); err != nil { - return err - } - - // Finally insert the link-node into the node metadata bucket keyed - // according to the its pubkey serialized in compressed form. - nodePub := l.IdentityPub.SerializeCompressed() - return nodeMetaBucket.Put(nodePub, b.Bytes()) -} - -// DeleteLinkNode removes the link node with the given identity from the -// database. -func (db *DB) DeleteLinkNode(identity *btcec.PublicKey) error { - return db.Update(func(tx *bbolt.Tx) error { - return db.deleteLinkNode(tx, identity) - }) -} - -func (db *DB) deleteLinkNode(tx *bbolt.Tx, identity *btcec.PublicKey) error { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return ErrLinkNodesNotFound - } - - pubKey := identity.SerializeCompressed() - return nodeMetaBucket.Delete(pubKey) -} - -// FetchLinkNode attempts to lookup the data for a LinkNode based on a target -// identity public key. If a particular LinkNode for the passed identity public -// key cannot be found, then ErrNodeNotFound if returned. -func (db *DB) FetchLinkNode(identity *btcec.PublicKey) (*LinkNode, error) { - var linkNode *LinkNode - err := db.View(func(tx *bbolt.Tx) error { - node, err := fetchLinkNode(tx, identity) - if err != nil { - return err - } - - linkNode = node - return nil - }) - - return linkNode, err -} - -func fetchLinkNode(tx *bbolt.Tx, targetPub *btcec.PublicKey) (*LinkNode, error) { - // First fetch the bucket for storing node metadata, bailing out early - // if it hasn't been created yet. - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return nil, ErrLinkNodesNotFound - } - - // If a link node for that particular public key cannot be located, - // then exit early with an ErrNodeNotFound. - pubKey := targetPub.SerializeCompressed() - nodeBytes := nodeMetaBucket.Get(pubKey) - if nodeBytes == nil { - return nil, ErrNodeNotFound - } - - // Finally, decode and allocate a fresh LinkNode object to be returned - // to the caller. - nodeReader := bytes.NewReader(nodeBytes) - return deserializeLinkNode(nodeReader) -} - -// TODO(roasbeef): update link node addrs in server upon connection - -// FetchAllLinkNodes starts a new database transaction to fetch all nodes with -// whom we have active channels with. -func (db *DB) FetchAllLinkNodes() ([]*LinkNode, error) { - var linkNodes []*LinkNode - err := db.View(func(tx *bbolt.Tx) error { - nodes, err := db.fetchAllLinkNodes(tx) - if err != nil { - return err - } - - linkNodes = nodes - return nil - }) - if err != nil { - return nil, err - } - - return linkNodes, nil -} - -// fetchAllLinkNodes uses an existing database transaction to fetch all nodes -// with whom we have active channels with. -func (db *DB) fetchAllLinkNodes(tx *bbolt.Tx) ([]*LinkNode, error) { - nodeMetaBucket := tx.Bucket(nodeInfoBucket) - if nodeMetaBucket == nil { - return nil, ErrLinkNodesNotFound - } - - var linkNodes []*LinkNode - err := nodeMetaBucket.ForEach(func(k, v []byte) error { - if v == nil { - return nil - } - - nodeReader := bytes.NewReader(v) - linkNode, err := deserializeLinkNode(nodeReader) - if err != nil { - return err - } - - linkNodes = append(linkNodes, linkNode) - return nil - }) - if err != nil { - return nil, err - } - - return linkNodes, nil -} - -func serializeLinkNode(w io.Writer, l *LinkNode) error { - var buf [8]byte - - byteOrder.PutUint32(buf[:4], uint32(l.Network)) - if _, err := w.Write(buf[:4]); err != nil { - return err - } - - serializedID := l.IdentityPub.SerializeCompressed() - if _, err := w.Write(serializedID); err != nil { - return err - } - - seenUnix := uint64(l.LastSeen.Unix()) - byteOrder.PutUint64(buf[:], seenUnix) - if _, err := w.Write(buf[:]); err != nil { - return err - } - - numAddrs := uint32(len(l.Addresses)) - byteOrder.PutUint32(buf[:4], numAddrs) - if _, err := w.Write(buf[:4]); err != nil { - return err - } - - for _, addr := range l.Addresses { - if err := serializeAddr(w, addr); err != nil { - return err - } - } - - return nil -} - -func deserializeLinkNode(r io.Reader) (*LinkNode, error) { - var ( - err error - buf [8]byte - ) - - node := &LinkNode{} - - if _, err := io.ReadFull(r, buf[:4]); err != nil { - return nil, err - } - node.Network = wire.BitcoinNet(byteOrder.Uint32(buf[:4])) - - var pub [33]byte - if _, err := io.ReadFull(r, pub[:]); err != nil { - return nil, err - } - node.IdentityPub, err = btcec.ParsePubKey(pub[:], btcec.S256()) - if err != nil { - return nil, err - } - - if _, err := io.ReadFull(r, buf[:]); err != nil { - return nil, err - } - node.LastSeen = time.Unix(int64(byteOrder.Uint64(buf[:])), 0) - - if _, err := io.ReadFull(r, buf[:4]); err != nil { - return nil, err - } - numAddrs := byteOrder.Uint32(buf[:4]) - - node.Addresses = make([]net.Addr, numAddrs) - for i := uint32(0); i < numAddrs; i++ { - addr, err := deserializeAddr(r) - if err != nil { - return nil, err - } - node.Addresses[i] = addr - } - - return node, nil -} diff --git a/channeldb/migration_01_to_11/nodes_test.go b/channeldb/migration_01_to_11/nodes_test.go deleted file mode 100644 index 481dc5bd..00000000 --- a/channeldb/migration_01_to_11/nodes_test.go +++ /dev/null @@ -1,140 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "net" - "testing" - "time" - - "github.com/btcsuite/btcd/btcec" - "github.com/btcsuite/btcd/wire" -) - -func TestLinkNodeEncodeDecode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - // First we'll create some initial data to use for populating our test - // LinkNode instances. - _, pub1 := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - _, pub2 := btcec.PrivKeyFromBytes(btcec.S256(), rev[:]) - addr1, err := net.ResolveTCPAddr("tcp", "10.0.0.1:9000") - if err != nil { - t.Fatalf("unable to create test addr: %v", err) - } - addr2, err := net.ResolveTCPAddr("tcp", "10.0.0.2:9000") - if err != nil { - t.Fatalf("unable to create test addr: %v", err) - } - - // Create two fresh link node instances with the above dummy data, then - // fully sync both instances to disk. - node1 := cdb.NewLinkNode(wire.MainNet, pub1, addr1) - node2 := cdb.NewLinkNode(wire.TestNet3, pub2, addr2) - if err := node1.Sync(); err != nil { - t.Fatalf("unable to sync node: %v", err) - } - if err := node2.Sync(); err != nil { - t.Fatalf("unable to sync node: %v", err) - } - - // Fetch all current link nodes from the database, they should exactly - // match the two created above. - originalNodes := []*LinkNode{node2, node1} - linkNodes, err := cdb.FetchAllLinkNodes() - if err != nil { - t.Fatalf("unable to fetch nodes: %v", err) - } - for i, node := range linkNodes { - if originalNodes[i].Network != node.Network { - t.Fatalf("node networks don't match: expected %v, got %v", - originalNodes[i].Network, node.Network) - } - - originalPubkey := originalNodes[i].IdentityPub.SerializeCompressed() - dbPubkey := node.IdentityPub.SerializeCompressed() - if !bytes.Equal(originalPubkey, dbPubkey) { - t.Fatalf("node pubkeys don't match: expected %x, got %x", - originalPubkey, dbPubkey) - } - if originalNodes[i].LastSeen.Unix() != node.LastSeen.Unix() { - t.Fatalf("last seen timestamps don't match: expected %v got %v", - originalNodes[i].LastSeen.Unix(), node.LastSeen.Unix()) - } - if originalNodes[i].Addresses[0].String() != node.Addresses[0].String() { - t.Fatalf("addresses don't match: expected %v, got %v", - originalNodes[i].Addresses, node.Addresses) - } - } - - // Next, we'll exercise the methods to append additional IP - // addresses, and also to update the last seen time. - if err := node1.UpdateLastSeen(time.Now()); err != nil { - t.Fatalf("unable to update last seen: %v", err) - } - if err := node1.AddAddress(addr2); err != nil { - t.Fatalf("unable to update addr: %v", err) - } - - // Fetch the same node from the database according to its public key. - node1DB, err := cdb.FetchLinkNode(pub1) - if err != nil { - t.Fatalf("unable to find node: %v", err) - } - - // Both the last seen timestamp and the list of reachable addresses for - // the node should be updated. - if node1DB.LastSeen.Unix() != node1.LastSeen.Unix() { - t.Fatalf("last seen timestamps don't match: expected %v got %v", - node1.LastSeen.Unix(), node1DB.LastSeen.Unix()) - } - if len(node1DB.Addresses) != 2 { - t.Fatalf("wrong length for node1 addresses: expected %v, got %v", - 2, len(node1DB.Addresses)) - } - if node1DB.Addresses[0].String() != addr1.String() { - t.Fatalf("wrong address for node: expected %v, got %v", - addr1.String(), node1DB.Addresses[0].String()) - } - if node1DB.Addresses[1].String() != addr2.String() { - t.Fatalf("wrong address for node: expected %v, got %v", - addr2.String(), node1DB.Addresses[1].String()) - } -} - -func TestDeleteLinkNode(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - _, pubKey := btcec.PrivKeyFromBytes(btcec.S256(), key[:]) - addr := &net.TCPAddr{ - IP: net.ParseIP("127.0.0.1"), - Port: 1337, - } - linkNode := cdb.NewLinkNode(wire.TestNet3, pubKey, addr) - if err := linkNode.Sync(); err != nil { - t.Fatalf("unable to write link node to db: %v", err) - } - - if _, err := cdb.FetchLinkNode(pubKey); err != nil { - t.Fatalf("unable to find link node: %v", err) - } - - if err := cdb.DeleteLinkNode(pubKey); err != nil { - t.Fatalf("unable to delete link node from db: %v", err) - } - - if _, err := cdb.FetchLinkNode(pubKey); err == nil { - t.Fatal("should not have found link node in db, but did") - } -} diff --git a/channeldb/migration_01_to_11/options.go b/channeldb/migration_01_to_11/options.go index c3cc2c4a..03b287e0 100644 --- a/channeldb/migration_01_to_11/options.go +++ b/channeldb/migration_01_to_11/options.go @@ -39,24 +39,3 @@ func DefaultOptions() Options { // OptionModifier is a function signature for modifying the default Options. type OptionModifier func(*Options) - -// OptionSetRejectCacheSize sets the RejectCacheSize to n. -func OptionSetRejectCacheSize(n int) OptionModifier { - return func(o *Options) { - o.RejectCacheSize = n - } -} - -// OptionSetChannelCacheSize sets the ChannelCacheSize to n. -func OptionSetChannelCacheSize(n int) OptionModifier { - return func(o *Options) { - o.ChannelCacheSize = n - } -} - -// OptionSetSyncFreelist allows the database to sync its freelist. -func OptionSetSyncFreelist(b bool) OptionModifier { - return func(o *Options) { - o.NoFreelistSync = !b - } -} diff --git a/channeldb/migration_01_to_11/payment_control.go b/channeldb/migration_01_to_11/payment_control.go index 83b1649a..7b069d24 100644 --- a/channeldb/migration_01_to_11/payment_control.go +++ b/channeldb/migration_01_to_11/payment_control.go @@ -1,373 +1,9 @@ package migration_01_to_11 import ( - "bytes" - "encoding/binary" - "errors" - "fmt" - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" ) -var ( - // ErrAlreadyPaid signals we have already paid this payment hash. - ErrAlreadyPaid = errors.New("invoice is already paid") - - // ErrPaymentInFlight signals that payment for this payment hash is - // already "in flight" on the network. - ErrPaymentInFlight = errors.New("payment is in transition") - - // ErrPaymentNotInitiated is returned if payment wasn't initiated in - // switch. - ErrPaymentNotInitiated = errors.New("payment isn't initiated") - - // ErrPaymentAlreadySucceeded is returned in the event we attempt to - // change the status of a payment already succeeded. - ErrPaymentAlreadySucceeded = errors.New("payment is already succeeded") - - // ErrPaymentAlreadyFailed is returned in the event we attempt to - // re-fail a failed payment. - ErrPaymentAlreadyFailed = errors.New("payment has already failed") - - // ErrUnknownPaymentStatus is returned when we do not recognize the - // existing state of a payment. - ErrUnknownPaymentStatus = errors.New("unknown payment status") - - // errNoAttemptInfo is returned when no attempt info is stored yet. - errNoAttemptInfo = errors.New("unable to find attempt info for " + - "inflight payment") -) - -// PaymentControl implements persistence for payments and payment attempts. -type PaymentControl struct { - db *DB -} - -// NewPaymentControl creates a new instance of the PaymentControl. -func NewPaymentControl(db *DB) *PaymentControl { - return &PaymentControl{ - db: db, - } -} - -// InitPayment checks or records the given PaymentCreationInfo with the DB, -// making sure it does not already exist as an in-flight payment. Then this -// method returns successfully, the payment is guranteeed to be in the InFlight -// state. -func (p *PaymentControl) InitPayment(paymentHash lntypes.Hash, - info *PaymentCreationInfo) error { - - var b bytes.Buffer - if err := serializePaymentCreationInfo(&b, info); err != nil { - return err - } - infoBytes := b.Bytes() - - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := createPaymentBucket(tx, paymentHash) - if err != nil { - return err - } - - // Get the existing status of this payment, if any. - paymentStatus := fetchPaymentStatus(bucket) - - switch paymentStatus { - - // We allow retrying failed payments. - case StatusFailed: - - // This is a new payment that is being initialized for the - // first time. - case StatusUnknown: - - // We already have an InFlight payment on the network. We will - // disallow any new payments. - case StatusInFlight: - updateErr = ErrPaymentInFlight - return nil - - // We've already succeeded a payment to this payment hash, - // forbid the switch from sending another. - case StatusSucceeded: - updateErr = ErrAlreadyPaid - return nil - - default: - updateErr = ErrUnknownPaymentStatus - return nil - } - - // Obtain a new sequence number for this payment. This is used - // to sort the payments in order of creation, and also acts as - // a unique identifier for each payment. - sequenceNum, err := nextPaymentSequence(tx) - if err != nil { - return err - } - - err = bucket.Put(paymentSequenceKey, sequenceNum) - if err != nil { - return err - } - - // Add the payment info to the bucket, which contains the - // static information for this payment - err = bucket.Put(paymentCreationInfoKey, infoBytes) - if err != nil { - return err - } - - // We'll delete any lingering attempt info to start with, in - // case we are initializing a payment that was attempted - // earlier, but left in a state where we could retry. - err = bucket.Delete(paymentAttemptInfoKey) - if err != nil { - return err - } - - // Also delete any lingering failure info now that we are - // re-attempting. - return bucket.Delete(paymentFailInfoKey) - }) - if err != nil { - return err - } - - return updateErr -} - -// RegisterAttempt atomically records the provided PaymentAttemptInfo to the -// DB. -func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, - attempt *PaymentAttemptInfo) error { - - // Serialize the information before opening the db transaction. - var a bytes.Buffer - if err := serializePaymentAttemptInfo(&a, attempt); err != nil { - return err - } - attemptBytes := a.Bytes() - - var updateErr error - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only register attempts for payments that are - // in-flight. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Add the payment attempt to the payments bucket. - return bucket.Put(paymentAttemptInfoKey, attemptBytes) - }) - if err != nil { - return err - } - - return updateErr -} - -// Success transitions a payment into the Succeeded state. After invoking this -// method, InitPayment should always return an error to prevent us from making -// duplicate payments to the same payment hash. The provided preimage is -// atomically saved to the DB for record keeping. -func (p *PaymentControl) Success(paymentHash lntypes.Hash, - preimage lntypes.Preimage) (*route.Route, error) { - - var ( - updateErr error - route *route.Route - ) - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only mark in-flight payments as succeeded. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Record the successful payment info atomically to the - // payments record. - err = bucket.Put(paymentSettleInfoKey, preimage[:]) - if err != nil { - return err - } - - // Retrieve attempt info for the notification. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil { - return err - } - - route = &attempt.Route - - return nil - }) - if err != nil { - return nil, err - } - - return route, updateErr -} - -// Fail transitions a payment into the Failed state, and records the reason the -// payment failed. After invoking this method, InitPayment should return nil on -// its next call for this payment hash, allowing the switch to make a -// subsequent payment. -func (p *PaymentControl) Fail(paymentHash lntypes.Hash, - reason FailureReason) (*route.Route, error) { - - var ( - updateErr error - route *route.Route - ) - err := p.db.Batch(func(tx *bbolt.Tx) error { - // Reset the update error, to avoid carrying over an error - // from a previous execution of the batched db transaction. - updateErr = nil - - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err == ErrPaymentNotInitiated { - updateErr = ErrPaymentNotInitiated - return nil - } else if err != nil { - return err - } - - // We can only mark in-flight payments as failed. - if err := ensureInFlight(bucket); err != nil { - updateErr = err - return nil - } - - // Put the failure reason in the bucket for record keeping. - v := []byte{byte(reason)} - err = bucket.Put(paymentFailInfoKey, v) - if err != nil { - return err - } - - // Retrieve attempt info for the notification, if available. - attempt, err := fetchPaymentAttempt(bucket) - if err != nil && err != errNoAttemptInfo { - return err - } - if err != errNoAttemptInfo { - route = &attempt.Route - } - - return nil - }) - if err != nil { - return nil, err - } - - return route, updateErr -} - -// FetchPayment returns information about a payment from the database. -func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( - *Payment, error) { - - var payment *Payment - err := p.db.View(func(tx *bbolt.Tx) error { - bucket, err := fetchPaymentBucket(tx, paymentHash) - if err != nil { - return err - } - - payment, err = fetchPayment(bucket) - - return err - }) - if err != nil { - return nil, err - } - - return payment, nil -} - -// createPaymentBucket creates or fetches the sub-bucket assigned to this -// payment hash. -func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( - *bbolt.Bucket, error) { - - payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) - if err != nil { - return nil, err - } - - return payments.CreateBucketIfNotExists(paymentHash[:]) -} - -// fetchPaymentBucket fetches the sub-bucket assigned to this payment hash. If -// the bucket does not exist, it returns ErrPaymentNotInitiated. -func fetchPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( - *bbolt.Bucket, error) { - - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil, ErrPaymentNotInitiated - } - - bucket := payments.Bucket(paymentHash[:]) - if bucket == nil { - return nil, ErrPaymentNotInitiated - } - - return bucket, nil - -} - -// nextPaymentSequence returns the next sequence number to store for a new -// payment. -func nextPaymentSequence(tx *bbolt.Tx) ([]byte, error) { - payments, err := tx.CreateBucketIfNotExists(paymentsRootBucket) - if err != nil { - return nil, err - } - - seq, err := payments.NextSequence() - if err != nil { - return nil, err - } - - b := make([]byte, 8) - binary.BigEndian.PutUint64(b, seq) - return b, nil -} - // fetchPaymentStatus fetches the payment status of the payment. If the payment // isn't found, it will default to "StatusUnknown". func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { @@ -385,113 +21,3 @@ func fetchPaymentStatus(bucket *bbolt.Bucket) PaymentStatus { return StatusUnknown } - -// ensureInFlight checks whether the payment found in the given bucket has -// status InFlight, and returns an error otherwise. This should be used to -// ensure we only mark in-flight payments as succeeded or failed. -func ensureInFlight(bucket *bbolt.Bucket) error { - paymentStatus := fetchPaymentStatus(bucket) - - switch { - - // The payment was indeed InFlight, return. - case paymentStatus == StatusInFlight: - return nil - - // Our records show the payment as unknown, meaning it never - // should have left the switch. - case paymentStatus == StatusUnknown: - return ErrPaymentNotInitiated - - // The payment succeeded previously. - case paymentStatus == StatusSucceeded: - return ErrPaymentAlreadySucceeded - - // The payment was already failed. - case paymentStatus == StatusFailed: - return ErrPaymentAlreadyFailed - - default: - return ErrUnknownPaymentStatus - } -} - -// fetchPaymentAttempt fetches the payment attempt from the bucket. -func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) { - attemptData := bucket.Get(paymentAttemptInfoKey) - if attemptData == nil { - return nil, errNoAttemptInfo - } - - r := bytes.NewReader(attemptData) - return deserializePaymentAttemptInfo(r) -} - -// InFlightPayment is a wrapper around a payment that has status InFlight. -type InFlightPayment struct { - // Info is the PaymentCreationInfo of the in-flight payment. - Info *PaymentCreationInfo - - // Attempt contains information about the last payment attempt that was - // made to this payment hash. - // - // NOTE: Might be nil. - Attempt *PaymentAttemptInfo -} - -// FetchInFlightPayments returns all payments with status InFlight. -func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { - var inFlights []*InFlightPayment - err := p.db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - return payments.ForEach(func(k, _ []byte) error { - bucket := payments.Bucket(k) - if bucket == nil { - return fmt.Errorf("non bucket element") - } - - // If the status is not InFlight, we can return early. - paymentStatus := fetchPaymentStatus(bucket) - if paymentStatus != StatusInFlight { - return nil - } - - var ( - inFlight = &InFlightPayment{} - err error - ) - - // Get the CreationInfo. - b := bucket.Get(paymentCreationInfoKey) - if b == nil { - return fmt.Errorf("unable to find creation " + - "info for inflight payment") - } - - r := bytes.NewReader(b) - inFlight.Info, err = deserializePaymentCreationInfo(r) - if err != nil { - return err - } - - // Now get the attempt info. It could be that there is - // no attempt info yet. - inFlight.Attempt, err = fetchPaymentAttempt(bucket) - if err != nil && err != errNoAttemptInfo { - return err - } - - inFlights = append(inFlights, inFlight) - return nil - }) - }) - if err != nil { - return nil, err - } - - return inFlights, nil -} diff --git a/channeldb/migration_01_to_11/payment_control_test.go b/channeldb/migration_01_to_11/payment_control_test.go deleted file mode 100644 index 9868475e..00000000 --- a/channeldb/migration_01_to_11/payment_control_test.go +++ /dev/null @@ -1,550 +0,0 @@ -package migration_01_to_11 - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "io/ioutil" - "reflect" - "testing" - "time" - - "github.com/btcsuite/fastsha256" - "github.com/coreos/bbolt" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lntypes" - "github.com/lightningnetwork/lnd/routing/route" -) - -func initDB() (*DB, error) { - tempPath, err := ioutil.TempDir("", "switchdb") - if err != nil { - return nil, err - } - - db, err := Open(tempPath) - if err != nil { - return nil, err - } - - return db, err -} - -func genPreimage() ([32]byte, error) { - var preimage [32]byte - if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { - return preimage, err - } - return preimage, nil -} - -func genInfo() (*PaymentCreationInfo, *PaymentAttemptInfo, - lntypes.Preimage, error) { - - preimage, err := genPreimage() - if err != nil { - return nil, nil, preimage, fmt.Errorf("unable to "+ - "generate preimage: %v", err) - } - - rhash := fastsha256.Sum256(preimage[:]) - return &PaymentCreationInfo{ - PaymentHash: rhash, - Value: 1, - CreationDate: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte("hola"), - }, - &PaymentAttemptInfo{ - PaymentID: 1, - SessionKey: priv, - Route: testRoute, - }, preimage, nil -} - -// TestPaymentControlSwitchFail checks that payment status returns to Failed -// status after failing, and that InitPayment allows another HTLC for the -// same payment hash. -func TestPaymentControlSwitchFail(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = pControl.Fail(info.PaymentHash, failReason) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - &failReason, - ) - - // Sends the htlc again, which should succeed since the prior payment - // failed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Record a new attempt. - attempt.PaymentID = 2 - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, - nil, - ) - - // Verifies that status was changed to StatusSucceeded. - var route *route.Route - route, err = pControl.Success(info.PaymentHash, preimg) - if err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - err = assertRouteEqual(route, &attempt.Route) - if err != nil { - t.Fatalf("unexpected route returned: %v vs %v: %v", - spew.Sdump(attempt.Route), spew.Sdump(*route), err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) - - // Attempt a final payment, which should now fail since the prior - // payment succeed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrAlreadyPaid { - t.Fatalf("unable to send htlc message: %v", err) - } -} - -// TestPaymentControlSwitchDoubleSend checks the ability of payment control to -// prevent double sending of htlc message, when message is in StatusInFlight. -func TestPaymentControlSwitchDoubleSend(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate base status and move it to - // StatusInFlight and verifies that it was changed. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, nil, lntypes.Preimage{}, - nil, - ) - - // Try to initiate double sending of htlc message with the same - // payment hash, should result in error indicating that payment has - // already been sent. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrPaymentInFlight { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") - } - - // Record an attempt. - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, lntypes.Preimage{}, - nil, - ) - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrPaymentInFlight { - t.Fatalf("payment control wrong behaviour: " + - "double sending must trigger ErrPaymentInFlight error") - } - - // After settling, the error should be ErrAlreadyPaid. - if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) - - err = pControl.InitPayment(info.PaymentHash, info) - if err != ErrAlreadyPaid { - t.Fatalf("unable to send htlc message: %v", err) - } -} - -// TestPaymentControlSuccessesWithoutInFlight checks that the payment -// control will disallow calls to Success when no payment is in flight. -func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, _, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Attempt to complete the payment should fail. - _, err = pControl.Success(info.PaymentHash, preimg) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, - nil, - ) -} - -// TestPaymentControlFailsWithoutInFlight checks that a strict payment -// control will disallow calls to Fail when no payment is in flight. -func TestPaymentControlFailsWithoutInFlight(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - info, _, _, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Calling Fail should return an error. - _, err = pControl.Fail(info.PaymentHash, FailureReasonNoRoute) - if err != ErrPaymentNotInitiated { - t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusUnknown) - assertPaymentInfo( - t, db, info.PaymentHash, nil, nil, lntypes.Preimage{}, nil, - ) -} - -// TestPaymentControlDeleteNonInFlight checks that calling DeletaPayments only -// deletes payments from the database that are not in-flight. -func TestPaymentControlDeleteNonInFligt(t *testing.T) { - t.Parallel() - - db, err := initDB() - if err != nil { - t.Fatalf("unable to init db: %v", err) - } - - pControl := NewPaymentControl(db) - - payments := []struct { - failed bool - success bool - }{ - { - failed: true, - success: false, - }, - { - failed: false, - success: true, - }, - { - failed: false, - success: false, - }, - } - - for _, p := range payments { - info, attempt, preimg, err := genInfo() - if err != nil { - t.Fatalf("unable to generate htlc message: %v", err) - } - - // Sends base htlc message which initiate StatusInFlight. - err = pControl.InitPayment(info.PaymentHash, info) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - err = pControl.RegisterAttempt(info.PaymentHash, attempt) - if err != nil { - t.Fatalf("unable to send htlc message: %v", err) - } - - if p.failed { - // Fail the payment, which should moved it to Failed. - failReason := FailureReasonNoRoute - _, err = pControl.Fail(info.PaymentHash, failReason) - if err != nil { - t.Fatalf("unable to fail payment hash: %v", err) - } - - // Verify the status is indeed Failed. - assertPaymentStatus(t, db, info.PaymentHash, StatusFailed) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, - lntypes.Preimage{}, &failReason, - ) - } else if p.success { - // Verifies that status was changed to StatusSucceeded. - _, err := pControl.Success(info.PaymentHash, preimg) - if err != nil { - t.Fatalf("error shouldn't have been received, got: %v", err) - } - - assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, preimg, nil, - ) - } else { - assertPaymentStatus(t, db, info.PaymentHash, StatusInFlight) - assertPaymentInfo( - t, db, info.PaymentHash, info, attempt, - lntypes.Preimage{}, nil, - ) - } - } - - // Delete payments. - if err := db.DeletePayments(); err != nil { - t.Fatal(err) - } - - // This should leave the in-flight payment. - dbPayments, err := db.FetchPayments() - if err != nil { - t.Fatal(err) - } - - if len(dbPayments) != 1 { - t.Fatalf("expected one payment, got %d", len(dbPayments)) - } - - status := dbPayments[0].Status - if status != StatusInFlight { - t.Fatalf("expected in-fligth status, got %v", status) - } -} - -func assertPaymentStatus(t *testing.T, db *DB, - hash [32]byte, expStatus PaymentStatus) { - - t.Helper() - - var paymentStatus = StatusUnknown - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil { - return nil - } - - // Get the existing status of this payment, if any. - paymentStatus = fetchPaymentStatus(bucket) - return nil - }) - if err != nil { - t.Fatalf("unable to fetch payment status: %v", err) - } - - if paymentStatus != expStatus { - t.Fatalf("payment status mismatch: expected %v, got %v", - expStatus, paymentStatus) - } -} - -func checkPaymentCreationInfo(bucket *bbolt.Bucket, c *PaymentCreationInfo) error { - b := bucket.Get(paymentCreationInfoKey) - switch { - case b == nil && c == nil: - return nil - case b == nil: - return fmt.Errorf("expected creation info not found") - case c == nil: - return fmt.Errorf("unexpected creation info found") - } - - r := bytes.NewReader(b) - c2, err := deserializePaymentCreationInfo(r) - if err != nil { - return err - } - if !reflect.DeepEqual(c, c2) { - return fmt.Errorf("PaymentCreationInfos don't match: %v vs %v", - spew.Sdump(c), spew.Sdump(c2)) - } - - return nil -} - -func checkPaymentAttemptInfo(bucket *bbolt.Bucket, a *PaymentAttemptInfo) error { - b := bucket.Get(paymentAttemptInfoKey) - switch { - case b == nil && a == nil: - return nil - case b == nil: - return fmt.Errorf("expected attempt info not found") - case a == nil: - return fmt.Errorf("unexpected attempt info found") - } - - r := bytes.NewReader(b) - a2, err := deserializePaymentAttemptInfo(r) - if err != nil { - return err - } - - return assertRouteEqual(&a.Route, &a2.Route) -} - -func checkSettleInfo(bucket *bbolt.Bucket, preimg lntypes.Preimage) error { - zero := lntypes.Preimage{} - b := bucket.Get(paymentSettleInfoKey) - switch { - case b == nil && preimg == zero: - return nil - case b == nil: - return fmt.Errorf("expected preimage not found") - case preimg == zero: - return fmt.Errorf("unexpected preimage found") - } - - var pre2 lntypes.Preimage - copy(pre2[:], b[:]) - if preimg != pre2 { - return fmt.Errorf("Preimages don't match: %x vs %x", - preimg, pre2) - } - - return nil -} - -func checkFailInfo(bucket *bbolt.Bucket, failReason *FailureReason) error { - b := bucket.Get(paymentFailInfoKey) - switch { - case b == nil && failReason == nil: - return nil - case b == nil: - return fmt.Errorf("expected fail info not found") - case failReason == nil: - return fmt.Errorf("unexpected fail info found") - } - - failReason2 := FailureReason(b[0]) - if *failReason != failReason2 { - return fmt.Errorf("Failure infos don't match: %v vs %v", - *failReason, failReason2) - } - - return nil -} - -func assertPaymentInfo(t *testing.T, db *DB, hash lntypes.Hash, - c *PaymentCreationInfo, a *PaymentAttemptInfo, s lntypes.Preimage, - f *FailureReason) { - - t.Helper() - - err := db.View(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil && c == nil { - return nil - } - if payments == nil { - return fmt.Errorf("sent payments not found") - } - - bucket := payments.Bucket(hash[:]) - if bucket == nil && c == nil { - return nil - } - - if bucket == nil { - return fmt.Errorf("payment not found") - } - - if err := checkPaymentCreationInfo(bucket, c); err != nil { - return err - } - - if err := checkPaymentAttemptInfo(bucket, a); err != nil { - return err - } - - if err := checkSettleInfo(bucket, s); err != nil { - return err - } - - if err := checkFailInfo(bucket, f); err != nil { - return err - } - return nil - }) - if err != nil { - t.Fatalf("assert payment info failed: %v", err) - } - -} diff --git a/channeldb/migration_01_to_11/payments.go b/channeldb/migration_01_to_11/payments.go index fd3db5a1..d34cd6e9 100644 --- a/channeldb/migration_01_to_11/payments.go +++ b/channeldb/migration_01_to_11/payments.go @@ -375,48 +375,6 @@ func fetchPayment(bucket *bbolt.Bucket) (*Payment, error) { return p, nil } -// DeletePayments deletes all completed and failed payments from the DB. -func (db *DB) DeletePayments() error { - return db.Update(func(tx *bbolt.Tx) error { - payments := tx.Bucket(paymentsRootBucket) - if payments == nil { - return nil - } - - var deleteBuckets [][]byte - err := payments.ForEach(func(k, _ []byte) error { - bucket := payments.Bucket(k) - if bucket == nil { - // We only expect sub-buckets to be found in - // this top-level bucket. - return fmt.Errorf("non bucket element in " + - "payments bucket") - } - - // If the status is InFlight, we cannot safely delete - // the payment information, so we return early. - paymentStatus := fetchPaymentStatus(bucket) - if paymentStatus == StatusInFlight { - return nil - } - - deleteBuckets = append(deleteBuckets, k) - return nil - }) - if err != nil { - return err - } - - for _, k := range deleteBuckets { - if err := payments.DeleteBucket(k); err != nil { - return err - } - } - - return nil - }) -} - func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { var scratch [8]byte diff --git a/channeldb/migration_01_to_11/payments_test.go b/channeldb/migration_01_to_11/payments_test.go index 07307941..c5584079 100644 --- a/channeldb/migration_01_to_11/payments_test.go +++ b/channeldb/migration_01_to_11/payments_test.go @@ -2,55 +2,17 @@ package migration_01_to_11 import ( "bytes" - "errors" "fmt" "math/rand" - "reflect" - "testing" "time" "github.com/btcsuite/btcd/btcec" - "github.com/davecgh/go-spew/spew" - "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" ) var ( priv, _ = btcec.NewPrivateKey(btcec.S256()) pub = priv.PubKey() - - tlvBytes = []byte{1, 2, 3} - tlvEncoder = tlv.StubEncoder(tlvBytes) - testHop1 = &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - TLVRecords: []tlv.Record{ - tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), - tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), - }, - } - - testHop2 = &route.Hop{ - PubKeyBytes: route.NewVertex(pub), - ChannelID: 12345, - OutgoingTimeLock: 111, - AmtToForward: 555, - LegacyPayload: true, - } - - testRoute = route.Route{ - TotalTimeLock: 123, - TotalAmount: 1234567, - SourcePubKey: route.NewVertex(pub), - Hops: []*route.Hop{ - testHop1, - testHop2, - }, - } ) func makeFakePayment() *outgoingPayment { @@ -81,27 +43,6 @@ func makeFakePayment() *outgoingPayment { return fakePayment } -func makeFakeInfo() (*PaymentCreationInfo, *PaymentAttemptInfo) { - var preimg lntypes.Preimage - copy(preimg[:], rev[:]) - - c := &PaymentCreationInfo{ - PaymentHash: preimg.Hash(), - Value: 1000, - // Use single second precision to avoid false positive test - // failures due to the monotonic time component. - CreationDate: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte(""), - } - - a := &PaymentAttemptInfo{ - PaymentID: 44, - SessionKey: priv, - Route: testRoute, - } - return c, a -} - // randomBytes creates random []byte with length in range [minLen, maxLen) func randomBytes(minLen, maxLen int) ([]byte, error) { randBuf := make([]byte, minLen+rand.Intn(maxLen-minLen)) @@ -165,160 +106,3 @@ func makeRandomFakePayment() (*outgoingPayment, error) { return fakePayment, nil } - -func TestSentPaymentSerialization(t *testing.T) { - t.Parallel() - - c, s := makeFakeInfo() - - var b bytes.Buffer - if err := serializePaymentCreationInfo(&b, c); err != nil { - t.Fatalf("unable to serialize creation info: %v", err) - } - - newCreationInfo, err := deserializePaymentCreationInfo(&b) - if err != nil { - t.Fatalf("unable to deserialize creation info: %v", err) - } - - if !reflect.DeepEqual(c, newCreationInfo) { - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(c), spew.Sdump(newCreationInfo), - ) - } - - b.Reset() - if err := serializePaymentAttemptInfo(&b, s); err != nil { - t.Fatalf("unable to serialize info: %v", err) - } - - newAttemptInfo, err := deserializePaymentAttemptInfo(&b) - if err != nil { - t.Fatalf("unable to deserialize info: %v", err) - } - - // First we verify all the records match up porperly, as they aren't - // able to be properly compared using reflect.DeepEqual. - err = assertRouteEqual(&s.Route, &newAttemptInfo.Route) - if err != nil { - t.Fatalf("Routes do not match after "+ - "serialization/deserialization: %v", err) - } - - // Clear routes to allow DeepEqual to compare the remaining fields. - newAttemptInfo.Route = route.Route{} - s.Route = route.Route{} - - if !reflect.DeepEqual(s, newAttemptInfo) { - s.SessionKey.Curve = nil - newAttemptInfo.SessionKey.Curve = nil - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(s), spew.Sdump(newAttemptInfo), - ) - } -} - -// assertRouteEquals compares to routes for equality and returns an error if -// they are not equal. -func assertRouteEqual(a, b *route.Route) error { - err := assertRouteHopRecordsEqual(a, b) - if err != nil { - return err - } - - // TLV records have already been compared and need to be cleared to - // properly compare the remaining fields using DeepEqual. - copyRouteNoHops := func(r *route.Route) *route.Route { - copy := *r - copy.Hops = make([]*route.Hop, len(r.Hops)) - for i, hop := range r.Hops { - hopCopy := *hop - hopCopy.TLVRecords = nil - copy.Hops[i] = &hopCopy - } - return © - } - - if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { - return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", - spew.Sdump(a), spew.Sdump(b)) - } - - return nil -} - -func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { - if len(r1.Hops) != len(r2.Hops) { - return errors.New("route hop count mismatch") - } - - for i := 0; i < len(r1.Hops); i++ { - records1 := r1.Hops[i].TLVRecords - records2 := r2.Hops[i].TLVRecords - if len(records1) != len(records2) { - return fmt.Errorf("route record count for hop %v "+ - "mismatch", i) - } - - for j := 0; j < len(records1); j++ { - expectedRecord := records1[j] - newRecord := records2[j] - - err := assertHopRecordsEqual(expectedRecord, newRecord) - if err != nil { - return fmt.Errorf("route record mismatch: %v", err) - } - } - } - - return nil -} - -func assertHopRecordsEqual(h1, h2 tlv.Record) error { - if h1.Type() != h2.Type() { - return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(), - h2.Type()) - } - - var b bytes.Buffer - if err := h2.Encode(&b); err != nil { - return fmt.Errorf("unable to encode record: %v", err) - } - - if !bytes.Equal(b.Bytes(), tlvBytes) { - return fmt.Errorf("wrong raw record: expected %x, got %x", - tlvBytes, b.Bytes()) - } - - if h1.Size() != h2.Size() { - return fmt.Errorf("wrong size: expected %v, "+ - "got %v", h1.Size(), h2.Size()) - } - - return nil -} - -func TestRouteSerialization(t *testing.T) { - t.Parallel() - - var b bytes.Buffer - if err := SerializeRoute(&b, testRoute); err != nil { - t.Fatal(err) - } - - r := bytes.NewReader(b.Bytes()) - route2, err := DeserializeRoute(r) - if err != nil { - t.Fatal(err) - } - - // First we verify all the records match up porperly, as they aren't - // able to be properly compared using reflect.DeepEqual. - err = assertRouteEqual(&testRoute, &route2) - if err != nil { - t.Fatalf("routes not equal: \n%v vs \n%v", - spew.Sdump(testRoute), spew.Sdump(route2)) - } -} diff --git a/channeldb/migration_01_to_11/reject_cache.go b/channeldb/migration_01_to_11/reject_cache.go deleted file mode 100644 index c54d78a8..00000000 --- a/channeldb/migration_01_to_11/reject_cache.go +++ /dev/null @@ -1,95 +0,0 @@ -package migration_01_to_11 - -// rejectFlags is a compact representation of various metadata stored by the -// reject cache about a particular channel. -type rejectFlags uint8 - -const ( - // rejectFlagExists is a flag indicating whether the channel exists, - // i.e. the channel is open and has a recent channel update. If this - // flag is not set, the channel is either a zombie or unknown. - rejectFlagExists rejectFlags = 1 << iota - - // rejectFlagZombie is a flag indicating whether the channel is a - // zombie, i.e. the channel is open but has no recent channel updates. - rejectFlagZombie -) - -// packRejectFlags computes the rejectFlags corresponding to the passed boolean -// values indicating whether the edge exists or is a zombie. -func packRejectFlags(exists, isZombie bool) rejectFlags { - var flags rejectFlags - if exists { - flags |= rejectFlagExists - } - if isZombie { - flags |= rejectFlagZombie - } - - return flags -} - -// unpack returns the booleans packed into the rejectFlags. The first indicates -// if the edge exists in our graph, the second indicates if the edge is a -// zombie. -func (f rejectFlags) unpack() (bool, bool) { - return f&rejectFlagExists == rejectFlagExists, - f&rejectFlagZombie == rejectFlagZombie -} - -// rejectCacheEntry caches frequently accessed information about a channel, -// including the timestamps of its latest edge policies and whether or not the -// channel exists in the graph. -type rejectCacheEntry struct { - upd1Time int64 - upd2Time int64 - flags rejectFlags -} - -// rejectCache is an in-memory cache used to improve the performance of -// HasChannelEdge. It caches information about the whether or channel exists, as -// well as the most recent timestamps for each policy (if they exists). -type rejectCache struct { - n int - edges map[uint64]rejectCacheEntry -} - -// newRejectCache creates a new rejectCache with maximum capacity of n entries. -func newRejectCache(n int) *rejectCache { - return &rejectCache{ - n: n, - edges: make(map[uint64]rejectCacheEntry, n), - } -} - -// get returns the entry from the cache for chanid, if it exists. -func (c *rejectCache) get(chanid uint64) (rejectCacheEntry, bool) { - entry, ok := c.edges[chanid] - return entry, ok -} - -// insert adds the entry to the reject cache. If an entry for chanid already -// exists, it will be replaced with the new entry. If the entry doesn't exists, -// it will be inserted to the cache, performing a random eviction if the cache -// is at capacity. -func (c *rejectCache) insert(chanid uint64, entry rejectCacheEntry) { - // If entry exists, replace it. - if _, ok := c.edges[chanid]; ok { - c.edges[chanid] = entry - return - } - - // Otherwise, evict an entry at random and insert. - if len(c.edges) == c.n { - for id := range c.edges { - delete(c.edges, id) - break - } - } - c.edges[chanid] = entry -} - -// remove deletes an entry for chanid from the cache, if it exists. -func (c *rejectCache) remove(chanid uint64) { - delete(c.edges, chanid) -} diff --git a/channeldb/migration_01_to_11/reject_cache_test.go b/channeldb/migration_01_to_11/reject_cache_test.go deleted file mode 100644 index e15e0a10..00000000 --- a/channeldb/migration_01_to_11/reject_cache_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package migration_01_to_11 - -import ( - "reflect" - "testing" -) - -// TestRejectCache checks the behavior of the rejectCache with respect to insertion, -// eviction, and removal of cache entries. -func TestRejectCache(t *testing.T) { - const cacheSize = 100 - - // Create a new reject cache with the configured max size. - c := newRejectCache(cacheSize) - - // As a sanity check, assert that querying the empty cache does not - // return an entry. - _, ok := c.get(0) - if ok { - t.Fatalf("reject cache should be empty") - } - - // Now, fill up the cache entirely. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) - } - - // Assert that the cache has all of the entries just inserted, since no - // eviction should occur until we try to surpass the max size. - assertHasEntries(t, c, 0, cacheSize) - - // Now, insert a new element that causes the cache to evict an element. - c.insert(cacheSize, entryForInt(cacheSize)) - - // Assert that the cache has this last entry, as the cache should evict - // some prior element and not the newly inserted one. - assertHasEntries(t, c, cacheSize, cacheSize) - - // Iterate over all inserted elements and construct a set of the evicted - // elements. - evicted := make(map[uint64]struct{}) - for i := uint64(0); i < cacheSize+1; i++ { - _, ok := c.get(i) - if !ok { - evicted[i] = struct{}{} - } - } - - // Assert that exactly one element has been evicted. - numEvicted := len(evicted) - if numEvicted != 1 { - t.Fatalf("expected one evicted entry, got: %d", numEvicted) - } - - // Remove the highest item which initially caused the eviction and - // reinsert the element that was evicted prior. - c.remove(cacheSize) - for i := range evicted { - c.insert(i, entryForInt(i)) - } - - // Since the removal created an extra slot, the last insertion should - // not have caused an eviction and the entries for all channels in the - // original set that filled the cache should be present. - assertHasEntries(t, c, 0, cacheSize) - - // Finally, reinsert the existing set back into the cache and test that - // the cache still has all the entries. If the randomized eviction were - // happening on inserts for existing cache items, we expect this to fail - // with high probability. - for i := uint64(0); i < cacheSize; i++ { - c.insert(i, entryForInt(i)) - } - assertHasEntries(t, c, 0, cacheSize) - -} - -// assertHasEntries queries the reject cache for all channels in the range [start, -// end), asserting that they exist and their value matches the entry produced by -// entryForInt. -func assertHasEntries(t *testing.T, c *rejectCache, start, end uint64) { - t.Helper() - - for i := start; i < end; i++ { - entry, ok := c.get(i) - if !ok { - t.Fatalf("reject cache should contain chan %d", i) - } - - expEntry := entryForInt(i) - if !reflect.DeepEqual(entry, expEntry) { - t.Fatalf("entry mismatch, want: %v, got: %v", - expEntry, entry) - } - } -} - -// entryForInt generates a unique rejectCacheEntry given an integer. -func entryForInt(i uint64) rejectCacheEntry { - exists := i%2 == 0 - isZombie := i%3 == 0 - return rejectCacheEntry{ - upd1Time: int64(2 * i), - upd2Time: int64(2*i + 1), - flags: packRejectFlags(exists, isZombie), - } -} diff --git a/channeldb/migration_01_to_11/waitingproof.go b/channeldb/migration_01_to_11/waitingproof.go deleted file mode 100644 index 64729116..00000000 --- a/channeldb/migration_01_to_11/waitingproof.go +++ /dev/null @@ -1,251 +0,0 @@ -package migration_01_to_11 - -import ( - "encoding/binary" - "sync" - - "io" - - "bytes" - - "github.com/coreos/bbolt" - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" -) - -var ( - // waitingProofsBucketKey byte string name of the waiting proofs store. - waitingProofsBucketKey = []byte("waitingproofs") - - // ErrWaitingProofNotFound is returned if waiting proofs haven't been - // found by db. - ErrWaitingProofNotFound = errors.New("waiting proofs haven't been " + - "found") - - // ErrWaitingProofAlreadyExist is returned if waiting proofs haven't been - // found by db. - ErrWaitingProofAlreadyExist = errors.New("waiting proof with such " + - "key already exist") -) - -// WaitingProofStore is the bold db map-like storage for half announcement -// signatures. The one responsibility of this storage is to be able to -// retrieve waiting proofs after client restart. -type WaitingProofStore struct { - // cache is used in order to reduce the number of redundant get - // calls, when object isn't stored in it. - cache map[WaitingProofKey]struct{} - db *DB - mu sync.RWMutex -} - -// NewWaitingProofStore creates new instance of proofs storage. -func NewWaitingProofStore(db *DB) (*WaitingProofStore, error) { - s := &WaitingProofStore{ - db: db, - cache: make(map[WaitingProofKey]struct{}), - } - - if err := s.ForAll(func(proof *WaitingProof) error { - s.cache[proof.Key()] = struct{}{} - return nil - }); err != nil && err != ErrWaitingProofNotFound { - return nil, err - } - - return s, nil -} - -// Add adds new waiting proof in the storage. -func (s *WaitingProofStore) Add(proof *WaitingProof) error { - s.mu.Lock() - defer s.mu.Unlock() - - err := s.db.Update(func(tx *bbolt.Tx) error { - var err error - var b bytes.Buffer - - // Get or create the bucket. - bucket, err := tx.CreateBucketIfNotExists(waitingProofsBucketKey) - if err != nil { - return err - } - - // Encode the objects and place it in the bucket. - if err := proof.Encode(&b); err != nil { - return err - } - - key := proof.Key() - - return bucket.Put(key[:], b.Bytes()) - }) - if err != nil { - return err - } - - // Knowing that the write succeeded, we can now update the in-memory - // cache with the proof's key. - s.cache[proof.Key()] = struct{}{} - - return nil -} - -// Remove removes the proof from storage by its key. -func (s *WaitingProofStore) Remove(key WaitingProofKey) error { - s.mu.Lock() - defer s.mu.Unlock() - - if _, ok := s.cache[key]; !ok { - return ErrWaitingProofNotFound - } - - err := s.db.Update(func(tx *bbolt.Tx) error { - // Get or create the top bucket. - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - return bucket.Delete(key[:]) - }) - if err != nil { - return err - } - - // Since the proof was successfully deleted from the store, we can now - // remove it from the in-memory cache. - delete(s.cache, key) - - return nil -} - -// ForAll iterates thought all waiting proofs and passing the waiting proof -// in the given callback. -func (s *WaitingProofStore) ForAll(cb func(*WaitingProof) error) error { - return s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - // Iterate over objects buckets. - return bucket.ForEach(func(k, v []byte) error { - // Skip buckets fields. - if v == nil { - return nil - } - - r := bytes.NewReader(v) - proof := &WaitingProof{} - if err := proof.Decode(r); err != nil { - return err - } - - return cb(proof) - }) - }) -} - -// Get returns the object which corresponds to the given index. -func (s *WaitingProofStore) Get(key WaitingProofKey) (*WaitingProof, error) { - proof := &WaitingProof{} - - s.mu.RLock() - defer s.mu.RUnlock() - - if _, ok := s.cache[key]; !ok { - return nil, ErrWaitingProofNotFound - } - - err := s.db.View(func(tx *bbolt.Tx) error { - bucket := tx.Bucket(waitingProofsBucketKey) - if bucket == nil { - return ErrWaitingProofNotFound - } - - // Iterate over objects buckets. - v := bucket.Get(key[:]) - if v == nil { - return ErrWaitingProofNotFound - } - - r := bytes.NewReader(v) - return proof.Decode(r) - }) - - return proof, err -} - -// WaitingProofKey is the proof key which uniquely identifies the waiting -// proof object. The goal of this key is distinguish the local and remote -// proof for the same channel id. -type WaitingProofKey [9]byte - -// WaitingProof is the storable object, which encapsulate the half proof and -// the information about from which side this proof came. This structure is -// needed to make channel proof exchange persistent, so that after client -// restart we may receive remote/local half proof and process it. -type WaitingProof struct { - *lnwire.AnnounceSignatures - isRemote bool -} - -// NewWaitingProof constructs a new waiting prof instance. -func NewWaitingProof(isRemote bool, proof *lnwire.AnnounceSignatures) *WaitingProof { - return &WaitingProof{ - AnnounceSignatures: proof, - isRemote: isRemote, - } -} - -// OppositeKey returns the key which uniquely identifies opposite waiting proof. -func (p *WaitingProof) OppositeKey() WaitingProofKey { - var key [9]byte - binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) - - if !p.isRemote { - key[8] = 1 - } - return key -} - -// Key returns the key which uniquely identifies waiting proof. -func (p *WaitingProof) Key() WaitingProofKey { - var key [9]byte - binary.BigEndian.PutUint64(key[:8], p.ShortChannelID.ToUint64()) - - if p.isRemote { - key[8] = 1 - } - return key -} - -// Encode writes the internal representation of waiting proof in byte stream. -func (p *WaitingProof) Encode(w io.Writer) error { - if err := binary.Write(w, byteOrder, p.isRemote); err != nil { - return err - } - - if err := p.AnnounceSignatures.Encode(w, 0); err != nil { - return err - } - - return nil -} - -// Decode reads the data from the byte stream and initializes the -// waiting proof object with it. -func (p *WaitingProof) Decode(r io.Reader) error { - if err := binary.Read(r, byteOrder, &p.isRemote); err != nil { - return err - } - - msg := &lnwire.AnnounceSignatures{} - if err := msg.Decode(r, 0); err != nil { - return err - } - - (*p).AnnounceSignatures = msg - return nil -} diff --git a/channeldb/migration_01_to_11/waitingproof_test.go b/channeldb/migration_01_to_11/waitingproof_test.go deleted file mode 100644 index 968f1157..00000000 --- a/channeldb/migration_01_to_11/waitingproof_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package migration_01_to_11 - -import ( - "testing" - - "reflect" - - "github.com/go-errors/errors" - "github.com/lightningnetwork/lnd/lnwire" -) - -// TestWaitingProofStore tests add/get/remove functions of the waiting proof -// storage. -func TestWaitingProofStore(t *testing.T) { - t.Parallel() - - db, cleanup, err := makeTestDB() - if err != nil { - t.Fatalf("failed to make test database: %s", err) - } - defer cleanup() - - proof1 := NewWaitingProof(true, &lnwire.AnnounceSignatures{ - NodeSignature: wireSig, - BitcoinSignature: wireSig, - }) - - store, err := NewWaitingProofStore(db) - if err != nil { - t.Fatalf("unable to create the waiting proofs storage: %v", - err) - } - - if err := store.Add(proof1); err != nil { - t.Fatalf("unable add proof to storage: %v", err) - } - - proof2, err := store.Get(proof1.Key()) - if err != nil { - t.Fatalf("unable retrieve proof from storage: %v", err) - } - if !reflect.DeepEqual(proof1, proof2) { - t.Fatal("wrong proof retrieved") - } - - if _, err := store.Get(proof1.OppositeKey()); err != ErrWaitingProofNotFound { - t.Fatalf("proof shouldn't be found: %v", err) - } - - if err := store.Remove(proof1.Key()); err != nil { - t.Fatalf("unable remove proof from storage: %v", err) - } - - if err := store.ForAll(func(proof *WaitingProof) error { - return errors.New("storage should be empty") - }); err != nil && err != ErrWaitingProofNotFound { - t.Fatal(err) - } -} diff --git a/channeldb/migration_01_to_11/witness_cache.go b/channeldb/migration_01_to_11/witness_cache.go deleted file mode 100644 index 69de1054..00000000 --- a/channeldb/migration_01_to_11/witness_cache.go +++ /dev/null @@ -1,229 +0,0 @@ -package migration_01_to_11 - -import ( - "fmt" - - "github.com/coreos/bbolt" - "github.com/lightningnetwork/lnd/lntypes" -) - -var ( - // ErrNoWitnesses is an error that's returned when no new witnesses have - // been added to the WitnessCache. - ErrNoWitnesses = fmt.Errorf("no witnesses") - - // ErrUnknownWitnessType is returned if a caller attempts to - ErrUnknownWitnessType = fmt.Errorf("unknown witness type") -) - -// WitnessType is enum that denotes what "type" of witness is being -// stored/retrieved. As the WitnessCache itself is agnostic and doesn't enforce -// any structure on added witnesses, we use this type to partition the -// witnesses on disk, and also to know how to map a witness to its look up key. -type WitnessType uint8 - -var ( - // Sha256HashWitness is a witness that is simply the pre image to a - // hash image. In order to map to its key, we'll use sha256. - Sha256HashWitness WitnessType = 1 -) - -// toDBKey is a helper method that maps a witness type to the key that we'll -// use to store it within the database. -func (w WitnessType) toDBKey() ([]byte, error) { - switch w { - - case Sha256HashWitness: - return []byte{byte(w)}, nil - - default: - return nil, ErrUnknownWitnessType - } -} - -var ( - // witnessBucketKey is the name of the bucket that we use to store all - // witnesses encountered. Within this bucket, we'll create a sub-bucket for - // each witness type. - witnessBucketKey = []byte("byte") -) - -// WitnessCache is a persistent cache of all witnesses we've encountered on the -// network. In the case of multi-hop, multi-step contracts, a cache of all -// witnesses can be useful in the case of partial contract resolution. If -// negotiations break down, we may be forced to locate the witness for a -// portion of the contract on-chain. In this case, we'll then add that witness -// to the cache so the incoming contract can fully resolve witness. -// Additionally, as one MUST always use a unique witness on the network, we may -// use this cache to detect duplicate witnesses. -// -// TODO(roasbeef): need expiry policy? -// * encrypt? -type WitnessCache struct { - db *DB -} - -// NewWitnessCache returns a new instance of the witness cache. -func (d *DB) NewWitnessCache() *WitnessCache { - return &WitnessCache{ - db: d, - } -} - -// witnessEntry is a key-value struct that holds each key -> witness pair, used -// when inserting records into the cache. -type witnessEntry struct { - key []byte - witness []byte -} - -// AddSha256Witnesses adds a batch of new sha256 preimages into the witness -// cache. This is an alias for AddWitnesses that uses Sha256HashWitness as the -// preimages' witness type. -func (w *WitnessCache) AddSha256Witnesses(preimages ...lntypes.Preimage) error { - // Optimistically compute the preimages' hashes before attempting to - // start the db transaction. - entries := make([]witnessEntry, 0, len(preimages)) - for i := range preimages { - hash := preimages[i].Hash() - entries = append(entries, witnessEntry{ - key: hash[:], - witness: preimages[i][:], - }) - } - - return w.addWitnessEntries(Sha256HashWitness, entries) -} - -// addWitnessEntries inserts the witnessEntry key-value pairs into the cache, -// using the appropriate witness type to segment the namespace of possible -// witness types. -func (w *WitnessCache) addWitnessEntries(wType WitnessType, - entries []witnessEntry) error { - - // Exit early if there are no witnesses to add. - if len(entries) == 0 { - return nil - } - - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( - witnessTypeBucketKey, - ) - if err != nil { - return err - } - - for _, entry := range entries { - err = witnessTypeBucket.Put(entry.key, entry.witness) - if err != nil { - return err - } - } - - return nil - }) -} - -// LookupSha256Witness attempts to lookup the preimage for a sha256 hash. If -// the witness isn't found, ErrNoWitnesses will be returned. -func (w *WitnessCache) LookupSha256Witness(hash lntypes.Hash) (lntypes.Preimage, error) { - witness, err := w.lookupWitness(Sha256HashWitness, hash[:]) - if err != nil { - return lntypes.Preimage{}, err - } - - return lntypes.MakePreimage(witness) -} - -// lookupWitness attempts to lookup a witness according to its type and also -// its witness key. In the case that the witness isn't found, ErrNoWitnesses -// will be returned. -func (w *WitnessCache) lookupWitness(wType WitnessType, witnessKey []byte) ([]byte, error) { - var witness []byte - err := w.db.View(func(tx *bbolt.Tx) error { - witnessBucket := tx.Bucket(witnessBucketKey) - if witnessBucket == nil { - return ErrNoWitnesses - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket := witnessBucket.Bucket(witnessTypeBucketKey) - if witnessTypeBucket == nil { - return ErrNoWitnesses - } - - dbWitness := witnessTypeBucket.Get(witnessKey) - if dbWitness == nil { - return ErrNoWitnesses - } - - witness = make([]byte, len(dbWitness)) - copy(witness[:], dbWitness) - - return nil - }) - if err != nil { - return nil, err - } - - return witness, nil -} - -// DeleteSha256Witness attempts to delete a sha256 preimage identified by hash. -func (w *WitnessCache) DeleteSha256Witness(hash lntypes.Hash) error { - return w.deleteWitness(Sha256HashWitness, hash[:]) -} - -// deleteWitness attempts to delete a particular witness from the database. -func (w *WitnessCache) deleteWitness(wType WitnessType, witnessKey []byte) error { - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - witnessTypeBucket, err := witnessBucket.CreateBucketIfNotExists( - witnessTypeBucketKey, - ) - if err != nil { - return err - } - - return witnessTypeBucket.Delete(witnessKey) - }) -} - -// DeleteWitnessClass attempts to delete an *entire* class of witnesses. After -// this function return with a non-nil error, -func (w *WitnessCache) DeleteWitnessClass(wType WitnessType) error { - return w.db.Batch(func(tx *bbolt.Tx) error { - witnessBucket, err := tx.CreateBucketIfNotExists(witnessBucketKey) - if err != nil { - return err - } - - witnessTypeBucketKey, err := wType.toDBKey() - if err != nil { - return err - } - - return witnessBucket.DeleteBucket(witnessTypeBucketKey) - }) -} diff --git a/channeldb/migration_01_to_11/witness_cache_test.go b/channeldb/migration_01_to_11/witness_cache_test.go deleted file mode 100644 index 92836abe..00000000 --- a/channeldb/migration_01_to_11/witness_cache_test.go +++ /dev/null @@ -1,238 +0,0 @@ -package migration_01_to_11 - -import ( - "crypto/sha256" - "testing" - - "github.com/lightningnetwork/lnd/lntypes" -) - -// TestWitnessCacheSha256Retrieval tests that we're able to add and lookup new -// sha256 preimages to the witness cache. -func TestWitnessCacheSha256Retrieval(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll be attempting to add then lookup two simple sha256 preimages - // within this test. - preimage1 := lntypes.Preimage(rev) - preimage2 := lntypes.Preimage(key) - - preimages := []lntypes.Preimage{preimage1, preimage2} - hashes := []lntypes.Hash{preimage1.Hash(), preimage2.Hash()} - - // First, we'll attempt to add the preimages to the database. - err = wCache.AddSha256Witnesses(preimages...) - if err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - // With the preimages stored, we'll now attempt to look them up. - for i, hash := range hashes { - preimage := preimages[i] - - // We should get back the *exact* same preimage as we originally - // stored. - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to look up witness: %v", err) - } - - if preimage != dbPreimage { - t.Fatalf("witnesses don't match: expected %x, got %x", - preimage[:], dbPreimage[:]) - } - } -} - -// TestWitnessCacheSha256Deletion tests that we're able to delete a single -// sha256 preimage, and also a class of witnesses from the cache. -func TestWitnessCacheSha256Deletion(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll start by adding two preimages to the cache. - preimage1 := lntypes.Preimage(key) - hash1 := preimage1.Hash() - - preimage2 := lntypes.Preimage(rev) - hash2 := preimage2.Hash() - - if err := wCache.AddSha256Witnesses(preimage1); err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - if err := wCache.AddSha256Witnesses(preimage2); err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - // We'll now delete the first preimage. If we attempt to look it up, we - // should get ErrNoWitnesses. - err = wCache.DeleteSha256Witness(hash1) - if err != nil { - t.Fatalf("unable to delete witness: %v", err) - } - _, err = wCache.LookupSha256Witness(hash1) - if err != ErrNoWitnesses { - t.Fatalf("expected ErrNoWitnesses instead got: %v", err) - } - - // Next, we'll attempt to delete the entire witness class itself. When - // we try to lookup the second preimage, we should again get - // ErrNoWitnesses. - if err := wCache.DeleteWitnessClass(Sha256HashWitness); err != nil { - t.Fatalf("unable to delete witness class: %v", err) - } - _, err = wCache.LookupSha256Witness(hash2) - if err != ErrNoWitnesses { - t.Fatalf("expected ErrNoWitnesses instead got: %v", err) - } -} - -// TestWitnessCacheUnknownWitness tests that we get an error if we attempt to -// query/add/delete an unknown witness. -func TestWitnessCacheUnknownWitness(t *testing.T) { - t.Parallel() - - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll attempt to add a new, undefined witness type to the database. - // We should get an error. - err = wCache.legacyAddWitnesses(234, key[:]) - if err != ErrUnknownWitnessType { - t.Fatalf("expected ErrUnknownWitnessType, got %v", err) - } -} - -// TestAddSha256Witnesses tests that insertion using AddSha256Witnesses behaves -// identically to the insertion via the generalized interface. -func TestAddSha256Witnesses(t *testing.T) { - cdb, cleanUp, err := makeTestDB() - if err != nil { - t.Fatalf("unable to make test database: %v", err) - } - defer cleanUp() - - wCache := cdb.NewWitnessCache() - - // We'll start by adding a witnesses to the cache using the generic - // AddWitnesses method. - witness1 := rev[:] - preimage1 := lntypes.Preimage(rev) - hash1 := preimage1.Hash() - - witness2 := key[:] - preimage2 := lntypes.Preimage(key) - hash2 := preimage2.Hash() - - var ( - witnesses = [][]byte{witness1, witness2} - preimages = []lntypes.Preimage{preimage1, preimage2} - hashes = []lntypes.Hash{hash1, hash2} - ) - - err = wCache.legacyAddWitnesses(Sha256HashWitness, witnesses...) - if err != nil { - t.Fatalf("unable to add witness: %v", err) - } - - for i, hash := range hashes { - preimage := preimages[i] - - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to lookup witness: %v", err) - } - - // Assert that the retrieved witness matches the original. - if dbPreimage != preimage { - t.Fatalf("retrieved witness mismatch, want: %x, "+ - "got: %x", preimage, dbPreimage) - } - - // We'll now delete the witness, as we'll be reinserting it - // using the specialized AddSha256Witnesses method. - err = wCache.DeleteSha256Witness(hash) - if err != nil { - t.Fatalf("unable to delete witness: %v", err) - } - } - - // Now, add the same witnesses using the type-safe interface for - // lntypes.Preimages.. - err = wCache.AddSha256Witnesses(preimages...) - if err != nil { - t.Fatalf("unable to add sha256 preimage: %v", err) - } - - // Finally, iterate over the keys and assert that the returned witnesses - // match the original witnesses. This asserts that the specialized - // insertion method behaves identically to the generalized interface. - for i, hash := range hashes { - preimage := preimages[i] - - dbPreimage, err := wCache.LookupSha256Witness(hash) - if err != nil { - t.Fatalf("unable to lookup witness: %v", err) - } - - // Assert that the retrieved witness matches the original. - if dbPreimage != preimage { - t.Fatalf("retrieved witness mismatch, want: %x, "+ - "got: %x", preimage, dbPreimage) - } - } -} - -// legacyAddWitnesses adds a batch of new witnesses of wType to the witness -// cache. The type of the witness will be used to map each witness to the key -// that will be used to look it up. All witnesses should be of the same -// WitnessType. -// -// NOTE: Previously this method exposed a generic interface for adding -// witnesses, which has since been deprecated in favor of a strongly typed -// interface for each witness class. We keep this method around to assert the -// correctness of specialized witness adding methods. -func (w *WitnessCache) legacyAddWitnesses(wType WitnessType, - witnesses ...[]byte) error { - - // Optimistically compute the witness keys before attempting to start - // the db transaction. - entries := make([]witnessEntry, 0, len(witnesses)) - for _, witness := range witnesses { - // Map each witness to its key by applying the appropriate - // transformation for the given witness type. - switch wType { - case Sha256HashWitness: - key := sha256.Sum256(witness) - entries = append(entries, witnessEntry{ - key: key[:], - witness: witness, - }) - default: - return ErrUnknownWitnessType - } - } - - return w.addWitnessEntries(wType, entries) -} From 4486a06b1ad1d1d18643b2f08a4ec8da0ea40fcd Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Thu, 31 Oct 2019 11:25:04 +0100 Subject: [PATCH 6/6] migration_01_to_11: remove version checking for migration tests --- channeldb/migration_01_to_11/db.go | 87 ------------------- channeldb/migration_01_to_11/error.go | 9 -- channeldb/migration_01_to_11/meta.go | 41 --------- channeldb/migration_01_to_11/meta_test.go | 25 ++---- .../migration_11_invoices_test.go | 9 -- .../migration_01_to_11/migrations_test.go | 37 +------- 6 files changed, 6 insertions(+), 202 deletions(-) diff --git a/channeldb/migration_01_to_11/db.go b/channeldb/migration_01_to_11/db.go index 623b33bc..116f06a4 100644 --- a/channeldb/migration_01_to_11/db.go +++ b/channeldb/migration_01_to_11/db.go @@ -21,11 +21,6 @@ const ( // up-to-date version of the database. type migration func(tx *bbolt.Tx) error -type version struct { - number uint32 - migration migration -} - var ( // Big endian is the preferred byte order, due to cursor scans over // integer keys iterating in order. @@ -220,89 +215,7 @@ func (d *DB) FetchClosedChannels(pendingOnly bool) ([]*ChannelCloseSummary, erro return chanSummaries, nil } -// syncVersions function is used for safe db version synchronization. It -// applies migration functions to the current database and recovers the -// previous state of db if at least one error/panic appeared during migration. -func (d *DB) syncVersions(versions []version) error { - meta, err := d.FetchMeta(nil) - if err != nil { - if err == ErrMetaNotFound { - meta = &Meta{} - } else { - return err - } - } - - latestVersion := getLatestDBVersion(versions) - log.Infof("Checking for schema update: latest_version=%v, "+ - "db_version=%v", latestVersion, meta.DbVersionNumber) - - switch { - - // If the database reports a higher version that we are aware of, the - // user is probably trying to revert to a prior version of lnd. We fail - // here to prevent reversions and unintended corruption. - case meta.DbVersionNumber > latestVersion: - log.Errorf("Refusing to revert from db_version=%d to "+ - "lower version=%d", meta.DbVersionNumber, - latestVersion) - return ErrDBReversion - - // If the current database version matches the latest version number, - // then we don't need to perform any migrations. - case meta.DbVersionNumber == latestVersion: - return nil - } - - log.Infof("Performing database schema migration") - - // Otherwise, we fetch the migrations which need to applied, and - // execute them serially within a single database transaction to ensure - // the migration is atomic. - migrations, migrationVersions := getMigrationsToApply( - versions, meta.DbVersionNumber, - ) - return d.Update(func(tx *bbolt.Tx) error { - for i, migration := range migrations { - if migration == nil { - continue - } - - log.Infof("Applying migration #%v", migrationVersions[i]) - - if err := migration(tx); err != nil { - log.Infof("Unable to apply migration #%v", - migrationVersions[i]) - return err - } - } - - meta.DbVersionNumber = latestVersion - return putMeta(meta, tx) - }) -} - // ChannelGraph returns a new instance of the directed channel graph. func (d *DB) ChannelGraph() *ChannelGraph { return d.graph } - -func getLatestDBVersion(versions []version) uint32 { - return versions[len(versions)-1].number -} - -// getMigrationsToApply retrieves the migration function that should be -// applied to the database. -func getMigrationsToApply(versions []version, version uint32) ([]migration, []uint32) { - migrations := make([]migration, 0, len(versions)) - migrationVersions := make([]uint32, 0, len(versions)) - - for _, v := range versions { - if v.number > version { - migrations = append(migrations, v.migration) - migrationVersions = append(migrationVersions, v.number) - } - } - - return migrations, migrationVersions -} diff --git a/channeldb/migration_01_to_11/error.go b/channeldb/migration_01_to_11/error.go index 232aaa2b..d096ae8b 100644 --- a/channeldb/migration_01_to_11/error.go +++ b/channeldb/migration_01_to_11/error.go @@ -5,11 +5,6 @@ import ( ) var ( - - // ErrDBReversion is returned when detecting an attempt to revert to a - // prior database version. - ErrDBReversion = fmt.Errorf("channel db cannot revert to prior version") - // ErrNoInvoicesCreated is returned when we don't have invoices in // our database to return. ErrNoInvoicesCreated = fmt.Errorf("there are no existing invoices") @@ -18,10 +13,6 @@ var ( // created. ErrNoPaymentsCreated = fmt.Errorf("there are no existing payments") - // ErrMetaNotFound is returned when meta bucket hasn't been - // created. - ErrMetaNotFound = fmt.Errorf("unable to locate meta information") - // ErrGraphNotFound is returned when at least one of the components of // graph doesn't exist. ErrGraphNotFound = fmt.Errorf("graph bucket not initialized") diff --git a/channeldb/migration_01_to_11/meta.go b/channeldb/migration_01_to_11/meta.go index a8f9bd41..3abcc0d0 100644 --- a/channeldb/migration_01_to_11/meta.go +++ b/channeldb/migration_01_to_11/meta.go @@ -18,47 +18,6 @@ type Meta struct { DbVersionNumber uint32 } -// FetchMeta fetches the meta data from boltdb and returns filled meta -// structure. -func (d *DB) FetchMeta(tx *bbolt.Tx) (*Meta, error) { - meta := &Meta{} - - err := d.View(func(tx *bbolt.Tx) error { - return fetchMeta(meta, tx) - }) - if err != nil { - return nil, err - } - - return meta, nil -} - -// fetchMeta is an internal helper function used in order to allow callers to -// re-use a database transaction. See the publicly exported FetchMeta method -// for more information. -func fetchMeta(meta *Meta, tx *bbolt.Tx) error { - metaBucket := tx.Bucket(metaBucket) - if metaBucket == nil { - return ErrMetaNotFound - } - - data := metaBucket.Get(dbVersionKey) - if data == nil { - meta.DbVersionNumber = 0 - } else { - meta.DbVersionNumber = byteOrder.Uint32(data) - } - - return nil -} - -// PutMeta writes the passed instance of the database met-data struct to disk. -func (d *DB) PutMeta(meta *Meta) error { - return d.Update(func(tx *bbolt.Tx) error { - return putMeta(meta, tx) - }) -} - // putMeta is an internal helper function used in order to allow callers to // re-use a database transaction. See the publicly exported PutMeta method for // more information. diff --git a/channeldb/migration_01_to_11/meta_test.go b/channeldb/migration_01_to_11/meta_test.go index be1af2f9..587116e1 100644 --- a/channeldb/migration_01_to_11/meta_test.go +++ b/channeldb/migration_01_to_11/meta_test.go @@ -3,6 +3,7 @@ package migration_01_to_11 import ( "testing" + "github.com/coreos/bbolt" "github.com/go-errors/errors" ) @@ -31,24 +32,6 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), // with test data. beforeMigration(cdb) - // Create test meta info with zero database version and put it on disk. - // Than creating the version list pretending that new version was added. - meta := &Meta{DbVersionNumber: 0} - if err := cdb.PutMeta(meta); err != nil { - t.Fatalf("unable to store meta data: %v", err) - } - - versions := []version{ - { - number: 0, - migration: nil, - }, - { - number: 1, - migration: migrationFunc, - }, - } - defer func() { if r := recover(); r != nil { err = errors.New(r) @@ -65,8 +48,10 @@ func applyMigration(t *testing.T, beforeMigration, afterMigration func(d *DB), afterMigration(cdb) }() - // Sync with the latest version - applying migration function. - err = cdb.syncVersions(versions) + // Apply migration. + err = cdb.Update(func(tx *bbolt.Tx) error { + return migrationFunc(tx) + }) if err != nil { log.Error(err) } diff --git a/channeldb/migration_01_to_11/migration_11_invoices_test.go b/channeldb/migration_01_to_11/migration_11_invoices_test.go index 31cfe48f..0776458d 100644 --- a/channeldb/migration_01_to_11/migration_11_invoices_test.go +++ b/channeldb/migration_01_to_11/migration_11_invoices_test.go @@ -88,15 +88,6 @@ func TestMigrateInvoices(t *testing.T) { // Verify that all invoices were migrated. afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'invoices' wasn't applied") - } - dbInvoices, err := d.FetchAllInvoices(false) if err != nil { t.Fatalf("unable to fetch invoices: %v", err) diff --git a/channeldb/migration_01_to_11/migrations_test.go b/channeldb/migration_01_to_11/migrations_test.go index 598832c2..980b029c 100644 --- a/channeldb/migration_01_to_11/migrations_test.go +++ b/channeldb/migration_01_to_11/migrations_test.go @@ -135,15 +135,6 @@ func TestPaymentStatusesMigration(t *testing.T) { // Verify that the created payment status is "Completed" for our one // fake payment. afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'paymentStatusesMigration' wasn't applied") - } - // Check that our completed payments were migrated. paymentStatus, err := d.fetchPaymentStatus(paymentHash) if err != nil { @@ -404,15 +395,6 @@ func TestMigrateOptionalChannelCloseSummaryFields(t *testing.T) { // After the migration it should be found in the new format. afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration wasn't applied") - } - // We generate the new serialized version, to check // against what is found in the DB. var b bytes.Buffer @@ -521,16 +503,8 @@ func TestMigrateGossipMessageStoreKeys(t *testing.T) { // 2. We can find the message under its new key. // 3. The message matches the original. afterMigration := func(db *DB) { - meta, err := db.FetchMeta(nil) - if err != nil { - t.Fatalf("unable to fetch db version: %v", err) - } - if meta.DbVersionNumber != 1 { - t.Fatalf("migration should have succeeded but didn't") - } - var rawMsg []byte - err = db.View(func(tx *bbolt.Tx) error { + err := db.View(func(tx *bbolt.Tx) error { messageStore := tx.Bucket(messageStoreBucket) if messageStore == nil { return errors.New("message store bucket not " + @@ -617,15 +591,6 @@ func TestOutgoingPaymentsMigration(t *testing.T) { // Verify that all payments were migrated. afterMigrationFunc := func(d *DB) { - meta, err := d.FetchMeta(nil) - if err != nil { - t.Fatal(err) - } - - if meta.DbVersionNumber != 1 { - t.Fatal("migration 'paymentStatusesMigration' wasn't applied") - } - sentPayments, err := d.fetchPaymentsMigration9() if err != nil { t.Fatalf("unable to fetch sent payments: %v", err)